Source code for aerosandbox.tools.pretty_plots.plots.contour

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from typing import Tuple, Dict, Union, Callable, List
from scipy import interpolate
from aerosandbox.tools.string_formatting import eng_string


[docs]def contour( *args, levels: Union[int, List, np.ndarray] = 31, colorbar: bool = True, linelabels: bool = True, cmap=None, alpha: float = 0.7, extend: str = "neither", linecolor="k", linewidths: float = 0.5, extendrect: bool = True, linelabels_format: Union[str, Callable[[float], str]] = eng_string, linelabels_fontsize: float = 8, max_side_length_nondim: float = np.inf, colorbar_label: str = None, x_log_scale: bool = False, y_log_scale: bool = False, z_log_scale: bool = False, mask: np.ndarray = None, drop_nans: bool = None, # smooth: Union[bool, int] = False, # TODO implement contour_kwargs: Dict = None, contourf_kwargs: Dict = None, colorbar_kwargs: Dict = None, linelabels_kwargs: Dict = None, **kwargs, ): """ An analogue for plt.contour and plt.tricontour and friends that produces a much prettier default graph. Can take inputs with either contour or tricontour syntax. See syntax here: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.contour.html https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.contourf.html https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.tricontour.html https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.tricontourf.html Args: X: If dataset is gridded, follow `contour` syntax. Otherwise, follow `tricontour` syntax. Y: If dataset is gridded, follow `contour` syntax. Otherwise, follow `tricontour` syntax. Z: If dataset is gridded, follow `contour` syntax. Otherwise, follow `tricontour` syntax. levels: See contour docs. colorbar: Should we draw a colorbar? linelabels: Should we add line labels? cmap: What colormap should we use? alpha: What transparency should all plot elements be? extend: See contour docs. linecolor: What color should the line labels be? linewidths: See contour docs. extendrect: See colorbar docs. linelabels_format: See ax.clabel docs. linelabels_fontsize: See ax.clabel docs. contour_kwargs: Additional keyword arguments for contour. contourf_kwargs: Additional keyword arguments for contourf. colorbar_kwargs: Additional keyword arguments for colorbar. linelabels_kwargs: Additional keyword arguments for the line labels (ax.clabel). **kwargs: Additional keywords, which are passed to both contour and contourf. Returns: A tuple of (contour, contourf, colorbar) objects. """ bad_signature_error = ValueError("Call signature should be one of:\n" " * `contour(Z, **kwargs)`\n" " * `contour(X, Y, Z, **kwargs)`\n" " * `contour(X, Y, Z, levels, **kwargs)`" ) ### Parse *args if len(args) == 1: X = None Y = None Z = args[0] elif len(args) == 3: X = args[0] Y = args[1] Z = args[2] else: raise bad_signature_error if X is None: X = np.arange(Z.shape[1]) if Y is None: Y = np.arange(Z.shape[0]) is_gridded = not ( # Determine if the data is gridded or not (i.e., contour vs. tricontour) X.ndim == 1 and Y.ndim == 1 and Z.ndim == 1 ) ### Check inputs for sanity for k, v in dict( X=X, Y=Y, Z=Z, ).items(): if np.all(np.isnan(v)): raise ValueError( f"All values of '{k}' are NaN!" ) ### Set defaults if cmap is None: cmap = mpl.colormaps.get_cmap('viridis') if contour_kwargs is None: contour_kwargs = {} if contourf_kwargs is None: contourf_kwargs = {} if colorbar_kwargs is None: colorbar_kwargs = {} if linelabels_kwargs is None: linelabels_kwargs = {} shared_kwargs = kwargs if levels is not None: shared_kwargs["levels"] = levels if alpha is not None: shared_kwargs["alpha"] = alpha if extend is not None: shared_kwargs["extend"] = extend if z_log_scale: if np.any(Z <= 0): raise ValueError( "All values of the `Z` input to `contour()` should be nonnegative if `z_log_scale` is True!" ) Z_ratio = np.nanmax(Z) / np.nanmin(Z) log10_ceil_z_max = np.ceil(np.log10(np.nanmax(Z))) log10_floor_z_min = np.floor(np.log10(np.nanmin(Z))) try: default_levels = int(levels) except TypeError: default_levels = 31 divisions_per_decade = np.ceil(default_levels / np.log10(Z_ratio)).astype(int) if Z_ratio > 1e8: locator = mpl.ticker.LogLocator() else: locator = mpl.ticker.LogLocator( subs=np.geomspace(1, 10, divisions_per_decade + 1)[:-1] ) shared_kwargs = { "norm" : mpl.colors.LogNorm(), "locator": locator, **shared_kwargs } colorbar_kwargs = { "norm": mpl.colors.LogNorm(), **colorbar_kwargs } if colorbar_label is not None: colorbar_kwargs["label"] = colorbar_label contour_kwargs = { "colors" : linecolor, "linewidths": linewidths, **shared_kwargs, **contour_kwargs } contourf_kwargs = { "cmap": cmap, **shared_kwargs, **contourf_kwargs } colorbar_kwargs = { "extendrect": extendrect, **colorbar_kwargs } linelabels_kwargs = { "inline" : 1, "fontsize": linelabels_fontsize, "fmt" : linelabels_format, **linelabels_kwargs } if drop_nans is None: if is_gridded: drop_nans = False else: drop_nans = True ### Now, with all the kwargs merged, prep for the actual plotting. if mask is not None: X = X[mask] Y = Y[mask] Z = Z[mask] is_gridded = False if drop_nans: nanmask = np.logical_not( np.logical_or.reduce( [np.isnan(X), np.isnan(Y), np.isnan(Z)] ) ) X = X[nanmask] Y = Y[nanmask] Z = Z[nanmask] is_gridded = False # if smooth: # if isinstance(smooth, bool): # smoothing_factor = 3 # else: # try: # smoothing_factor = int(smooth) # except TypeError: # raise TypeError("`smooth` must be an integer (the smoothing factor) or a boolean!") ### Do the actual plotting if is_gridded: cont = plt.contour(X, Y, Z, **contour_kwargs) contf = plt.contourf(X, Y, Z, **contourf_kwargs) else: ### If this fails, then the data is unstructured (i.e. X and Y are 1D arrays) ### Create the triangulation tri = mpl.tri.Triangulation(X, Y) t = tri.triangles ### Filter out extrapolation that's too large # See also: https://stackoverflow.com/questions/42426095/matplotlib-contour-contourf-of-concave-non-gridded-data if x_log_scale: X_nondim = ( np.log(X[t]) - np.roll(np.log(X[t]), 1, axis=1) ) / (np.nanmax(np.log(X)) - np.nanmin(np.log(X))) else: X_nondim = ( X[t] - np.roll(X[t], 1, axis=1) ) / (np.nanmax(X) - np.nanmin(X)) if y_log_scale: Y_nondim = ( np.log(Y[t]) - np.roll(np.log(Y[t]), 1, axis=1) ) / (np.nanmax(np.log(Y)) - np.nanmin(np.log(Y))) else: Y_nondim = ( Y[t] - np.roll(Y[t], 1, axis=1) ) / (np.nanmax(Y) - np.nanmin(Y)) side_length_nondim = np.max( np.sqrt( X_nondim ** 2 + Y_nondim ** 2 ), axis=1 ) if np.all(side_length_nondim > max_side_length_nondim): raise ValueError( "All triangles in the triangulation are too large to be plotted!\n" "Try increasing `max_side_length_nondim`!" ) tri.set_mask(side_length_nondim > max_side_length_nondim) cont = plt.tricontour(tri, Z, **contour_kwargs) contf = plt.tricontourf(tri, Z, **contourf_kwargs) if x_log_scale: plt.xscale("log") if y_log_scale: plt.yscale("log") if colorbar: from matplotlib import cm cbar = plt.colorbar( ax=contf.axes, mappable=cm.ScalarMappable( norm=contf.norm, cmap=contf.cmap, ), **colorbar_kwargs ) if z_log_scale: cbar.ax.tick_params(which="minor", labelsize=8) if Z_ratio >= 10 ** 2.05: cbar.ax.yaxis.set_major_locator(mpl.ticker.LogLocator()) cbar.ax.yaxis.set_minor_locator(mpl.ticker.LogLocator(subs=np.arange(1, 10))) cbar.ax.yaxis.set_major_formatter(mpl.ticker.LogFormatterSciNotation()) cbar.ax.yaxis.set_minor_formatter(mpl.ticker.NullFormatter()) elif Z_ratio >= 10 ** 1.5: cbar.ax.yaxis.set_major_locator(mpl.ticker.LogLocator()) cbar.ax.yaxis.set_minor_locator(mpl.ticker.LogLocator(subs=np.arange(1, 10))) cbar.ax.yaxis.set_major_formatter(mpl.ticker.LogFormatterSciNotation()) cbar.ax.yaxis.set_minor_formatter(mpl.ticker.LogFormatterSciNotation( minor_thresholds=(np.inf, np.inf) )) else: cbar.ax.yaxis.set_major_locator(mpl.ticker.LogLocator(subs=np.arange(1, 10))) cbar.ax.yaxis.set_minor_locator(mpl.ticker.LogLocator(subs=np.arange(10, 100) / 10)) cbar.ax.yaxis.set_major_formatter(mpl.ticker.ScalarFormatter()) cbar.ax.yaxis.set_minor_formatter(mpl.ticker.NullFormatter()) else: cbar = None if linelabels: cont.axes.clabel(cont, **linelabels_kwargs) return cont, contf, cbar
if __name__ == '__main__': import matplotlib.pyplot as plt import aerosandbox.tools.pretty_plots as p
[docs] x = np.linspace(0, 1, 100)
y = np.linspace(0, 1, 100) X, Y = np.meshgrid(x, y) Z_ratio = 1 Z = 10 ** ( Z_ratio / 2 * np.cos( 2 * np.pi * (X ** 4 + Y ** 4) ) ) # Z += 0.1 * np.random.randn(*Z.shape) fig, ax = plt.subplots(figsize=(6, 6)) cmap = p.mpl.colormaps.get_cmap("rainbow") cont, contf, cbar = contour( X, Y, np.abs(Z), drop_nans=True, # x_log_scale=True, z_log_scale=True, cmap=cmap, levels=20, colorbar_label="Colorbar label" ) # plt.clim(0.1, 10) p.show_plot( "Title", "X label", "Y label" )