Source code for aerosandbox.tools.pretty_plots.labellines.core

import warnings
from math import atan2, degrees

import matplotlib.patheffects as path_effects
import numpy as np
from matplotlib.container import ErrorbarContainer
from matplotlib.dates import DateConverter, num2date

from aerosandbox.tools.pretty_plots.labellines.utils import ensure_float, maximum_bipartite_matching, always_iterable


# Label line with line2D label data
[docs]def labelLine( line, x, label=None, align=True, drop_label=False, yoffset=0, yoffset_logspace=False, outline_color="auto", outline_width=8, **kwargs, ): """Label a single matplotlib line at position x Parameters ---------- line : matplotlib.lines.Line The line holding the label x : number The location in data unit of the label label : string, optional The label to set. This is inferred from the line by default drop_label : bool, optional If True, the label is consumed by the function so that subsequent calls to e.g. legend do not use it anymore. yoffset : double, optional Space to add to label's y position yoffset_logspace : bool, optional If True, then yoffset will be added to the label's y position in log10 space outline_color : None | "auto" | color Colour of the outline. If set to "auto", use the background color. If set to None, do not draw an outline. outline_width : number Width of the outline kwargs : dict, optional Optional arguments passed to ax.text """ ax = line.axes xdata = ensure_float(line.get_xdata()) ydata = line.get_ydata() mask = np.isfinite(ydata) if mask.sum() == 0: raise Exception(f"The line {line} only contains nan!") # Find first segment of xdata containing x if isinstance(xdata, tuple) and len(xdata) == 2: i = 0 xa = min(xdata) xb = max(xdata) else: for imatch, (xa, xb) in enumerate(zip(xdata[:-1], xdata[1:])): if min(xa, xb) <= ensure_float(x) <= max(xa, xb): i = imatch break else: raise Exception("x label location is outside data range!") xfa = ensure_float(xa) xfb = ensure_float(xb) ya = ydata[i] yb = ydata[i + 1] # Handle vertical case if xfb == xfa: fraction = 0.5 else: fraction = (ensure_float(x) - xfa) / (xfb - xfa) if yoffset_logspace: y = ya + (yb - ya) * fraction y *= 10 ** yoffset else: y = ya + (yb - ya) * fraction + yoffset if not (np.isfinite(ya) and np.isfinite(yb)): warnings.warn( ( "%s could not be annotated due to `nans` values. " "Consider using another location via the `x` argument." ) % line, UserWarning, ) return if not label: label = line.get_label() if drop_label: line.set_label(None) if align: if ax.get_aspect() == "auto": # Compute the slope and label rotation screen_dx, screen_dy = ( ax.transData.transform((xfb, yb)) - ax.transData.transform((xfa, ya)) ) elif isinstance(ax.get_aspect(), (float, int)): screen_dx = xfb - xfa screen_dy = (yb - ya) * ax.get_aspect() rotation = (degrees(atan2(screen_dy, screen_dx)) + 90) % 180 - 90 else: rotation = 0 # Set a bunch of keyword arguments if "color" not in kwargs: kwargs["color"] = line.get_color() if ("horizontalalignment" not in kwargs) and ("ha" not in kwargs): kwargs["ha"] = "center" if ("verticalalignment" not in kwargs) and ("va" not in kwargs): kwargs["va"] = "center" if "clip_on" not in kwargs: kwargs["clip_on"] = True if "zorder" not in kwargs: kwargs["zorder"] = 2.5 if outline_color == "auto": outline_color = ax.get_facecolor() txt = ax.text(x, y, label, rotation=rotation, **kwargs) if outline_color is None: effects = [path_effects.Normal()] else: effects = [ path_effects.Stroke(linewidth=outline_width, foreground=outline_color), path_effects.Normal(), ] txt.set_path_effects(effects) return txt
[docs]def labelLines( lines, align=True, xvals=None, drop_label=False, shrink_factor=0.05, yoffsets=0, outline_color="auto", outline_width=5, **kwargs, ): """Label all lines with their respective legends. Parameters ---------- lines : list of matplotlib lines The lines to label align : boolean, optional If True, the label will be aligned with the slope of the line at the location of the label. If False, they will be horizontal. xvals : (xfirst, xlast) or array of float, optional The location of the labels. If a tuple, the labels will be evenly spaced between xfirst and xlast (in the axis units). drop_label : bool, optional If True, the label is consumed by the function so that subsequent calls to e.g. legend do not use it anymore. shrink_factor : double, optional Relative distance from the edges to place closest labels. Defaults to 0.05. yoffsets : number or list, optional. Distance relative to the line when positioning the labels. If given a number, the same value is used for all lines. outline_color : None | "auto" | color Colour of the outline. If set to "auto", use the background color. If set to None, do not draw an outline. outline_width : number Width of the outline kwargs : dict, optional Optional arguments passed to ax.text """ ax = lines[0].axes handles, allLabels = ax.get_legend_handles_labels() all_lines = [] for h in handles: if isinstance(h, ErrorbarContainer): all_lines.append(h.lines[0]) else: all_lines.append(h) # In case no x location was provided, we need to use some heuristics # to generate them. if xvals is None: xvals = ax.get_xlim() xvals_rng = xvals[1] - xvals[0] shrinkage = xvals_rng * shrink_factor xvals = (xvals[0] + shrinkage, xvals[1] - shrinkage) if isinstance(xvals, tuple) and len(xvals) == 2: xmin, xmax = xvals xscale = ax.get_xscale() if xscale == "log": xvals = np.logspace(np.log10(xmin), np.log10(xmax), len(all_lines) + 2)[ 1:-1 ] else: xvals = np.linspace(xmin, xmax, len(all_lines) + 2)[1:-1] # Build matrix line -> xvalue ok_matrix = np.zeros((len(all_lines), len(all_lines)), dtype=bool) for i, line in enumerate(all_lines): xdata = ensure_float(line.get_xdata()) minx, maxx = min(xdata), max(xdata) for j, xv in enumerate(xvals): ok_matrix[i, j] = minx < xv < maxx # If some xvals do not fall in their corresponding line, # find a better matching using maximum bipartite matching. if not np.all(np.diag(ok_matrix)): order = maximum_bipartite_matching(ok_matrix) # The maximum match may miss a few points, let's add them back imax = order.max() order[order < 0] = np.arange(imax + 1, len(order)) # Now reorder the xvalues old_xvals = xvals.copy() xvals[order] = old_xvals else: xvals = list(always_iterable(xvals)) # force the creation of a copy labLines, labels = [], [] # Take only the lines which have labels other than the default ones for i, (line, xv) in enumerate(zip(all_lines, xvals)): label = allLabels[all_lines.index(line)] labLines.append(line) labels.append(label) # Move xlabel if it is outside valid range xdata = ensure_float(line.get_xdata()) if not (min(xdata) <= xv <= max(xdata)): warnings.warn( ( "The value at position %s in `xvals` is outside the range of its " "associated line (xmin=%s, xmax=%s, xval=%s). Clipping it " "into the allowed range." ) % (i, min(xdata), max(xdata), xv), UserWarning, ) new_xv = min(xdata) + (max(xdata) - min(xdata)) * 0.9 xvals[i] = new_xv # Convert float values back to datetime in case of datetime axis if isinstance(ax.xaxis.converter, DateConverter): xvals = [num2date(x).replace(tzinfo=ax.xaxis.get_units()) for x in xvals] txts = [] try: yoffsets = [float(yoffsets)] * len(all_lines) except TypeError: pass for line, x, yoffset, label in zip(labLines, xvals, yoffsets, labels): txts.append( labelLine( line, x, label=label, align=align, drop_label=drop_label, yoffset=yoffset, outline_color=outline_color, outline_width=outline_width, **kwargs, ) ) return txts
if __name__ == '__main__': import matplotlib.pyplot as plt fig, ax = plt.subplots()
[docs] x_plt = np.linspace(0, 2, 1000)
y_plt = 0.1 * np.sin(2 * np.pi * x_plt) line, = plt.plot(x_plt, y_plt, label="hi") # plt.axis("equal") # print(ax.get_aspect()) labelLine(line, x=1) plt.show() x = 1 label = None align = True drop_label = False yoffset = 0 yoffset_logspace = False outline_color = "auto" outline_width = 8