Source code for brutus.plotting.offsets

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Photometric offset visualization functions.

This module provides functions for plotting photometric offsets
in 1D and 2D formats.
"""

import warnings

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import logsumexp

from ..core.sed_utils import get_seds
from ..utils.photometry import magnitude, phot_loglike
from ..utils.sampling import quantile


[docs] def photometric_offsets( phot, err, mask, models, idxs, reds, dreds, dists, x=None, flux=True, weights=None, bins=100, offset=None, dim_prior=True, plot_thresh=0.0, cmap="viridis", xspan=None, yspan=None, titles=None, xlabel=None, plot_kwargs=None, fig=None, ): """ Plot photometric offsets (`mag_pred - mag_obs`). Parameters ---------- phot : `~numpy.ndarray` of shape `(Nobj, Nfilt)`, optional Observed data values (fluxes). If provided, these will be overplotted. err : `~numpy.ndarray` of shape `(Nobj, Nfilt)` Associated errors on the data values. If provided, these will be overplotted as error bars. mask : `~numpy.ndarray` of shape `(Nobj, Nfilt)` Binary mask (0/1) indicating whether the data value was observed. If provided, these will be used to mask missing/bad data values. models : `~numpy.ndarray` of shape `(Nmodels, Nfilts, Ncoeffs)` Array of magnitude polynomial coefficients used to generate reddened photometry. idxs : `~numpy.ndarray` of shape `(Nobj, Nsamps)` An array of resampled indices corresponding to the set of models used to fit the data. reds : `~numpy.ndarray` of shape `(Nobj, Nsamps)` Reddening samples (in Av) associated with the model indices. dreds : `~numpy.ndarray` of shape `(Nsamps)` "Differential" reddening samples (in Rv) associated with the model indices. dists : `~numpy.ndarray` of shape `(Nobj, Nsamps)` Distance samples (in kpc) associated with the model indices. x : `~numpy.ndarray` with shape `(Nobj)` or `(Nobj, Nsamps)`, optional Corresponding values to be plotted on the `x` axis. If not provided, the default behavior is to plot as a function of observed magnitude. flux : bool, optional Whether the photometry provided is in fluxes (instead of magnitudes). Default is `True`. weights : `~numpy.ndarray` of shape `(Nobj)` or `(Nobj, Nsamps)`, optional An optional set of importance weights used to reweight the samples. bins : single value or iterable of length `Nfilt`, optional The number of bins to use. Passed to `~matplotlib.pyplot.hist2d`. Default is `100`. offset : `~numpy.ndarray` of shape `(Nfilt)`, optional Multiplicative photometric offsets that will be applied to the data (i.e. `data_new = data * phot_offsets`) and errors when provided. dim_prior : bool, optional Whether to apply a dimensional-based correction (prior) to the log-likelihood when reweighting the data while cycling through each band. Transforms the likelihood to a chi2 distribution with `Nfilt - 3` degrees of freedom. Default is `True`. plot_thresh : float, optional The threshold used to threshold the colormap when plotting. Default is `0.`. cmap : colormap, optional The colormap used when plotting results. Default is `'viridis'`. xspan : iterable with shape `(nfilt, 2)`, optional A list where each element is a length-2 tuple containing lower and upper bounds for the x-axis for each plot. yspan : iterable with shape `(nfilt, 2)`, optional A list where each element is a length-2 tuple containing lower and upper bounds for the y-axis for each plot. titles : iterable of str of length `Nfilt`, optional Titles for each of the subplots corresponding to each band. If not provided `Band #` will be used. xlabel : str, optional Labels for the x-axis of each subplot. If not provided, these will default to the titles. plot_kwargs : kwargs, optional Keyword arguments to be passed to `~matplotlib.pyplot.imshow`. fig : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`), optional If provided, overplot the traces and marginalized 1-D posteriors onto the provided figure. Otherwise, by default an internal figure is generated. Returns ------- postpredplot : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`) The associated figure and axes for the photometric offsets. """ # Initialize values. nmodels, nfilt, ncoeff = models.shape nobj, nsamps = idxs.shape if plot_kwargs is None: plot_kwargs = dict() if weights is None: weights = np.ones((nobj, nsamps)) elif weights.shape != (nobj, nsamps): weights = np.repeat(weights, nsamps).reshape(nobj, nsamps) try: nbins = len(bins) if nbins != 2: bins = [b for b in bins] else: bins = [bins for i in range(nfilt)] except TypeError: bins = [bins for i in range(nfilt)] if titles is None: titles = ["Band {0}".format(i) for i in range(nfilt)] if xlabel is None: if x is None: xlabel = titles else: xlabel = ["Label" for i in range(nfilt)] else: xlabel = [xlabel for i in range(nfilt)] if offset is None: offset = np.ones(nfilt) # Compute posterior predictive SED magnitudes. mpred = get_seds(models[idxs.flatten()], av=reds.flatten(), rv=dreds.flatten()) mpred += 5.0 * np.log10(dists.flatten())[:, None] mpred = mpred.reshape(nobj, nsamps, nfilt) # Convert observed data to magnitudes. with warnings.catch_warnings(): warnings.simplefilter("ignore") # ignore bad values if flux: magobs, mageobs = magnitude(phot * offset, err * offset) else: magobs, mageobs = phot + offset, err # Generate figure. if fig is None: ncols = 5 nrows = (nfilt - 1) // ncols + 1 fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 6, nrows * 5)) else: fig, axes = fig nrows, ncols = axes.shape ax = axes.flatten() # Plot offsets. for i in range(nfilt): # Compute selection ignoring current band. mtemp = np.array(mask) mtemp[:, i] = False s = ( mask[:, i] & (np.sum(mtemp, axis=1) > 3) & (np.all(np.isfinite(magobs), axis=1)) ) # Compute weights from ignoring current band. # TODO: vectorize this loop over objects for performance lnl = np.array( [ phot_loglike( mo[None, :], me[None, :], mp[None, :, :], mask=mt[None, :], dim_prior=dim_prior, ).squeeze(0) for mo, me, mt, mp in zip(magobs[s], mageobs[s], mtemp[s], mpred[s]) ] ) levid = logsumexp(lnl, axis=1) logwt = lnl - levid[:, None] wt = np.exp(logwt) wt /= wt.sum(axis=1)[:, None] # Repeat to match up with `nsamps`. mobs = np.repeat(magobs[s, i], nsamps) if x is None: xp = mobs else: if x.shape == (nobj, nsamps): xp = x[s].flatten() else: xp = np.repeat(x[s], nsamps) # Plot 2-D histogram. mp = mpred[s, :, i].flatten() w = weights[s].flatten() * wt.flatten() if xspan is None: xlow, xhigh = quantile(xp, [0.02, 0.98], weights=w) bx = np.linspace(xlow, xhigh, bins[i] + 1) else: bx = np.linspace(xspan[i][0], xspan[i][1], bins[i] + 1) if yspan is None: ylow, yhigh = quantile(mp - mobs, [0.02, 0.98], weights=w) by = np.linspace(ylow, yhigh, bins[i] + 1) else: by = np.linspace(yspan[i][0], yspan[i][1], bins[i] + 1) ax[i].hist2d( xp, mp - mobs, bins=(bx, by), weights=w, cmin=plot_thresh, cmap=cmap, **plot_kwargs, ) ax[i].set_xlabel(xlabel[i]) ax[i].set_title(titles[i]) ax[i].set_ylabel(r"$\Delta\,$mag") # Clear other axes. for i in range(nfilt, nrows * ncols): ax[i].set_frame_on(False) ax[i].set_xticks([]) ax[i].set_yticks([]) plt.tight_layout() return fig, axes
[docs] def photometric_offsets_2d( phot, err, mask, models, idxs, reds, dreds, dists, x, y, flux=True, weights=None, bins=100, offset=None, dim_prior=True, plot_thresh=10.0, cmap="coolwarm", clims=(-0.05, 0.05), xspan=None, yspan=None, titles=None, show_off=True, xlabel=None, ylabel=None, plot_kwargs=None, fig=None, ): """ Plot photometric offsets (`mag_pred - mag_obs`). Parameters ---------- phot : `~numpy.ndarray` of shape `(Nobj, Nfilt)`, optional Observed data values (fluxes). If provided, these will be overplotted. err : `~numpy.ndarray` of shape `(Nobj, Nfilt)` Associated errors on the data values. If provided, these will be overplotted as error bars. mask : `~numpy.ndarray` of shape `(Nobj, Nfilt)` Binary mask (0/1) indicating whether the data value was observed. If provided, these will be used to mask missing/bad data values. models : `~numpy.ndarray` of shape `(Nmodels, Nfilts, Ncoeffs)` Array of magnitude polynomial coefficients used to generate reddened photometry. idxs : `~numpy.ndarray` of shape `(Nobj, Nsamps)` An array of resampled indices corresponding to the set of models used to fit the data. reds : `~numpy.ndarray` of shape `(Nobj, Nsamps)` Reddening samples (in Av) associated with the model indices. dreds : `~numpy.ndarray` of shape `(Nsamps)` "Differential" reddening samples (in Rv) associated with the model indices. dists : `~numpy.ndarray` of shape `(Nobj, Nsamps)` Distance samples (in kpc) associated with the model indices. x : `~numpy.ndarray` with shape `(Nobj)` or `(Nobj, Nsamps)` Values to be plotted on the `x` axis. Required. y : `~numpy.ndarray` with shape `(Nobj)` or `(Nobj, Nsamps)` Values to be plotted on the `y` axis. Required. flux : bool, optional Whether the photometry provided is in fluxes (instead of magnitudes). Default is `True`. weights : `~numpy.ndarray` of shape `(Nobj)` or `(Nobj, Nsamps)`, optional An optional set of importance weights used to reweight the samples. bins : single value or iterable of length `Nfilt`, optional The number of bins to use. Passed to `~matplotlib.pyplot.hist2d`. Default is `100`. offset : `~numpy.ndarray` of shape `(Nfilt)`, optional Multiplicative photometric offsets that will be applied to the data (i.e. `data_new = data * phot_offsets`) and errors when provided. dim_prior : bool, optional Whether to apply a dimensional-based correction (prior) to the log-likelihood when reweighting the data while cycling through each band. Transforms the likelihood to a chi2 distribution with `Nfilt - 3` degrees of freedom. Default is `True`. plot_thresh : float, optional The threshold used to threshold the colormap when plotting. Default is `10.`. cmap : colormap, optional The colormap used when plotting results. Default is `'coolwarm'`. clims : 2-tuple, optional Plotting bounds for the colorbar. Default is `(-0.05, 0.05)`. xspan : iterable with shape `(nfilt, 2)`, optional A list where each element is a length-2 tuple containing lower and upper bounds for the x-axis for each plot. yspan : iterable with shape `(nfilt, 2)`, optional A list where each element is a length-2 tuple containing lower and upper bounds for the y-axis for each plot. titles : iterable of str of length `Nfilt`, optional Titles for each of the subplots corresponding to each band. If not provided `Band #` will be used. show_off : bool, optional Whether to include the offsets in the titles. Default is `True`. xlabel : str, optional Label for the x-axis of each subplot. If not provided, this will default to `X`. ylabel : str, optional Label for the y-axis of each subplot. If not provided, this will default to `Y`. plot_kwargs : kwargs, optional Keyword arguments to be passed to `~matplotlib.pyplot.imshow`. fig : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`), optional If provided, overplot the traces and marginalized 1-D posteriors onto the provided figure. Otherwise, by default an internal figure is generated. Returns ------- postpredplot : (`~matplotlib.figure.Figure`, `~matplotlib.axes.Axes`) The associated figure and axes for the photometric offsets. """ # Initialize values. nmodels, nfilt, ncoeff = models.shape nobj, nsamps = idxs.shape if plot_kwargs is None: plot_kwargs = dict() if weights is None: weights = np.ones((nobj, nsamps)) elif weights.shape != (nobj, nsamps): weights = np.repeat(weights, nsamps).reshape(nobj, nsamps) try: nbins = len(bins) if nbins != 2: bins = [b for b in bins] else: bins = [bins for i in range(nfilt)] except TypeError: bins = [bins for i in range(nfilt)] if titles is None: titles = ["Band {0}".format(i) for i in range(nfilt)] if show_off and offset is not None: titles = [ t + " ({:2.2}% offset)".format(100.0 * (off - 1.0)) for t, off in zip(titles, offset) ] if xlabel is None: xlabel = "X" if ylabel is None: ylabel = "Y" if offset is None: offset = np.ones(nfilt) # Compute posterior predictive SED magnitudes. mpred = get_seds(models[idxs.flatten()], av=reds.flatten(), rv=dreds.flatten()) mpred += 5.0 * np.log10(dists.flatten())[:, None] mpred = mpred.reshape(nobj, nsamps, nfilt) # Convert observed data to magnitudes. with warnings.catch_warnings(): warnings.simplefilter("ignore") # ignore bad values if flux: magobs, mageobs = magnitude(phot * offset, err * offset) else: magobs, mageobs = phot + offset, err # Magnitude offsets. dm = mpred - magobs[:, None] for i in range(nfilt): dm[~mask[:, i], :, i] = np.nan # Generate figure. if fig is None: ncols = 5 nrows = (nfilt - 1) // ncols + 1 fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 15, nrows * 12)) else: fig, axes = fig nrows, ncols = axes.shape ax = axes.flatten() # Plot offsets. for i in range(nfilt): # Bin in 2-D. n, xbins, ybins = np.histogram2d(x, y, bins=bins[i]) xcent = 0.5 * (xbins[1:] + xbins[:-1]) ycent = 0.5 * (ybins[1:] + ybins[:-1]) bounds = [xcent[0], xcent[-1], ycent[0], ycent[-1]] # default size # Digitize values. xloc, yloc = np.digitize(x, xbins) - 1, np.digitize(y, ybins) - 1 # Compute selection ignoring current band. mtemp = np.array(mask) mtemp[:, i] = False s = ( mask[:, i] & (np.sum(mtemp, axis=1) > 3) & (np.all(np.isfinite(magobs), axis=1)) ) # Compute weights from ignoring current band. # TODO: vectorize this loop over objects for performance with warnings.catch_warnings(): warnings.simplefilter("ignore") # ignore bad values lnl = np.array( [ phot_loglike( mo[None, :], me[None, :], mp[None, :, :], mask=mt[None, :], dim_prior=dim_prior, ).squeeze(0) for mo, me, mt, mp in zip(magobs, mageobs, mtemp, mpred) ] ) levid = logsumexp(lnl, axis=1) logwt = lnl - levid[:, None] wt = np.exp(logwt) wt /= wt.sum(axis=1)[:, None] # Compute weighted median offsets. offset2d = np.zeros((len(xbins) - 1, len(ybins) - 1)) for xidx in range(len(xbins) - 1): for yidx in range(len(ybins) - 1): bsel = np.where((xloc == xidx) & (yloc == yidx) & s)[0] if len(bsel) >= plot_thresh: # If we have enough objects, compute weighted median. off, w = dm[bsel, :, i], wt[bsel] * weights[bsel] off_med = quantile(off.flatten(), [0.5], w.flatten())[0] offset2d[xidx, yidx] = off_med else: # If we don't have enough objects, mask bin. offset2d[xidx, yidx] = np.nan # Plot offsets over 2-D histogram. if xspan is not None: bounds[:2] = xspan[i] if yspan is not None: bounds[2:] = yspan[i] img = ax[i].imshow( offset2d.T, origin="lower", extent=bounds, vmin=clims[0], vmax=clims[1], aspect="auto", cmap=cmap, **plot_kwargs, ) ax[i].set_xlabel(xlabel) ax[i].set_ylabel(ylabel) ax[i].set_title(titles[i]) plt.colorbar(img, ax=ax[i], label=r"$\Delta\,$mag") # Clear other axes. for i in range(nfilt, nrows * ncols): ax[i].set_frame_on(False) ax[i].set_xticks([]) ax[i].set_yticks([]) plt.tight_layout() return fig, axes
__all__ = ["photometric_offsets", "photometric_offsets_2d"]