Source code for brutus.plotting.corner

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Corner plot visualization functions.

This module provides functions for creating corner plots of multi-dimensional
posterior distributions.
"""

import copy
import warnings
from functools import partial

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator, NullLocator, ScalarFormatter
from scipy.ndimage import gaussian_filter as norm_kde
from scipy.special import logsumexp

from ..priors import logp_galactic_structure as gal_lnprior
from ..priors import logp_parallax
from ..utils.sampling import draw_sar, quantile
from .utils import hist2d

__all__ = ["cornerplot"]


[docs] def cornerplot( idxs, data, params, lndistprior=None, coord=None, avlim=(0.0, 6.0), rvlim=(1.0, 8.0), weights=None, parallax=None, parallax_err=None, Nr=500, applied_parallax=True, pcolor="blue", parallax_kwargs=None, span=None, quantiles=[0.025, 0.5, 0.975], color="black", smooth=10, hist_kwargs=None, hist2d_kwargs=None, labels=None, label_kwargs=None, show_titles=False, title_fmt=".2f", title_kwargs=None, title_quantiles=[0.025, 0.5, 0.975], truths=None, truth_color="red", truth_kwargs=None, max_n_ticks=5, top_ticks=False, use_math_text=False, verbose=False, fig=None, rstate=None, R_solar=8.2, Z_solar=0.025, ): """ Generate a corner plot of the 1-D and 2-D marginalized posteriors. Parameters ---------- idxs : `~numpy.ndarray` of shape `(Nsamps)` An array of resampled indices corresponding to the set of models used to fit the data. data : 3-tuple or 4-tuple containing `~numpy.ndarray`s of shape `(Nsamps)` The data that will be plotted. Either a collection of `(dists, reds, dreds)` that were saved, or a collection of `(scales, avs, rvs, covs_sar)` that will be used to regenerate `(dists, reds, dreds)` in conjunction with any applied distance and/or parallax priors. params : structured `~numpy.ndarray` with shape `(Nmodels,)` Set of parameters corresponding to the input set of models. Note that `'agewt'` will always be ignored. lndistprior : func, optional The log-distsance prior function used. If not provided, the galactic model from Green et al. (2014) will be assumed. coord : 2-tuple, optional The galactic `(l, b)` coordinates for the object, which is passed to `lndistprior`. avlim : 2-tuple, optional The Av limits used to truncate results. Default is `(0., 6.)`. rvlim : 2-tuple, optional The Rv limits used to truncate results. Default is `(1., 8.)`. weights : `~numpy.ndarray` of shape `(Nsamps)`, optional An optional set of importance weights used to reweight the samples. parallax : float, optional The parallax estimate for the source. parallax_err : float, optional The parallax error. Nr : int, optional The number of Monte Carlo realizations used when sampling using the provided parallax prior. Default is `500`. applied_parallax : bool, optional Whether the parallax was applied when initially computing the fits. Default is `True`. pcolor : str, optional Color used when plotting the parallax prior. Default is `'blue'`. parallax_kwargs : kwargs, optional Keyword arguments used when plotting the parallax prior passed to `fill_between`. span : iterable with shape `(ndim,)`, optional A list where each element is either a length-2 tuple containing lower and upper bounds or a float from `(0., 1.]` giving the fraction of (weighted) samples to include. If a fraction is provided, the bounds are chosen to be equal-tailed. An example would be:: span = [(0., 10.), 0.95, (5., 6.)] Default is `0.99` (99% credible interval). quantiles : iterable, optional A list of fractional quantiles to overplot on the 1-D marginalized posteriors as vertical dashed lines. Default is `[0.025, 0.5, 0.975]` (spanning the 95%/2-sigma credible interval). color : str or iterable with shape `(ndim,)`, optional A `~matplotlib`-style color (either a single color or a different value for each subplot) used when plotting the histograms. Default is `'black'`. smooth : float or iterable with shape `(ndim,)`, optional The standard deviation (either a single value or a different value for each subplot) for the Gaussian kernel used to smooth the 1-D and 2-D marginalized posteriors, expressed as a fraction of the span. If an integer is provided instead, this will instead default to a simple (weighted) histogram with `bins=smooth`. Default is `10` (10 bins). hist_kwargs : dict, optional Extra keyword arguments to send to the 1-D (smoothed) histograms. hist2d_kwargs : dict, optional Extra keyword arguments to send to the 2-D (smoothed) histograms. labels : iterable with shape `(ndim,)`, optional A list of names for each parameter. If not provided, the names will be taken from `params.dtype.names`. label_kwargs : dict, optional Extra keyword arguments that will be sent to the `~matplotlib.axes.Axes.set_xlabel` and `~matplotlib.axes.Axes.set_ylabel` methods. show_titles : bool, optional Whether to display a title above each 1-D marginalized posterior showing the quantiles specified by `title_quantiles`. By default, This will show the median (0.5 quantile) along with the upper/lower bounds associated with the 0.025 and 0.975 (95%/2-sigma credible interval) quantiles. Default is `False`. title_fmt : str, optional The format string for the quantiles provided in the title. Default is `'.2f'`. title_kwargs : dict, optional Extra keyword arguments that will be sent to the `~matplotlib.axes.Axes.set_title` command. title_quantiles : iterable, optional A list of 3 fractional quantiles displayed in the title, ordered from lowest to highest. Default is `[0.025, 0.5, 0.975]` (spanning the 95%/2-sigma credible interval). truths : iterable with shape `(ndim,)`, optional A list of reference values that will be overplotted on the traces and marginalized 1-D posteriors as solid horizontal/vertical lines. Individual values can be exempt using `None`. Default is `None`. truth_color : str or iterable with shape `(ndim,)`, optional A `~matplotlib`-style color (either a single color or a different value for each subplot) used when plotting `truths`. Default is `'red'`. truth_kwargs : dict, optional Extra keyword arguments that will be used for plotting the vertical and horizontal lines with `truths`. max_n_ticks : int, optional Maximum number of ticks allowed. Default is `5`. top_ticks : bool, optional Whether to label the top (rather than bottom) ticks. Default is `False`. use_math_text : bool, optional Whether the axis tick labels for very large/small exponents should be displayed as powers of 10 rather than using `e`. Default is `False`. verbose : bool, optional Whether to print the values of the computed quantiles associated with each parameter. Default is `False`. fig : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`), optional If provided, overplot the traces and marginalized 1-D posteriors onto the provided figure. Otherwise, by default an internal figure is generated. rstate : `~numpy.random.RandomState`, optional `~numpy.random.RandomState` instance. Returns ------- cornerplot : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`) Output corner plot. """ # Initialize values. if quantiles is None: quantiles = [] if truth_kwargs is None: truth_kwargs = dict() if label_kwargs is None: label_kwargs = dict() if title_kwargs is None: title_kwargs = dict() if hist_kwargs is None: hist_kwargs = dict() if hist2d_kwargs is None: hist2d_kwargs = dict() if weights is None: weights = np.ones_like(idxs, dtype="float") if rstate is None: rstate = np.random if applied_parallax: if parallax is None or parallax_err is None: raise ValueError( "`parallax` and `parallax_err` must be provided " "together." ) if parallax_kwargs is None: parallax_kwargs = dict() if lndistprior is None: lndistprior = partial(gal_lnprior, R_solar=R_solar, Z_solar=Z_solar) # Set defaults. hist_kwargs["alpha"] = hist_kwargs.get("alpha", 0.6) hist2d_kwargs["alpha"] = hist2d_kwargs.get("alpha", 0.6) truth_kwargs["linestyle"] = truth_kwargs.get("linestyle", "solid") truth_kwargs["linewidth"] = truth_kwargs.get("linewidth", 2) truth_kwargs["alpha"] = truth_kwargs.get("alpha", 0.7) parallax_kwargs["alpha"] = parallax_kwargs.get("alpha", 0.3) # Ignore age weights. if labels is None: labels = [x for x in params.dtype.names if x != "agewt"] else: # Copy any caller-supplied labels: the Av/Rv/Parallax/Distance appends # below mutate this list in place, which would otherwise corrupt the # caller's list and break reuse across repeated calls. labels = list(labels) # Deal with 1D results. with warnings.catch_warnings(): warnings.simplefilter("ignore") # ignore bad values samples = params[idxs] samples = np.array([samples[lbl] for lbl in labels]).T samples = np.atleast_1d(samples) if len(samples.shape) == 1: samples = np.atleast_2d(samples) else: assert len(samples.shape) == 2, "Samples must be 1- or 2-D." samples = samples.T assert samples.shape[0] <= samples.shape[1], ( "There are more " "dimensions than samples!" ) try: # Grab distance and reddening samples. ddraws, adraws, rdraws = copy.deepcopy(data) pdraws = 1.0 / ddraws except (ValueError, TypeError): # Regenerate distance and reddening samples from inputs. scales, avs, rvs, covs_sar = copy.deepcopy(data) _is_default_prior = lndistprior is gal_lnprior or ( hasattr(lndistprior, "func") and lndistprior.func is gal_lnprior ) if _is_default_prior and coord is None: raise ValueError( "`coord` must be passed if the default distance " "prior was used." ) # Add in scale/parallax/distance, Av, and Rv realizations. nsamps = len(idxs) sdraws, adraws, rdraws = draw_sar( scales, avs, rvs, covs_sar, ndraws=Nr, avlim=avlim, rvlim=rvlim, rstate=rstate, ) pdraws = np.sqrt(sdraws) ddraws = 1.0 / pdraws # Re-apply distance and parallax priors to realizations. lnp_draws = lndistprior(ddraws, coord) if applied_parallax: lnp_draws += logp_parallax(pdraws, parallax, parallax_err) # Resample draws. lnp = logsumexp(lnp_draws, axis=1) pwt = np.exp(lnp_draws - lnp[:, None]) pwt /= pwt.sum(axis=1)[:, None] ridx = [rstate.choice(Nr, p=pwt[i]) for i in range(nsamps)] pdraws = pdraws[np.arange(nsamps), ridx] ddraws = ddraws[np.arange(nsamps), ridx] adraws = adraws[np.arange(nsamps), ridx] rdraws = rdraws[np.arange(nsamps), ridx] # Append to samples. samples = np.c_[samples.T, adraws, rdraws, pdraws, ddraws].T ndim, nsamps = samples.shape # Check weights. if weights.ndim != 1: raise ValueError("Weights must be 1-D.") if nsamps != weights.shape[0]: raise ValueError("The number of weights and samples disagree!") # Determine plotting bounds. if span is None: span = [0.99 for i in range(ndim)] span = list(span) if len(span) != ndim: raise ValueError("Dimension mismatch between samples and span.") for i, _ in enumerate(span): try: xmin, xmax = span[i] except (TypeError, ValueError): q = [0.5 - 0.5 * span[i], 0.5 + 0.5 * span[i]] span[i] = quantile(samples[i], q, weights=weights) # Append additional labels for extra dimensions labels.append("Av") labels.append("Rv") labels.append("Parallax") labels.append("Distance") # Setting up smoothing. if isinstance(smooth, int) or isinstance(smooth, float): smooth = [smooth for i in range(ndim)] # Setup axis layout (from `corner.py`). factor = 2.0 # size of side of one panel lbdim = 0.5 * factor # size of left/bottom margin trdim = 0.2 * factor # size of top/right margin whspace = 0.05 # size of width/height margin plotdim = factor * ndim + factor * (ndim - 1.0) * whspace # plot size dim = lbdim + plotdim + trdim # total size # Initialize figure. if fig is None: fig, axes = plt.subplots(ndim, ndim, figsize=(dim, dim)) else: try: fig, axes = fig axes = np.array(axes).reshape((ndim, ndim)) except (ValueError, TypeError): raise ValueError("Mismatch between axes and dimension.") # Format figure. lb = lbdim / dim tr = (lbdim + plotdim) / dim fig.subplots_adjust( left=lb, bottom=lb, right=tr, top=tr, wspace=whspace, hspace=whspace ) # Plotting. for i, x in enumerate(samples): if np.shape(samples)[0] == 1: ax = axes else: ax = axes[i, i] # Plot the 1-D marginalized posteriors. # Setup axes ax.set_xlim(span[i]) if max_n_ticks == 0: ax.xaxis.set_major_locator(NullLocator()) ax.yaxis.set_major_locator(NullLocator()) else: ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="lower")) ax.yaxis.set_major_locator(NullLocator()) # Label axes. sf = ScalarFormatter(useMathText=use_math_text) ax.xaxis.set_major_formatter(sf) if i < ndim - 1: if top_ticks: ax.xaxis.set_ticks_position("top") for tick in ax.get_xticklabels(): tick.set_rotation(45) else: ax.set_xticklabels([]) else: for tick in ax.get_xticklabels(): tick.set_rotation(45) ax.set_xlabel(labels[i], **label_kwargs) ax.xaxis.set_label_coords(0.5, -0.3) # Generate distribution. sx = smooth[i] if isinstance(sx, int): # If `sx` is an integer, plot a weighted histogram with # `sx` bins within the provided bounds. n, b, _ = ax.hist( x, bins=sx, weights=weights, color=color, range=np.sort(span[i]), **hist_kwargs, ) else: # If `sx` is a float, oversample the data relative to the # smoothing filter by a factor of 10, then use a Gaussian # filter to smooth the results. bins = int(round(10.0 / sx)) n, b = np.histogram(x, bins=bins, weights=weights, range=np.sort(span[i])) n = norm_kde(n, 10.0) b0 = 0.5 * (b[1:] + b[:-1]) n, b, _ = ax.hist( b0, bins=b, weights=n, range=np.sort(span[i]), color=color, **hist_kwargs, ) ax.set_ylim([0.0, max(n) * 1.05]) # Plot quantiles. if quantiles is not None and len(quantiles) > 0: qs = quantile(x, quantiles, weights=weights) for q in qs: ax.axvline(q, lw=2, ls="dashed", color=color) if verbose: print("Quantiles:") print(labels[i], [blob for blob in zip(quantiles, qs)]) # Add truth value(s). if truths is not None and truths[i] is not None: try: for t in truths[i]: ax.axvline(t, color=truth_color, **truth_kwargs) except TypeError: ax.axvline(truths[i], color=truth_color, **truth_kwargs) # Set titles. if show_titles: title = None if title_fmt is not None: ql, qm, qh = quantile(x, title_quantiles, weights=weights) q_minus, q_plus = qm - ql, qh - qm fmt = f"{{0:{title_fmt}}}".format title = rf"${{{fmt(qm)}}}_{{-{fmt(q_minus)}}}^{{+{fmt(q_plus)}}}$" title = f"{labels[i]} = {title}" ax.set_title(title, **title_kwargs) # Add parallax prior. if i == ndim - 2 and parallax is not None and parallax_err is not None: parallax_logpdf = logp_parallax(b, parallax, parallax_err) parallax_pdf = np.exp(parallax_logpdf - max(parallax_logpdf)) parallax_pdf *= max(n) / max(parallax_pdf) ax.fill_between(b, parallax_pdf, color=pcolor, **parallax_kwargs) for j, y in enumerate(samples): if np.shape(samples)[0] == 1: ax = axes else: ax = axes[i, j] # Plot the 2-D marginalized posteriors. # Setup axes. if j > i: ax.set_frame_on(False) ax.set_xticks([]) ax.set_yticks([]) continue elif j == i: continue if max_n_ticks == 0: ax.xaxis.set_major_locator(NullLocator()) ax.yaxis.set_major_locator(NullLocator()) else: ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="lower")) ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="lower")) # Label axes. sf = ScalarFormatter(useMathText=use_math_text) ax.xaxis.set_major_formatter(sf) ax.yaxis.set_major_formatter(sf) if i < ndim - 1: ax.set_xticklabels([]) else: for tick in ax.get_xticklabels(): tick.set_rotation(45) ax.set_xlabel(labels[j], **label_kwargs) ax.xaxis.set_label_coords(0.5, -0.3) if j > 0: ax.set_yticklabels([]) else: for tick in ax.get_yticklabels(): tick.set_rotation(45) ax.set_ylabel(labels[i], **label_kwargs) ax.yaxis.set_label_coords(-0.3, 0.5) # Generate distribution. sy = smooth[j] check_ix = isinstance(sx, int) check_iy = isinstance(sy, int) if check_ix and check_iy: fill_contours = False plot_contours = False else: fill_contours = True plot_contours = True hist2d_kwargs["fill_contours"] = hist2d_kwargs.get( "fill_contours", fill_contours ) hist2d_kwargs["plot_contours"] = hist2d_kwargs.get( "plot_contours", plot_contours ) hist2d( y, x, ax=ax, span=[span[j], span[i]], weights=weights, color=color, smooth=[sy, sx], **hist2d_kwargs, ) # Add truth values if truths is not None: if truths[j] is not None: try: for t in truths[j]: ax.axvline(t, color=truth_color, **truth_kwargs) except TypeError: ax.axvline(truths[j], color=truth_color, **truth_kwargs) if truths[i] is not None: try: for t in truths[i]: ax.axhline(t, color=truth_color, **truth_kwargs) except TypeError: ax.axhline(truths[i], color=truth_color, **truth_kwargs) return (fig, axes)