#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Distance and reddening visualization functions.
This module provides functions for plotting distance vs reddening
posterior distributions.
"""
import numpy as np
from matplotlib import pyplot as plt
from .binning import bin_pdfs_distred
__all__ = ["dist_vs_red"]
[docs]
def dist_vs_red(
data,
ebv=None,
dist_type="distance_modulus",
lndistprior=None,
coord=None,
avlim=(0.0, 6.0),
rvlim=(1.0, 8.0),
weights=None,
parallax=None,
parallax_err=None,
Nr=300,
cmap="Blues",
bins=300,
span=None,
smooth=0.015,
plot_kwargs=None,
truths=None,
truth_color="red",
truth_kwargs=None,
rstate=None,
):
"""
Generate a 2-D plot of distance vs reddening.
Parameters
----------
data : 3-tuple or 4-tuple containing `~numpy.ndarray`s of shape `(Nsamps)`
The data that will be plotted. Either a collection of
`(dists, reds, dreds)` that were saved, or a collection of
`(scales, avs, rvs, covs_sar)` that will be used to regenerate
`(dists, reds)` in conjunction with any applied distance
and/or parallax priors.
ebv : bool, optional
If provided, will convert from Av to E(B-V) when plotting using
the provided Rv values. Default is `False`.
dist_type : str, optional
The distance format to be plotted. Options include `'parallax'`,
`'scale'`, `'distance'`, and `'distance_modulus'`.
Default is `'distance_modulus`.
lndistprior : func, optional
The log-distsance prior function used. If not provided, the galactic
model from Green et al. (2014) will be assumed.
coord : 2-tuple, optional
The galactic `(l, b)` coordinates for the object, which is passed to
`lndistprior`.
avlim : 2-tuple, optional
The Av limits used to truncate results. Default is `(0., 6.)`.
rvlim : 2-tuple, optional
The Rv limits used to truncate results. Default is `(1., 8.)`.
weights : `~numpy.ndarray` of shape `(Nsamps)`, optional
An optional set of importance weights used to reweight the samples.
parallax : float, optional
The parallax estimate for the source.
parallax_err : float, optional
The parallax error.
Nr : int, optional
The number of Monte Carlo realizations used when sampling using the
provided parallax prior. Default is `300`.
cmap : str, optional
The colormap used when plotting. Default is `'Blues'`.
bins : int or list of ints with length `(ndim,)`, optional
The number of bins to be used in each dimension. Default is `300`.
span : iterable with shape `(2, 2)`, optional
A list where each element is a length-2 tuple containing
lower and upper bounds. If not provided, the x-axis will use the
provided Av bounds while the y-axis will span `(4., 19.)` in
distance modulus (both appropriately transformed).
smooth : int/float or list of ints/floats with shape `(ndim,)`, optional
The standard deviation (either a single value or a different value for
each axis) for the Gaussian kernel used to smooth the 2-D
marginalized posteriors. If an int is passed, the smoothing will
be applied in units of the binning in that dimension. If a float
is passed, it is expressed as a fraction of the span.
Default is `0.015` (1.5% smoothing).
**Cannot smooth by more than the provided parallax will allow.**
plot_kwargs : dict, optional
Extra keyword arguments to be used when plotting the smoothed
2-D histograms.
truths : iterable with shape `(ndim,)`, optional
A list of reference values that will be overplotted on the traces and
marginalized 1-D posteriors as solid horizontal/vertical lines.
Individual values can be exempt using `None`. Default is `None`.
truth_color : str or iterable with shape `(ndim,)`, optional
A `~matplotlib`-style color (either a single color or a different
value for each subplot) used when plotting `truths`.
Default is `'red'`.
truth_kwargs : dict, optional
Extra keyword arguments that will be used for plotting the vertical
and horizontal lines with `truths`.
rstate : `~numpy.random.RandomState`, optional
`~numpy.random.RandomState` instance.
Returns
-------
hist2d : (counts, xedges, yedges, `~matplotlib.figure.Image`)
Output 2-D histogram.
"""
# Initialize values.
if truth_kwargs is None:
truth_kwargs = dict()
if plot_kwargs is None:
plot_kwargs = dict()
# Set defaults for truth plotting
truth_kwargs["linestyle"] = truth_kwargs.get("linestyle", "solid")
truth_kwargs["linewidth"] = truth_kwargs.get("linewidth", 2)
truth_kwargs["alpha"] = truth_kwargs.get("alpha", 0.7)
# Handle single object case - convert to array format expected by bin_pdfs_distred
# bin_pdfs_distred expects (n_objects, n_samples) shape
if len(data[0].shape) == 1:
# Single object case: convert from (n_samples,) to (1, n_samples)
if len(data) == 3: # (dists, reds, dreds)
data = tuple(arr[None, :] for arr in data) # Add object dimension
elif len(data) == 4: # (scales, avs, rvs, covs_sar)
data = (
data[0][None, :],
data[1][None, :],
data[2][None, :],
data[3][None, :],
)
single_object = True
# Convert coord to list format
if coord is not None:
coord = [coord]
# Convert parallax info to array format
if parallax is not None:
parallax = np.array([parallax])
if parallax_err is not None:
parallax_err = np.array([parallax_err])
else:
# Multi-object case
single_object = False
# Use bin_pdfs_distred to do all the heavy lifting for data preparation
binned_vals, xedges, yedges = bin_pdfs_distred(
data,
cdf=False,
ebv=ebv,
dist_type=dist_type,
lndistprior=lndistprior,
coord=coord,
avlim=avlim,
rvlim=rvlim,
parallaxes=parallax,
parallax_errors=parallax_err,
Nr=Nr,
bins=bins,
span=span,
smooth=smooth,
rstate=rstate,
verbose=False,
)
# For single object, extract the first (and only) object's data
if single_object:
H = binned_vals[0]
else:
# For multiple objects, we need to decide how to combine them
# Default behavior: use the first object
H = binned_vals[0]
# Set up axis labels
if ebv:
ylabel = r"$E(B-V)$ [mag]"
else:
ylabel = r"$A_v$ [mag]"
if dist_type == "scale":
xlabel = r"$s$"
elif dist_type == "parallax":
xlabel = r"$\pi$ [mas]"
elif dist_type == "distance":
xlabel = r"$d$ [kpc]"
elif dist_type == "distance_modulus":
xlabel = r"$\mu$"
# Determine plot extent
xlims = [xedges[0], xedges[-1]]
ylims = [yedges[0], yedges[-1]]
# Generate the plot
img = plt.imshow(
H.T,
cmap=cmap,
aspect="auto",
interpolation="none",
origin="lower",
extent=[xlims[0], xlims[1], ylims[0], ylims[1]],
**plot_kwargs,
)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
# Add truth values if provided
if truths is not None:
if truths[0] is not None: # x-axis truth
try:
[plt.axvline(t, color=truth_color, **truth_kwargs) for t in truths[0]]
except TypeError:
plt.axvline(truths[0], color=truth_color, **truth_kwargs)
if truths[1] is not None: # y-axis truth
try:
[plt.axhline(t, color=truth_color, **truth_kwargs) for t in truths[1]]
except TypeError:
plt.axhline(truths[1], color=truth_color, **truth_kwargs)
return H, xedges, yedges, img