"""2D visualization primitives based on Matplotlib. 1) Plot images with `plot_images`. 2) Call functions to plot heatmaps, vector fields, and horizon lines. 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. """ import matplotlib.patheffects as path_effects import matplotlib.pyplot as plt import numpy as np import torch from geocalib.perspective_fields import get_perspective_field from geocalib.utils import rad2deg # mypy: ignore-errors def plot_images(imgs, titles=None, cmaps="gray", dpi=200, pad=0.5, adaptive=True): """Plot a list of images. Args: imgs (List[np.ndarray]): List of images to plot. titles (List[str], optional): Titles. Defaults to None. cmaps (str, optional): Colormaps. Defaults to "gray". dpi (int, optional): Dots per inch. Defaults to 200. pad (float, optional): Padding. Defaults to 0.5. adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True. Returns: plt.Figure: Figure of the images. """ n = len(imgs) if not isinstance(cmaps, (list, tuple)): cmaps = [cmaps] * n ratios = [i.shape[1] / i.shape[0] for i in imgs] if adaptive else [4 / 3] * n figsize = [sum(ratios) * 4.5, 4.5] fig, axs = plt.subplots(1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}) if n == 1: axs = [axs] for i, (img, ax) in enumerate(zip(imgs, axs)): ax.imshow(img, cmap=plt.get_cmap(cmaps[i])) ax.set_axis_off() if titles: ax.set_title(titles[i]) fig.tight_layout(pad=pad) return fig def plot_image_grid( imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, fig=None, adaptive=True, figs=3.0, return_fig=False, set_lim=False, ) -> plt.Figure: """Plot a grid of images. Args: imgs (List[np.ndarray]): List of images to plot. titles (List[str], optional): Titles. Defaults to None. cmaps (str, optional): Colormaps. Defaults to "gray". dpi (int, optional): Dots per inch. Defaults to 100. pad (float, optional): Padding. Defaults to 0.5. fig (_type_, optional): Figure to plot on. Defaults to None. adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True. figs (float, optional): Figure size. Defaults to 3.0. return_fig (bool, optional): Whether to return the figure. Defaults to False. set_lim (bool, optional): Whether to set the limits. Defaults to False. Returns: plt.Figure: Figure and axes or just axes. """ nr, n = len(imgs), len(imgs[0]) if not isinstance(cmaps, (list, tuple)): cmaps = [cmaps] * n if adaptive: ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] # W / H else: ratios = [4 / 3] * n figsize = [sum(ratios) * figs, nr * figs] if fig is None: fig, axs = plt.subplots( nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} ) else: axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios}) fig.figure.set_size_inches(figsize) if nr == 1 and n == 1: axs = [[axs]] elif n == 1: axs = axs[:, None] elif nr == 1: axs = [axs] for j in range(nr): for i in range(n): ax = axs[j][i] ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i])) ax.set_axis_off() if set_lim: ax.set_xlim([0, imgs[j][i].shape[1]]) ax.set_ylim([imgs[j][i].shape[0], 0]) if titles: ax.set_title(titles[j][i]) if isinstance(fig, plt.Figure): fig.tight_layout(pad=pad) return (fig, axs) if return_fig else axs def add_text( idx, text, pos=(0.01, 0.99), fs=15, color="w", lcolor="k", lwidth=4, ha="left", va="top", axes=None, **kwargs, ): """Add text to a plot. Args: idx (int): Index of the axes. text (str): Text to add. pos (tuple, optional): Text position. Defaults to (0.01, 0.99). fs (int, optional): Font size. Defaults to 15. color (str, optional): Text color. Defaults to "w". lcolor (str, optional): Line color. Defaults to "k". lwidth (int, optional): Line width. Defaults to 4. ha (str, optional): Horizontal alignment. Defaults to "left". va (str, optional): Vertical alignment. Defaults to "top". axes (List[plt.Axes], optional): Axes to put text on. Defaults to None. Returns: plt.Text: Text object. """ if axes is None: axes = plt.gcf().axes ax = axes[idx] t = ax.text( *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes, zorder=5, **kwargs, ) if lcolor is not None: t.set_path_effects( [ path_effects.Stroke(linewidth=lwidth, foreground=lcolor), path_effects.Normal(), ] ) return t def plot_heatmaps( heatmaps, vmin=-1e-6, # include negative zero vmax=None, cmap="Spectral", a=0.5, axes=None, contours_every=None, contour_style="solid", colorbar=False, ): """Plot heatmaps with optional contours. To plot latitude field, set vmin=-90, vmax=90 and contours_every=15. Args: heatmaps (List[np.ndarray | torch.Tensor]): List of 2D heatmaps. vmin (float, optional): Min Value. Defaults to -1e-6. vmax (float, optional): Max Value. Defaults to None. cmap (str, optional): Colormap. Defaults to "Spectral". a (float, optional): Alpha value. Defaults to 0.5. axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. contours_every (int, optional): If not none, will draw contours. Defaults to None. contour_style (str, optional): Style of the contours. Defaults to "solid". colorbar (bool, optional): Whether to show colorbar. Defaults to False. Returns: List[plt.Artist]: List of artists. """ if axes is None: axes = plt.gcf().axes artists = [] for i in range(len(axes)): a_ = a if isinstance(a, float) else a[i] if isinstance(heatmaps[i], torch.Tensor): heatmaps[i] = heatmaps[i].cpu().numpy() alpha = a_ # Plot the heatmap art = axes[i].imshow( heatmaps[i], alpha=alpha, vmin=vmin, vmax=vmax, cmap=cmap, ) if colorbar: cmax = vmax or np.percentile(heatmaps[i], 99) art.set_clim(vmin, cmax) cbar = plt.colorbar(art, ax=axes[i]) artists.append(cbar) artists.append(art) if contours_every is not None: # Add contour lines to the heatmap contour_data = np.arange(vmin, vmax + contours_every, contours_every) # Get the colormap colors for contour lines contour_colors = [ plt.colormaps.get_cmap(cmap)(plt.Normalize(vmin=vmin, vmax=vmax)(level)) for level in contour_data ] contours = axes[i].contour( heatmaps[i], levels=contour_data, linewidths=2, colors=contour_colors, linestyles=contour_style, ) contours.set_clim(vmin, vmax) fmt = { level: f"{label}°" for level, label in zip(contour_data, contour_data.astype(int).astype(str)) } t = axes[i].clabel(contours, inline=True, fmt=fmt, fontsize=16, colors="white") for label in t: label.set_path_effects( [ path_effects.Stroke(linewidth=1, foreground="k"), path_effects.Normal(), ] ) artists.append(contours) return artists def plot_horizon_lines( cameras, gravities, line_colors="orange", lw=2, styles="solid", alpha=1.0, ax=None ): """Plot horizon lines on the perspective field. Args: cameras (List[Camera]): List of cameras. gravities (List[Gravity]): Gravities. line_colors (str, optional): Line Colors. Defaults to "orange". lw (int, optional): Line width. Defaults to 2. styles (str, optional): Line styles. Defaults to "solid". alpha (float, optional): Alphas. Defaults to 1.0. ax (List[plt.Axes], optional): Axes to draw horizon line on. Defaults to None. """ if not isinstance(line_colors, list): line_colors = [line_colors] * len(cameras) if not isinstance(styles, list): styles = [styles] * len(cameras) fig = plt.gcf() ax = fig.gca() if ax is None else ax if isinstance(ax, plt.Axes): ax = [ax] * len(cameras) assert len(ax) == len(cameras), f"{len(ax)}, {len(cameras)}" for i in range(len(cameras)): _, lat = get_perspective_field(cameras[i], gravities[i]) # horizon line is zero level of the latitude field lat = lat[0, 0].cpu().numpy() contours = ax[i].contour(lat, levels=[0], linewidths=lw, colors=line_colors[i]) for contour_line in contours.collections: contour_line.set_linestyle(styles[i]) def plot_vector_fields( vector_fields, cmap="lime", subsample=15, scale=None, lw=None, alphas=0.8, axes=None, ): """Plot vector fields. Args: vector_fields (List[torch.Tensor]): List of vector fields of shape (2, H, W). cmap (str, optional): Color of the vectors. Defaults to "lime". subsample (int, optional): Subsample the vector field. Defaults to 15. scale (float, optional): Scale of the vectors. Defaults to None. lw (float, optional): Line width of the vectors. Defaults to None. alphas (float | np.ndarray, optional): Alpha per vector or global. Defaults to 0.8. axes (List[plt.Axes], optional): List of axes to draw on. Defaults to None. Returns: List[plt.Artist]: List of artists. """ if axes is None: axes = plt.gcf().axes vector_fields = [v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in vector_fields] artists = [] H, W = vector_fields[0].shape[-2:] if scale is None: scale = subsample / min(H, W) if lw is None: lw = 0.1 / subsample if alphas is None: alphas = np.ones_like(vector_fields[0][0]) alphas = np.stack([alphas] * len(vector_fields), 0) elif isinstance(alphas, float): alphas = np.ones_like(vector_fields[0][0]) * alphas alphas = np.stack([alphas] * len(vector_fields), 0) else: alphas = np.array(alphas) subsample = min(W, H) // subsample offset_x = ((W % subsample) + subsample) // 2 samples_x = np.arange(offset_x, W, subsample) samples_y = np.arange(int(subsample * 0.9), H, subsample) x_grid, y_grid = np.meshgrid(samples_x, samples_y) for i in range(len(axes)): # vector field of shape (2, H, W) with vectors of norm == 1 vector_field = vector_fields[i] a = alphas[i][samples_y][:, samples_x] x, y = vector_field[:, samples_y][:, :, samples_x] c = cmap if not isinstance(cmap, str): c = cmap[i][samples_y][:, samples_x].reshape(-1, 3) s = scale * min(H, W) arrows = axes[i].quiver( x_grid, y_grid, x, y, scale=s, scale_units="width" if H > W else "height", units="width" if H > W else "height", alpha=a, color=c, angles="xy", antialiased=True, width=lw, headaxislength=3.5, zorder=5, ) artists.append(arrows) return artists def plot_latitudes( latitude, is_radians=True, vmin=-90, vmax=90, cmap="seismic", contours_every=15, alpha=0.4, axes=None, **kwargs, ): """Plot latitudes. Args: latitude (List[torch.Tensor]): List of latitudes. is_radians (bool, optional): Whether the latitudes are in radians. Defaults to True. vmin (int, optional): Min value to clip to. Defaults to -90. vmax (int, optional): Max value to clip to. Defaults to 90. cmap (str, optional): Colormap. Defaults to "seismic". contours_every (int, optional): Contours every. Defaults to 15. alpha (float, optional): Alpha value. Defaults to 0.4. axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. Returns: List[plt.Artist]: List of artists. """ if axes is None: axes = plt.gcf().axes assert len(axes) == len(latitude), f"{len(axes)}, {len(latitude)}" lat = [rad2deg(lat) for lat in latitude] if is_radians else latitude return plot_heatmaps( lat, vmin=vmin, vmax=vmax, cmap=cmap, a=alpha, axes=axes, contours_every=contours_every, **kwargs, ) def plot_perspective_fields(cameras, gravities, axes=None, **kwargs): """Plot perspective fields. Args: cameras (List[Camera]): List of cameras. gravities (List[Gravity]): List of gravities. axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. Returns: List[plt.Artist]: List of artists. """ if axes is None: axes = plt.gcf().axes assert len(axes) == len(cameras), f"{len(axes)}, {len(cameras)}" artists = [] for i in range(len(axes)): up, lat = get_perspective_field(cameras[i], gravities[i]) artists += plot_vector_fields([up[0]], axes=[axes[i]], **kwargs) artists += plot_latitudes([lat[0, 0]], axes=[axes[i]], **kwargs) return artists def plot_confidences( confidence, as_log=True, vmin=-4, vmax=0, cmap="turbo", alpha=0.4, axes=None, **kwargs, ): """Plot confidences. Args: confidence (List[torch.Tensor]): Confidence maps. as_log (bool, optional): Whether to plot in log scale. Defaults to True. vmin (int, optional): Min value to clip to. Defaults to -4. vmax (int, optional): Max value to clip to. Defaults to 0. cmap (str, optional): Colormap. Defaults to "turbo". alpha (float, optional): Alpha value. Defaults to 0.4. axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. Returns: List[plt.Artist]: List of artists. """ if axes is None: axes = plt.gcf().axes assert len(axes) == len(confidence), f"{len(axes)}, {len(confidence)}" if as_log: confidence = [torch.log10(c.clip(1e-5)).clip(vmin, vmax) for c in confidence] # normalize to [0, 1] confidence = [(c - c.min()) / (c.max() - c.min()) for c in confidence] return plot_heatmaps(confidence, vmin=0, vmax=1, cmap=cmap, a=alpha, axes=axes, **kwargs) def save_plot(path, **kw): """Save the current figure without any white margin.""" plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)