Source code for brutus.analysis.individual

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Individual star analysis using grid-based Bayesian inference.

This module provides the BruteForce class for fitting stellar parameters,
distances, and extinction using pre-computed model grids. It performs
brute-force Bayesian inference over the entire grid to derive posterior
distributions for stellar properties.

The fitting procedure uses gradient-based optimization to find the maximum
likelihood extinction and distance for each grid point, then computes
Bayesian posteriors incorporating priors on stellar parameters, Galactic
structure, dust maps, and astrometry.

Classes
-------
BruteForce : Grid-based stellar parameter estimation
    Performs Bayesian inference over pre-computed stellar model grids
    to estimate stellar parameters, distances, and extinction for
    individual stars. Provides methods for computing log-likelihoods
    and log-posteriors over the grid.

Functions
---------
_optimize_fit_mag : Optimize extinction in magnitude space
_optimize_fit_flux : Optimize extinction in flux space
_get_sed_mle : Compute maximum likelihood SED parameters

See Also
--------
brutus.core.StarGrid : Pre-computed stellar model grids
brutus.priors : Prior probability functions
brutus.utils.photometry : Photometry utilities

Notes
-----
The module uses numba JIT compilation for performance-critical functions.
The fitting algorithm alternates between optimizing in magnitude and flux
space for numerical stability.

Examples
--------
Basic usage with grid-based fitting:

>>> from brutus.data import load_models
>>> from brutus.core import StarGrid
>>> from brutus.analysis import BruteForce
>>>
>>> # Load pre-computed grid
>>> models, labels, params = load_models('grid_mist_v9.h5')
>>> grid = StarGrid(models, labels, params)
>>>
>>> # Initialize fitter
>>> fitter = BruteForce(grid)
>>>
>>> # Fit photometry with parallax
>>> results = fitter.fit(
...     phot_data, phot_err, phot_mask,
...     data_labels, save_file='results.h5',
...     parallax=parallax, parallax_err=parallax_err,
...     data_coords=coords
... )
"""

import sys
import time
import warnings
from functools import partial
from math import log

import h5py
import numpy as np
from numba import jit
from scipy.special import logsumexp

# Import StarGrid and SED utilities
from ..core import StarGrid
from ..core.sed_utils import _get_seds
from ..priors.astrometric import logp_parallax, logp_parallax_scale
from ..priors.extinction import logp_extinction
from ..priors.galactic import logp_galactic_structure

# Import refactored prior functions
from ..priors.stellar import logp_imf, logp_ps1_luminosity_function

# Import utility functions
from ..utils.math import inverse3 as _inverse3
from ..utils.photometry import magnitude
from ..utils.sampling import sample_multivariate_normal

__all__ = ["BruteForce"]


# ============================================================================
# Numerical constants
# ============================================================================

# Proxy for log(0) or negative infinity in log-probability calculations.
# Used instead of -np.inf to avoid numerical issues with arithmetic operations.
LOG_ZERO = -1e300

# Minimum allowed scale factor to prevent division by zero or log(0).
MIN_SCALE = 1e-20


# ============================================================================
# Grid-based optimization functions
# ============================================================================


@jit(nopython=True, cache=True)
def _optimize_fit_mag(
    data,
    tot_var,
    models,
    rvecs,
    drvecs,
    av,
    rv,
    mag_coeffs,
    resid,
    stepsize,
    mags,
    mags_var,
    avlim=(0.0, 20.0),
    av_gauss=(0.0, 1e6),
    rvlim=(1.0, 8.0),
    rv_gauss=(3.32, 0.18),
    tol=0.05,
    init_thresh=5e-3,
):
    """
    Optimize the distance and reddening between the models and the data using
    the gradient in **magnitudes**. This executes multiple `(Av, Rv)` updates.

    Parameters
    ----------
    data : `~numpy.ndarray` of shape `(Nfilt)`
        Observed data values.

    tot_var : `~numpy.ndarray` of shape `(Nfilt,)`
        Associated (Normal) variance on the observed flux values.
        1D since the variance is a property of the data, not the models.

    models : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Model predictions.

    rvecs : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Associated model reddening vectors.

    drvecs : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Associated differential model reddening vectors.

    av : `~numpy.ndarray` of shape `(Nmodel,)`
        Av values of the models.

    rv : `~numpy.ndarray` of shape `(Nmodel,)`
        Rv values of the models.

    mag_coeffs : `~numpy.ndarray` of shape `(Nmodel, Nfilt, 3)`
        Magnitude coefficients.

    resid : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Initial residuals.

    stepsize : `~numpy.ndarray` of shape `(Nmodel)`
        Gradient stepsize.

    mags : `~numpy.ndarray` of shape `(Nfilt)`
        Data in magnitudes.

    mags_var : `~numpy.ndarray` of shape `(Nfilt,)`
        Data variance in magnitudes.

    avlim : 2-tuple, optional
        Bounds on A(V). Default is `(0., 20.)`.

    av_gauss : 2-tuple, optional
        The mean and standard deviation of a Gaussian prior on A(V).
        Default is `(0., 1e6)` (i.e. flat).

    rvlim : 2-tuple, optional
        Bounds on R(V). Default is `(1., 8.)`.

    rv_gauss : 2-tuple, optional
        The mean and standard deviation of a Gaussian prior on R(V).
        Default is `(3.32, 0.18)`.

    tol : float, optional
        The fractional tolerance to convergence. Default is `0.05`.

    init_thresh : float, optional
        The initial fractional tolerance to convergence. Default is `5e-3`.

    Returns
    -------
    models : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Model flux densities for each model and filter.

    rvecs : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Reddening vectors at a given Av and Rv for all models.

    drvecs : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Differential reddening vectors at a given Av and Rv for all models.

    scale : `~numpy.ndarray` of shape `(Nmodel)`
        Scale-factors (related to distance as s = 1/d^2).

    av : `~numpy.ndarray` of shape `(Nmodel)`
        Optimized A(V) values.

    rv : `~numpy.ndarray` of shape `(Nmodel)`
        Optimized R(V) values.

    icov_sar : `~numpy.ndarray` of shape `(Nmodel, 3, 3)`
        Inverse covariance matrices over `(scale, av, rv)`.

    resid : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Final residuals between data and optimized models.

    See Also
    --------
    _optimize_fit_flux : Flux-space optimization (single iteration)
    _get_sed_mle : MLE computation for SED parameters

    Notes
    -----
    This function performs iterative optimization in magnitude space by
    solving a linear system for (scale, Av, Rv) at each iteration. The
    magnitude-space formulation is numerically stable but requires
    multiple iterations to converge.

    Convergence is determined by monitoring the fractional change in
    Av and Rv for all well-fitting models (within init_thresh of the
    best fit).

    The optimization alternately solves for Av (at fixed Rv) and Rv
    (at fixed Av) to incorporate priors and bounds on each parameter
    independently.
    """
    Nmodel, Nfilt = models.shape

    avmin, avmax = avlim
    rvmin, rvmax = rvlim

    Av_mean, Av_std = av_gauss
    Rv_mean, Rv_std = rv_gauss
    Av_varinv, Rv_varinv = 1.0 / Av_std**2, 1.0 / Rv_std**2

    log_init_thresh = log(init_thresh)

    # Precompute reciprocal of magnitude variance (1D, same for all models).
    inv_mags_var = 1.0 / mags_var

    # In magnitude space, we can solve a linear system
    # explicitly for `(s_ML, Av_ML, r_ML=Av_ML*Rv_ML)`. We opt to
    # solve for Av and Rv in turn to so we can impose priors and bounds
    # on both quantities as well as additional regularization.

    # Compute constants.
    s_den, rp_den = np.zeros(Nmodel), np.zeros(Nmodel)
    srp_mix = np.zeros(Nmodel)
    for i in range(Nmodel):
        for j in range(Nfilt):
            s_den[i] += inv_mags_var[j]
            rp_den[i] += drvecs[i][j] * drvecs[i][j] * inv_mags_var[j]
            srp_mix[i] += drvecs[i][j] * inv_mags_var[j]

    # Main loop.
    # Uses loop fusion to reduce from 5 to 3 inner loops per model,
    # cutting memory traffic by ~40% (verified numerically identical).
    logwt = np.zeros(Nmodel)
    dav, drv = np.zeros(Nmodel), np.zeros(Nmodel)
    while True:
        for i in range(Nmodel):
            # --- Fused pass 1: Av solve ---
            # Read resid[i][j] and rvecs[i][j] once to compute all Av terms.
            a_den_i = Av_varinv
            sa_mix_i = 0.0
            resid_s_i = 0.0
            resid_a_i = (Av_mean - av[i]) * Av_varinv
            for j in range(Nfilt):
                iv = inv_mags_var[j]
                rv_j = rvecs[i][j]
                res_j = resid[i][j]
                riv = rv_j * iv
                a_den_i += rv_j * riv
                sa_mix_i += riv
                resid_s_i += res_j * iv
                resid_a_i += res_j * riv
            # Compute ML solution for Delta_Av.
            sa_idet = 1.0 / (s_den[i] * a_den_i - sa_mix_i * sa_mix_i)
            dav[i] = sa_idet * (s_den[i] * resid_a_i - sa_mix_i * resid_s_i)
            dav[i] = dav[i] * stepsize[i]

            # Prevent Av from sliding off the provided bounds.
            if dav[i] < avmin - av[i]:
                dav[i] = avmin - av[i]
            if dav[i] > avmax - av[i]:
                dav[i] = avmax - av[i]

            # Increment to new Av.
            av[i] = av[i] + dav[i]

            # --- Fused pass 2: Av residual update + Rv solve ---
            # Update residuals from Av change and immediately compute Rv terms.
            resid_s_i = 0.0
            resid_r_i = 0.0
            for j in range(Nfilt):
                resid[i][j] = resid[i][j] - dav[i] * rvecs[i][j]
                iv = inv_mags_var[j]
                resid_s_i += resid[i][j] * iv
                resid_r_i += resid[i][j] * drvecs[i][j] * iv
            # Derive Rv partial derivatives.
            r_den_i = rp_den[i] * av[i] * av[i] + Rv_varinv
            sr_mix_i = srp_mix[i] * av[i]
            resid_r_i = resid_r_i * av[i] + (Rv_mean - rv[i]) * Rv_varinv
            # Compute ML solution for Delta_Rv.
            sr_idet = 1.0 / (s_den[i] * r_den_i - sr_mix_i * sr_mix_i)
            drv[i] = sr_idet * (s_den[i] * resid_r_i - sr_mix_i * resid_s_i)
            drv[i] = drv[i] * stepsize[i]

            # Prevent Rv from sliding off the provided bounds.
            if drv[i] < rvmin - rv[i]:
                drv[i] = rvmin - rv[i]
            if drv[i] > rvmax - rv[i]:
                drv[i] = rvmax - rv[i]

            # Increment to new Rv.
            rv[i] = rv[i] + drv[i]

            # --- Fused pass 3: Rv residual/rvecs update + chi2 ---
            av_drv = av[i] * drv[i]
            chi2_i = 0.0
            for j in range(Nfilt):
                resid[i][j] = resid[i][j] - av_drv * drvecs[i][j]
                rvecs[i][j] = rvecs[i][j] + drv[i] * drvecs[i][j]
                chi2_i += resid[i][j] * resid[i][j] * inv_mags_var[j]
            logwt[i] = -0.5 * chi2_i

        # Find current best-fit model.
        max_logwt = -1e300
        for i in range(Nmodel):
            if logwt[i] > max_logwt:
                max_logwt = logwt[i]

        # Find relative tolerance (error) to determine convergance.
        err = -1e300
        for i in range(Nmodel):
            # Only include models that are "reasonably good" fits.
            if logwt[i] > max_logwt + log_init_thresh:
                dav_err, drv_err = abs(dav[i]), abs(drv[i])
                if dav_err > err:
                    err = dav_err
                if drv_err > err:
                    err = drv_err

        # Check convergence.
        if err < tol:
            break

    # Get MLE models and associated quantities.
    (models, rvecs, drvecs, scale, icov_sar, resid) = _get_sed_mle(
        data, tot_var, resid, mag_coeffs, av, rv, av_gauss=av_gauss, rv_gauss=rv_gauss
    )

    return models, rvecs, drvecs, scale, av, rv, icov_sar, resid


@jit(nopython=True, cache=True)
def _optimize_fit_flux(
    data,
    tot_var,
    models,
    rvecs,
    drvecs,
    av,
    rv,
    mag_coeffs,
    resid,
    stepsize,
    avlim=(0.0, 20.0),
    av_gauss=(0.0, 1e6),
    rvlim=(1.0, 8.0),
    rv_gauss=(3.32, 0.18),
):
    """
    Optimize distance and reddening using flux densities gradient (single update).

    This executes **only one** `(Av, Rv)` update using the gradient
    in **flux densities**.

    Parameters
    ----------
    data : `~numpy.ndarray` of shape `(Nfilt)`
        Observed data values.

    tot_var : `~numpy.ndarray` of shape `(Nfilt,)`
        Associated (Normal) variance on the observed flux values.
        1D since the variance is a property of the data, not the models.

    models : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Model predictions.

    rvecs : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Associated model reddening vectors.

    drvecs : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Associated differential model reddening vectors.

    av : `~numpy.ndarray` of shape `(Nmodel,)`
        Av values of the models.

    rv : `~numpy.ndarray` of shape `(Nmodel,)`
        Rv values of the models.

    mag_coeffs : `~numpy.ndarray` of shape `(Nmodel, Nfilt, 3)`
        Magnitude coefficients used to compute reddened photometry for a given
        model.

    resid : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Residuals between the data and models.

    stepsize : `~numpy.ndarray`
        The stepsize (in units of the computed gradient).

    avlim : 2-tuple, optional
        The lower and upper bound where the reddened photometry is reliable.
        Default is `(0., 20.)`.

    av_gauss : 2-tuple, optional
        The mean and standard deviation of the Gaussian prior that is placed
        on A(V). The default is `(0., 1e6)`, which is designed to be
        essentially flat over `avlim`.

    rvlim : 2-tuple, optional
        The lower and upper bound where the reddening vector shape changes
        are reliable. Default is `(1., 8.)`.

    rv_gauss : 2-tuple, optional
        The mean and standard deviation of the Gaussian prior that is placed
        on R(V). The default is `(3.32, 0.18)` based on the results from
        Schlafly et al. (2016).

    Returns
    -------
    models_new : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        New model predictions. Always returned in flux densities.

    rvecs_new : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        New reddening vectors. Always returned in flux densities.

    drvecs_new : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        New differential reddening vectors. Always returned in flux densities.

    scale : `~numpy.ndarray` of shape `(Nmodel)`, optional
        The best-fit scale factor.

    Av : `~numpy.ndarray` of shape `(Nmodel)`, optional
        The best-fit reddening.

    Rv : `~numpy.ndarray` of shape `(Nmodel)`, optional
        The best-fit reddening shapes.

    icov_sar : `~numpy.ndarray` of shape `(Nmodel, 3, 3)`, optional
        The precision (inverse covariance) matrices expanded around
        `(s_ML, Av_ML, Rv_ML)`.

    resid : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Residuals between the data and models.

    See Also
    --------
    _optimize_fit_mag : Magnitude-space optimization (iterative)
    _get_sed_mle : MLE computation for SED parameters

    Notes
    -----
    This function performs a single update step in flux space using
    gradient descent. It is called iteratively by the main fitting
    routine until convergence.

    Unlike magnitude-space fitting, flux-space optimization uses a
    Taylor expansion which can be less numerically stable for large
    extinctions but is faster per iteration.
    """
    Nmodel, Nfilt = models.shape

    avmin, avmax = avlim
    rvmin, rvmax = rvlim

    Av_mean, Av_std = av_gauss
    Rv_mean, Rv_std = rv_gauss
    Av_varinv, Rv_varinv = 1.0 / Av_std**2, 1.0 / Rv_std**2

    # Precompute reciprocal of flux variance (1D, same for all models).
    inv_tot_var = 1.0 / tot_var

    # In flux density space, we can solve the linear system
    # implicitly for `(s_ML, Av_ML, Rv_ML)`. However, the solution
    # is not necessarily as numerically stable as one might hope
    # due to the nature of our Taylor expansion in flux.
    # Instead, it is easier to iterate in `(dAv, dRv)` from
    # a good guess for `(s_ML, Av_ML, Rv_ML)`. We opt to solve both
    # independently at fixed `(Av, Rv)` to avoid recomputing models.

    # Fused: compute Av and Rv sums in a single pass over filters,
    # since both read the same resid[i][j], rvecs[i][j], drvecs[i][j].
    for i in range(Nmodel):
        a_num_i = (Av_mean - av[i]) * Av_varinv
        a_den_i = Av_varinv
        r_num_i = (Rv_mean - rv[i]) * Rv_varinv
        r_den_i = Rv_varinv
        for j in range(Nfilt):
            iv = inv_tot_var[j]
            rv_j = rvecs[i][j]
            drv_j = drvecs[i][j]
            res_j = resid[i][j]
            a_num_i += rv_j * res_j * iv
            a_den_i += rv_j * rv_j * iv
            r_num_i += drv_j * res_j * iv
            r_den_i += drv_j * drv_j * iv
        dav_i = a_num_i / a_den_i * stepsize[i]
        drv_i = r_num_i / r_den_i * stepsize[i]

        # Prevent Av from sliding off the provided bounds.
        if dav_i < avmin - av[i]:
            dav_i = avmin - av[i]
        if dav_i > avmax - av[i]:
            dav_i = avmax - av[i]

        # Increment to new Av.
        av[i] += dav_i

        # Prevent Rv from sliding off the provided bounds.
        if drv_i < rvmin - rv[i]:
            drv_i = rvmin - rv[i]
        if drv_i > rvmax - rv[i]:
            drv_i = rvmax - rv[i]

        # Increment to new Rv.
        rv[i] += drv_i

    # Get MLE models and associated quantities.
    (models, rvecs, drvecs, scale, icov_sar, resid) = _get_sed_mle(
        data, tot_var, resid, mag_coeffs, av, rv, av_gauss=av_gauss, rv_gauss=rv_gauss
    )

    return models, rvecs, drvecs, scale, av, rv, icov_sar, resid


@jit(nopython=True, cache=True)
def _get_sed_mle(
    data, tot_var, resid, mag_coeffs, av, rv, av_gauss=(0.0, 1e6), rv_gauss=(3.32, 0.18)
):
    """
    Recompute model SEDs, derive the MLE scale factor, compute residuals,
    and build the Fisher information (precision) matrix.

    Given current `(Av, Rv)` values for each model, this function
    regenerates the reddened SEDs from the magnitude polynomial
    coefficients, computes the maximum-likelihood scale factor
    (equivalently distance) for each model, forms the residuals
    between the scaled model and data, and assembles the 3x3 Fisher
    information (precision) matrix in `(scale, Av, Rv)` space.

    Parameters
    ----------
    data : `~numpy.ndarray` of shape `(Nfilt)`
        Observed data values.

    tot_var : `~numpy.ndarray` of shape `(Nfilt,)`
        Associated (Normal) variance on the observed flux values.
        1D since the variance is a property of the data, not the models.

    resid : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Residuals between the data and models.

    mag_coeffs : `~numpy.ndarray` of shape `(Nmodel, Nfilt, 3)`
        Magnitude coefficients used to compute reddened photometry for a given
        model.

    av : `~numpy.ndarray` of shape `(Nmodel,)`
        Av values of the models.

    rv : `~numpy.ndarray` of shape `(Nmodel,)`
        Rv values of the models.

    av_gauss : 2-tuple, optional
        The mean and standard deviation of the Gaussian prior that is placed
        on A(V). The default is `(0., 1e6)`, which is designed to be
        essentially flat over `avlim`.

    rv_gauss : 2-tuple, optional
        The mean and standard deviation of the Gaussian prior that is placed
        on R(V). The default is `(3.32, 0.18)` based on the results from
        Schlafly et al. (2016).

    Returns
    -------
    models_new : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        New model predictions. Always returned in flux densities.

    rvecs_new : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        New reddening vectors. Always returned in flux densities.

    drvecs_new : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        New differential reddening vectors. Always returned in flux densities.

    scale : `~numpy.ndarray` of shape `(Nmodel)`, optional
        The best-fit scale factor.

    icov_sar : `~numpy.ndarray` of shape `(Nmodel, 3, 3)`, optional
        The precision (inverse covariance) matrices expanded around
        `(s_ML, Av_ML, Rv_ML)`.

    resid : `~numpy.ndarray` of shape `(Nmodel, Nfilt)`
        Residuals between the data and models.

    """
    Av_mean, Av_std = av_gauss
    Rv_mean, Rv_std = rv_gauss

    # Recompute models with new Rv.
    models, rvecs, drvecs = _get_seds(mag_coeffs, av, rv, return_flux=True)
    Nmodel, Nfilt = models.shape

    # Precompute reciprocal of flux variance (1D, same for all models).
    inv_tot_var = 1.0 / tot_var

    # Derive scale-factors (`scale`) between data and models.
    s_num, s_den, scale = np.zeros(Nmodel), np.zeros(Nmodel), np.zeros(Nmodel)
    for i in range(Nmodel):
        for j in range(Nfilt):
            s_num[i] += models[i][j] * data[j] * inv_tot_var[j]
            s_den[i] += models[i][j] * models[i][j] * inv_tot_var[j]
        scale[i] = s_num[i] / s_den[i]  # MLE scalefactor
        if scale[i] <= 1e-20:
            scale[i] = 1e-20

    # Derive reddening terms.
    sr_mix, sa_mix = np.zeros(Nmodel), np.zeros(Nmodel)
    a_den, r_den = np.zeros(Nmodel), np.zeros(Nmodel)
    ar_mix = np.zeros(Nmodel)
    Av_varinv, Rv_varinv = 1.0 / Av_std**2, 1.0 / Rv_std**2
    for i in range(Nmodel):
        for j in range(Nfilt):
            # Compute reddening effect.
            models_int = 10.0 ** (-0.4 * mag_coeffs[i][j][0])
            reddening = models[i][j] - models_int

            # Rescale models.
            models[i][j] = models[i][j] * scale[i]

            # Compute residuals.
            resid[i][j] = data[j] - models[i][j]

            # Derive scale cross-terms.
            # Note: drvecs = dR/dRv * fac * flux, so the true
            # ∂f/∂Rv = Av * drvecs. We must include Av for correct
            # Fisher information in R(V)-related terms.
            sr_mix[i] += models[i][j] * drvecs[i][j] * av[i] * inv_tot_var[j]
            sa_mix[i] += models[i][j] * rvecs[i][j] * inv_tot_var[j]

            # Rescale reddening quantities.
            rvecs[i][j] = rvecs[i][j] * scale[i]
            drvecs[i][j] = drvecs[i][j] * scale[i]
            reddening *= scale[i]

            # Derive reddening (cross-)terms
            # ar_mix: Gauss-Newton cross-term (∂f/∂Av)(∂f/∂Rv) / var
            # where ∂f/∂Av = rvecs (already scaled) and
            # ∂f/∂Rv = Av * drvecs (already scaled).
            ar_mix[i] += rvecs[i][j] * drvecs[i][j] * av[i] * inv_tot_var[j]
            a_den[i] += rvecs[i][j] * rvecs[i][j] * inv_tot_var[j]
            r_den[i] += (drvecs[i][j] * av[i]) ** 2 * inv_tot_var[j]

        # Add in priors.
        a_den[i] += Av_varinv
        r_den[i] += Rv_varinv

    # Construct precision matrices (inverse covariances).
    icov_sar = np.zeros((Nmodel, 3, 3))
    for i in range(Nmodel):
        icov_sar[i][0][0] = s_den[i]  # scale
        icov_sar[i][1][1] = a_den[i]  # Av
        icov_sar[i][2][2] = r_den[i]  # Rv
        icov_sar[i][0][1] = sa_mix[i]  # scale-Av cross-term
        icov_sar[i][1][0] = sa_mix[i]  # scale-Av cross-term
        icov_sar[i][0][2] = sr_mix[i]  # scale-Rv cross-term
        icov_sar[i][2][0] = sr_mix[i]  # scale-Rv cross-term
        icov_sar[i][1][2] = ar_mix[i]  # Av-Rv cross-term
        icov_sar[i][2][1] = ar_mix[i]  # Av-Rv cross-term

    return models, rvecs, drvecs, scale, icov_sar, resid


[docs] class BruteForce: """ Bayesian parameter estimation for individual stars using grid-based models. This class performs brute-force fitting over a pre-computed stellar model grid to estimate stellar parameters, distances, and extinction. It uses the StarGrid infrastructure for model management and applies Bayesian priors for robust inference. Parameters ---------- star_grid : StarGrid Pre-loaded stellar model grid for SED generation. verbose : bool, optional Whether to print initialization information. Default is True. Attributes ---------- star_grid : StarGrid The underlying stellar model grid. models : numpy.ndarray Magnitude coefficients from the grid. models_labels : structured numpy.ndarray Labels for each model in the grid. labels_mask : dict Mask indicating which labels are grid parameters (True) vs predictions (False). See Also -------- brutus.core.StarGrid : Stellar model grid infrastructure brutus.data.load_models : Load pre-computed grids loglike_grid : Compute log-likelihoods logpost_grid : Compute log-posteriors Notes ----- The BruteForce fitter uses a two-stage approach: 1. **Likelihood computation** (`loglike_grid`): Optimizes distance and extinction for each grid point to find maximum likelihood 2. **Posterior computation** (`logpost_grid`): Integrates over distance and extinction uncertainty using Monte Carlo, applying priors for Galactic structure, dust maps, and astrometry The fitter automatically handles: - Grid parameter vs. prediction distinction - Age weighting for proper sampling - Grid spacing corrections - Parallax constraints - Galactic structure priors - Dust map priors Examples -------- Basic usage with a pre-loaded grid: >>> from brutus.core import StarGrid >>> from brutus.analysis.individual import BruteForce >>> from brutus.data import load_models >>> >>> # Load grid >>> models, labels, params = load_models('grid_mist_v9.h5') >>> grid = StarGrid(models, labels, params) >>> >>> # Initialize fitter >>> fitter = BruteForce(grid) >>> >>> # Fit data >>> results = fitter.fit( ... data, data_err, data_mask, ... data_labels, save_file='results.h5', ... parallax=parallax_data, ... data_coords=coordinates ... ) """
[docs] def __init__(self, star_grid, verbose=True): """Initialize BruteForce with a StarGrid instance.""" if not isinstance(star_grid, StarGrid): raise TypeError("star_grid must be a StarGrid instance") self.star_grid = star_grid self.models = star_grid.models self.models_labels = star_grid.labels # Generate labels mask automatically self.labels_mask = self._generate_labels_mask() if verbose: n_grid = sum(1 for m in self.labels_mask.values() if m) n_pred = sum(1 for m in self.labels_mask.values() if not m) grid_params = [lbl for lbl, m in self.labels_mask.items() if m] pred_params = [lbl for lbl, m in self.labels_mask.items() if not m] print(f"BruteForce initialized with {self.nmodels:,} models") print(f" Grid parameters ({n_grid}): {', '.join(grid_params)}") if n_pred > 0: preview = ", ".join(pred_params[:3]) if len(pred_params) > 3: preview += f", ... ({len(pred_params)-3} more)" print(f" Predictions ({n_pred}): {preview}")
@property def nmodels(self): """Number of models in the grid.""" return self.models.shape[0] @property def nfilters(self): """Number of filters in the grid.""" return self.models.shape[1] def _generate_labels_mask(self): """ Generate labels mask from StarGrid structure. Creates a dictionary mapping label names to boolean values indicating whether each label is a grid parameter (True) or a derived prediction (False). Grid parameters are the dimensions used to construct the grid, while predictions are stellar properties interpolated from the grid. Returns ------- labels_mask : dict Dictionary where keys are label names and values are True for grid parameters (e.g., mini, eep, feh) or False for predictions (e.g., loga, logt, logg). Notes ----- This distinction is important for applying grid spacing corrections (only to grid parameters) and for understanding which parameters define the grid structure vs. which are interpolated outputs. """ labels_mask = {} # Grid parameters (used to compute the grid) for label in self.star_grid.label_names: labels_mask[label] = True # Predictions (derived from grid) if self.star_grid.param_names: for param in self.star_grid.param_names: labels_mask[param] = False return labels_mask
[docs] def get_sed_grid(self, indices=None, av=None, rv=None, return_flux=False): r""" Compute SEDs for multiple grid points simultaneously. This is the grid-based batch computation method, distinct from StarGrid.get_seds() which handles single star synthesis. Parameters ---------- indices : array-like, optional Grid indices to compute SEDs for. If None, uses all models. av : array-like or float, optional A(V) values for each model. If None, defaults to 0. rv : array-like or float, optional R(V) values for each model. If None, defaults to 3.3. return_flux : bool, optional If True, return fluxes instead of magnitudes. Default is False. Returns ------- seds : numpy.ndarray of shape (Nmodels, Nbands) Computed SEDs. rvecs : numpy.ndarray of shape (Nmodels, Nbands) Reddening vectors. drvecs : numpy.ndarray of shape (Nmodels, Nbands) Differential reddening vectors with respect to Rv. See Also -------- brutus.core.sed_utils._get_seds : Underlying SED computation StarGrid.get_seds : Single star SED generation Notes ----- This method performs batch SED computation for multiple grid points simultaneously, which is more efficient than calling `StarGrid.get_seds()` repeatedly. The scale factor relates to distance as :math:`s = 1/d^2` where d is distance in parsecs. The reddening is applied as: .. math:: m(\\lambda) = m_0(\\lambda) + A_V \\cdot [r_0(\\lambda) + R_V \\cdot dr(\\lambda)] where :math:`r_0` and :math:`dr` are the reddening vector components. """ if indices is not None: mag_coeffs = self.models[indices] else: mag_coeffs = self.models # Ensure av and rv are arrays n_models = len(mag_coeffs) if av is None: av = np.zeros(n_models) elif np.isscalar(av): av = np.full(n_models, av) else: av = np.asarray(av) if rv is None: rv = np.full(n_models, 3.3) elif np.isscalar(rv): rv = np.full(n_models, rv) else: rv = np.asarray(rv) return _get_seds(mag_coeffs, av, rv, return_flux=return_flux)
def _setup( self, data, data_err, data_mask, data_labels=None, phot_offsets=None, parallax=None, parallax_err=None, av_gauss=None, lnprior=None, wt_thresh=1e-3, cdf_thresh=2e-3, apply_agewt=True, apply_grad=True, lngalprior=None, lndustprior=None, dustfile=None, data_coords=None, ltol_subthresh=1e-2, logl_initthresh=5e-3, mag_max=50.0, merr_max=0.25, rstate=None, R_solar=8.2, Z_solar=0.025, ): """ Pre-process data and initialize priors for fitting. This internal method prepares photometric data for fitting by applying quality cuts, photometric offsets, and initializing appropriate prior distributions. Parameters ---------- data : numpy.ndarray Photometric flux densities data_err : numpy.ndarray Photometric errors data_mask : numpy.ndarray Initial data quality mask apply_agewt : bool Whether to apply age weighting to priors apply_grad : bool Whether to apply grid spacing corrections Other parameters See fit() method for full parameter descriptions Returns ------- tuple Processed (data, data_err, data_mask, lnprior, lngalprior, lndustprior) Notes ----- This method performs several important setup tasks: 1. Applies photometric offsets if provided 2. Filters data based on magnitude and error limits 3. Initializes stellar priors (IMF or luminosity function) 4. Applies age gradient weighting for proper sampling 5. Applies grid spacing corrections 6. Sets up Galactic structure and dust priors """ # Apply photometric offsets if provided if phot_offsets is not None: data = data * phot_offsets data_err = data_err * phot_offsets # Apply magnitude cuts with warnings.catch_warnings(): warnings.simplefilter("ignore") mags, merr = magnitude(data, data_err) # Mask bad data mask_update = (mags < mag_max) & (merr < merr_max) data_mask = data_mask & mask_update # Initialize prior if not provided if lnprior is None: # Check for initial mass prior if "mini" in self.models_labels.dtype.names: lnprior = logp_imf(self.models_labels["mini"]) # Check for luminosity function prior elif "Mr" in self.models_labels.dtype.names: lnprior = logp_ps1_luminosity_function(self.models_labels["Mr"]) else: lnprior = np.zeros(self.nmodels) # Apply age weighting if requested if apply_agewt: try: lnprior += np.log(np.abs(self.models_labels["agewt"])) except (KeyError, ValueError): pass # Reweight based on grid spacing if apply_grad: for lbl in self.models_labels.dtype.names: label = self.models_labels[lbl] if self.labels_mask[lbl]: # Only for grid parameters ulabel = np.unique(label) if len(ulabel) > 1: # Compute and add gradient lngrad_label = np.log(np.gradient(ulabel)) lnprior += np.interp(label, ulabel, lngrad_label) # Initialize Galactic prior if lngalprior is None and data_coords is None: raise ValueError( "`data_coords` must be provided if using the " "default Galactic model prior." ) if lngalprior is None: lngalprior = partial( logp_galactic_structure, R_solar=R_solar, Z_solar=Z_solar ) # Initialize dust prior if lndustprior is None and dustfile is not None: lndustprior = logp_extinction return (data, data_err, data_mask, lnprior, lngalprior, lndustprior)
[docs] def loglike_grid( self, data, data_err, data_mask, avlim=(0.0, 20.0), av_gauss=(0.0, 1e6), rvlim=(1.0, 8.0), rv_gauss=(3.32, 0.18), av_init=None, rv_init=None, dim_prior=True, ltol=3e-2, ltol_subthresh=1e-2, init_thresh=5e-3, parallax=None, parallax_err=None, return_vals=False, indices=None, **kwargs, ): """ Compute log-likelihood over the stellar model grid. This is a wrapper around the module-level loglike_grid function that uses the instance's model grid. Parameters ---------- data : `~numpy.ndarray` of shape `(Nfilt)` Measured flux densities. data_err : `~numpy.ndarray` of shape `(Nfilt)` Measurement errors. data_mask : `~numpy.ndarray` of shape `(Nfilt)` Binary mask for valid data. indices : array-like, optional Subset of model indices to use. If None, uses all models. avlim : tuple, optional (min, max) bounds on A(V). Default is (0.0, 20.0). av_gauss : tuple, optional (mean, std) for Gaussian prior on A(V). Default is (0.0, 1e6). rvlim : tuple, optional (min, max) bounds on R(V). Default is (1.0, 8.0). rv_gauss : tuple, optional (mean, std) for Gaussian prior on R(V). Default is (3.32, 0.18). parallax : float, optional Parallax measurement in mas. parallax_err : float, optional Parallax error in mas. return_vals : bool, optional If True, return full results including covariances. Default is False. Other Parameters ---------------- **kwargs Passed to optimization functions. Returns ------- If return_vals=False: lnl : numpy.ndarray Log-likelihoods for each grid point Ndim : int Number of dimensions (filters) chi2 : numpy.ndarray Chi-squared values If return_vals=True: Also includes scale, av, rv, icov_sar arrays See Also -------- logpost_grid : Compute log-posteriors from likelihoods _optimize_fit_mag : Magnitude-space optimization _optimize_fit_flux : Flux-space optimization Notes ----- This method optimizes (distance, Av, Rv) for each grid point by: 1. Initial magnitude-space fit for numerical stability 2. Iterative flux-space refinement until convergence 3. Optional parallax constraint during optimization The optimization uses priors on Av and Rv but not on distance/scale (distance priors are applied in logpost_grid). """ # Select models if indices is not None: mag_coeffs = self.models[indices] else: mag_coeffs = self.models # Implementation of grid-based likelihood computation Nmodels, Nfilt, Ncoef = mag_coeffs.shape # Clean data with warnings.catch_warnings(): warnings.simplefilter("ignore") clean = np.isfinite(data) & np.isfinite(data_err) & (data_err > 0.0) data_mask[~clean] = False Ndim = sum(data_mask) if Ndim == 0: # Fail fast with a clear message. With zero valid bands the # optimizer operates on empty arrays and divides by zero (an opaque # ZeroDivisionError under numba), so this would otherwise abort a # whole batch with an unhelpful error. raise ValueError( "No valid photometric bands for this object (all bands are " "masked or non-finite). Fitting requires at least 4 valid bands " "(3 free parameters scale/A(V)/R(V) plus >=1 DOF); filter such " "objects out of the input before fitting." ) if Ndim < 4: warnings.warn( f"Only {Ndim} valid photometric bands. Minimum 4 recommended " f"for reliable fits (3 free parameters: scale, Av, Rv).", RuntimeWarning, stacklevel=2, ) # Subselect only clean observations flux, fluxerr = data[data_mask], data_err[data_mask] mcoeffs = mag_coeffs[:, data_mask, :] tot_var = np.square(fluxerr) # 1D: (Nfilt,) # Get started by fitting in magnitudes with warnings.catch_warnings(): warnings.simplefilter("ignore") mags = -2.5 * np.log10(flux) mags_var = np.square(2.5 / np.log(10.0)) * tot_var / np.square(flux) mclean = np.isfinite(mags) mags[~mclean], mags_var[~mclean] = 0.0, 1e50 # Set default Gaussian priors if not provided # These defaults are essentially flat over the allowed range. # Default flow: fit(av_gauss=None) -> _fit(av_gauss=(0.0, 1e6)) # -> loglike_grid(av_gauss=(0.0, 1e6)) -> _get_sed_mle internal # default (0.0, 1e6). The large std (1e6) makes the prior # effectively flat over the allowed A(V) range. if av_gauss is None: av_gauss = (0.0, 1e6) if rv_gauss is None: # When the user explicitly passes None, it means "no R(V) prior # in the likelihood computation; priors are applied separately # in logpost_grid." A huge sigma effectively removes the prior. rv_gauss = (3.32, 1e6) # Initialize values if av_init is None: av_init = np.zeros(Nmodels) + av_gauss[0] if rv_init is None: rv_init = np.zeros(Nmodels) + rv_gauss[0] # Compute unreddened photometry models, rvecs, drvecs = _get_seds(mcoeffs, av_init, rv_init, return_flux=False) # Compute initial magnitude fit mtol = 2.5 * ltol resid = mags - models stepsize = np.ones(Nmodels) results = _optimize_fit_mag( flux, tot_var, models, rvecs, drvecs, av_init, rv_init, mcoeffs, resid, stepsize, mags, mags_var, tol=mtol, init_thresh=init_thresh, avlim=avlim, av_gauss=av_gauss, rvlim=rvlim, rv_gauss=rv_gauss, ) models, rvecs, drvecs, scale, av, rv, icov_sar, resid = results if init_thresh is not None: # Cull initial bad fits before moving on chi2 = np.sum(np.square(resid) / tot_var, axis=1) lnl = -0.5 * chi2 # Add parallax to log-likelihood lnl_p = lnl if parallax is not None and parallax_err is not None: if np.isfinite(parallax) and np.isfinite(parallax_err): with warnings.catch_warnings(): warnings.simplefilter("ignore") par = np.sqrt(scale) # sqrt(scale) = 1/d_kpc = parallax_mas chi2_p = (par - parallax) ** 2 / parallax_err**2 lnl_p = lnl - 0.5 * chi2_p # Subselect models using log-likelihood thresholding lnl_sel = lnl_p > np.max(lnl_p) + np.log(init_thresh) init_sel = np.where(lnl_sel)[0] # Subselect models (tot_var is 1D and data-only, no subselection needed) models = models[init_sel] rvecs = rvecs[init_sel] drvecs = drvecs[init_sel] av_new = av[init_sel] rv_new = rv[init_sel] mcoeffs = mcoeffs[init_sel] resid = resid[init_sel] else: # Keep all models init_sel = np.arange(Nmodels) chi2 = np.ones(Nmodels) - LOG_ZERO # Large positive value lnl = np.ones(Nmodels) + LOG_ZERO # Large negative value av_new = np.array(av, order="F") rv_new = np.array(rv, order="F") # Iterate until convergence lnl_old, lerr = LOG_ZERO, -LOG_ZERO # -inf, +inf proxies stepsize, rescaling = np.ones(Nmodels)[init_sel], 1.2 ln_ltol_subthresh = np.log(ltol_subthresh) while lerr > ltol: # Re-compute models results = _optimize_fit_flux( flux, tot_var, models, rvecs, drvecs, av_new, rv_new, mcoeffs, resid, stepsize, avlim=avlim, av_gauss=av_gauss, rvlim=rvlim, rv_gauss=rv_gauss, ) (models, rvecs, drvecs, scale_new, av_new, rv_new, icov_sar_new, resid) = ( results ) # Compute chi2 chi2_new = np.sum(np.square(resid) / tot_var, axis=1) # Compute multivariate normal logpdf lnl_new = -0.5 * chi2_new # Compute stopping criterion lnl_sel = np.where(lnl_new > np.max(lnl_new) + ln_ltol_subthresh)[0] lerr = np.max(np.abs(lnl_new - lnl_old)[lnl_sel]) # Adjust stepsize stepsize[lnl_new < lnl_old] /= rescaling lnl_old = lnl_new # Insert optimized models into initial array of results lnl_new += -0.5 * (Ndim * np.log(2.0 * np.pi) + np.sum(np.log(tot_var))) lnl[init_sel], chi2[init_sel] = lnl_new, chi2_new scale[init_sel], av[init_sel], rv[init_sel] = scale_new, av_new, rv_new icov_sar[init_sel] = icov_sar_new # Apply dimensional prior if dim_prior and Ndim > 0: # Guard against Ndim == 0 (fully-masked object): np.log(0) = -inf # would inject +inf into every model's log-likelihood. For Ndim >= 1 # this is identical to the previous behaviour. lnl -= 0.5 * (3.0 - Ndim) * np.log(Ndim) if return_vals: return lnl, Ndim, chi2, scale, av, rv, icov_sar else: return lnl, Ndim, chi2
[docs] def logpost_grid( self, results, parallax=None, parallax_err=None, coord=None, Nmc_prior=100, lnprior=None, wt_thresh=1e-3, cdf_thresh=2e-3, max_models=50000, precision_shrinkage=0.0, subsample_mode="representative", lngalprior=None, lndustprior=None, dustfile=None, dlabels=None, avlim=(0.0, 20.0), rvlim=(1.0, 8.0), mem_lim=8000.0, rstate=None, apply_av_prior=True, R_solar=8.2, Z_solar=0.025, **kwargs, ): """ Compute log-posterior over the stellar model grid. This is a wrapper around the module-level logpost_grid function. Parameters ---------- results : tuple Results from loglike_grid with return_vals=True. Other parameters are passed to logpost_grid. Returns ------- Results from logpost_grid. """ # Use instance's labels if not provided if dlabels is None: dlabels = self.models_labels # Implementation of grid-based posterior computation # Unpack results (using plural names for consistency with original lnpost) lnlike, Ndim, chi2, scales, avs, rvs, icovs_sar = results Nmodels = len(lnlike) # Initialize random state if rstate is None: rstate = np.random.RandomState() # Apply prior if lnprior is None: lnprior = np.zeros(Nmodels) # Compute initial posterior (without parallax — parallax is applied # exactly on MC samples below, not via the approximate scale-based # form which would double-count the constraint). lnprob_base = lnlike + lnprior # Add parallax prior for MODEL SELECTION only. # logp_parallax_scale provides a Gaussian approximation on scale # to help select relevant models. This is NOT propagated to MC # sample weights where the exact logp_parallax is used instead. lnprob_sel = lnprob_base.copy() if parallax is not None and parallax_err is not None: # Convert parallax to scale (VECTORIZED) scales_err = np.full(Nmodels, 1e10) # Large error = uninformative valid_mask = icovs_sar[:, 0, 0] > 0 scales_err[valid_mask] = 1.0 / np.sqrt(icovs_sar[valid_mask, 0, 0]) lnprob_sel += logp_parallax_scale( scales, scales_err, parallax, parallax_err ) # Select models above threshold (using parallax-informed weights) if wt_thresh is not None: sel = np.where(lnprob_sel > np.max(lnprob_sel) + np.log(wt_thresh))[0] elif cdf_thresh is not None: idx_sort = np.argsort(lnprob_sel)[::-1] cdf = np.cumsum(np.exp(lnprob_sel[idx_sort] - np.max(lnprob_sel))) cdf /= cdf[-1] Nsel = np.searchsorted(cdf, 1.0 - cdf_thresh) + 1 sel = idx_sort[:Nsel] else: sel = np.arange(Nmodels) Nsel = len(sel) # Subsample if Nsel exceeds max_models if max_models is not None and Nsel > max_models: if subsample_mode == "representative": # Gumbel-max trick: add Gumbel noise to log-weights, # take top-k. This gives exact weighted sampling without # replacement in O(N) via argpartition. gumbel_noise = -np.log( -np.log(rstate.uniform(size=Nsel) + 1e-300) + 1e-300 ) perturbed = lnprob_sel[sel] + gumbel_noise top_k = np.argpartition(perturbed, -max_models)[-max_models:] sel = sel[top_k] elif subsample_mode == "topk": # Deterministic: keep highest-weight models top_k = np.argpartition(lnprob_sel[sel], -max_models)[-max_models:] sel = sel[top_k] else: raise ValueError( f"Unknown subsample_mode '{subsample_mode}'. " f"Must be 'representative' or 'topk'." ) Nsel = len(sel) # Select precision matrices for the chosen models. icovs_selected = icovs_sar[sel] # Shape: (Nsel, 3, 3) # Monte Carlo integration over distance and extinction (VECTORIZED). # Floor at 1: a very small mem_lim relative to Nsel can drive the # memory-derived cap to 0, which would otherwise yield empty MC draws, # log(Nmc) = -inf, and a crash in the downstream reduction. Nmc = max(1, min(Nmc_prior, int(mem_lim * 1e6 / (8.0 * Nsel * 4)))) # Transform PRECISION matrix from (scale, Av, Rv) to (ln d, Av, Rv). # This reparameterization avoids the Jacobian bias that arises when # sampling in scale-space but evaluating priors in distance-space. # The transformation uses eta = ln(d) = -0.5*ln(s), so # d(eta)/d(s) = -1/(2s), i.e. J = diag(-1/(2s), 1, 1). # # The PRECISION matrix transforms as: # icov_lnd = J^{-T} * icov_sar * J^{-1} # where J^{-1} = diag(-2s, 1, 1). This gives: # icov_lnd[0,0] = 4*s^2 * icov_sar[0,0] # icov_lnd[0,j] = -2*s * icov_sar[0,j] for j=1,2 # icov_lnd[j,0] = -2*s * icov_sar[j,0] for j=1,2 # icov_lnd[j,k] = icov_sar[j,k] for j,k in {1,2} # # This is well-behaved: for small s (distant/unconstrained stars), # 4*s^2 * icov_sar[0,0] gets SMALLER (less precision in ln d), # which is physically correct. The old approach of transforming the # covariance had cov_lnd[0,0] = cov_sar[0,0] / (4*s^2) which # explodes for small s, causing MC samples to span many orders of # magnitude in distance. s_sel = scales[sel] # Shape: (Nsel,) s_sel = np.maximum(s_sel, MIN_SCALE) # Guard against zero/negative two_s = 2.0 * s_sel four_s2 = 4.0 * s_sel**2 icov_lnd = icovs_selected.copy() icov_lnd[:, 0, 0] = four_s2 * icovs_selected[:, 0, 0] icov_lnd[:, 0, 1] = -two_s * icovs_selected[:, 0, 1] icov_lnd[:, 1, 0] = -two_s * icovs_selected[:, 1, 0] icov_lnd[:, 0, 2] = -two_s * icovs_selected[:, 0, 2] icov_lnd[:, 2, 0] = -two_s * icovs_selected[:, 2, 0] # (1,1), (1,2), (2,1), (2,2) are unchanged # Optional: apply fixed diagonal shrinkage to the precision matrix # before inversion for numerical stability. Small alpha (e.g., 0.05) # regularizes the off-diagonal cross-terms which carry the most # inversion error at high condition numbers. Default: no shrinkage # (alpha=0), relying on the exact eigvalsh regularization in _inverse3. if precision_shrinkage > 0: off_diag_mask = ~np.eye(3, dtype=bool) icov_lnd[:, off_diag_mask] *= 1.0 - precision_shrinkage # Invert precision to get covariance in (ln d, Av, Rv) cov_lnd = _inverse3(icov_lnd, regularize=True) # Also compute cov_sar for backward-compatible output. cov_sar = _inverse3(icovs_selected, regularize=True) # Prepare means in (ln d, Av, Rv) space eta_sel = -0.5 * np.log(s_sel) # ln(d) = -0.5*ln(s) means = np.column_stack([eta_sel, avs[sel], rvs[sel]]) # Shape: (Nsel, 3) # BATCH SAMPLING in (ln d, Av, Rv) space samples_all = sample_multivariate_normal( means, cov_lnd, size=Nmc, rstate=rstate ) # samples_all shape: (3, Nmc, Nsel) # Extract and transform samples (VECTORIZED) eta_samples = samples_all[0] # Shape: (Nmc, Nsel) a_mc = samples_all[1] # Shape: (Nmc, Nsel) r_mc = samples_all[2] # Shape: (Nmc, Nsel) # Convert ln(d) to distance dist_mc = np.exp(eta_samples) dist_mc = np.clip(dist_mc, 0.001, 1e6) a_mc = np.clip(a_mc, avlim[0], avlim[1]) r_mc = np.clip(r_mc, rvlim[0], rvlim[1]) # Initialize log-posterior from base (without parallax scale approx). # The exact parallax likelihood is added below via logp_parallax. lnp_mc = np.tile(lnprob_base[sel], (Nmc, 1)) # Shape: (Nmc, Nsel) # Jacobian correction for ln(d) -> d transformation. # The prior pi_d(d) is defined in distance-space. In ln(d)-space, # the prior becomes pi_eta(eta) = pi_d(d) * |dd/d(eta)| = pi_d(d) * d. # Since logp_galactic_structure already includes the d^2 volume element, # the net Jacobian is just d (not d^3 as it would be in scale-space). lnp_mc += np.log(dist_mc + 1e-300) # Prior evaluations - coordinate is fixed, so we can still optimize if coord is not None: if lngalprior is None: lngalprior = partial( logp_galactic_structure, R_solar=R_solar, Z_solar=Z_solar ) # Galactic prior evaluation (FULLY VECTORIZED) # We have dist_mc shape (Nmc, Nsel) and need to evaluate for each model's labels # Each model has 1 label, each model has Nmc distances # Solution: tile labels to match distances, then evaluate all at once # Flatten all distances: shape (Nmc * Nsel,) dist_flat = dist_mc.ravel() if dlabels is None: # No labels - evaluate once for all distances lnp_gal_flat = lngalprior(dist_flat, coord, labels=None) else: # Create labels array that matches flattened distances (VECTORIZED) # Extract labels for selected models: shape (Nsel,) labels_selected = dlabels[sel] # Use np.tile to repeat the label array Nmc times: shape (Nmc * Nsel,) # This creates: [label0, label1, ..., labelN, label0, label1, ..., labelN, ...] # |---- MC sample 0 ----| |---- MC sample 1 ----| # which matches dist_mc.ravel() layout (Nmc, Nsel) in row-major order. labels_flat = np.tile(labels_selected, Nmc) # Evaluate prior for all distance-label pairs at once lnp_gal_flat = lngalprior(dist_flat, coord, labels=labels_flat) # Reshape back to (Nmc, Nsel) lnp_gal_reshaped = lnp_gal_flat.reshape(Nmc, Nsel) lnp_mc += lnp_gal_reshaped if dustfile is not None: # Load dust map from file path if needed if isinstance(dustfile, str): from ..dust import Bayestar dustfile = Bayestar(dustfile=dustfile) if lndustprior is None: lndustprior = logp_extinction # Dust prior evaluation (VECTORIZED) # For 3D dust maps, pass distances so the prior can interpolate # the extinction profile to each star's distance av_flat = a_mc.ravel() # Shape: (Nmc * Nsel,) dist_flat = dist_mc.ravel() # Shape: (Nmc * Nsel,) lnp_dust_flat = lndustprior(av_flat, dustfile, coord, distance=dist_flat) lnp_dust_reshaped = lnp_dust_flat.reshape(Nmc, Nsel) # Shape: (Nmc, Nsel) lnp_mc += lnp_dust_reshaped # Parallax prior (FULLY VECTORIZED) if parallax is not None and parallax_err is not None: par_mc = 1.0 / dist_mc # Shape: (Nmc, Nsel) lnp_parallax_all = logp_parallax(par_mc, parallax, parallax_err) lnp_mc += lnp_parallax_all # Compute integrated posterior (VECTORIZED) lnp = logsumexp(lnp_mc, axis=0) - np.log(Nmc) # Safety check - replace non-finite values with log(0) proxy lnp_mask = np.where(~np.isfinite(lnp))[0] if len(lnp_mask) > 0: lnp[lnp_mask] = LOG_ZERO # Compute effective sample size (ESS) for each selected model's # MC samples. ESS = 1 / sum(w_i^2) where w_i are normalized weights. # VECTORIZED across all Nsel models. lw_max = np.max(lnp_mc, axis=0, keepdims=True) # (1, Nsel) w = np.exp(lnp_mc - lw_max) # (Nmc, Nsel) w_sum = w.sum(axis=0, keepdims=True) # (1, Nsel) w_normed = w / np.maximum(w_sum, 1e-300) # (Nmc, Nsel) mc_ess = 1.0 / np.sum(w_normed**2, axis=0) # (Nsel,) mc_ess = np.where(w_sum.ravel() > 0, mc_ess, 0.0) return sel, cov_sar, lnp, dist_mc.T, a_mc.T, r_mc.T, lnp_mc.T, mc_ess
[docs] def fit( self, data, data_err, data_mask, data_labels, save_file, phot_offsets=None, parallax=None, parallax_err=None, Nmc_prior=50, avlim=(0.0, 20.0), av_gauss=None, rvlim=(1.0, 8.0), rv_gauss=(3.32, 0.18), lnprior=None, wt_thresh=1e-3, cdf_thresh=2e-3, Ndraws=250, apply_agewt=True, apply_grad=True, lngalprior=None, lndustprior=None, dustfile=None, apply_dlabels=True, data_coords=None, logl_dim_prior=True, ltol=3e-2, ltol_subthresh=1e-2, logl_initthresh=5e-3, mag_max=50.0, merr_max=0.25, rstate=None, save_dar_draws=True, running_io=True, mem_lim=8000.0, max_models=50000, precision_shrinkage=0.0, subsample_mode="representative", verbose=True, R_solar=8.2, Z_solar=0.025, ): """ Fit all input models to the input data to compute log-posteriors. This is the main interface for fitting stellar parameters using grid-based Bayesian inference. Results are saved to an HDF5 file. Parameters ---------- data : numpy.ndarray of shape (Ndata, Nfilt) Observed flux densities for each object. data_err : numpy.ndarray of shape (Ndata, Nfilt) Associated errors on the flux densities. data_mask : numpy.ndarray of shape (Ndata, Nfilt) Binary mask (0/1) indicating whether each measurement is valid. data_labels : numpy.ndarray of shape (Ndata, Nlabels) Labels for each object to be stored in the output file. save_file : str Path to the output HDF5 file. The '.h5' extension will be added if not present. phot_offsets : numpy.ndarray of shape (Nfilt,), optional Multiplicative photometric offsets applied to data and errors. parallax : numpy.ndarray of shape (Ndata,), optional Parallax measurements in mas for each object. parallax_err : numpy.ndarray of shape (Ndata,), optional Errors on parallax measurements. Required if parallax is provided. Nmc_prior : int, optional Number of Monte Carlo samples for prior integration. Default is 50. avlim : tuple, optional (min, max) bounds on A(V). Default is (0.0, 20.0). av_gauss : tuple, optional (mean, std) for Gaussian prior on A(V). If provided, this is used instead of the distance-reddening prior during fitting. rvlim : tuple, optional (min, max) bounds on R(V). Default is (1.0, 8.0). rv_gauss : tuple, optional (mean, std) for Gaussian prior on R(V). Default is (3.32, 0.18) based on Schlafly et al. (2016). lnprior : numpy.ndarray of shape (Nmodel,), optional Log-prior for each model. If not provided, defaults to Kroupa IMF prior on initial mass (for MIST models) or PS1 luminosity function prior (for Bayestar models). wt_thresh : float, optional Threshold `wt_thresh * max(weight)` for model selection. Default is 1e-3. cdf_thresh : float, optional CDF threshold for model selection (used if wt_thresh is None). Default is 2e-3. Ndraws : int, optional Number of posterior draws to save per object. Default is 250. apply_agewt : bool, optional Whether to apply age weighting from MIST models. Default is True. apply_grad : bool, optional Whether to apply grid spacing corrections. Default is True. lngalprior : callable, optional Galactic structure prior function. If not provided, uses the default Galactic model from Green et al. (2014). lndustprior : callable, optional Dust prior function with signature ``f(avs, dustmap, coord, distance=None)``. If not provided and dustfile is given, uses ``logp_extinction``. dustfile : str or `~brutus.dust.Bayestar`, optional 3D dust map for extinction priors. Can be a file path (string) to a Bayestar HDF5 file, which will be loaded automatically, or a pre-loaded ``Bayestar`` object. apply_dlabels : bool, optional Whether to pass model labels to Galactic prior. Default is True. data_coords : numpy.ndarray of shape (Ndata, 2), optional Galactic (l, b) coordinates for each object in degrees. Required if using default Galactic prior. logl_dim_prior : bool, optional Whether to apply dimensional correction to log-likelihood. Default is True. ltol : float, optional Convergence tolerance for likelihood optimization. Default is 3e-2. ltol_subthresh : float, optional Sub-threshold for convergence. Default is 1e-2. logl_initthresh : float, optional Initial likelihood threshold for model culling. Default is 5e-3. mag_max : float, optional Maximum allowed magnitude for valid data. Default is 50.0. merr_max : float, optional Maximum allowed magnitude error. Default is 0.25. rstate : numpy.random.RandomState, optional Random state for reproducibility. save_dar_draws : bool, optional Whether to save distance, A(V), and R(V) draws. Default is True. running_io : bool, optional If True, writes results to disk after each object (safer for long runs). If False, accumulates results in memory and writes at end (faster for slow filesystems). Default is True. mem_lim : float, optional Memory limit in MB for Monte Carlo sampling. Default is 8000.0. verbose : bool, optional Whether to print progress to stderr. Default is True. max_models : int, optional When the number of models selected for an object exceeds this, the models are subsampled (see ``subsample_mode``) to bound memory and runtime. Default is 50000. precision_shrinkage : float, optional Fractional shrinkage applied to the off-diagonal terms of the 3x3 (scale, A(V), R(V)) precision matrix before Monte Carlo sampling, which stabilizes strongly-correlated fits. 0 disables it; ~0.03 is a reasonable nonzero value. Default is 0.0. subsample_mode : str, optional Strategy used to thin models when the selection exceeds ``max_models``: ``'representative'`` (Gumbel-max sampling weighted by likelihood) or ``'topk'`` (the highest-likelihood models). Default is ``'representative'``. R_solar : float, optional Solar galactocentric radius in kpc, used by the Galactic structure prior. Default is 8.2. Z_solar : float, optional Solar height above the Galactic midplane in kpc, used by the Galactic structure prior. Default is 0.025. Returns ------- str Path to the output HDF5 file. Notes ----- The output HDF5 file contains the following datasets: - ``labels``: Object labels (Ndata, Nlabels) - ``model_idx``: Resampled model indices (Ndata, Ndraws) - ``ml_scale``: Maximum likelihood scale factors (Ndata, Ndraws) - ``ml_av``: Maximum likelihood A(V) values (Ndata, Ndraws) - ``ml_rv``: Maximum likelihood R(V) values (Ndata, Ndraws) - ``ml_cov_sar``: Covariance matrices (Ndata, Ndraws, 3, 3) - ``obj_log_post``: Log-posteriors (Ndata, Ndraws) - ``obj_log_evid``: Log-evidence per object (Ndata,) - ``obj_chi2min``: Minimum chi-squared per object (Ndata,) - ``obj_Nbands``: Number of bands used per object (Ndata,) - ``mc_ess``: Monte Carlo effective sample size (Ndata, Ndraws). For each posterior draw, the ESS of the MC integration over (distance, Av, Rv) for that draw's source model. Low ESS indicates poor overlap between the Gaussian proposal and the target posterior. If save_dar_draws=True, also includes: - ``samps_dist``: Distance draws in kpc (Ndata, Ndraws) - ``samps_red``: A(V) draws (Ndata, Ndraws) - ``samps_dred``: R(V) draws (Ndata, Ndraws) - ``samps_logp``: Log-weights for draws (Ndata, Ndraws) Examples -------- >>> from brutus.analysis import BruteForce >>> from brutus.core import StarGrid >>> from brutus.data import load_models >>> >>> # Load model grid >>> models, labels, params = load_models('grid.h5') >>> grid = StarGrid(models, labels, params) >>> fitter = BruteForce(grid) >>> >>> # Fit photometry >>> results_file = fitter.fit( ... phot, phot_err, phot_mask, obj_labels, ... save_file='results.h5', ... parallax=parallax, parallax_err=parallax_err, ... data_coords=coords ... ) """ # Load dust map from file path if needed (once for all objects) if dustfile is not None and isinstance(dustfile, str): from ..dust import Bayestar dustfile = Bayestar(dustfile=dustfile) # Pre-process data and initialize priors setup_results = self._setup( data, data_err, data_mask, data_labels=data_labels, phot_offsets=phot_offsets, parallax=parallax, parallax_err=parallax_err, av_gauss=av_gauss, lnprior=lnprior, wt_thresh=wt_thresh, cdf_thresh=cdf_thresh, apply_agewt=apply_agewt, apply_grad=apply_grad, lngalprior=lngalprior, lndustprior=lndustprior, dustfile=dustfile, data_coords=data_coords, ltol_subthresh=ltol_subthresh, logl_initthresh=logl_initthresh, mag_max=mag_max, merr_max=merr_max, rstate=rstate, R_solar=R_solar, Z_solar=Z_solar, ) ( data_proc, data_err_proc, data_mask_proc, lnprior_proc, lngalprior_proc, lndustprior_proc, ) = setup_results Ndata, Nfilt = data.shape # Initialize random state if rstate is None: rstate = np.random.RandomState() # Ensure save_file has .h5 extension if not save_file.endswith(".h5"): save_file = f"{save_file}.h5" # Get model labels for priors dlabels = self.models_labels if apply_dlabels else None # Initialize results file out = h5py.File(save_file, "w") out.create_dataset("labels", data=data_labels) # TODO: Consolidate HDF5 dataset creation for streaming and batch modes # into a shared specification dict to reduce duplication. if running_io: # Streaming mode: create datasets upfront, write as we go out.create_dataset( "model_idx", data=np.full((Ndata, Ndraws), -99, dtype="int32") ) out.create_dataset( "ml_scale", data=np.ones((Ndata, Ndraws), dtype="float32") ) out.create_dataset("ml_av", data=np.zeros((Ndata, Ndraws), dtype="float32")) out.create_dataset("ml_rv", data=np.zeros((Ndata, Ndraws), dtype="float32")) out.create_dataset( "ml_cov_sar", data=np.zeros((Ndata, Ndraws, 3, 3), dtype="float32") ) out.create_dataset( "obj_log_post", data=np.zeros((Ndata, Ndraws), dtype="float32") ) out.create_dataset("obj_log_evid", data=np.zeros(Ndata, dtype="float32")) out.create_dataset("obj_chi2min", data=np.zeros(Ndata, dtype="float32")) out.create_dataset("obj_Nbands", data=np.zeros(Ndata, dtype="int16")) out.create_dataset( "mc_ess", data=np.zeros((Ndata, Ndraws), dtype="float32") ) if save_dar_draws: out.create_dataset( "samps_dist", data=np.ones((Ndata, Ndraws), dtype="float32") ) out.create_dataset( "samps_red", data=np.ones((Ndata, Ndraws), dtype="float32") ) out.create_dataset( "samps_dred", data=np.ones((Ndata, Ndraws), dtype="float32") ) out.create_dataset( "samps_logp", data=np.ones((Ndata, Ndraws), dtype="float32") ) else: # Batch mode: accumulate in memory, write at end idxs_arr = np.full((Ndata, Ndraws), -99, dtype="int32") scale_arr = np.ones((Ndata, Ndraws), dtype="float32") av_arr = np.zeros((Ndata, Ndraws), dtype="float32") rv_arr = np.zeros((Ndata, Ndraws), dtype="float32") cov_arr = np.zeros((Ndata, Ndraws, 3, 3), dtype="float32") logpost_arr = np.zeros((Ndata, Ndraws), dtype="float32") logevid_arr = np.zeros(Ndata, dtype="float32") chi2best_arr = np.zeros(Ndata, dtype="float32") nbands_arr = np.zeros(Ndata, dtype="int16") mc_ess_arr = np.zeros((Ndata, Ndraws), dtype="float32") if save_dar_draws: dist_arr = np.ones((Ndata, Ndraws), dtype="float32") red_arr = np.ones((Ndata, Ndraws), dtype="float32") dred_arr = np.ones((Ndata, Ndraws), dtype="float32") logp_arr = np.ones((Ndata, Ndraws), dtype="float32") # Fit data if verbose: t1 = time.time() t = 0.0 sys.stderr.write(f"\rFitting object 1/{Ndata} ") sys.stderr.flush() for i in range(Ndata): # Get parallax for this object par_i = parallax[i] if parallax is not None else None par_err_i = parallax_err[i] if parallax_err is not None else None coord_i = data_coords[i] if data_coords is not None else None # Fit individual object results = self._fit( data_proc[i], data_err_proc[i], data_mask_proc[i], parallax=par_i, parallax_err=par_err_i, coord=coord_i, Nmc_prior=Nmc_prior, avlim=avlim, av_gauss=av_gauss, rvlim=rvlim, rv_gauss=rv_gauss, lnprior=lnprior_proc, wt_thresh=wt_thresh, cdf_thresh=cdf_thresh, Ndraws=Ndraws, lngalprior=lngalprior_proc, lndustprior=lndustprior_proc, dustfile=dustfile, dlabels=dlabels, logl_dim_prior=logl_dim_prior, ltol=ltol, ltol_subthresh=ltol_subthresh, logl_initthresh=logl_initthresh, mem_lim=mem_lim, precision_shrinkage=precision_shrinkage, max_models=max_models, subsample_mode=subsample_mode, rstate=rstate, return_distreds=save_dar_draws, R_solar=R_solar, Z_solar=Z_solar, ) # Unpack results if save_dar_draws: ( idxs, scales, avs, rvs, covs_sar, Ndim, lpost, levid, chi2min, dists, reds, dreds, logwt, mc_ess_i, ) = results else: ( idxs, scales, avs, rvs, covs_sar, Ndim, lpost, levid, chi2min, mc_ess_i, ) = results # Print progress if verbose and i < Ndata - 1: t2 = time.time() dt = t2 - t1 t1 = t2 t += dt t_avg = t / (i + 1) t_est = t_avg * (Ndata - i - 1) sys.stderr.write( f"\rFitting object {i + 2}/{Ndata} " f"[chi2/n: {chi2min:.1f}/{Ndim}] " f"(mean time: {t_avg:.3f} s/obj, " f"est. remaining: {t_est:.3f} s) " ) sys.stderr.flush() # Save results if running_io: out["model_idx"][i] = idxs out["ml_scale"][i] = scales out["ml_av"][i] = avs out["ml_rv"][i] = rvs out["ml_cov_sar"][i] = covs_sar out["obj_Nbands"][i] = Ndim out["obj_log_post"][i] = lpost out["obj_log_evid"][i] = levid out["obj_chi2min"][i] = chi2min out["mc_ess"][i] = mc_ess_i if save_dar_draws: out["samps_dist"][i] = dists out["samps_red"][i] = reds out["samps_dred"][i] = dreds out["samps_logp"][i] = logwt else: idxs_arr[i] = idxs scale_arr[i] = scales av_arr[i] = avs rv_arr[i] = rvs cov_arr[i] = covs_sar logpost_arr[i] = lpost logevid_arr[i] = levid chi2best_arr[i] = chi2min nbands_arr[i] = Ndim mc_ess_arr[i] = mc_ess_i if save_dar_draws: dist_arr[i] = dists red_arr[i] = reds dred_arr[i] = dreds logp_arr[i] = logwt # Final progress update if verbose: t2 = time.time() dt = t2 - t1 t += dt t_avg = t / Ndata sys.stderr.write( f"\rFitting object {Ndata}/{Ndata} " f"[chi2/n: {chi2min:.1f}/{Ndim}] " f"(mean time: {t_avg:.3f} s/obj, " f"total: {t:.3f} s) \n" ) sys.stderr.flush() # Dump results to disk if using batch mode if not running_io: out.create_dataset("model_idx", data=idxs_arr) out.create_dataset("ml_scale", data=scale_arr) out.create_dataset("ml_av", data=av_arr) out.create_dataset("ml_rv", data=rv_arr) out.create_dataset("ml_cov_sar", data=cov_arr) out.create_dataset("obj_log_post", data=logpost_arr) out.create_dataset("obj_log_evid", data=logevid_arr) out.create_dataset("obj_chi2min", data=chi2best_arr) out.create_dataset("obj_Nbands", data=nbands_arr) out.create_dataset("mc_ess", data=mc_ess_arr) if save_dar_draws: out.create_dataset("samps_dist", data=dist_arr) out.create_dataset("samps_red", data=red_arr) out.create_dataset("samps_dred", data=dred_arr) out.create_dataset("samps_logp", data=logp_arr) # Close output file out.close() return save_file
def _fit( self, data, data_err, data_mask, parallax=None, parallax_err=None, coord=None, Nmc_prior=100, avlim=(0.0, 20.0), av_gauss=(0.0, 1e6), rvlim=(1.0, 8.0), rv_gauss=(3.32, 0.18), lnprior=None, wt_thresh=1e-3, cdf_thresh=2e-3, Ndraws=250, lngalprior=None, lndustprior=None, dustfile=None, dlabels=None, logl_dim_prior=True, ltol=3e-2, ltol_subthresh=1e-2, logl_initthresh=5e-3, mem_lim=8000.0, max_models=50000, precision_shrinkage=0.0, subsample_mode="representative", rstate=None, return_distreds=True, R_solar=8.2, Z_solar=0.025, ): """ Perform internal fitting for a single object. Parameters ---------- data : numpy.ndarray Photometric flux densities for a single object. data_err : numpy.ndarray Photometric errors. data_mask : numpy.ndarray Binary mask for valid data. parallax : float, optional Parallax measurement in mas. parallax_err : float, optional Parallax error in mas. coord : tuple, optional Galactic (l, b) coordinates in degrees. Nmc_prior : int, optional Number of Monte Carlo samples for prior integration. avlim : tuple, optional (min, max) bounds on A(V). av_gauss : tuple, optional (mean, std) for Gaussian prior on A(V). rvlim : tuple, optional (min, max) bounds on R(V). rv_gauss : tuple, optional (mean, std) for Gaussian prior on R(V). lnprior : numpy.ndarray, optional Log-prior for each model. wt_thresh : float, optional Threshold for weight-based model selection. cdf_thresh : float, optional CDF threshold for model selection. Ndraws : int, optional Number of posterior draws to return. lngalprior : callable, optional Galactic structure prior function. lndustprior : callable, optional Dust prior function. dustfile : str, optional Path to 3D dust map file. dlabels : numpy.ndarray, optional Model labels for prior evaluation. logl_dim_prior : bool, optional Whether to apply dimensional prior correction. ltol : float, optional Convergence tolerance. ltol_subthresh : float, optional Sub-threshold for convergence. logl_initthresh : float, optional Initial likelihood threshold. mem_lim : float, optional Memory limit in MB. rstate : numpy.random.RandomState, optional Random state for reproducibility. return_distreds : bool, optional Whether to return distance/reddening draws. Default is True. Returns ------- tuple If return_distreds=True: (idxs, scales, avs, rvs, covs_sar, Ndim, lnprob, levid, chi2min, dists, reds, dreds, logwts, mc_ess) If return_distreds=False: (idxs, scales, avs, rvs, covs_sar, Ndim, lnprob, levid, chi2min, mc_ess) Where: - idxs: Resampled model indices (Ndraws,) - scales: Scale factors for resampled models (Ndraws,) - avs: A(V) values for resampled models (Ndraws,) - rvs: R(V) values for resampled models (Ndraws,) - covs_sar: Covariance matrices (Ndraws, 3, 3) - Ndim: Number of photometric bands used - lnprob: Log-posteriors for resampled models (Ndraws,) - levid: Log-evidence (scalar) - chi2min: Minimum chi-squared (scalar) - dists: Distance draws in kpc (Ndraws,) - reds: A(V) draws (Ndraws,) - dreds: R(V) draws (Ndraws,) - logwts: Log-weights for draws (Ndraws,) - mc_ess: Effective sample size for each draw's source model (Ndraws,) """ # Initialize random state if rstate is None: rstate = np.random.RandomState() # Compute grid likelihoods loglike_results = self.loglike_grid( data, data_err, data_mask, avlim=avlim, av_gauss=av_gauss, rvlim=rvlim, rv_gauss=rv_gauss, dim_prior=logl_dim_prior, ltol=ltol, ltol_subthresh=ltol_subthresh, init_thresh=logl_initthresh, parallax=parallax, parallax_err=parallax_err, return_vals=True, ) # Unpack likelihood results lnlike, Ndim, chi2, scales_all, avs_all, rvs_all, icovs_sar_all = ( loglike_results ) # Compute grid posteriors logpost_results = self.logpost_grid( loglike_results, parallax=parallax, parallax_err=parallax_err, coord=coord, Nmc_prior=Nmc_prior, lnprior=lnprior, wt_thresh=wt_thresh, cdf_thresh=cdf_thresh, lngalprior=lngalprior, lndustprior=lndustprior, dustfile=dustfile, dlabels=dlabels, avlim=avlim, rvlim=rvlim, mem_lim=mem_lim, precision_shrinkage=precision_shrinkage, max_models=max_models, subsample_mode=subsample_mode, rstate=rstate, R_solar=R_solar, Z_solar=Z_solar, ) # Unpack posterior results # sel: selected model indices, cov_sar: covariances for selected models # lnp: log-posteriors, dist_mc/a_mc/r_mc: MC samples, lnp_mc: MC log-posteriors # mc_ess: effective sample size per selected model sel, cov_sar_sel, lnp, dist_mc, a_mc, r_mc, lnp_mc, mc_ess_sel = logpost_results Nsel = len(sel) # Add parallax contribution to chi2 and Ndim if provided Ndim_out = Ndim chi2_with_par = chi2.copy() if parallax is not None and parallax_err is not None: if np.isfinite(parallax) and np.isfinite(parallax_err): with warnings.catch_warnings(): warnings.simplefilter("ignore") par_pred = np.sqrt( scales_all ) # sqrt(scale) = 1/d_kpc = parallax_mas chi2_with_par += (par_pred - parallax) ** 2 / parallax_err**2 Ndim_out += 1 # Compute goodness-of-fit metrics (over ALL models, not just posterior-selected) chi2min = np.min(chi2_with_par) levid = logsumexp(lnp) # Resample from posterior with warnings.catch_warnings(): warnings.simplefilter("ignore") wt = np.exp(lnp - levid) wt_sum = wt.sum() if wt_sum > 0: wt /= wt_sum else: wt = np.ones(Nsel) / Nsel idxs_local = rstate.choice(Nsel, size=Ndraws, p=wt) sidxs = sel[idxs_local] # Extract values for resampled models scales = scales_all[sidxs] avs = avs_all[sidxs] rvs = rvs_all[sidxs] covs_sar = cov_sar_sel[idxs_local] lnprob = lnp[idxs_local] mc_ess = mc_ess_sel[idxs_local] if return_distreds: # Draw distance and reddening samples from MC integration # For each resampled model, pick one MC sample weighted by its contribution dists = np.zeros(Ndraws, dtype="float32") reds = np.zeros(Ndraws, dtype="float32") dreds = np.zeros(Ndraws, dtype="float32") logwts = np.zeros(Ndraws, dtype="float32") # Vectorized MC sample drawing using Gumbel-max trick. # For each posterior draw, select one MC sample proportional # to its weight — equivalent to rstate.choice(p=weights) # but fully vectorized across all Ndraws at once. all_mc_logwts = lnp_mc[idxs_local] # (Ndraws, Nmc) Nmc_actual = all_mc_logwts.shape[1] gumbel_noise = -np.log( -np.log(rstate.uniform(size=(Ndraws, Nmc_actual)) + 1e-300) + 1e-300 ) imc_all = np.argmax(all_mc_logwts + gumbel_noise, axis=1) dists = dist_mc[idxs_local, imc_all].astype("float32") reds = a_mc[idxs_local, imc_all].astype("float32") dreds = r_mc[idxs_local, imc_all].astype("float32") logwts = all_mc_logwts[np.arange(Ndraws), imc_all].astype("float32") return ( sidxs, scales, avs, rvs, covs_sar, Ndim_out, lnprob, levid, chi2min, dists, reds, dreds, logwts, mc_ess, ) else: return ( sidxs, scales, avs, rvs, covs_sar, Ndim_out, lnprob, levid, chi2min, mc_ess, )
[docs] def __repr__(self): """Return string representation of BruteForce object.""" return ( f"BruteForce(nmodels={self.nmodels:,}, " f"nfilters={self.nfilters}, " f"labels={len(self.labels_mask)})" )