#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Galactic structure priors for Bayesian stellar parameter estimation.
This module provides log-prior functions for Galactic structure modeling
including disk and halo number densities, metallicity distributions,
and age-metallicity relations. These priors encode our knowledge of the
Milky Way's stellar populations and are essential for distance estimation.
Functions
---------
logn_disk : Disk number density
Exponential disk model
logn_halo : Halo number density
Flattened power-law halo
logp_feh : Metallicity distribution
Gaussian metallicity distribution
logp_age_from_feh : Age-metallicity relation
Age distribution conditional on metallicity
logp_galactic_structure : Combined prior
Full Galactic structure prior combining disk and halo
See Also
--------
brutus.analysis.individual.BruteForce : Uses Galactic priors for distance fitting
brutus.priors.stellar : Stellar population priors
brutus.priors.astrometric : Parallax and proper motion priors
Notes
-----
These priors provide critical constraints on stellar distances by
incorporating knowledge of Galactic structure:
- **Number density priors** (disk and halo) weight distance based on
expected stellar distributions in the Galaxy
- **Metallicity priors** provide realistic [Fe/H] distributions for
disk and halo populations
- **Age-metallicity relations** link stellar age and composition
The combined prior `logp_galactic_structure` integrates disk and halo
models with appropriate mixing fractions.
The priors use Galactocentric coordinates, which requires coordinate
transformations from equatorial coordinates and distances.
Examples
--------
Evaluate disk number density:
>>> import numpy as np
>>> from brutus.priors.galactic import logn_disk
>>>
>>> # Position in Galactic disk
>>> R = np.array([8.0]) # kpc from Galactic center
>>> Z = np.array([0.1]) # kpc above midplane
>>> log_density = logn_disk(R, Z)
>>> print(f"Log-density: {log_density[0]:.3f}")
Combined Galactic structure prior:
>>> from brutus.priors.galactic import logp_galactic_structure
>>>
>>> # Galactic coordinates (l, b) in degrees and distances in kpc
>>> coord = (90.0, 30.0)
>>> distances = np.array([1.0, 2.0, 5.0])
>>>
>>> # Evaluate the combined thin disk + thick disk + halo prior
>>> log_prior = logp_galactic_structure(distances, coord)
"""
import warnings
from math import erf, exp, log, pi, sqrt
import numpy as np
from numba import jit, prange
# Import utility functions from brutus.utils
from brutus.utils import truncnorm_logpdf
from brutus.utils.math import galactic_to_galactocentric_cyl
_LOG_2PI = log(2.0 * pi)
_SQRT2 = sqrt(2.0)
def _logsumexp3(a, b, c):
"""Fast logsumexp of 3 arrays using numpy's logaddexp ufunc."""
return np.logaddexp(np.logaddexp(a, b), c)
@jit(nopython=True, parallel=True, cache=True)
def _galactic_prior_fused(
dists,
R_arr,
Z_arr,
feh_arr,
loga_arr,
has_feh,
has_loga,
R_solar,
Z_solar,
# Thin disk
R_thin,
Z_thin,
Rs_thin,
f_thin,
# Thick disk
R_thick,
Z_thick,
Rs_thick,
f_thick,
# Halo
Rs_halo,
eta_halo,
q_halo_ctr,
q_halo_inf,
r_q_halo,
f_halo,
# Metallicity
feh_thin_mean,
feh_thin_sigma,
feh_thick_mean,
feh_thick_sigma,
feh_halo_mean,
feh_halo_sigma,
# Age
max_age,
min_age,
feh_age_ctr,
feh_age_scale,
nsigma_from_max_age,
max_sigma_age,
min_sigma_age,
):
"""Fused galactic prior: computes density + vol + met + age in one pass."""
N = len(dists)
logp_out = np.empty(N)
# Solar halo normalization (scalar, computed once)
r_solar = sqrt(R_solar**2 + Z_solar**2)
r_prime_solar = sqrt(r_solar**2 + r_q_halo**2)
q_solar = q_halo_inf - (q_halo_inf - q_halo_ctr) * exp(
1.0 - r_prime_solar / r_q_halo
)
R_eff_solar_halo = sqrt(R_solar**2 + (Z_solar / q_solar) ** 2 + Rs_halo**2)
for i in prange(N):
d = dists[i]
R = R_arr[i]
Z = Z_arr[i]
# Volume factor
vol = 2.0 * log(d + 1e-300)
# Thin disk
R_eff_thin = sqrt(R**2 + Rs_thin**2)
lnp_thin = (
-(R_eff_thin - R_solar) / R_thin - (abs(Z) - abs(Z_solar)) / Z_thin + vol
)
# Thick disk
R_eff_thick = sqrt(R**2 + Rs_thick**2)
lnp_thick = (
-(R_eff_thick - R_solar) / R_thick
- (abs(Z) - abs(Z_solar)) / Z_thick
+ vol
+ log(f_thick)
)
# Halo
r = sqrt(R**2 + Z**2)
r_prime = sqrt(r**2 + r_q_halo**2)
q = q_halo_inf - (q_halo_inf - q_halo_ctr) * exp(1.0 - r_prime / r_q_halo)
R_eff_halo = sqrt(R**2 + (Z / q) ** 2 + Rs_halo**2)
lnp_halo = -eta_halo * log(R_eff_halo / R_eff_solar_halo) + vol + log(f_halo)
# logsumexp3
mx = max(lnp_thin, max(lnp_thick, lnp_halo))
logp_total = mx + log(
exp(lnp_thin - mx) + exp(lnp_thick - mx) + exp(lnp_halo - mx)
)
# Metallicity prior
if has_feh:
feh_val = feh_arr[i]
# Component membership
ln_w_thin = lnp_thin - logp_total
ln_w_thick = lnp_thick - logp_total
ln_w_halo = lnp_halo - logp_total
# Gaussian logpdf for each component
feh_lnp_thin = (
-0.5
* (
(feh_val - feh_thin_mean) ** 2 / feh_thin_sigma**2
+ _LOG_2PI
+ 2 * log(feh_thin_sigma)
)
+ ln_w_thin
)
feh_lnp_thick = (
-0.5
* (
(feh_val - feh_thick_mean) ** 2 / feh_thick_sigma**2
+ _LOG_2PI
+ 2 * log(feh_thick_sigma)
)
+ ln_w_thick
)
feh_lnp_halo = (
-0.5
* (
(feh_val - feh_halo_mean) ** 2 / feh_halo_sigma**2
+ _LOG_2PI
+ 2 * log(feh_halo_sigma)
)
+ ln_w_halo
)
mx2 = max(feh_lnp_thin, max(feh_lnp_thick, feh_lnp_halo))
feh_lnp = mx2 + log(
exp(feh_lnp_thin - mx2)
+ exp(feh_lnp_thick - mx2)
+ exp(feh_lnp_halo - mx2)
)
logp_total += feh_lnp
# Age prior
if has_loga:
age_val = 10.0 ** loga_arr[i] / 1e9 # Gyr
if not has_feh:
feh_lnp = 0.0
ln_w_thin = lnp_thin - (logp_total - feh_lnp if has_feh else logp_total)
ln_w_thick = lnp_thick - (logp_total - feh_lnp if has_feh else logp_total)
ln_w_halo = lnp_halo - (logp_total - feh_lnp if has_feh else logp_total)
# Compute truncated normal for each component
age_lnp_total = -1e300
for comp_idx in range(3):
if comp_idx == 0:
fm = feh_thin_mean
ln_w = ln_w_thin
elif comp_idx == 1:
fm = feh_thick_mean
ln_w = ln_w_thick
else:
fm = feh_halo_mean
ln_w = ln_w_halo
age_mean = (max_age - min_age) / (
1.0 + exp((fm - feh_age_ctr) / feh_age_scale)
) + min_age
age_sigma = (max_age - age_mean) / nsigma_from_max_age
if age_sigma < min_sigma_age:
age_sigma = min_sigma_age
if age_sigma > max_sigma_age:
age_sigma = max_sigma_age
xi = (age_val - age_mean) / age_sigma
alpha = (min_age - age_mean) / age_sigma
beta = (max_age - age_mean) / age_sigma
lnphi = -0.5 * (_LOG_2PI + xi * xi)
denom = max(erf(beta / _SQRT2) - erf(alpha / _SQRT2), 1e-300)
lndenom = log(age_sigma / 2.0) + log(denom)
age_comp = lnphi - lndenom + ln_w
# Bounds check
if age_val < min_age or age_val > max_age:
age_comp = -1e300
# Accumulate logsumexp
if age_comp > age_lnp_total:
age_lnp_total = (
age_comp + log(1.0 + exp(age_lnp_total - age_comp))
if age_lnp_total > -1e200
else age_comp
)
else:
age_lnp_total = (
age_lnp_total + log(1.0 + exp(age_comp - age_lnp_total))
if age_comp > -1e200
else age_lnp_total
)
logp_total += age_lnp_total
logp_out[i] = logp_total
return logp_out
__all__ = [
"logn_disk",
"logn_halo",
"logp_feh",
"logp_age_from_feh",
"logp_galactic_structure",
]
[docs]
def logn_disk(R, Z, R_solar=8.2, Z_solar=0.025, R_scale=2.6, Z_scale=0.3, R_smooth=2.0):
r"""
Log-number density for the Galactic disk stellar population.
Implements an exponential disk model with separate radial and vertical
scale lengths, smoothed near the Galactic center to avoid singularities.
Parameters
----------
R : array_like
Galactocentric cylindrical radius in kpc.
Z : array_like
Height above the Galactic midplane in kpc.
R_solar : float, optional
Solar Galactocentric radius in kpc. Default is 8.2.
Z_solar : float, optional
Solar height above midplane in kpc. Default is 0.025.
R_scale : float, optional
Disk radial scale length in kpc. Default is 2.6.
Z_scale : float, optional
Disk vertical scale height in kpc. Default is 0.3.
R_smooth : float, optional
Smoothing radius to avoid central singularity in kpc. Default is 2.0.
Returns
-------
logn : array_like
Normalized log-number density relative to Solar neighborhood.
See Also
--------
logn_halo : Halo number density
logp_galactic_structure : Combined disk+halo model
Notes
-----
The disk number density follows:
.. math::
n_{\\text{disk}}(R, Z) \\propto \\exp\\left(-\\frac{R_{\\text{eff}} - R_\\odot}{R_{\\text{scale}}} - \\frac{|Z| - |Z_\\odot|}{Z_{\\text{scale}}}\\right)
where :math:`R_{\\text{eff}} = \\sqrt{R^2 + R_{\\text{smooth}}^2}` provides
smoothing near the Galactic center.
References
----------
Bland-Hawthorn & Gerhard (2016) - The Galaxy in Context
"""
R = np.asarray(R)
Z = np.asarray(Z)
# Smoothed effective radius
R_eff = np.sqrt(R**2 + R_smooth**2)
# Exponential disk components
radial_term = (R_eff - R_solar) / R_scale
vertical_term = (np.abs(Z) - np.abs(Z_solar)) / Z_scale
return -(radial_term + vertical_term)
[docs]
def logn_halo(
R,
Z,
R_solar=8.2,
Z_solar=0.025,
R_smooth=2.0,
eta=4.2,
q_ctr=0.2,
q_inf=0.8,
r_q=6.0,
):
r"""
Log-number density for the Galactic halo stellar population.
Implements a flattened power-law halo model with radius-dependent
oblateness following observational constraints.
Parameters
----------
R : array_like
Galactocentric cylindrical radius in kpc.
Z : array_like
Height above the Galactic midplane in kpc.
R_solar : float, optional
Solar Galactocentric radius in kpc. Default is 8.2.
Z_solar : float, optional
Solar height above midplane in kpc. Default is 0.025.
R_smooth : float, optional
Smoothing radius to avoid central singularity in kpc. Default is 2.0.
eta : float, optional
Power-law index for halo density profile. Default is 4.2.
q_ctr : float, optional
Halo oblateness at Galactic center. Default is 0.2.
q_inf : float, optional
Halo oblateness at large radii. Default is 0.8.
r_q : float, optional
Scale radius for oblateness transition in kpc. Default is 6.0.
Returns
-------
logn : array_like
Normalized log-number density relative to Solar neighborhood.
See Also
--------
logn_disk : Disk number density
logp_galactic_structure : Combined disk+halo model
Notes
-----
The halo follows a flattened power-law profile:
.. math::
n_{\\text{halo}}(R, Z) \\propto R_{\\text{eff}}^{-\\eta}
where the effective radius includes radius-dependent flattening:
.. math::
R_{\\text{eff}} = \\sqrt{R^2 + (Z/q)^2 + R_{\\text{smooth}}^2}
q(r) = q_\\infty - (q_\\infty - q_{\\text{ctr}}) e^{1 - r'/r_q}
r' = \\sqrt{r^2 + r_q^2}, \\quad r = \\sqrt{R^2 + Z^2}
References
----------
Bland-Hawthorn & Gerhard (2016) - The Galaxy in Context
Bell et al. (2008) - Stellar Halo Properties from SDSS
"""
R = np.asarray(R)
Z = np.asarray(Z)
# Spherical radius from Galactic center
r = np.sqrt(R**2 + Z**2)
# Radius-dependent oblateness
r_prime = np.sqrt(r**2 + r_q**2)
q = q_inf - (q_inf - q_ctr) * np.exp(1.0 - r_prime / r_q)
# Effective radius with flattening and smoothing
R_eff = np.sqrt(R**2 + (Z / q) ** 2 + R_smooth**2)
# Solar normalization values
r_solar = np.sqrt(R_solar**2 + Z_solar**2)
r_prime_solar = np.sqrt(r_solar**2 + r_q**2)
q_solar = q_inf - (q_inf - q_ctr) * np.exp(1.0 - r_prime_solar / r_q)
R_eff_solar = np.sqrt(R_solar**2 + (Z_solar / q_solar) ** 2 + R_smooth**2)
# Power-law halo profile
logn = -eta * np.log(R_eff / R_eff_solar)
return logn
[docs]
def logp_feh(feh, feh_mean=-0.2, feh_sigma=0.3):
r"""
Log-prior for stellar metallicity in Galactic components.
Implements a Gaussian metallicity distribution appropriate for
different Galactic stellar populations (disk, thick disk, halo).
Parameters
----------
feh : array_like
Stellar metallicity [Fe/H] in dex.
feh_mean : float, optional
Mean metallicity of the population in dex. Default is -0.2 (thin disk).
feh_sigma : float, optional
Metallicity dispersion in dex. Default is 0.3.
Returns
-------
logp : array_like
Normalized log-probability density for the input metallicities.
Notes
-----
The metallicity prior follows a normal distribution:
.. math::
\\log p([\\text{Fe/H}]) = -\\frac{1}{2}\\left[\\frac{([\\text{Fe/H}] - \\mu_{\\text{Fe/H}})^2}{\\sigma_{\\text{Fe/H}}^2} + \\log(2\\pi\\sigma_{\\text{Fe/H}}^2)\\right]
Typical values for different Galactic components:
- Thin disk: feh_mean = -0.2, feh_sigma = 0.3
- Thick disk: feh_mean = -0.7, feh_sigma = 0.4
- Halo: feh_mean = -1.6, feh_sigma = 0.5
References
----------
Bland-Hawthorn & Gerhard (2016) - The Galaxy in Context
"""
feh = np.asarray(feh)
# Gaussian log-prior
chi2 = (feh - feh_mean) ** 2 / feh_sigma**2
log_norm = np.log(2.0 * np.pi * feh_sigma**2)
logp = -0.5 * (chi2 + log_norm)
return logp
[docs]
def logp_age_from_feh(
age,
feh_mean=-0.2,
max_age=13.8,
min_age=0.0,
feh_age_ctr=-0.5,
feh_age_scale=0.5,
nsigma_from_max_age=2.0,
max_sigma=4.0,
min_sigma=1.0,
):
r"""
Log-prior for stellar age based on metallicity-age relation.
Implements the age-metallicity relation observed in the Galactic disk,
where older stars tend to be more metal-poor. Uses truncated normal
distribution bounded by physically reasonable ages.
Parameters
----------
age : array_like
Stellar ages in Gyr.
feh_mean : float, optional
Mean metallicity of the population in dex. Default is -0.2.
max_age : float, optional
Maximum allowed stellar age in Gyr. Default is 13.8 (age of Universe).
min_age : float, optional
Minimum allowed stellar age in Gyr. Default is 0.0.
feh_age_ctr : float, optional
Metallicity where mean age is halfway between min/max. Default is -0.5.
feh_age_scale : float, optional
Scale length for metallicity-age relation in dex. Default is 0.5.
nsigma_from_max_age : float, optional
Number of σ the mean age is below max_age. Default is 2.0.
max_sigma : float, optional
Maximum age dispersion in Gyr. Default is 4.0.
min_sigma : float, optional
Minimum age dispersion in Gyr. Default is 1.0.
Returns
-------
logp : array_like
Normalized log-probability density for the input ages.
Notes
-----
The age-metallicity relation follows a logistic function:
.. math::
\\langle t \\rangle = \\frac{t_{\\max} - t_{\\min}}{1 + \\exp\\left(\\frac{[\\text{Fe/H}] - c}{s}\\right)} + t_{\\min}
where c is the central metallicity and s is the scale length.
The age dispersion decreases for younger (more metal-rich) stars:
.. math::
\\sigma_t = \\min\\left(\\max\\left(\\frac{t_{\\max} - \\langle t \\rangle}{n\\sigma}, \\sigma_{\\min}\\right), \\sigma_{\\max}\\right)
Ages are drawn from a truncated normal distribution bounded by [min_age, max_age].
References
----------
Bland-Hawthorn & Gerhard (2016) - The Galaxy in Context
Nordström et al. (2004) - Age-metallicity relation in Solar neighborhood
"""
age = np.asarray(age)
# Predicted mean age from metallicity
age_mean_pred = (max_age - min_age) / (
1.0 + np.exp((feh_mean - feh_age_ctr) / feh_age_scale)
) + min_age
# Age dispersion (younger stars have smaller dispersion)
age_sigma_pred = (max_age - age_mean_pred) / nsigma_from_max_age
age_sigma_pred = np.clip(age_sigma_pred, min_sigma, max_sigma)
# Truncated normal distribution bounds
a = (min_age - age_mean_pred) / age_sigma_pred # Lower bound
b = (max_age - age_mean_pred) / age_sigma_pred # Upper bound
# Compute truncated normal log-probability
logp = truncnorm_logpdf(age, a, b, loc=age_mean_pred, scale=age_sigma_pred)
return logp
[docs]
def logp_galactic_structure(
dists,
coord,
labels=None,
R_solar=8.2,
Z_solar=0.025,
R_thin=2.6,
Z_thin=0.3,
Rs_thin=2.0,
R_thick=2.0,
Z_thick=0.9,
f_thick=0.04,
Rs_thick=2.0,
Rs_halo=2.0,
q_halo_ctr=0.2,
q_halo_inf=0.8,
r_q_halo=6.0,
eta_halo=4.2,
f_halo=0.005,
feh_thin=-0.2,
feh_thin_sigma=0.3,
feh_thick=-0.7,
feh_thick_sigma=0.4,
feh_halo=-1.6,
feh_halo_sigma=0.5,
max_age=13.8,
min_age=0.0,
feh_age_ctr=-0.5,
feh_age_scale=0.5,
nsigma_from_max_age=2.0,
max_sigma=4.0,
min_sigma=1.0,
return_components=False,
):
"""
Complete Galactic structure log-prior with thin disk, thick disk, and halo.
Implements a sophisticated three-component Galactic model based on
Bland-Hawthorn & Gerhard (2016). Combines spatial number density priors
with optional metallicity and age priors for realistic stellar populations.
Parameters
----------
dists : array_like
Distance from observer in kpc.
coord : tuple of floats
Galactic coordinates (l, b) in degrees.
labels : structured array, optional
Stellar labels containing 'feh' and/or 'loga' for metallicity/age priors.
R_solar : float, optional
Solar Galactocentric radius in kpc. Default is 8.2.
Z_solar : float, optional
Solar height above midplane in kpc. Default is 0.025.
R_thin : float, optional
Thin disk radial scale length in kpc. Default is 2.6.
Z_thin : float, optional
Thin disk vertical scale height in kpc. Default is 0.3.
Rs_thin : float, optional
Thin disk smoothing radius in kpc. Default is 2.0.
R_thick : float, optional
Thick disk radial scale length in kpc. Default is 2.0.
Z_thick : float, optional
Thick disk vertical scale height in kpc. Default is 0.9.
f_thick : float, optional
Thick disk relative normalization. Default is 0.04.
Rs_thick : float, optional
Thick disk smoothing radius in kpc. Default is 2.0.
Rs_halo : float, optional
Halo smoothing radius in kpc. Default is 2.0.
q_halo_ctr : float, optional
Halo central oblateness. Default is 0.2.
q_halo_inf : float, optional
Halo asymptotic oblateness. Default is 0.8.
r_q_halo : float, optional
Halo oblateness transition radius in kpc. Default is 6.0.
eta_halo : float, optional
Halo power-law index. Default is 4.2.
f_halo : float, optional
Halo relative normalization. Default is 0.005.
feh_thin : float, optional
Thin disk mean metallicity in dex. Default is -0.2.
feh_thin_sigma : float, optional
Thin disk metallicity dispersion in dex. Default is 0.3.
feh_thick : float, optional
Thick disk mean metallicity in dex. Default is -0.7.
feh_thick_sigma : float, optional
Thick disk metallicity dispersion in dex. Default is 0.4.
feh_halo : float, optional
Halo mean metallicity in dex. Default is -1.6.
feh_halo_sigma : float, optional
Halo metallicity dispersion in dex. Default is 0.5.
max_age : float, optional
Maximum stellar age in Gyr. Default is 13.8.
min_age : float, optional
Minimum stellar age in Gyr. Default is 0.0.
feh_age_ctr : float, optional
Central metallicity for age-metallicity relation. Default is -0.5.
feh_age_scale : float, optional
Scale length for age-metallicity relation. Default is 0.5.
nsigma_from_max_age : float, optional
Age dispersion parameter. Default is 2.0.
max_sigma : float, optional
Maximum age dispersion in Gyr. Default is 4.0.
min_sigma : float, optional
Minimum age dispersion in Gyr. Default is 1.0.
return_components : bool, optional
Whether to return individual component contributions. Default is False.
Returns
-------
logp : array_like
Total log-prior probability density.
components : dict, optional
Individual component contributions (if return_components=True).
Notes
-----
The Galactic model combines three stellar populations:
1. **Thin Disk**: Young, metal-rich stars with small scale height
2. **Thick Disk**: Intermediate-age, metal-poor stars with larger scale height
3. **Halo**: Old, very metal-poor stars with flattened power-law profile
Each component has distinct spatial, metallicity, and age distributions
calibrated from observations. The model accounts for:
- Coordinate transformations from observer to Galactocentric frame
- Volume correction factors (dV ∝ distance²)
- Component membership probabilities
- Conditional metallicity and age priors
When stellar labels are provided, applies population-specific priors:
- Metallicity: Gaussian distributions with different means/dispersions
- Age: Age-metallicity relation with truncated normal distributions
References
----------
Bland-Hawthorn & Gerhard (2016) - The Galaxy in Context
"""
dists = np.asarray(dists)
# Volume correction factor (dV ∝ r² dr)
vol_factor = 2.0 * np.log(dists + 1e-300)
# Convert to galactocentric cylindrical coordinates
if hasattr(coord, "galactic"):
# coord is already a SkyCoord object
ell = coord.galactic.l.deg
b = coord.galactic.b.deg
else:
# coord is a tuple of (l, b) in degrees
ell, b = coord[0], coord[1]
# Convert to Galactocentric cylindrical coordinates using fast NumPy math
R, Z = galactic_to_galactocentric_cyl(
dists, ell, b, R_solar=R_solar, Z_solar=Z_solar
)
# Fast path: use fused numba kernel when labels are provided and
# return_components is not needed. Eliminates ~15 temporary arrays.
if labels is not None and not return_components and len(dists) > 1000:
has_feh = "feh" in labels.dtype.names
has_loga = "loga" in labels.dtype.names
feh_arr = labels["feh"] if has_feh else np.empty(0)
loga_arr = labels["loga"] if has_loga else np.empty(0)
try:
return _galactic_prior_fused(
dists,
R,
Z,
feh_arr,
loga_arr,
has_feh,
has_loga,
R_solar,
Z_solar,
R_thin,
Z_thin,
Rs_thin,
1.0,
R_thick,
Z_thick,
Rs_thick,
f_thick,
Rs_halo,
eta_halo,
q_halo_ctr,
q_halo_inf,
r_q_halo,
f_halo,
feh_thin,
feh_thin_sigma,
feh_thick,
feh_thick_sigma,
feh_halo,
feh_halo_sigma,
max_age,
min_age,
feh_age_ctr,
feh_age_scale,
nsigma_from_max_age,
max_sigma,
min_sigma,
)
except Exception as e:
warnings.warn(
f"Numba fused galactic prior failed, falling back to numpy: {e}",
RuntimeWarning,
stacklevel=2,
)
# Thin disk component
logp_thin = logn_disk(
R,
Z,
R_solar=R_solar,
Z_solar=Z_solar,
R_scale=R_thin,
Z_scale=Z_thin,
R_smooth=Rs_thin,
)
logp_thin += vol_factor
# Thick disk component
logp_thick = logn_disk(
R,
Z,
R_solar=R_solar,
Z_solar=Z_solar,
R_scale=R_thick,
Z_scale=Z_thick,
R_smooth=Rs_thick,
)
logp_thick += vol_factor + np.log(f_thick)
# Halo component
logp_halo = logn_halo(
R,
Z,
R_solar=R_solar,
Z_solar=Z_solar,
R_smooth=Rs_halo,
eta=eta_halo,
q_ctr=q_halo_ctr,
q_inf=q_halo_inf,
r_q=r_q_halo,
)
logp_halo += vol_factor + np.log(f_halo)
# Combined number density prior
logp = _logsumexp3(logp_thin, logp_thick, logp_halo)
# Component tracking
components = {"number_density": [logp_thin, logp_thick, logp_halo]}
# Apply metallicity and age priors if labels provided
if labels is not None:
# Component membership probabilities (reuse logp_thin/thick/halo)
lnprior_thin = logp_thin - logp
lnprior_thick = logp_thick - logp
lnprior_halo = logp_halo - logp
# Metallicity prior — fused computation to minimize temporaries
if "feh" in labels.dtype.names:
try:
feh = labels["feh"]
# Inline Gaussian logpdf for all 3 components at once
# logp_feh(feh, mu, sigma) = -0.5*((feh-mu)^2/sigma^2 + log(2*pi*sigma^2))
feh_lnp_thin = (
-0.5
* (
(feh - feh_thin) ** 2 / feh_thin_sigma**2
+ np.log(2.0 * np.pi * feh_thin_sigma**2)
)
+ lnprior_thin
)
feh_lnp_thick = (
-0.5
* (
(feh - feh_thick) ** 2 / feh_thick_sigma**2
+ np.log(2.0 * np.pi * feh_thick_sigma**2)
)
+ lnprior_thick
)
feh_lnp_halo = (
-0.5
* (
(feh - feh_halo) ** 2 / feh_halo_sigma**2
+ np.log(2.0 * np.pi * feh_halo_sigma**2)
)
+ lnprior_halo
)
feh_lnp = _logsumexp3(feh_lnp_thin, feh_lnp_thick, feh_lnp_halo)
logp += feh_lnp
components["feh"] = [feh_lnp_thin, feh_lnp_thick, feh_lnp_halo]
except (KeyError, IndexError, ValueError) as e:
warnings.warn(
f"Metallicity prior computation failed: {e}",
RuntimeWarning,
stacklevel=2,
)
# Age prior — fused computation
if "loga" in labels.dtype.names:
try:
age = 10.0 ** labels["loga"] / 1e9 # Convert log(age) to Gyr
# Compute age priors for all 3 components with shared parameters.
# logp_age_from_feh computes a truncated normal; inline the core
# computation to avoid 3 separate function calls + temporaries.
age_params = []
for fm, lnp_comp in [
(feh_thin, lnprior_thin),
(feh_thick, lnprior_thick),
(feh_halo, lnprior_halo),
]:
age_mean = (max_age - min_age) / (
1.0 + np.exp((fm - feh_age_ctr) / feh_age_scale)
) + min_age
age_sigma = np.clip(
(max_age - age_mean) / nsigma_from_max_age,
min_sigma,
max_sigma,
)
age_params.append(
truncnorm_logpdf(
age,
(min_age - age_mean) / age_sigma,
(max_age - age_mean) / age_sigma,
loc=age_mean,
scale=age_sigma,
)
+ lnp_comp
)
age_lnp_thin, age_lnp_thick, age_lnp_halo = age_params
age_lnp = _logsumexp3(age_lnp_thin, age_lnp_thick, age_lnp_halo)
logp += age_lnp
components["age"] = [age_lnp_thin, age_lnp_thick, age_lnp_halo]
except (KeyError, IndexError, ValueError) as e:
warnings.warn(
f"Age prior computation failed: {e}",
RuntimeWarning,
stacklevel=2,
)
if return_components:
return logp, components
else:
return logp