"""
Waveform plotting functionality
Produces waveform plots from stream objects and plots misfit windows
outputted by Pyflex as well as adjoint sources from Pyadjoint.
Flexible to allow for only waveform plots, or for the addition of objects
based on inputs.
"""
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from pyatoa.utils.calculate import normalize_a_to_b, abs_max
# Hardcoded colors that represent rejected misfit windows. Description from
# pyflex.WindowSelector function that rejected the window
[docs]
rejected_window_colors = {
"water_level": "C0", # reject_on_minima_water_level()
"prominence": "C1", # reject_on_prominence_of_central_peak()
"phase_sep": "C2", # reject_on_phase_separation()
"curtail": "C3", # curtail_length_of_windows()
"min_length": "C4", # reject_windows_based_on_minimum_length()
"cc": "C5", # reject_based_on_data_fit_criteria()
"tshift": "C6", # reject_based_on_data_fit_criteria()
"dlna": "C7", # reject_based_on_data_fit_criteria()
"s2n": "C8", # reject_based_on_signal_to_noise_ratio()
"traveltimes": "C9", # reject_on_traveltimes()
"amplitude": "C10" # pyatoa reject_on_global_amplitude_ratio()
}
[docs]
class WaveMaker:
"""
Standardized waveform figures featuring observed and synthetic traces,
STA/LTA waveforms, misfit windows, rejected windows, adjoint sources and
auxiliary information collected within the workflow.
WAVEFORM KEYWORD ARGUMENTS:
FIGURE:
figsize (tuple): size of the figure, defaults (800, 200) pixels
per channel
dpi (float): dots per inch of the figure, defaults 100
legend_fontsize (int): size
linewidth (float): line width for all lines on plot, default 1.6
axes_linewidth (float): line width for axis spines, default 1
xlim_s (list): time-axis bounds of the plot in seconds, def full trace
figure (pyplot.Figure): an existing figure object to plot to, rather
than generating a new figure
subplot_spec (gridspec.GridSpec): an overlying grid that waveforms
will be plotted into. Useful for combining waveform plots
FONTSIZE:
fontsize (int): font size of the title, axis labels, def 8
axes_fontsize (int): font size of the tick labels, def 8
rejected_window_fontsize (int): fontsize for the annotations that
describe the rejected windows, default 6
window_anno_fontsize (str): fontsize for window annotation, def 7
COLORS:
obs_color (str): color for observed waveform, defaults 'k'
syn_color (str): color for synthetic waveform, defaults 'r
stalta_color (str): color of stalta waveform, default 'gray'
window_color (str): color for misfit windows, default 'orange'
adj_src_color (str): color for adjoint sources, default 'g'
ADJOINT SOURCE
adj_src_linestyle (str, tuple): adjoint souce style, default tight dash
adj_src_alpha (float): opacity of adjoint source, default 0.4
WINDOW ANNOTATIONS:
window_anno (str): a custom string which can contain the optional
format arguemnts: [max_cc, cc_shift, dlnA, left, length]. None,
defaults to formatting all arguments
window_anno_alternate (str): custom string for all windows that
aren't the first window, useful for dropping the labels for
parameters, allows for cleaner annotations without
compromising readability
window_anno_height (float): annotation height, percentage of y axis,
default 0.7
alternate_anno_height (float): optional, shift the annotation height
each window to prevent overlapping annotations
window_anno_rotation (float): rotation of annotation (deg), def 0
window_anno_fontcolor (str): color of annotation text, def 'k'
window_anno_fontweight (str): weight of font, default 'normal'
window_anno_bbox (dict): bbox dict for window annotations, None means
no bounding box
TOGGLES:
plot_xaxis (bool): toggle the labels and ticks on the x axis, def True
plot_yaxis (bool): toggle the labels and ticks on the y axis, def True
plot_windows (bool): toggle window plotting, default True
plot_rejected_windows (bool): toggle rejected window plot, default T
plot_window_annos (bool): toggle window annotations, default True
plot_staltas (bool): toggle stalta plotting, default True
plot_adjsrcs (bool): toggle adjoint source plotting, default True
plot_waterlevel (bool): toggle stalta waterlevel plotting, def True
plot_arrivals (bool): toggle phase arrival plotting, default True
plot_legend (bool): toggle legend, default True
MISC:
normalize (bool): normalize waveform data before plotting
set_title (bool or str): create a default title using workflow
parameters, if str given, overwrites all title
append_title(str): User appended string to the end of the title.
useful to get extra information on top of the default title
"""
def __init__(self, mgmt, **kwargs):
"""
Introduce class-wide objects that accessed for plotting the figure
Accessible keyword arguments can be found in ManagerPlotter()
"""
self.st_obs = mgmt.st_obs.copy()
self.st_syn = mgmt.st_syn.copy()
self.config = mgmt.config
# If auxiliary data is None, initialize as empty dictionary so that
# waveforms can still be plotted
self.windows = mgmt.windows or {}
self.staltas = mgmt.staltas or {}
self.adjsrcs = mgmt.adjsrcs or {}
self.rejwins = mgmt.rejwins or {}
self.fig = None
self.axes = None
self.twaxes = None
self.kwargs = kwargs
self.time_axis = self.st_obs[0].times() + mgmt.stats.time_offset_sec
[docs]
def setup_plot(self, dpi, figsize, twax_off=False):
"""
Dynamically set up plots according to number_of given
Calculate the figure size based on DPI, (800, 250) pixels per channel
:type dpi: float
:param dpi: dots per inch, to be set by plot()
:type figsize: tuple
:param figsize: size of the figure, set by plot()
:type twax_off: bool
:param twax_off: if True, dont instantiate a twin-x axis
:rtype (tw)axes: matplotlib axes
:return (tw)axes: axis objects
"""
# Optional kwargs to override the figure and axis instantiation
figure = self.kwargs.get("figure", None)
subplot_spec = self.kwargs.get("subplot_spec", None)
# Axis related kwargs
fontsize = self.kwargs.get("axes_fontsize", 8)
axes_linewidth = self.kwargs.get("axes_linewidth", 1)
# Initiate the figure and fill it up with grids
if figure is None:
self.fig = plt.figure(figsize=figsize, dpi=dpi)
else:
self.fig = figure
nrows, ncols = len(self.st_obs), 1
heights = [1] * len(self.st_obs)
if subplot_spec is None:
gs = mpl.gridspec.GridSpec(nrows, ncols, height_ratios=heights,
hspace=0)
else:
# gridspeception!
gs = mpl.gridspec.GridSpecFromSubplotSpec(nrows, ncols,
subplot_spec=subplot_spec,
height_ratios=heights,
hspace=0)
axes, twaxes = [], []
for i in range(gs.get_geometry()[0]):
if i == 0:
ax = plt.subplot(gs[i])
else:
ax = plt.subplot(gs[i], sharex=axes[0])
twinax = ax.twinx()
pretty_grids(twinax, twax=True, fontsize=fontsize,
linewidth=axes_linewidth)
pretty_grids(ax, fontsize=fontsize, linewidth=axes_linewidth)
twaxes.append(twinax)
axes.append(ax)
# remove x-tick labels except for last axis
for ax in axes[0:-1]:
plt.setp(ax.get_xticklabels(), visible=False)
# option to turn off twin axis
if twax_off:
for twax in twaxes:
twax.axes.get_xaxis().set_visible(False)
twax.axes.get_yaxis().set_visible(False)
self.axes = axes
self.twaxes = twaxes
[docs]
def plot_stalta(self, ax, stalta, plot_waterlevel=True):
"""
Plot the Short-term-average/long-term-average waveform to help visually
identify the peaks/troughs used to determine windows
:type ax: matplotlib.axes.Axes
:param ax: axis object on which to plot
:type stalta: numpy.ndarray
:param stalta: data array containing the sta/lta waveform
:type plot_waterlevel: bool
:param plot_waterlevel: plot a horizontal line showing the relative
waterlevel of the sta/lta which is used in determining windows
:rtype: list of matplotlib.lines.Lines2D objects
:return: list containing the lines plotted on the axis
"""
linewidth = self.kwargs.get("linewidth", 1.6)
stalta_color = self.kwargs.get("stalta_color", "gray")
fontsize = self.kwargs.get("stalta_wl_fontsize", 6)
# Get waterlevel from the Pyflex config object
stalta_wl = self.config.pyflex_config.stalta_waterlevel
# Bounds for use in setting positions
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
# Normalize the bounds of the sta/lta to the y-axis because we only
# care about the phase and waterlevel
stalta = normalize_a_to_b(stalta, ymin, ymax)
waterlevel = (ymax - ymin) * stalta_wl + ymin
b2, = ax.plot(self.time_axis, stalta, stalta_color, alpha=0.4,
linewidth=linewidth, zorder=9, label=f"STA/LTA")
if plot_waterlevel:
# Plot the waterlevel of the STA/LTA defined by Pyflex Config
ax.axhline(y=waterlevel, xmin=self.time_axis[0],
xmax=self.time_axis[-1], alpha=0.4, zorder=8,
linewidth=linewidth, c=stalta_color, linestyle='--')
ax.annotate(f"stalta_waterlevel = {stalta_wl}", alpha=0.7,
fontsize=fontsize,
xy=(0.75 * (xmax - xmin) + xmin, waterlevel)
)
return [b2]
[docs]
def plot_windows(self, ax, windows, plot_window_annos=True,
plot_phase_arrivals=True):
"""
Plot misfit windows, add annotations to each window related to
information contained in the Window object.
.. note::
The keyword argument 'window_anno_height' should be given as a
percentage of visible y-axis, e.g. 0.25 means 25% of the y-axis
:type ax: matplotlib.axes.Axes
:param ax: axis object on which to plot
:type windows: list of pyflex.Window objects
:param windows: list of windows to plot
:type plot_window_annos: bool
:param plot_window_annos: annotate window information onto windows
:type plot_phase_arrivals: bool
:param plot_phase_arrivals: make small tick mark if P or S phase arrival
within the window
"""
window_anno = self.kwargs.get("window_anno", None)
window_color = self.kwargs.get("window_color", "orange")
window_anno_fontsize = self.kwargs.get("window_anno_fontsize", 6)
window_anno_height = self.kwargs.get("window_anno_height", 1)
window_anno_rotation = self.kwargs.get("window_anno_rotation", 90)
window_anno_fontcolor = self.kwargs.get("window_anno_fontcolor", "k")
window_anno_fontweight = self.kwargs.get("window_anno_fontweight",
"normal")
window_anno_ha = self.kwargs.get("window_anno_ha", "left")
window_anno_va = self.kwargs.get("window_anno_va", "top")
window_anno_bbox = self.kwargs.get("window_anno_box", None)
# Determine heights for the annotations, allow alternating heights so
# that adjacent windows don't write over one another
ymin, ymax = ax.get_ylim()
y_anno = window_anno_height * (ymax - ymin) + ymin
# Default window annotation string
if window_anno is None:
window_anno = "cc={max_cc:.2f} / dT={cc_shift:.2f} / dA={dlnA:.2f}"
for j, window in enumerate(windows):
tleft = window.left * window.dt + self.time_axis[0]
tright = window.right * window.dt + self.time_axis[0]
# Misfit windows as rectangle; taken from Pyflex
ax.add_patch(Rectangle(xy=(tleft, ymin), width=tright - tleft,
height=(ymax + np.abs(ymin)),
fc=window_color, ec="k",
alpha=(window.max_cc_value ** 2) * 0.25,
zorder=10
)
)
# Outline the rectangle with solid black lines
for x_ in [tleft, tright]:
ax.axvline(x=x_, ymin=0, ymax=1, color="k", alpha=1., zorder=11)
if plot_window_annos:
# Annotate window information into each window
t_anno = (tright - tleft) * 0.025 + tleft
s_anno = window_anno.format(
i=j+1,
max_cc=window.max_cc_value,
cc_shift=window.cc_shift * window.dt,
dlnA=window.dlnA,
left=tleft,
length=tright - tleft)
ax.annotate(s_anno, ha=window_anno_ha, va=window_anno_va,
xy=(t_anno, y_anno),
zorder=12, fontsize=window_anno_fontsize,
rotation=window_anno_rotation,
color=window_anno_fontcolor,
fontweight=window_anno_fontweight,
bbox=window_anno_bbox
)
if plot_phase_arrivals:
for phase_arrivals in window.phase_arrivals:
if phase_arrivals["name"] in ["p", "s"]:
ax.axvline(x=phase_arrivals["time"], ymin=0, ymax=0.05,
color='b', alpha=0.5
)
ax.annotate(phase_arrivals["name"],
xy=(0.975 * phase_arrivals["time"],
0.05 * (ymax-ymin) + ymin),
fontsize=8
)
[docs]
def plot_rejected_windows(self, ax, rejwin, windows=None, skip_tags=None):
"""
Plot rejected windows as transparent lines at the bottom of the axis.
Hardcoded color dictionary (defined at top) used as a way to visually
identify why certain windows were rejected
The function performs some array manipulation to exclude rejected
windows that fall within already chosen time windows to avoid redundant
plotting.
:type ax: matplotlib.axes.Axes
:param ax: axis object on which to plot
:type rejwin: list of pyflex.Window objects
:param rejwin: list of rejected windows to plot
:type windows: list of pyflex.Window objects
:param windows: list of windows to use for exclusion
:type skip_tags: list of str
:param skip_tags: an optional list of tags that can be used to skip
specific rejected window tags
"""
fontsize = self.kwargs.get("rejected_window_fontsize", 6)
# By default skip the water_level tag, from experience those windows
# usually just cover the entire window search range, i.e. uninformative
if skip_tags is None:
skip_tags = ["water_level"]
ymin, ymax = ax.get_ylim()
dy = 0.075 * (ymax - ymin) # increment for bar location
# Chosen time windows - collapse adjacent windows into single windows by
# looking for repeated values and removing them, then recombining array
if windows is not None:
win_arr = np.array(
[[_.left * _.dt, _.right * _.dt] for _ in windows])
win_arr, count = np.unique(win_arr, return_counts=True)
idx_vals_repeated = np.where(count > 1)[0]
win_arr = np.delete(win_arr, idx_vals_repeated)
win_arr = win_arr.reshape(len(win_arr) // 2, 2)
else:
win_arr = None
for tag in rejwin.keys():
# Skip plotting certain window rejects
if tag in skip_tags:
continue
# We will compare the start and endtimes using a boolean array
rwin_arr = np.array([[_.left * _.dt, _.right * _.dt] for _ in
rejwin[tag]]
)
# Check if rejected windows are contained within the window bounds
if win_arr is not None:
bool_arr = None
for wa in win_arr:
bool_arr_ = np.logical_and(rwin_arr[:, 0] >= wa[0], # start
rwin_arr[:, 1] <= wa[1] # end
)
if bool_arr is not None:
bool_arr = np.logical_and(bool_arr, bool_arr_)
else:
bool_arr = bool_arr_
rwin_arr = rwin_arr[~bool_arr]
# Negate the booleans to exclude rej windows within bounds
if rwin_arr.any():
for rw in rwin_arr:
# Shift rejected windows by the proper time offset
rw += self.time_axis[0]
# Plot as rectangle, shorter windows get larger zorder
ax.add_patch(Rectangle(xy=(rw[0], ymin),
width=rw[1] - rw[0],
height=dy, ec="k", alpha=0.25,
fc=rejected_window_colors[tag],
zorder=15 + 1 / (rw[1] - rw[0])
)
)
# Annotate the leftmost rejected window point with the tag
ax.annotate(tag.replace("_", " "),
xy=(rwin_arr[:, 0].min(), ymin), fontsize=fontsize,
zorder=14)
ymin -= dy
# Reset ylimits based on the extent of the rejected windows
ax.set_ylim([-abs(ymin), ymax])
[docs]
def plot_adjsrcs(self, ax, adjsrc):
"""
Plot adjoint sources behind streams, time reverse the adjoint source
that is provided by Pyadjoint so that it lines up with waveforms
and windows.
Note:
The unit of adjoint source is based on Eq. 57 of Tromp et al. 2005,
which is a traveltime adjoint source, and includes the units of the
volumentric delta function since it's assumed this is happening in a
3D volume.
:type ax: matplotlib.axes.Axes
:param ax: axis object on which to plot
:type adjsrc: pyadjoint.adjoint_source.AdjointSource objects
:param adjsrc: adjsrc object containing data to plot
:rtype: list of matplotlib.lines.Lines2D objects
:return: list containing the lines plotted on the axis
"""
linewidth = self.kwargs.get("linewidth", 1.6)
linestyle = self.kwargs.get("adj_src_linestyle", (0, (5, 1)))
color = self.kwargs.get("adj_src_color", "g")
alpha = self.kwargs.get("adj_src_alpha", 0.4)
# Time reverse adjoint source; line up with waveforms
b1, = ax.plot(self.time_axis, adjsrc.adjoint_source[::-1], color,
alpha=alpha, linewidth=linewidth, linestyle=linestyle,
zorder=9,
label=fr"Adjoint Source ($\chi$={adjsrc.misfit:.2f})"
)
return [b1]
[docs]
def plot_amplitude_threshold(self, ax, obs):
"""
Plot a line to show the amplitude threshold criteria used by Pyatoa
:type ax: matplotlib.axes.Axes
:param ax: axis object on which to plot
:type obs: obspy.core.trace.Trace
:param obs: observed trace plotted on the current axis, used to
determine the peak amplitude value
"""
xmin, xmax = ax.get_xlim()
threshold_amp = abs(self.config.win_amp_ratio * abs_max(obs.data))
# Plot both negative and positive bounds
for sign in [-1, 1]:
ax.axhline(y=sign * threshold_amp, xmin=self.time_axis[0],
xmax=self.time_axis[-1], alpha=0.35, zorder=6,
linewidth=1.25, c='k', linestyle=':')
# Annotate window amplitude ratio
ax.annotate(f"{self.config.win_amp_ratio * 100:.0f}% peak amp. obs.",
alpha=0.7, xy=(0.85 * (xmax-xmin) + xmin, threshold_amp),
fontsize=8
)
[docs]
def create_title(self, normalized=False, append_title=None):
"""
Create the title based on information provided to the class
:type normalized: bool
:param normalized: whether or not the data was normalized
:type append_title: str
:param append_title: append extra information to title
:rtype: str
:return: title string composed of configuration parameters and source
receiver information
"""
title = f"{self.st_obs[0].stats.network}.{self.st_obs[0].stats.station}"
# Event id may not be given, if it is, append to title
if self.config.event_id is not None:
title += f" {self.config.event_id}"
# Filter bounds to plot title
title += f" [{self.config.min_period}-{self.config.max_period}]s"
# Tell the User if the data has been normalized
if normalized:
title += " (normalized) "
# Add information about the iteration, windowing and misfit measurement
title += "\n"
if self.config.iter_tag is not None:
title += self.config.iter_tag
if self.config.step_tag is not None:
title += self.config.step_tag
# Add information about the Pyadjoint parameters used
if self.kwargs.get("plot_adjsrc", True):
title += f" pyadjoint={self.config.adj_src_type}, "
# User appended title information
if append_title is not None:
title = " ".join([title, append_title])
return title
[docs]
def plot(self, show=True, save=False, **kwargs):
"""
High level plotting function that plots all parts of the class and
formats the axes nicely
"""
# Allow additional kwargs to be passed in to the plot argument
self.kwargs.update(kwargs)
# Distribute some kwargs before starting
dpi = self.kwargs.get("dpi", 100)
# Make the figure just slightly larger than the size of all traces
figsize = self.kwargs.get("figsize",
(800 / dpi, 200 * (len(self.st_obs) + .3) / dpi)
)
fontsize = self.kwargs.get("fontsize", 8)
legend_fontsize = self.kwargs.get("legend_fontsize", 6)
append_title = self.kwargs.get("append_title", None)
normalize = self.kwargs.get("normalize", False)
xlim_s = self.kwargs.get("xlim_s", None)
percent_over = self.kwargs.get("percent_over", 0.125)
set_title = self.kwargs.get("set_title", True)
plot_xaxis = self.kwargs.get("plot_xaxis", True)
plot_yaxis = self.kwargs.get("plot_yaxis", True)
plot_windows = self.kwargs.get("plot_windows", True)
plot_rejected_windows = self.kwargs.get("plot_rejected_windows", True)
plot_window_annos = self.kwargs.get("plot_window_anno", True)
plot_staltas = self.kwargs.get("plot_stalta", True)
plot_adjsrcs = self.kwargs.get("plot_adjsrc", True)
plot_waterlevel = self.kwargs.get("plot_waterlevel", True)
plot_arrivals = self.kwargs.get("plot_arrivals", True)
plot_legend = self.kwargs.get("legend", True)
# If nothing on the twin axis, this will turn off tick marks
twax_off = bool(not plot_staltas or not plot_adjsrcs)
# Plot per component in the same fashion, only if observed data exists
self.setup_plot(dpi=dpi, figsize=figsize, twax_off=twax_off)
for i, obs in enumerate(self.st_obs):
comp = obs.stats.channel[-1]
# Try to retrieve auxiliary data by component name
syn = self.st_syn.select(component=comp)[0]
windows = self.windows.get(comp, None)
adjsrc = self.adjsrcs.get(comp, None)
stalta = self.staltas.get(comp, None)
rejwin = self.rejwins.get(comp, None)
# Begin plotting by distributing axis objects
ax = self.axes[i]
twax = self.twaxes[i]
lines = [] # List of lines for making the legend
lines += self.plot_waveforms(obs=obs, syn=syn, ax=ax,
normalize=normalize)
if rejwin is not None and plot_rejected_windows:
self.plot_rejected_windows(ax=ax, rejwin=rejwin,
windows=windows)
# Format now as windows use the y-limits for height of windows
format_axis(ax, percent_over)
if windows is not None and plot_windows:
self.plot_windows(ax=ax, windows=windows,
plot_window_annos=plot_window_annos,
plot_phase_arrivals=plot_arrivals)
if adjsrc is not None and plot_adjsrcs:
lines += self.plot_adjsrcs(ax=twax, adjsrc=adjsrc)
if i == len(self.st_obs) // 2:
# middle trace: append units of the adjoint source on ylabel
twax.set_ylabel("adjoint source [m$^{-4}$ s]", rotation=270,
labelpad=20, fontsize=fontsize)
else:
twax.set_yticklabels([]) # turn off yticks if no adjsrc
# Format twax because stalta will use y-limits for its waveforms
format_axis(twax)
if stalta is not None and plot_staltas:
lines += self.plot_stalta(ax=twax, stalta=stalta,
plot_waterlevel=plot_waterlevel
)
if self.config.win_amp_ratio > 0:
self.plot_amplitude_threshold(ax=ax, obs=obs)
# Finish with axes formatting
ax.set_ylabel(comp.upper(), fontsize=fontsize)
if i == len(self.st_obs) // 2:
# Middle trace will carry the units of the waveforms
units = {"DISP": "displacement [m]",
"VEL": "velocity [m/s]",
"ACC": "acceleration [m/s^2]",
"none": ""}[self.config.unit_output]
ax.set_ylabel(f"{units}\n{ax.get_ylabel()}", fontsize=fontsize)
if plot_legend:
labels = [l.get_label() for l in lines]
ax.legend(lines, labels, prop={"size": legend_fontsize},
loc="upper right")
align_yaxes(ax, twax)
# Final touch ups for the figure
if isinstance(set_title, bool):
if set_title:
self.axes[0].set_title(self.create_title(
append_title=append_title, normalized=normalize),
fontsize=fontsize
)
else:
self.axes[0].set_title(set_title, fontsize=fontsize)
if xlim_s is not None:
self.axes[0].set_xlim(xlim_s)
else:
self.axes[0].set_xlim([self.time_axis[0], self.time_axis[-1]])
self.axes[-1].set_xlabel("time [s]", fontsize=fontsize)
# Option to turn off tick labels and axis labels
if not plot_xaxis or not plot_yaxis:
for axis in [self.axes, self.twaxes]:
for ax in axis:
if not plot_xaxis:
ax.axes.xaxis.set_ticklabels([])
ax.set_xlabel("")
if not plot_yaxis:
ax.axes.yaxis.set_ticklabels([])
ax.set_ylabel("")
if save:
plt.savefig(save)
if show:
plt.show()
[docs]
def align_yaxes(ax1, ax2):
"""
Plotting tool to adjust ax2 ylimit so that v2 in ax2 is aligned to v1 in ax1
:type ax1: matplotlib axis
:param ax1: axes to adjust
:type ax2: matplotlib axis
:param ax2: axes to adjust
"""
ymin_a1, ymax_a1 = ax1.get_ylim()
ymin_a2, ymax_a2 = ax2.get_ylim()
_, y1 = ax1.transData.transform((0, (ymax_a1+ymin_a1)/2))
_, y2 = ax2.transData.transform((0, (ymax_a2+ymin_a2)/2))
inv = ax2.transData.inverted()
_, dy = inv.transform((0, 0)) - inv.transform((0, y1-y2))
ax2.set_ylim(ymin_a2+dy, ymax_a2+dy)
[docs]
def pretty_grids(input_ax, twax=False, grid=False, fontsize=8, linewidth=1,
sci_format=True):
"""
Standard plot skeleton formatting, thick lines and internal tick marks etc.
:type input_ax: matplotlib axis
:param input_ax: axis to prettify
:type twax: bool
:param twax: If twax (twin axis), do not set grids
:type grid: bool
:param grid: turn on grids of the axes, default grids off
:type fontsize: float
:param fontsize: fontsize of the axis tick labels
:type linewidth: float
:param linewidth: line width of the axis spines or boundign box
:type sci_format: bool
:param sci_format: turn on/off scientific formatting of tick labels
default scientific format on/True.
"""
input_ax.set_axisbelow(True)
if sci_format:
input_ax.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
# Make sure the scientific notation also has the same fontsize
input_ax.yaxis.get_offset_text().set_fontsize(fontsize)
# Ensure the twin axis doesn't clash with ticks of the main axis
if twax:
left = False
right = True
else:
left = True
right = False
input_ax.tick_params(which='major', direction='in', top=True, right=right,
left=left, length=8, labelsize=fontsize,
width=2*linewidth/3)
input_ax.tick_params(which='minor', direction='in', top=True, right=right,
left=left, length=4, labelsize=fontsize,
width=2*linewidth/3)
for axis in ["top", "bottom", "left", "right"]:
input_ax.spines[axis].set_linewidth(linewidth)
# Set the grids 'on' only if main axis
if not twax:
input_ax.minorticks_on()
if grid:
for axis_ in ['major', 'minor']:
input_ax.grid(which=axis_, linestyle=':', linewidth='0.5',
color='k', alpha=0.25)