Source code for pyatoa.visuals.improve_wave

#!/usr/bin/env python3
"""
Functions to create a figure showing progressive waveform improvement over
the course of a seismic inversion.

Show the changes in synthetic waveforms with progressive model updates. 
Each individual model gets its on row in the plot.
"""
import os
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pyasdf
from pyasdf import ASDFDataSet as asdf 
from pyatoa import Manager, logger
from pyatoa.utils.form import format_event_name
from pyatoa.visuals.wave_maker import format_axis
from pyflex import logger as pflogger

pflogger.setLevel("DEBUG")
logger.setLevel("INFO")


[docs] class ImproveWave: """ A class to plot waveform improvement for a given ASDFDataSet .. code:: python ds = pyasdf.ASDFDataSet("dataset.h5") wi = WaveformImprovement(ds) wi.gather("NZ.BFZ", 10, 30) wi.plot() wi.plot("NZ.KNZ", 8, 30) """ def __init__(self): """ Initiate empty objects and keep dataset as an internal attribute :type ds: pyasdf.ASDFDataSet :param ds: dataset containing waveform data and windows """ # self.ds = ds self.st_obs = None self.synthetics = None self.windows = None self.time_axis = None
[docs] def get_models(self): """ Figure out which step goes to which iteration to get model numbers """ models = {"m00": "i01/s00"} iterations = self.ds.auxiliary_data.MisfitWindows.list() for iter_ in iterations: steps = self.ds.auxiliary_data.MisfitWindows[iter_] if iter_.replace("i", "m") in models: continue elif "s00" in steps.list(): models[iter_.replace("i", "m")] = f"{iter_}/s00" else: models[iter_.replace("i", "m")] = \ f"{prev_iter}/{prev_steps.list()[-1]}" prev_iter = iter_ prev_steps = steps # Get the last step if prev_steps.list()[-1] != "s00": final_model = f"m{int(prev_iter.split('i')[-1]) + 1:0>2}" models[final_model] = f"{prev_iter}/{prev_steps.list()[-1]}" return models
[docs] def gather(self, sta, min_period, max_period, rotate_to_rtz=False, fix_windows=False, pyflex_preset=False): """ Parse dataset for given station, gather observed and synthetic data, preprocess data and return as stream objects. :type sta: str :param sta: station to gather data for :type min_period: float :param min_period: minimum filter period in seconds :type max_period: float :param max_period: maximum filter period in seconds :type rotate_to_rtz: bool :param rotate_to_rtz: rotate components from NEZ to RTZ Config. if False, instrument response will be removed from obs. :type fix_windows: bool :param fix_windows: dont recalculate windows when gathering :type pyflex_preset: str :param pyflex_preset: overwrite the pyflex preset provided in the Config object """ if min_period is None or max_period is None: raise TypeError("must specify 'min_period' and 'max_period'") assert(sta in self.ds.waveforms.list()), f"{sta} not in ASDFDataSet" models = self.get_models() # Preprocess all traces using Pyatoa and store in dict st_obs, synthetics, windows = None, {}, {} for model, path in models.items(): # Gather synthetic data mgmt = Manager(ds=self.ds) print(path) mgmt.load(sta, path) # Overwrite some config parameters mgmt.config.min_period = min_period mgmt.config.max_period = max_period if rotate_to_rtz: mgmt.config.rotate_to_rtz = rotate_to_rtz mgmt.config.component_list = ["Z", "R", "T"] if pyflex_preset: mgmt.config.pyflex_preset = pyflex_preset mgmt.config._check() mgmt.standardize() mgmt.preprocess() iter_, step_ = path.split("/") mgmt.window(fix_windows=fix_windows, iteration=iter_, step_count=step_) windows[model] = mgmt.windows synthetics[model] = mgmt.st_syn.copy() # Observed waveform will be the same if st_obs is None: st_obs = mgmt.st_obs.copy() # Internally used by plotting function self.st_obs = st_obs self.synthetics = synthetics self.windows = windows self.time_axis = self.st_obs[0].times( reftime=st_obs[0].stats.starttime - mgmt.stats.time_offset_sec )
[docs] def gather_simple(self, event, sta, min_period, max_period, path_dict=None, component=None): """ Manually set the model values based on inspection of the Inspector Don't return windows or anything, keep it simple """ models = {"m00": ("i01/s00", "a"), "m03": ("i03/s03", "a"), "m09": ("i09/s02", "a"), "m12": ("i12/s04", "a"), "m17": ("i17/s01", "a"), "m24": ("i07/s01", "b"), "m28": ("i11/s03", "c"), } st_obs, synthetics = None, {} windows = None for model, tup in models.items(): path, tag = tup if path_dict: ds_fid = os.path.join(path_dict[tag], f"{event_id}.h5") else: ds_fid = f"{event_id}{tag}.h5" with asdf(ds_fid, mode="r") as ds: mgmt = Manager(ds=ds) mgmt.load(sta, path) mgmt.config.min_period = min_period mgmt.config.max_period = max_period mgmt.standardize().preprocess() if component: synthetics[model] = mgmt.st_syn.select( component=component).copy() else: synthetics[model] = mgmt.st_syn.copy() if st_obs is None: if component: st_obs = mgmt.st_obs.select(component=component).copy() else: st_obs = mgmt.st_obs.copy() self.st_obs = st_obs self.synthetics = synthetics
[docs] def setup_plot(self, nrows, ncols, **kwargs): """ Dynamically set up plots according to number_of given Returns a list of lists of axes objects e.g. axes[i][j] gives the ith column and the jth row :type nrows: int :param nrows: number of rows in the gridspec :type ncols: int :param ncols: number of columns in the gridspec :rtype axes: matplotlib axes :return axes: axis objects """ dpi = kwargs.get("dpi", 150) figsize = kwargs.get("figsize", (500/dpi, 800/dpi)) fontsize = kwargs.get("fontsize", 10) axis_linewidth = kwargs.get("axis_linewidth", 2) f = plt.figure(figsize=figsize, dpi=dpi) gs = mpl.gridspec.GridSpec(nrows, ncols, hspace=0, wspace=0.025, width_ratios=[1] * ncols, height_ratios=[3] * nrows ) axes = [[] for _ in range(nrows)] for row in range(0, gs.get_geometry()[0]): for col in range(0, gs.get_geometry()[1]): # Ensure axis sharing if col == 0: sharey = None else: sharey = axes[row][0] if row == 0 and col == 0: sharex = None else: sharex = axes[0][0] ax = plt.subplot(gs[row, col], sharey=sharey, sharex=sharex) ax.set_axisbelow(True) ax.minorticks_on() ax.tick_params(which='major', direction='in', top=True, right=False, left=False, labelsize=fontsize, length=3, width=2*axis_linewidth/3) ax.tick_params(which='minor', direction='in', length=1.5, top=True, bottom=True, right=False, left=False, width=2*axis_linewidth/3) for axis in ["top", "bottom", "left", "right"]: ax.spines[axis].set_linewidth(axis_linewidth) # Turn off the y axes because we wont show units ax.get_yaxis().set_ticks([]) axes[row].append(ax) # remove x-tick labels except for last axis for row in axes[:-1]: for col in row: plt.setp(col.get_xticklabels(), visible=False) return f, axes
[docs] def plot(self, sta=None, event_id=None, min_period=None, max_period=None, plot_windows=False, trace_length=None, show=True, save=False, **kwargs): """ Plot waveforms iterative based on model updates :type sta: str :param sta: station to gather data for, if None, skips gathering assuming data has already been gathered :type min_period: float :param min_period: minimum filter period for waveforms :type max_period: float :param max_period: maximum filter period for waveforms :type plot_windows: bool :param plot_windows: plot misfit windows above waveforms :type trace_length: list of floats :param trace_length: [trace_start, trace_end] will be used to set the x limit on the waveform data. If none, no xlim will be set :type show: bool :param show: Show the plot or do not :type save: str :param save: if given, save the figure to this path """ linewidth = kwargs.get("linewidth", 1.) fontsize = kwargs.get("fontsize", 10) anno_fontsize = kwargs.get("anno_fontsize", 8) window_color = kwargs.get("window_color", "orange") percent_over = kwargs.get("percent_over", 0.125) anno_choice = kwargs.get("anno_choice", "all") # Allows for skipping the gather call and including it directly in plot if sta is not None: self.gather(sta, min_period, max_period) assert self.st_obs, "must collect data for a station before plotting" # Instantiate the plotting object f, axes = self.setup_plot(nrows=len(self.synthetics.keys()), ncols=len(self.st_obs), **kwargs) # if not trace_length: # trace_length = [self.time_axis[0], self.time_axis[-1]] # Plot each model on a different row synthetic_keys = list(self.synthetics.keys()) synthetic_keys.sort() for row, syn_key in enumerate(synthetic_keys): ylab = syn_key.split('_')[-1] # e.g. 'm00' # Plot each component in a different column component_list = [_.stats.channel[-1] for _ in self.st_obs] for col, comp in enumerate(component_list): obs = self.st_obs.select(component=comp)[0] syn = self.synthetics[syn_key].select(component=comp)[0] # Plot waveforms a1, = axes[row][col].plot(obs.times(), obs.data, 'k', zorder=10, label="Obs", linewidth=linewidth) a2, = axes[row][col].plot(syn.times(), syn.data, ["r", "b", "g"][col], zorder=10, label="Syn", linewidth=linewidth) # Format the axes for a standardized look format_axis(axes[row][col], percent_over=percent_over) # Plot windows if available for this given component tshift_max = 0 # temporary if plot_windows: windows = self.windows[syn_key].get(comp, []) for w, win in enumerate(windows): ymin, ymax = axes[row][col].get_ylim() tleft = win.left * win.dt + self.time_axis[0] tright = win.right * win.dt + self.time_axis[0] tshift = win.cc_shift * win.dt axes[row][col].add_patch(mpl.patches.Rectangle( xy=(tleft, ymin), width=tright-tleft, ec='k', fc=window_color, height=(ymax + np.abs(ymin)), alpha=(win.max_cc_value **2) / 4) ) # Outline the rectangle with solid lines for t_ in [tleft, tright]: axes[row][col].axvline(x=t_, ymin=0, ymax=1, color="k", alpha=1., zorder=11) # Annotate time shift value into window, # Alternate height if multiple windows so no overlap if anno_choice == "all": axes[row][col].text( s=f"{tshift:.2f}s", x=tleft, y=(ymax-ymin)*[0.7, 0.06][w%2]+ymin, fontsize=anno_fontsize, zorder=11 ) # If only annotate the largest timeshift if abs(tshift) > abs(tshift_max): tshift_max = tshift tleft_max = tleft # tright_max = tright # If annotate largesttime shift value into window if tshift_max and anno_choice == "max": axes[row][col].text( s=f"{tshift_max:.2f}s", x=tleft_max, y=(ymax-ymin)*0.06 + ymin, fontsize=anno_fontsize, zorder=11 ) if row == 0: # determine how long the traces should be # hardcode the trace length based on user params if isinstance(trace_length, list): axes[row][col].set_xlim(trace_length) # else: # axes[row][col].set_xlim([ # np.maximum(self.time_axis[0], -10), t[-1] # ]) # Set titles for the first row, middle column if col == len(self.st_obs) // 2: title = (f"{self.st_obs[0].stats.network}." f"{self.st_obs[0].stats.station} " f"{event_id} Z") axes[row][col].set_title(title, fontsize=fontsize) # Append component to bottom right of subplot if False: axes[row][col].text( x=0.95, y=0.15, s=comp.upper(), horizontalalignment="center", verticalalignment="center", transform=axes[row][col].transAxes) # y-label after all the processing has occurred axes[row][0].set_ylabel(ylab, rotation="horizontal", ha="right", fontsize=fontsize) # Label the time axis on the bottom row, middle column axes[-1][len(self.st_obs) // 2].set_xlabel("time [s]", fontsize=fontsize) # Save the generated figure if save: plt.savefig(save) if show: plt.show() return f, axes
def gather_simple(self, models, event_id, sta, component, min_period, max_period): """ Gather waveforms from manually input model values, usually determined by using the Inspector class :type models: dict of tuples :param models: model values as keys, (iter/step, tag) as tuple value. Tags allow multiple datasets to be used, e.g. if an inversion spans over multiple legs and more than one dataset is used to store waveform data :type event_id: str :param event_id: name of the event, used to access the ASDFDataSet ;type sta: str :param sta: station id to gather data for :type min_period: float :param min_period: period to filter data at :type max_period: float :param max_period: period to filter data at """ st_obs, synthetics = None, {}
if __name__ == "__main__":
[docs] pairs = [ ("2013p617227", "NZ.TOZ", "Z"), # ("2014p952799", "NZ.NTVZ", "N"), # ("2016p105478", "NZ.PUZ", "Z"), # ("2016p881118", "NZ.MWZ", "E"), # ("2018p465580", "NZ.KHEZ", "E"), # ("2019p738432", "NZ.KHZ", "Z"), # ("2019p754447", "NZ.HIZ", "Z"), # ("2019p927023", "NZ.VRZ", "Z"), ]
path_dict = {"a": "../waveform_comparisons/aspen/", "b": "/home/chowbr/current/birch/scratch/preprocess/datasets/i07_corrupted/", "c": "../waveform_comparisons/birch/"} for event_id, sta, comp in pairs: wi = ImproveWave() wi.gather_simple(event_id, sta, 6, 30, path_dict=path_dict, component=comp) wi.plot(show=False, save=f"{event_id}_{sta}.png", event_id=event_id, trace_length=[70,290]) a=1/0 # MAIN event_id = "2019p927023" with asdf(f"{event_id}a.h5") as ds: stations = ds.waveforms.list() self.st_obs = st_obs self.synthetics = synthetics if __name__ == "__main__":
[docs] event_id = "2013p507880"
component = "Z" models = {"m00": ("i01/s00", ""), "m01": ("i01/s04", ""), "m02": ("i02/s01", ""), "m02": ("i03/s00", ""), "m03": ("i03/s01", ""), "m03": ("i04/s00", ""), "m04": ("i04/s04", ""), "m04": ("i05/s00", ""), "m05": ("i05/s04", ""), "m05": ("i06/s00", ""), "m06": ("i06/s01", ""), "m07": ("i07/s01", ""), "m08": ("i08/s03", ""), "m09": ("i09/s01", ""), "m10": ("i10/s04", ""), "m10": ("i11/s00", ""), "m11": ("i11/s03", ""), "m12": ("i12/s01", ""), "m13": ("i13/s01", ""),} # with asdf(f"{event_id}.h5") as ds: # stations = ds.waveforms.list() stations = ["NZ.TOZ"] for sta in stations: wi = ImproveWave() wi.gather_simple(models, event_id, sta, component, 4, 30) wi.plot()