Source code for pyatoa.visuals.insp_plot

#!/usr/bin/env python3
"""
The plotting functionality of the Inspector class. Used to generate statistics
and basemap like plots from the Inspector DataFrame objects.
"""
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from pyatoa import logger
from matplotlib.patches import Rectangle
from pyatoa.utils.calculate import normalize_a_to_b


# A map from the Pyflex parameter names into cleaner looking label strings
[docs] common_labels = {"cc_shift_in_seconds": "Time Shift (s)", "dlnA": "$\Delta\ln$(A)", "misfit": "Misfit", "length_s": "Window Length (s)", "max_cc_value": "Peak Cross Correlation", "relative_starttime": "Relative Start Time (s)", "relative_endtime": "Relative End Time (s)", }
[docs] class InspectorPlotter: """ A class of methods for plotting statistics from an Inspector. Should not be called on its own, these functions will be inherited by the Inspector class automatically. """
[docs] def map(self, event=None, network=None, station=None, show=True, save=False, **kwargs): """ Plot source and receiver locations with map view. Optional arguments for only plotting certain stations or events. :type event: str or list :param event: particular event or list of events to plot :type network: str or list :param network: particular network or list of networks to plot :type station: str or list :param station: particular station or list of stations to plot :type show: bool :param show: Show the plot :type save: str :param save: fid to save the given figure """ # For isolating parameters, ensure they are lists. quack if isinstance(event, str): event = [event] if isinstance(network, str): network = [network] if isinstance(station, str): station = [station] markersize = kwargs.get("markersize", 10) f, ax = plt.subplots() if not self.sources.empty: # Allow for isolation of particular events if event is not None: srcs = self.sources.loc[event] else: srcs = self.sources src_lat = srcs.latitude src_lon = srcs.longitude src_names = srcs.index sc_sources = plt.scatter(src_lon, src_lat, marker="o", c="None", edgecolors="k", s=markersize, zorder=100, label="event(s)" ) if not self.receivers.empty: sc_receiver_names, sc_receiver_list = [], [] # Allow for isolation of networks if network is not None: networks = self.receivers.loc[network] else: networks = self.receivers networks = networks.index.get_level_values("network").unique() for net in networks: # Allow for isolation of stations if station is not None: try: rcv_lat = self.receivers.loc[net].loc[station].latitude rcv_lon = self.receivers.loc[net].loc[station].longitude rcv_nam = station except KeyError: continue else: # else just plot all stations in a given network rcv_lat = self.receivers.loc[net].latitude rcv_lon = self.receivers.loc[net].longitude rcv_nam = self.receivers.loc[net].index.to_numpy() # Random color cycle for networks sc_receivers = plt.scatter(rcv_lon, rcv_lat, marker="v", s=markersize, zorder=100, label=net) sc_receiver_list.append(sc_receivers) sc_receiver_names.append(rcv_nam) plt.xlabel("Longitude") plt.ylabel("Latitude") plt.legend() plt.title(f"{len(self.events)} events; {len(self.receivers)} stations") # Calculate aspect ratio based on latitude w = 1 / np.cos(np.radians(rcv_lat[0])) plt.gca().set_aspect(w) if save: plt.savefig(save) if show: hover_on_plot(f, ax, sc_sources, src_names) for rcvs, rcv_names in zip(sc_receiver_list, sc_receiver_names): hover_on_plot(f, ax, rcvs, rcv_names) plt.show() return f, ax
[docs] def scatter(self, x, y, iteration=None, step_count=None, save=None, show=True, **kwargs): """ Create a scatter plot between two chosen keys in the windows attribute :type x: str :param x: key to choose for the x axis of the plot :type y: str :param y: key to chooose for the y axis of the plot :type iteration: str :param iteration: the chosen iteration to plot for, if None will default to the latest iteration available :type step_coutn: str :param step_count: chosen step count. If None, defaults to latest """ if iteration is None: iteration, _ = self.initial_model if step_count is None: step_count = self.steps.loc[iteration][-1] # Ensure we have distance and backazimuth values in the dataframe df = self.isolate(iteration=iteration, step_count=step_count, **kwargs) df = df.merge(self.srcrcv, on=["event", "network", "station"]) assert(x in df.keys()), f"X value {x} does not match keys {df.keys()}" assert(y in df.keys()), f"Y value {y} does not match keys {df.keys()}" f, ax = plt.subplots(figsize=(8, 6)) plt.scatter(df[x].to_numpy(), df[y].to_numpy(), **kwargs) plt.xlabel(x) plt.ylabel(y) plt.title(f"{x} vs. {y}; N={len(x)}") default_axes(ax, **kwargs) if save: plt.savefig(save) if show: plt.show() return f, ax
[docs] def travel_times(self, iteration=None, step_count=None, component=None, constants=None, t_offset=0, hist=False, hist_max=None, plot_end=False, save=None, show=True, **kwargs): """ Plot relative window starttime (proxy for phase arrival) against source-receiver distance, to try to convey which phases are included in the measurement. Similar to Figure 4.18 in Shearer's Intro to Seismology. :type iteration: str :param iteration: the chosen iteration to plot for, if None will default to the latest iteration available :type step_count: str :param step_count: chosen step count. If None, defaults to latest :type component: str :param component: optional specify a measurement component to isolate only e.g., 'Z' components to look at Rayleigh waves :type constants: list of floats :param constants: plot lines of constant velocity to estimate the average wavespeed that leads to some of the linear trends :type t_offset: float :param t_offset: if the synthetic offset time in SPECFEM is set then the constant lines will need to be offset by the same amount to match the measurements. :type hist: bool :param hist: create a histogram binning the approximate seismic velocities :type plot_end: bool :param plot_end: if True, plots the beginning and end of the misfit window as a vertical line. If False, plots only the beginning of the misfit window """ hist_color = kwargs.get("hist_color", "deepskyblue") title_plot = kwargs.get("title_plot", None) title_hist = kwargs.get("title_hist", None) markersize = kwargs.get("markersize", 1) markertype = kwargs.get("markertype", "x") legend_fontsize = kwargs.get("legend_fontsize", 8) ylim = kwargs.get("ylim", None) xlim = kwargs.get("xlim", None) if iteration is None: iteration, _ = self.final_model if step_count is None: step_count = self.steps.loc[iteration][-1] # Ensure we have distance and backazimuth values in the dataframe df = self.isolate(iteration=iteration, step_count=step_count, component=component) df = df.merge(self.srcrcv, on=["event", "network", "station"]) # Assuming that isolate has only picked values from a single iterstep iterstep = f"{df.iteration[0]}{df.step[0]}" dist, start, end, length = df[["distance_km", "relative_starttime", "relative_endtime", "length_s"]].to_numpy().T # Shift relative starttimes by the user-defined offset start -= t_offset end -= t_offset # size of the markers based on the length of the window length = normalize_a_to_b(length, .5, .5) f, ax = plt.subplots(figsize=(8, 6)) # Either plot the window start only, or plot the entire window if not plot_end: plt.scatter(dist, start, c="k", s=markersize, marker=markertype, zorder=5, alpha=0.5) else: for d_, s_, e_ in zip(dist, start, end): plt.plot([d_, d_], [s_, e_], f"k{markertype}-", zorder=5, alpha=0.1, markersize=markersize) if title_plot is not None: plt.title(title_plot) else: plt.title(f"Apparent travel times ({iterstep} N={len(dist)})") plt.xlabel("Source-receiver distance [km]") plt.ylabel("Relative start time [s]") # Plot apparent velocities as straight lines if constants is not None: x = np.linspace(0, dist.max() + dist.max()/3, len(dist)) for i, c in enumerate(constants): y = x / c plt.plot(x, y, c=f"C{i}", lw=2, zorder=1, label=f"{c} km/s") plt.legend(fontsize=legend_fontsize) if xlim is not None: plt.xlim(xlim) else: plt.xlim([0, dist.max()]) if ylim is not None: plt.ylim(ylim) else: plt.ylim([0, start.max()]) # f.tight_layout() default_axes(ax, **kwargs) if save: plt.savefig(save) if show: plt.show() plt.close() # Now make a separate histogram showing the apparent velocities if hist: f, ax = plt.subplots(figsize=(8, 6)) velocities = dist / start # Max velocity based on PREM highest (ish) Vp n, bins, patches = plt.hist(x=velocities, bins=np.arange(0, 12, .5), color=hist_color, histtype="bar", edgecolor="black", linewidth=2, zorder=11, alpha=1. ) if hist_max: plt.ylim([0, hist_max]) xmin, xmax = plt.gca().get_xlim() plt.xlim([0, xmax]) if title_hist is not None: plt.title(title_hist) else: plt.title(f"Apparent velocities " f"({iterstep} N={len(velocities)})") plt.xlabel("Velocity [km/s]") plt.ylabel("Count") # f.tight_layout() default_axes(plt.gca(), **kwargs) if save: plt.savefig(f"hist_{save}") if show: plt.show()
[docs] def event_depths(self, xaxis="longitude", show=True, save=None, **kwargs): """ Create a scatter plot of events at depth. Compresses all events onto a single slice, optional choice of showing the x-axis or the y-axis :type xaxis: str :param xaxis: variable to use as the x-axis on the plot 'latitude' or 'longitude' :type show: bool :param show: show the plot :type save: str :param save: fid to save the figure """ if xaxis == "latitude": x_vals = self.sources.latitude.to_numpy() elif xaxis == "longitude": x_vals = self.sources.longitude.to_numpy() else: raise NotImplementedError( "'xaxis' must be 'latitude' or 'longitude" ) # Plot initializations f, ax = plt.subplots(figsize=(8, 6)) depths = self.sources.depth_km.to_numpy() mags = self.sources.magnitude.to_numpy() mags = normalize_a_to_b(mags, 100, 500) names = self.sources.index # Inverted axis for positive depth values if depths[0] > 0: plt.gca().invert_yaxis() # Scatter plot sc = plt.scatter(x_vals, depths, s=mags, c="None", marker="o", edgecolors="k") plt.xlabel(xaxis.capitalize()) plt.ylabel("Depth (km)") plt.title(f"N={len(depths)}") plt.grid(which="both", linestyle=":", alpha=0.5) hover_on_plot(f, ax, sc, names, dissapear=True) default_axes(ax, **kwargs) if save: plt.savefig(save) if show: plt.show return f, ax
[docs] def raypaths(self, iteration=None, step_count=None, color_by=None, show=True, save=False, vmin=None, vmax=None, **kwargs): """ Plot rays connecting sources and receivers based on the availability of measurements. Useful for getting an approximation of resolution. :type iteration: int :param iteration: iteration to retrieve data from :type step_count: int :param step_count: step count to retrieve data from :type color_by: str :param color_by: allow rays to be colored based on a normalized value. nwin: color rays by the number of windows available for that path misfit: color rays by total misfit :type show: bool :param show: show the plot :type save: str :param save: fid to save the figure """ cmap = kwargs.get("cmap", "viridis") ray_color = kwargs.get("ray_color", "k") ray_linewidth = kwargs.get("ray_linewidth", 1) ray_alpha = kwargs.get("ray_alpha", 0.1) station_color = kwargs.get("station_color", "c") event_color = kwargs.get("event_color", "orange") figsize = kwargs.get("figsize", (8, 8)) markersize = kwargs.get("markersize", 25) f, ax = plt.subplots(figsize=figsize) iteration, step_count = self._parse_nonetype_eval(iteration, step_count) df = self.misfit(level="station").loc[iteration, step_count] # Get lat/lon information from sources and receivers stations = self.receivers.droplevel(0) # remove network index events = self.sources.drop(["time", "magnitude", "depth_km"], axis=1) # Set up the normalized colorbar cbar, extend = None, None if color_by is not None: assert(color_by in df.keys()), f"{color_by} must be in {df.keys()}" if vmin is None: vmin = df[color_by].min() elif vmin > df[color_by].min(): extend = "min" if vmax is None: vmax = df[color_by].max() elif vmax < df[color_by].max(): if extend == "min": extend = "both" else: extend = "max" sm, norm, cbar = colormap_colorbar(cmap, vmin=vmin, vmax=vmax, cbar_label=color_by.capitalize(), extend=extend ) plotted, names = [], [] for event, sta in df.index.to_numpy(): elon, elat = events.loc[event].longitude, events.loc[event].latitude slon, slat = stations.loc[sta].longitude, stations.loc[sta].latitude # Plot a marker for each event and station if event not in plotted: plt.scatter(elon, elat, marker="o", c=event_color, edgecolors="k", s=markersize, zorder=100) plotted.append(event) if sta not in plotted: plt.scatter(slon, slat, marker="v", c=station_color, edgecolors="k", s=markersize, zorder=100) plotted.append(event) if color_by is not None: ray_color = sm.cmap(norm(df.loc[event].loc[sta][color_by])) # Connect source and receiver with a line plt.plot([elon, slon], [elat, slat], color=ray_color, linestyle="-", alpha=0.1, zorder=50, linewidth=ray_linewidth) plt.xlabel("Longitude") plt.ylabel("Latitude") plt.title(f"{len(df)} raypaths") # plt.title(f"{len(df)} raypaths ({len(events)} events, " # f"{len(stations)} stations)") # Calculate aspect ratio based on latitude w = 1 / np.cos(np.radians(elat)) plt.gca().set_aspect(w) if save: plt.savefig(save) if show: plt.show() return f, ax
[docs] def raypath_density(self, iteration=None, step_count=None, point_spacing_km=.5, bin_spacing_km=8, cmap="viridis", show=True, save=False, **kwargs): """ Create a raypath density plot to provide a more deatiled illustration of raypath gradients, which may be interpreted alongside tomographic inversion results as a preliminary resolution test. The idea behind this is to partition each individual raypath line into discrete points and then create a 2D histogram with all points :type point_spacing_km: float :param point_spacing_km: approximate discretization interval for each raypath line. Smaller numbers will lead to higher resolution but also longer computation time. :type bin_spacing_km: float :param bin_spacing_km: the bin size in km of the 2d histogram. If the same as 'point_spacing_km' then you'll probably just see the lines. Should be larger than 'point_spacing_km' for a more contour plot looking feel. """ figsize = kwargs.get("figsize", (8, 8)) event_color = kwargs.get("event_color", "orange") station_color = kwargs.get("station_color", "cyan") markersize = kwargs.get("markersize", 26) iteration, step_count = self._parse_nonetype_eval(iteration, step_count) f, ax = plt.subplots(figsize=figsize) df = self.misfit(level="station").loc[iteration, step_count] # Get lat/lon information from sources and receivers stations = self.receivers.droplevel(0) # remove network index events = self.sources.drop(["time", "magnitude", "depth_km"], axis=1) # Determine grid bounds and required number of bins for histograms x_min = min(stations.longitude.min(), events.longitude.min()) x_max = max(stations.longitude.max(), events.longitude.max()) y_min = min(stations.latitude.min(), events.latitude.min()) y_max = max(stations.latitude.max(), events.latitude.max()) # 111.11 VERY roughly converts degrees to km, not really geographically # correct though. Should be okay for this low-res application x_bins = int(abs(x_max - x_min) * 111.11 / bin_spacing_km) y_bins = int(abs(y_max - y_min) * 111.11 / bin_spacing_km) # Convert station names and event ids into coordinates dx = point_spacing_km / 111.11 # grid spacing in degrees # Initiate empty arrays to be filled x = np.array([]) y = np.array([]) plotted = [] for event, sta in df.index.to_numpy(): elon, elat = events.loc[event].longitude, events.loc[event].latitude slon, slat = stations.loc[sta].longitude, stations.loc[sta].latitude # Plot a marker for each event and station if event not in plotted: plt.scatter(elon, elat, marker="o", c=event_color, edgecolors="k", s=markersize, zorder=100) plotted.append(event) if sta not in plotted: plt.scatter(slon, slat, marker="v", c=station_color, edgecolors="k", s=markersize, zorder=100) plotted.append(event) # Calculate the necessary number of discrete points to create line nlon = int(abs(elon - slon) * 111.11 / point_spacing_km) nlat = int(abs(elat - slat) * 111.11 / point_spacing_km) nvals = max(nlon, nlat) x_ = np.linspace(elon, slon, nvals) y_ = np.linspace(elat, slat, nvals) x = np.concatenate((x, x_)) y = np.concatenate((y, y_)) # Create the 2D histogram of raypath density plt.hist2d(x, y, bins=(x_bins, y_bins), cmap=plt.get_cmap(cmap), zorder=5) cbar = plt.colorbar(label="counts", shrink=0.9, pad=0.025) plt.title(f"Raypath Density {iteration}{step_count} " f"(N={len(df)} src-rcv pairs)") plt.xlabel("Longitude") plt.ylabel("Latitude") # Calculate aspect ratio based on latitude w = 1 / np.cos(np.radians(y[0])) plt.gca().set_aspect(w) default_axes(plt.gca(), cbar) if save: plt.savefig(save) if show: plt.show() plt.close()
[docs] def event_hist(self, choice, show=True, save=None): """ Make a histogram of event information :return: """ assert choice in self.sources.keys(), \ f"Choice must be in {self.sources.keys()}" f, ax = plt.subplots() arr = self.sources[choice].to_numpy() # Compare iterations, plot original iteration on top n, bins, patches = plt.hist(x=arr, color="w", histtype="bar", bins=list(np.arange(4.5, 6.1, .1)), edgecolor="black", linewidth=2., label=choice, alpha=1., zorder=20 ) mu, var, std = get_histogram_stats(n, bins) default_axes(ax) plt.xlabel(choice) plt.ylabel("count") if save: plt.savefig(save) if show: plt.show() plt.close() return f, ax
[docs] def measurement_hist(self, iteration=None, step_count=None, choice="event", show=True, save=False): """ Make histograms of measurements for stations or events to show the distribution of measurements. :type iteration: str :param iteration: iteration number e.g. 'i00' :type step_count: str :param step_count: step count e.g. 's00' :type choice: str :param choice: choice of making hist by 'event' or 'station' :type show: bool :param show: Show the plot :type save: str :param save: fid to save the given figure """ iteration, step_count = self._parse_nonetype_eval(iteration, step_count) arr = self.nwin( level=choice).loc[iteration, step_count].nwin.to_numpy() n, bins, patches = plt.hist(x=arr, color="orange", histtype="bar", edgecolor="black", linewidth=4., label=choice, alpha=1., zorder=20 ) # Find mean and standard deviation of measurement number mu, var, std = get_histogram_stats(n, bins) plt.axvline(x=mu, ymin=0, ymax=1, linewidth=2, c="k", linestyle="--", zorder=15, alpha=0.5) for sign in [-1, 1]: plt.axvline(x=mu + sign * std, ymin=0, ymax=1, linewidth=2, c="k", linestyle=":", zorder=15, alpha=0.5) default_axes(plt.gca()) plt.xlabel(f"{choice} number of measurements") plt.ylabel("count") plt.title(f"{iteration}{step_count}; N={len(arr)}\n" f"solid line = mean; dashed line = 1 std") if save: plt.savefig(save) if show: plt.show() else: plt.close()
[docs] def station_event_misfit_map(self, station, iteration, step_count, choice, show=True, save=False, **kwargs): """ Plot a single station and all events that it has measurements for. Events will be colored by choice of value: misfit or nwin (num windows) :type station: str :param station: specific station to use for map :type iteration: str :param iteration: iteration number e.g. 'i00' :type step_count: str :param step_count: step count e.g. 's00' :type choice: str :param choice: choice of misfit value, either 'misfit' or 'nwin' :type show: bool :param show: Show the plot :type save: str :param save: fid to save the given figure """ assert (station in self.stations), "station name not found" cmap = kwargs.get("cmap", "viridis") sta = self.receivers.droplevel(0).loc[station] # Get misfit on a per-station basis df = self.misfit(level="station").loc[ iteration, step_count].swaplevel(0, 1) df = df.sort_values(by="station").loc[station] # Get source lat/lon values as a single dataframe with same index name src = self.sources.drop(["time", "magnitude", "depth_km"], axis=1) src.index.names = ["event"] # This is a dataframe of events corresponding to a single station df = df.merge(src, on="event") f, ax = plt.subplots() src = plt.scatter(sta.longitude, sta.latitude, marker="v", c="orange", edgecolors="k", s=25, zorder=100) plt.scatter(df.longitude.to_numpy(), df.latitude.to_numpy(), c=df[choice].to_numpy(), marker="o", s=25, zorder=99, cmap=cmap) plt.xlabel("Longitude") plt.ylabel("Latitude") plt.title(f"{station} {iteration}{step_count}; {len(df)} events") _, _, cbar = colormap_colorbar(cmap, vmin=df[choice].to_numpy().min(), vmax=df[choice].to_numpy().max(), cbar_label=choice) default_axes(ax, cbar, **kwargs) if save: plt.savefig(save) if show: hover_on_plot(f, ax, src, df.index.to_numpy()) plt.show() return f, ax
[docs] def event_station_misfit_map(self, event, iteration, step_count, choice, show=True, save=False, **kwargs): """ Plot a single event and all stations with measurements. Stations are colored by choice of value: misfit or nwin (number of windows) :type event: str :param event: specific event to use for map :type iteration: str :param iteration: iteration number e.g. 'i00' :type step_count: str :param step_count: step count e.g. 's00' :type choice: str :param choice: choice of misfit value, either 'misfit' or 'nwin' :type show: bool :param show: Show the plot :type save: str :param save: fid to save the given figure """ assert (event in self.sources.index), "event name not found" cmap = kwargs.get("cmap", "viridis") f, ax = plt.subplots() source = self.sources.loc[event] src = plt.scatter(source.longitude, source.latitude, marker="o", c="r", edgecolors="k", s=20, zorder=100) # Go through each of the stations corresponding to this source df = self.misfit(level="station").loc[iteration, step_count, event] assert (choice in df.columns), f"choice must be in {df.columns}" # Get lat lon values for receivers df = df.merge(self.receivers, on="station") misfit_values = df[choice].to_numpy() rcvs = plt.scatter(df.longitude.to_numpy(), df.latitude.to_numpy(), c=misfit_values, marker="v", s=15, zorder=100, cmap=cmap ) plt.xlabel("Longitude") plt.ylabel("Latitude") plt.title(f"{event} {iteration}{step_count}; {len(df)} stations") _, _, cbar = colormap_colorbar(cmap, vmin=misfit_values.min(), vmax=misfit_values.max(), cbar_label=choice,) default_axes(ax, cbar, **kwargs) if save: plt.savefig(save) if show: hover_on_plot(f, ax, rcvs, df.index.to_numpy()) plt.show() return f, ax
[docs] def event_misfit_map(self, choice=None, iteration=None, step_count=None, show=True, save=False, **kwargs): """ Plot all events on a map and their corresponding scaled misfit value :type iteration: str :param iteration: iteration number e.g. 'i00' :type step_count: str :param step_count: step count e.g. 's00' :type choice: str :param choice: choice of misfit value, either 'misfit' or 'nwin' or 'unscaled_misfit' :type show: bool :param show: Show the plot :type save: str :param save: fid to save the given figure """ cmap = kwargs.get("cmap", "viridis") markersize = kwargs.get("markersize", 20) marker = kwargs.get("marker", "o") if iteration is None: iteration, step_count = self.final_model if choice is None: choice = "misfit" f, ax = plt.subplots() sources = self.sources.drop(["time", "magnitude", "depth_km"], axis=1) # Rename from event_id to event to match the naming of the dataframe sources.rename_axis("event", inplace=True) df = self.misfit(level="event").loc[iteration, step_count] df = df.merge(sources, on="event") srcs = plt.scatter(df.longitude.to_numpy(), df.latitude.to_numpy(), c=df[choice].to_numpy(), marker=marker, s=markersize, zorder=100, cmap=cmap) plt.xlabel("Longitude") plt.ylabel("Latitude") plt.title(f"{iteration}{step_count}; {choice} {len(df)} events") _, _, cbar = colormap_colorbar(cmap, vmin=df[choice].to_numpy().min(), vmax=df[choice].to_numpy().max(), cbar_label=choice,) default_axes(ax, cbar, **kwargs) if save: plt.savefig(save) if show: hover_on_plot(f, ax, srcs, df.index.to_numpy()) plt.show() return f, ax
[docs] def hist(self, iteration=None, step_count=None, iteration_comp=None, step_count_comp=None, f=None, ax=None, event=None, station=None, choice="cc_shift_in_seconds", binsize=None, show=True, save=None, **kwargs): """ Create a histogram of misfit information for either time shift or amplitude differences. Option to compare against different iterations, and to look at different choices. Choices are any column value in the Inspector.windows attribute :type iteration: str :param iteration: iteration to choose for misfit :type step_count: str :param step_count: step count to query, e.g. 's00' :type iteration_comp: str :param iteration_comp: iteration to compare with, will be plotted in front of `iteration` :type step_count_comp: str :param step_count_comp: step to compare with :type f: matplotlib.figure :param f: plot to an existing figure :type ax: matplotlib.axes._subplots.AxesSubplot :param ax: plot to an existing axis e.g. to string together histograms :type event: str :param event: filter for measurements for a given event :type station: str :param station: filter for measurements for a given station :type choice: str :param choice: choice of 'cc_shift_s' for time shift, or 'dlnA' as amplitude :type binsize: float :param binsize: size of the histogram bins :type show: bool :param show: show the plot :type save: str :param save: fid to save the figure """ # Optional kwargs for fine tuning figure parameters title = kwargs.get("title", "") xlim = kwargs.get("xlim", None) color = kwargs.get("color", "darkorange") color_comp = kwargs.get("color_comp", "deepskyblue") fontsize = kwargs.get("fontsize", 12) figsize = kwargs.get("figsize", (8, 6)) legend = kwargs.get("legend", True) legend_loc = kwargs.get("legend_loc", "best") label_range = kwargs.get("label_range", False) xstep = kwargs.get("xstep", 2) ymax = kwargs.get("ymax", None) xlabel = kwargs.get("xlabel", None) zeroline = kwargs.get("zeroline", False) meanline = kwargs.get("meanline", False) stdline = kwargs.get("stdline", False) linewidth = kwargs.get("linewidth", 2.5) label = kwargs.get("label", None) label_comp = kwargs.get("label_comp", None) # If no arguments are given, default to first and last evaluations if iteration is None and iteration_comp is None: iteration, step_count = self.initial_model iteration_comp, step_count_comp = self.final_model # Check that the provided values are available in the Inspector assert iteration in self.iterations, \ f"iteration must be in {self.iterations}" if step_count is None: assert step_count in self.steps.loc[iteration], \ f"step must be in {self.steps.loc[iteration]}" if iteration_comp is not None: assert iteration_comp in self.iterations, \ f"iteration_comp must be in {self.iterations}" assert step_count_comp in self.steps.loc[iteration_comp], \ f"step_comp must be in {self.steps.loc[iteration_comp]}" # Try to set a default binsize that may or may not work if binsize is None: try: binsize = {"cc_shift_in_seconds": 1, "dlnA": 0.25, "max_cc_value": 0.05, "misfit": 10, "relative_starttime": 15, "relative_endtime": 15}[choice] except KeyError: binsize = 1 def get_values(m, s, e, sta): """short hand to get the data, and the maximum value in DataFrame""" df_a = self.isolate(iteration=m, step_count=s, event=e, station=sta) try: val_ = df_a.loc[:, choice].to_numpy() except KeyError as e: raise KeyError(f"Inspector.windows has no key {choice}") from e lim_ = max(abs(np.floor(min(val_))), abs(np.ceil(max(val_)))) return val_, lim_ # Instantiate the plot objects and 'goforyourlifemate' if f is None: f, ax = plt.subplots(figsize=figsize) if ax is None: ax = plt.gca() val, lim = get_values(iteration, step_count, event, station) if iteration_comp: val_comp, lim_comp = get_values(iteration_comp, step_count_comp, event, station) # Reset the limit to be the greater of the two lim = max(lim, lim_comp) # Compare iterations, plot original iteration on top n, bins, patches = plt.hist( x=val, bins=np.arange(-1 * lim, lim + .1, binsize), color=color, histtype="bar", edgecolor="black", linewidth=linewidth, label=(label or f"{iteration}{step_count}") + f"; N={len(val)}", zorder=11, alpha=1. ) # mu1, var1, std1 = get_histogram_stats(n, bins) mean = np.mean(val) std = np.std(val) med = np.median(val) # Plot comparison below n2, bins2, patches2 = plt.hist( x=val_comp, bins=np.arange(-1 * lim, lim + .1, binsize), color=color_comp, histtype="bar", edgecolor="black", linewidth=linewidth + 1, zorder=10, label=(label_comp or f"{iteration_comp}{step_count_comp}") + f"; N={len(val_comp)}", ) # mu2, var2, std2 = get_histogram_stats(n2, bins2) mean_comp = np.mean(val_comp) std_comp = np.std(val_comp) med_comp = np.median(val_comp) # Plot edges of comparison over top plt.hist(x=val_comp, bins=np.arange(-1 * lim, lim + .1, binsize), color="k", histtype="step", edgecolor=color_comp, linewidth=linewidth, zorder=12, ) else: # No comparison iteration, plot single histogram n, bins, patches = plt.hist( x=val, bins=len(np.arange(-1 * lim, lim, binsize)), color=color, histtype="bar", edgecolor="black", linewidth=linewidth, label=f"N={len(val)}", zorder=10, ) # mu1, var1, std1 = get_histogram_stats(n, bins) mean = np.mean(val) std = np.std(val) med = np.median(val) # Plot reference lines if zeroline: plt.axvline(x=0, ymin=0, ymax=1, linewidth=2.5, c="k", zorder=15, alpha=0.75, linestyle="-") # Plot the mean of the histogram if meanline: plt.axvline(x=mean, ymin=0, ymax=1, linewidth=2.5, c="k", linestyle="--", zorder=15, alpha=0.75) # Plot one standard deviation if stdline: for sign in [-1, 1]: plt.axvline(x=mean + sign * std, ymin=0, ymax=1, linewidth=2.5, c="k", linestyle=(0, (1, 1)), zorder=15, alpha=0.75) # Set xlimits of the plot if xlim: plt.xlim(xlim) else: if choice == "dlna": plt.xlim([-1.75, 1.75]) if ymax: plt.ylim([0, ymax]) # Stats in the title by default if not title: tit_fmt = "mean: {mean:.2f} / std: {std:.2f} / med: {med:.2f}" title = tit_fmt.format(mean=mean, std=std, med=med) if iteration_comp: tit_comp = tit_fmt.format(mean=mean_comp, std=std_comp, med=med_comp) title = " ".join([f"[{label or iteration}]", title, "\n", f"[{label_comp or iteration_comp}]", tit_comp ]) # Finalize plot details if xlabel: xlab_ = xlabel else: try: # For cleaner formatting of x-axis label xlab_ = common_labels[choice] except KeyError: xlab_ = choice plt.xlabel(xlab_, fontsize=fontsize) plt.ylabel("Count", fontsize=fontsize) plt.title(title) if label_range: plt.xticks(np.arange(-1 * label_range, label_range + .1, step=xstep)) if legend: leg = plt.legend(fontsize=fontsize / 1.25, loc=legend_loc) # Thin border around legend objects, unnecessarily thick bois for leg_ in leg.legendHandles: leg_.set_linewidth(1.5) default_axes(ax, **kwargs) plt.tight_layout() if save: plt.savefig(save) if show: plt.show() return f, ax
[docs] def plot_windows(self, iteration=None, step_count=None, iteration_comp=None, step_count_comp=None, choice="cc_shift_in_seconds", event=None, network=None, station=None, component=None, no_overlap=True, distances=False, annotate=False, bounds=False, show=True, save=False, **kwargs): """ Show lengths of windows chosen based on source-receiver distance, akin to Tape's Thesis or to the LASIF plots. These are useful for showing which phases are chosen, and window choosing behavior as distance increases and (probably) misfit increases. :type iteration: str :param iteration: iteration to analyze :type step_count: str :param step_count: step count to query, e.g. 's00' :type iteration_comp: str :param iteration_comp: Optional, if provided, difference the 'choice' values with the chosen 'iteration/step'. Useful for easily checking for improvement. Only works if the windows are the same. :type step_count_comp: str :param step_count_comp: associated step count for 'iteration_comp' :type event: str :param event: filter for measurements for a given event :type network: str :param network: filter for measurements for a given network :type station: str :param station: filter for measurements for a given station :type component: str :param component: choose a specific component to analyze :type choice: str :param choice: choice of value to define the colorscale by. These relate to the keys of Inspector.windows. Default is 'cc_shift_in_seconds' :type no_overlap: bool :param no_overlap: If real distances are used, many src-rcv pairs are at the same or very similar distances, leading to overlapping rectangles. If this is set to True, to minimize overlap, the function will try to shift the distance to a value that hasn't yet been plotted. It will alternate larger positive and negative values until something is found. Will lead to non-real distances. :type distances: bool :param distances: If set False, just plot one window atop the other, which makes for more concise, easier to view plots, but then real distance information is lost, only relative distance kept. :type annotate: bool :param annotate: If True, will annotate event and station information for each window. May get messy if `distances == True` and `no_overlap == False` because you will get many overlapping annotations. Works ideally if `distances == False`. :type bounds: bool or list of float :param bounds: * (bool) False: set default bounds based on the min and max of data * (bool) True: set default bounds equal, based on abs max of data * (list) Manually set the bounds of the colorbar :type show: bool :param show: show the plot after generating :type save: str :param save: save the plot to the given filename Keyword Arguments :: float alpha: The opacity of the rectangles, defaults to 0.25 str cmap: The colormap used to plot the values of `choice` str cbar_label: The label for the colorbar float rectangle_height: The vertical size of the rectangles, defaults to 1. float anno_shift: The distance in seconds to shift the plot to accomodate annotations. This needs to be played as its based on the length of the strings that are used in the annotations. """ alpha = kwargs.get("alpha", 0.6) cmap = kwargs.get("cmap", "viridis") cbar_label = kwargs.get("cbar_label", None) rectangle_height = kwargs.get("rectangle_height", 1.0) anno_shift = kwargs.get("anno_shift", 50) iteration, step_count = self._parse_nonetype_eval(iteration, step_count) assert(iteration in self.iterations and step_count in self.steps[iteration]), \ f"{iteration}{step_count} does not exist in Inspector" assert(choice in self.windows.keys()), (f"Color by choice {choice} not " f"in list of available keys") # Filter out the specific windows that we're interested in df = self.isolate(iteration=iteration, step_count=step_count, event=event, network=network, station=station, component=component) # If a comparison iteration is given, isolate the 'choice' key, and # subtract it from the main dataframe. The new plotted values are diffs! if iteration_comp: df_comp = self.isolate(iteration=iteration_comp, step_count=step_count_comp, event=event, network=network, station=station, component=component) # This is enough unique info to identify a specific window merge_keys = ["event", "network", "station", "channel", "relative_starttime", choice] df_comp = df_comp.loc[:, merge_keys] df_comp.rename({choice: f"{choice}_comp"}, axis=1, inplace=True) # Crude check to see if the number of windows is comparable assert(len(df) == len(df_comp)), (f"Number of windows does not " f"match between " f"{iteration}{step_count} and " f"{iteration_comp}" f"{step_count_comp}") df = df.merge(df_comp, on=merge_keys[:-1]) # Subtract the comparison iteration from the initial check df[choice] = df[choice] - df[f"{choice}_comp"] # Merge window information with source-receiver distances, not BAz df = df.merge(self.srcrcv.drop("backazimuth", axis=1), on=["event", "network", "station"] ) # Drop unnecessary information except that needed to plot # IMPORTANT: Sort by distance so that when the dataframe is iterated on # it starts from the smallest distance and goes up df = df.loc[:, ["event", "station", "component", "relative_starttime", "relative_endtime", "distance_km", choice] ].sort_values(by="distance_km") if df.empty: logger.warning("Filtered dataframe is empty, no windows to plot") return # Plotting begins here f, ax = plt.subplots(figsize=(8, 6)) # Create a custom color scale based on the min and max values of choice if cbar_label is None: try: # For cleaner formatting of colorbar label cbar_label = common_labels[choice] except KeyError: cbar_label = choice if iteration_comp: cbar_label = f"DIFF {cbar_label}" # Set the bounds of the colorbar if isinstance(bounds, list): vmin, vmax = bounds else: if bounds: vmax = max(abs(df[choice].min()), abs(df[choice].max())) vmin = -1 * vmax else: vmin = df[choice].min() vmax = df[choice].max() sm, norm, _ = colormap_colorbar(cmap, vmin=vmin, vmax=vmax, cbar_label=cbar_label) # Determine the global xmin and xmax which will be used more than once xmin = df.relative_starttime.min() if annotate: # Shift to accomodate annotations xmin -= anno_shift xmax = df.relative_endtime.max() dist_values, y_value = [], 0 # keep track of what y-values are used for window in df.to_numpy(): ev, sta, comp, start, end, dist, value = window if not distances: # Ignore distances and simply plot linearly dist_ = y_value y_value += rectangle_height else: # Try not to overlap windows that are very close in distance dist_ = int(dist) if no_overlap: if dist_ in dist_values: shift, sign = 1, -1 while dist_ in dist_values: dist_ += shift # Alternate shift so that we search # 1, -1, 2, -2, 3, -3, etc... shift = sign * (abs(shift) + rectangle_height) sign *= -1 dist_values.append(dist_) logger.warning(f"Shifted {ev} {sta}: {dist - dist_}km") else: dist_values.append(dist_) # Plot the windows as rectangles to sort of match waveform plots ax.add_patch(Rectangle(xy=(start, dist_ - rectangle_height / 2), width=end - start, ec="k", alpha=alpha, height=rectangle_height, fc=sm.cmap(norm(value)), zorder=12) ) # Black background line for frame of reference / gridding ax.hlines(y=dist_, xmin=xmin, xmax=xmax, colors="k", alpha=0.3, linewidth=0.3, zorder=10 ) # Annotate event, station, component, distance and value for # easier identification. Can be messy with a lot of windows if annotate: plt.text(xmin, dist_, f"{ev} {sta} {comp} {dist:.2f}km {value:.2f}", fontsize=4.5, zorder=11) # Finalize the look of the plot plt.title(f"Window Plot: N = {len(df)} " f"[{iteration}{step_count}] " f"[{iteration_comp}{step_count_comp}]\n" f"Event: {event} / Station: {station} / Network: {network} / " f"Component: {component}") plt.xlabel("Time [s]") plt.xlim([xmin, xmax]) if distances: plt.ylabel("Distance [km]") plt.ylim([df.distance_km.min() - 10, df.distance_km.max() + 10]) else: # Relative distances means the y-axis values are useless plt.ylabel("Relative Distance") plt.ylim([-rectangle_height, dist_ + rectangle_height]) ax.yaxis.set_ticks([]) if save: plt.savefig(save) if show: plt.show() else: plt.close()
[docs] def convergence(self, windows="length_s", trials=False, show=True, save=None, normalize=False, float_precision=3, annotate=False, restarts="default", restart_annos=None, xvalues="model", **kwargs): """ TO DO: Separate the sorting functionality from the plotting functionality, this function is too confusing. Plot the convergence rate over the course of an inversion. Scatter plot of total misfit against iteration number, or by step count .. note:: Because misfits are floats, they wont be exactly equal, so we need to set some small tolerance in which they can differ :type windows: str or bool :param windows: parameter to use for Inspector.measurements() to determine how to illustrate measurement number, either by: * length_s: cumulative window length in seconds * nwin: number of misfit windows * None: will not plot window information :type trials: str :param trials: plot the discarded trial step function evaluations from the line searches. Useful for understanding optimization efficiency * marker: plot trial steps as red x's at their respective misfit val * text: annotate the number of trial steps but not their misfit val :type normalize: bool :param normalize: normalize the objective function values between [0, 1] :type float_precision: int :param float_precision: acceptable floating point precision for comparisons of misfits. Defaults to 3 values after decimal :type restarts: list of int :param restarts: If the inversion was restarted, e.g. for parameter changes, then the convergence figure should separate two line plots. This list allows the User to tell the function where to separate the convergence plot. The integers should correspond to indices of the Inspector.models attribute. :type annotate: bool :param annotate: annotate misfit values next to markers :type restart_annos: list of str :param restart_annos: if restarts is not None, allow annotating text next to each restart. Useful for annotating e.g. parameter changes that accompany each restart :type xvalues: str :param xvalues: How the x-axis should be labelled, available: * model: plot the model number under each point * eval: number sequentially from 1 :type show: bool :param show: show the plot after making it :type save: str :param save: file id to save the figure to """ f = kwargs.get("f", None) ax = kwargs.get("ax", None) dpi = kwargs.get("dpi", 100) fontsize = kwargs.get("fontsize", 15) anno_fontsize = kwargs.get("anno_fontsize", 15) figsize = kwargs.get("figsize", (8, 6)) legend = kwargs.get("legend", True) title = kwargs.get("title", None) misfit_label = kwargs.get("misfit_label", "misfit") trial_label = kwargs.get("trial_label", "trials") window_label = kwargs.get("window_label", "windows") trial_color = kwargs.get("trial_color", "r") window_color = kwargs.get("window_color", "orange") legend_loc = kwargs.get("legend_loc", "best") axis_linewidth = kwargs.get("axis_linewidth", 2.) # Set some default parameters based on user choices, check parameters if windows: assert (windows in ["nwin", "length_s"]), \ "plot_windows must be: 'nwin; or 'length_s'" # Default restart values are chosen automatically by the Inspector if restarts == "default": restarts = self.restarts.index.values if restarts is not None and restart_annos is not None: assert(len(restarts) + 1 == len(restart_annos)), \ "Length of restart anno must match length of `restarts` + 1" assert(xvalues in ["model", "eval"]), \ "xvalues must be 'model' or 'eval'" # It may take a while to calculate models so do it once here models = self.models misfit = models.misfit.round(decimals=float_precision) nwin = self.nwin() # Set up the figure if f is None: f, ax = plt.subplots(figsize=figsize, dpi=dpi) if ax is None: ax = plt.gca() # First, we will sort the model values by accepted models, initial # evaluations, and discarded trials steps. Also need to check if # accepted models and initial evaluations are equal to one another. x = 0 # the x-position on the axis lines, xvals, yvals, xlabs = [], [], [], [] # main plot xdiscards, ydiscards, ywindows, xrestarts = [], [], [], [] # secondary for j in range(len(models)): i = j - 1 # we always need to compare to the previous misfit value # Status 0 means initial evaluation of iteration if models.state[j] == 0: # Ignore very first function evaluation if j == 0: xlab = "m00" pass # If initial eval matches line search final, treat equally elif misfit[i] == misfit[j]: continue # If they differ, treat them as different points else: x += 1 xlab = "" # xlab = f"{models.model[j]}_r" # Status 1 means final evaluation in line search elif models.state[j] == 1: x += 1 xlab = f"{models.model[j]}" # Status -1 means discarded trial step, plot on the same X value elif models.state[j] == -1: xdiscards.append(x + 1) # discards are related to next model ydiscards.append(misfit[j]) continue xvals.append(x) yvals.append(misfit[j]) xlabs.append(xlab) # Convert restart values from Inspector.models indices if restarts is not None and j in restarts: xrestarts.append(x) # Get the corresponding window number based on iter/step count if windows: i_ = models.iteration[j] s_ = models.step_count[j] ywindows.append(nwin.loc[i_].loc[s_][windows]) # Define a re-usable plotting function that takes arguments from main fx def plot_vals(x_, y_, idx=None, c="k", label=misfit_label): """ Re-used plotting commands plot a scatter plot with a certain color and label. Normalizes y-values, annotates text, if required :type x_: np.array :param x_: x values to plot :type y_: np.array :param y_: y values to plot :type idx: int :param idx: index of the inversion leg for color and label, if None defaults to `c` and `label` for color and label :type c: str :param c: color for marker and line color :type label: str :param label: label for legend, defaults to `misfit_label` from kwargs of main function """ # Overwrite default values if idx is not None: c = f"C{idx}" label = f"{misfit_label}" if normalize: y_ = [_ / max(y_) for _ in y_] line = ax.plot(x_, y_, "o-", linewidth=3, markersize=10, c=c, label=label, zorder=10, markeredgecolor="k", markeredgewidth=1.5) if annotate: for x_anno, y_anno in zip(x_, y_): ax.text(x_anno, y_anno, f"{y_anno:.3f}", zorder=11, fontsize=anno_fontsize) if restart_annos: ax.text(x_[0], y_[0], restart_annos[idx - 1], zorder=12, fontsize=anno_fontsize, verticalalignment="bottom") return line # Primary: Two methods of plotting: if xrestarts: # 1) with user-defined restarts separating legs of the inversion first = 0 # first iteration in the current leg for i, last in enumerate(xrestarts): j = i + 1 # Leg counting should start at 1 plot_vals(xvals[first:last], yvals[first:last], j) first = last # Plot the final leg lines += plot_vals(xvals[last:], yvals[last:], j + 1) if restart_annos: ax.text(xvals[last], yvals[last], restart_annos[j]) else: # 2) plot the entire convergence in one line lines += plot_vals(xvals, yvals, idx=None) # Secondary: Plot number of windows/ window length in a separate axis if windows: ax2 = ax.twinx() # Set ax2 below ax1 ax.set_zorder(ax2.get_zorder() + 1) ax.patch.set_visible(False) lines += ax2.plot(xvals, ywindows, "d--", linewidth=2, markersize=8, c=window_color, label=window_label, zorder=5, markeredgecolor="k", markeredgewidth=2 ) ydict = {"length_s": "Cumulative Window Length [s]", "nwin": "Number of Measurements"} ax2.set_ylabel(f"{ydict[windows]} (dashed)", rotation=270, labelpad=15., fontsize=fontsize ) ax2.ticklabel_format(style="sci", axis="y", scilimits=(0, 0)) ax2.yaxis.get_offset_text().set_fontsize(fontsize) ax2.tick_params(labelsize=fontsize) # Secondary: Plot the discarded trial steps if trials == "marker": # Scatterplot as red X's to show the misfit value. Not the best # because it throws off the scaling of the normal misfit values sc = ax.scatter(xdiscards, ydiscards, c=trial_color, marker="x", s=10, zorder=9, label=trial_label ) lines.append(sc) elif trials == "text": # Annotate the number of trial steps next to the corresponding value for xdiscard in set(xdiscards): # Since yvalues are normalized elsewhere, just plot the text # near the bottom of the visible axis ymin, ymax = ax.get_ylim() yval = 0.25 * (ymax - ymin) + ymin num_discards = xdiscards.count(xdiscard) ax.text(xdiscard, yval, f"{num_discards} trial(s)") # Format the axes if xvalues.lower() == "model": xlabel_ = "Model Number" ax.set_xticklabels(xlabs, rotation=60, ha="center") elif xvalues.lower() == "eval": xlabel_ = "Function Evaluation" ax.set_xticklabels(np.arange(1, len(xvals) + 1, 1)) else: xlabel = "Iteration" ax.set_xlabel(xlabel_, fontsize=fontsize) ax.xaxis.set_ticks(xvals) ax.set_ylabel("Total Normalized Misfit", fontsize=fontsize) ax.tick_params(axis="both", which="major", labelsize=fontsize) # Only set ticks on the x-axis ax.xaxis.grid(True, which="minor", linestyle=":") ax.xaxis.grid(True, which="major", linestyle="-") for axis in ["top", "bottom", "left", "right"]: ax.spines[axis].set_linewidth(axis_linewidth) if title is None: ax.set_title(f"{self.tag.title()} Convergence\n" f"{len(self.events)} Events / " f"{len(self.stations)} Stations") else: ax.set_title(title) if legend: labels = [line.get_label() for line in lines] ax.legend(lines, labels, prop={"size": 12}, loc=legend_loc) f.tight_layout() if save: plt.savefig(save) if show: plt.show() return f, ax
[docs] def default_axes(ax, cbar=None, **kwargs): """ Ensure that all plots have the same default look. Should be more flexible than setting rcParams or having a style sheet. Also allows the same kwargs to be thrown by all functions so that the function calls have the same format. Keyword Arguments :: """ tick_fontsize = kwargs.get("tick_fontsize", 10) tick_linewidth = kwargs.get("tick_linewidth", 1.5) tick_length = kwargs.get("tick_length", 5) tick_direction = kwargs.get("tick_direction", "in") label_fontsize = kwargs.get("label_fontsize", 12) axis_linewidth = kwargs.get("axis_linewidth", 2.) title_fontsize = kwargs.get("title_fontsize", 14) cbar_tick_fontsize = kwargs.get("cbar_tick_fontsize", 10) cbar_label_fontsize = kwargs.get("cbar_label_fontsize", 12) cbar_outline_color = kwargs.get("cbar_outline_color", "k") cbar_linewidth = kwargs.get("cbar_linewdith", 2.) # Re-set font sizes for labels already created ax.title.set_fontsize(title_fontsize) ax.xaxis.label.set_fontsize(label_fontsize) ax.yaxis.label.set_fontsize(label_fontsize) ax.tick_params(axis="both", which="both", width=tick_linewidth, direction=tick_direction, labelsize=tick_fontsize, length=tick_length) # Thicken up the bounding axis lines for axis in ["top", "bottom", "left", "right"]: ax.spines[axis].set_linewidth(axis_linewidth) # Adjust font and bounding bar of colorbar if available if cbar is not None: cbar.ax.tick_params(labelsize=cbar_tick_fontsize) cbar.ax.yaxis.label.set_fontsize(cbar_label_fontsize) cbar.outline.set_edgecolor(cbar_outline_color) cbar.outline.set_linewidth(cbar_linewidth)
[docs] def colormap_colorbar(cmap, vmin=0., vmax=1., dv=None, cbar_label="", extend="neither"): """ Create a custom colormap and colorbar :type cmap: matplotlib.colors.ListedColormap :param cmap: colormap to use, called like plt.cm.viridis :type vmin: float :param vmin: min value for colormap :type vmax: float :param vmax: max value for colormap :type dv: float :param dv: colormap boundary separations, if None, continuous colorbar :type cbar_label: str :param cbar_label: label for colorbar :rtype: :return: """ norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) sm.set_clim(vmin, vmax) if dv: boundaries = np.arange(vmin, vmax, dv) else: boundaries = None cbar = plt.colorbar(sm, boundaries=boundaries, shrink=0.9, pad=0.025, extend=extend) if cbar_label: cbar.ax.set_ylabel(cbar_label, rotation=270, labelpad=15) return sm, norm, cbar
[docs] def hover_on_plot(f, ax, obj, values, dissapear=True): """ Allow for hover on a plot for custom annotated information .. note:: This functionality is copied from StackOverflow: https://stackoverflow.com/questions/7908636/possible-to-make-labels-\ appear-when-hovering-over-a-point-in-matplotlib :type f: matplotlib.figure.Figure :param f: figure object for hover :type ax: matplotlib.axes._subplot.AxesSubplot :param ax: axis object for hover :type obj: matplotlib.collections.PathCollection or matplotlib.lines.Line2D :param obj: scatter plot, returned from plt.scatter() or plt.plot() :type values: list of str :param values: list of annotations :type dissapear: bool :param dissapear: annotations dissapear when mouse moves off :rtype hover: function :return hover: the hover function to be passed to matplotlib """ # Make some objects to be used for hover-over capabilities anno = ax.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points", bbox=dict(boxstyle="round", fc="w"), arrowprops=dict(arrowstyle="->"), zorder=5000 ) anno.set_visible(False) def update_anno(ind): """Functionality for getting info when hovering over a point during an interacting mpl session """ # Choice between a 2D line and a scatter plot if isinstance(obj, mpl.lines.Line2D): x, y = obj.get_data() anno.xy = (x[ind["ind"][0]], y[ind["ind"][0]]) elif isinstance(obj, mpl.collections.PathCollection): pos = obj.get_offsets()[ind["ind"][0]] anno.xy = pos text = "{}".format("\n".join([values[n] for n in ind["ind"]])) anno.set_text(text) anno.get_bbox_patch().set_facecolor("w") anno.get_bbox_patch().set_alpha(0.5) def hover(event): """Functionality for getting info when hovering over a point during an interacting mpl session """ vis = anno.get_visible() if event.inaxes == ax: cont, ind = obj.contains(event) if cont: update_anno(ind) anno.set_visible(True) f.canvas.draw_idle() # This code snippet will make the annotation dissapear when # the mouse moves away else: if vis and dissapear: anno.set_visible(False) f.canvas.draw_idle() f.canvas.mpl_connect("motion_notify_event", hover) return hover
[docs] def get_histogram_stats(n, bins): """ Get mean, variance and standard deviation from a histogram :type n: array or list of arrays :param n: values of histogram bins :type bins: array :param bins: edges of the bins """ mids = 0.5 * (bins[1:] + bins[:-1]) mean = np.average(mids, weights=n) var = np.average((mids - mean) ** 2, weights=n) std = np.sqrt(var) return mean, var, std
[docs] def annotate_txt(ax, txt, anno_location="lower-right", **kwargs): """ Convenience function to annotate some information :type ax: matplot.axes._subplots.AxesSubplot :param ax: axis to annotate onto :type txt: str :param txt: text to annotate :type anno_location: str :param anno_location: location on the figure to annotate available: bottom-right """ acceptable_locations = ["lower-right", "upper-right", "lower-left", "upper-left"] assert(anno_location in acceptable_locations), \ f"anno_location must be in {acceptable_locations}" xmin, xmax = ax.get_xlim() ymin, ymax = ax.get_ylim() if anno_location == "lower-right": x = xmin + (xmax - xmin) * 0.675 y = ymin + (ymax - ymin) * 0.025 multialignment = "right" elif anno_location == "upper-right": x = xmin + (xmax - xmin) * 0.675 y = ymin + (ymax - ymin) * 0.745 multialignment = "right" elif anno_location == "lower-left": x = xmin + (xmax - xmin) * 0.050 y = ymin + (ymax - ymin) * 0.025 multialignment = "left" elif anno_location == "upper-left": x = xmin + (xmax - xmin) * 0.050 y = ymin + (ymax - ymin) * 0.745 multialignment = "left" ax.annotate(txt, xy=(x, y), multialignment=multialignment, **kwargs)