Source code for brutus.utils.math

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Mathematical utility functions for brutus.

This module contains mathematical utility functions including matrix operations,
statistical distributions, and numerical utilities. Many functions are JIT-compiled
with numba for performance.

Functions
---------
galactic_to_galactocentric_cyl : Coordinate transform
    Convert Galactic (l, b, distance) to galactocentric cylindrical (R, phi, Z)
inverse3 : Matrix inversion
    Fast 3x3 matrix inversion with optional diagonal preconditioning
isPSD : Matrix check
    Check if matrix is positive semi-definite
chisquare_logpdf : Chi-square log-PDF
    Log-probability density for chi-square distribution
truncnorm_pdf : Truncated normal PDF
    Probability density for truncated normal
truncnorm_logpdf : Truncated normal log-PDF
    Log-probability density for truncated normal

See Also
--------
brutus.utils.sampling : Sampling utilities
brutus.priors.galactic : Uses truncated normal distributions

Notes
-----
The numba-compiled functions provide significant speedups for tight loops
in Bayesian inference. The `inverse3` function is specifically optimized
for the (scale, A_V, R_V) covariance matrices used throughout brutus.

Matrix regularization in `inverse3` prevents numerical issues when matrices
are near-singular by adding a small value to the diagonal when eigenvalues
are too small.

Examples
--------
>>> import numpy as np
>>> from brutus.utils.math import inverse3, isPSD
>>>
>>> # Create a 3x3 covariance matrix
>>> cov = np.array([[1.0, 0.1, 0.05],
...                 [0.1, 0.5, 0.02],
...                 [0.05, 0.02, 0.3]])
>>>
>>> # Check if positive semi-definite
>>> is_valid = isPSD(cov)
>>>
>>> # Invert with diagonal preconditioning / regularization
>>> icov = inverse3(cov, regularize=True)
"""

from math import lgamma, log

import numpy as np
from numba import jit
from scipy.special import erf

__all__ = [
    "galactic_to_galactocentric_cyl",
    "inverse3",
    "isPSD",
    "chisquare_logpdf",
    "truncnorm_pdf",
    "truncnorm_logpdf",
]


class _function_wrapper:
    """
    A hack to make functions pickleable when `args` or `kwargs` are.

    also included. Based on the implementation in
    `emcee <http://dan.iel.fm/emcee/>`_.

    Parameters
    ----------
    func : callable
        The function to wrap.
    args : tuple
        Additional positional arguments to pass to the function.
    kwargs : dict
        Additional keyword arguments to pass to the function.
    name : str, optional
        Name for the function (used in error messages).
    """

    def __init__(self, func, args, kwargs, name="input"):
        self.func = func
        self.args = args
        self.kwargs = kwargs
        self.name = name

    def __call__(self, x):
        """Call the wrapped function with stored arguments."""
        try:
            return self.func(x, *self.args, **self.kwargs)
        except Exception:
            import traceback

            print(f"Exception while calling {self.name} function:")
            print("  params:", x)
            print("  args:", self.args)
            print("  kwargs:", self.kwargs)
            print("  exception:")
            traceback.print_exc()
            raise


@jit(nopython=True, cache=True)
def _matrix_det_3x3(A):
    """Compute 3x3 matrix determinant - numba compatible."""
    return (
        A[0, 0] * (A[1, 1] * A[2, 2] - A[1, 2] * A[2, 1])
        - A[0, 1] * (A[1, 0] * A[2, 2] - A[1, 2] * A[2, 0])
        + A[0, 2] * (A[1, 0] * A[2, 1] - A[1, 1] * A[2, 0])
    )


@jit(nopython=True, cache=True)
def _invert_3x3_analytical(A):
    """
    Invert a 3x3 matrix using analytical formulas.

    This is numba-compatible and numerically stable for well-conditioned matrices.
    Uses the standard analytical inversion formula with explicit determinant calculation.
    """
    # Extract matrix elements
    a11, a12, a13 = A[0, 0], A[0, 1], A[0, 2]
    a21, a22, a23 = A[1, 0], A[1, 1], A[1, 2]
    a31, a32, a33 = A[2, 0], A[2, 1], A[2, 2]

    # Calculate determinant
    det = _matrix_det_3x3(A)

    # Check for singular matrix
    if abs(det) < 1e-15:  # Essentially zero determinant
        # Return matrix filled with inf/nan for singular case
        inv = np.empty_like(A)
        inv.fill(np.inf)
        return inv

    # Calculate inverse using cofactor method
    inv = np.empty_like(A)

    inv[0, 0] = (a22 * a33 - a23 * a32) / det
    inv[0, 1] = (a13 * a32 - a12 * a33) / det
    inv[0, 2] = (a12 * a23 - a13 * a22) / det

    inv[1, 0] = (a23 * a31 - a21 * a33) / det
    inv[1, 1] = (a11 * a33 - a13 * a31) / det
    inv[1, 2] = (a13 * a21 - a11 * a23) / det

    inv[2, 0] = (a21 * a32 - a22 * a31) / det
    inv[2, 1] = (a12 * a31 - a11 * a32) / det
    inv[2, 2] = (a11 * a22 - a12 * a21) / det

    return inv


@jit(nopython=True, cache=True)
def _batch_invert_3x3(A_batch):
    """
    Numba-compiled batch 3x3 matrix inversion using analytical method.

    Parameters
    ----------
    A_batch : ndarray of shape (N, 3, 3)
        Batch of 3x3 matrices to invert.

    Returns
    -------
    inv_batch : ndarray of shape (N, 3, 3)
        Batch of inverted matrices.
    """
    N = A_batch.shape[0]
    result = np.empty_like(A_batch)

    for i in range(N):
        result[i] = _invert_3x3_analytical(A_batch[i])

    return result


def _invert_3x3_preconditioned(P, min_eigenval_threshold=1e-12):
    """
    Invert a single 3x3 symmetric positive-definite matrix using diagonal
    preconditioning for numerical stability.

    Normalizes the precision matrix to a correlation-like matrix (1s on
    diagonal) before inversion, then transforms back. This reduces the
    condition number from O(max_diag/min_diag) to O(1/(1-rho_max)),
    which depends only on correlations, not parameter scales.

    Parameters
    ----------
    P : ndarray of shape (3, 3)
        Symmetric positive-definite precision matrix.
    min_eigenval_threshold : float
        Minimum eigenvalue threshold for the normalized covariance.

    Returns
    -------
    C : ndarray of shape (3, 3)
        The inverse (covariance) matrix.
    """
    # Step 1: Diagonal scaling factors
    d0 = np.sqrt(max(P[0, 0], 1e-30))
    d1 = np.sqrt(max(P[1, 1], 1e-30))
    d2 = np.sqrt(max(P[2, 2], 1e-30))
    d_inv = np.array([1.0 / d0, 1.0 / d1, 1.0 / d2])

    # Step 2: Symmetrize and normalize to correlation-like matrix (1s on diagonal)
    P_norm = np.empty((3, 3))
    for i in range(3):
        for j in range(3):
            P_norm[i, j] = 0.5 * (P[i, j] + P[j, i]) * d_inv[i] * d_inv[j]

    # Step 3: Pre-regularize if the normalized matrix is singular.
    # In normalized space, a small additive shift is well-scaled since
    # all diagonals are 1.0.
    det_norm = _matrix_det_3x3(P_norm)
    if abs(det_norm) < 1e-10:
        P_norm[0, 0] += 1e-3
        P_norm[1, 1] += 1e-3
        P_norm[2, 2] += 1e-3

    # Invert the normalized matrix analytically
    C_norm = _invert_3x3_analytical(P_norm)

    # Step 4: Handle inversion failure (inf/nan from truly degenerate input)
    if not np.all(np.isfinite(C_norm)):
        # Fallback: return diagonal covariance from precision diagonal
        C = np.zeros((3, 3))
        C[0, 0] = 1.0 / max(P[0, 0], 1e-30)
        C[1, 1] = 1.0 / max(P[1, 1], 1e-30)
        C[2, 2] = 1.0 / max(P[2, 2], 1e-30)
        return C

    # Step 5: Symmetrize and regularize in normalized space.
    # Use exact eigenvalues (not Gershgorin) because highly correlated
    # parameters (ρ > 0.9, common for distance-extinction degeneracy)
    # cause Gershgorin to give wildly pessimistic bounds that trigger
    # massive false regularization.
    C_sym = 0.5 * (C_norm + C_norm.T)
    min_eig = np.min(np.linalg.eigvalsh(C_sym))
    if min_eig < min_eigenval_threshold:
        shift = min_eigenval_threshold - min_eig
        C_sym[0, 0] += shift
        C_sym[1, 1] += shift
        C_sym[2, 2] += shift
        C_norm = C_sym

    # Step 6: Transform back to original parameter space
    # C_original = D^{-1} C_norm D^{-1}
    C = np.empty((3, 3))
    for i in range(3):
        for j in range(3):
            C[i, j] = C_norm[i, j] * d_inv[i] * d_inv[j]

    return C


def _batch_invert_3x3_preconditioned(P_batch, min_eigenval_threshold=1e-12):
    """
    Batch-invert N 3x3 matrices with diagonal preconditioning.

    Fully vectorized: operates on the entire (N, 3, 3) array at once
    using numpy broadcasting. Avoids Python loops over N matrices.
    """

    # Step 1: Diagonal scaling (vectorized)
    diag = np.sqrt(np.maximum(np.diagonal(P_batch, axis1=1, axis2=2), 1e-30))  # (N, 3)
    d_inv = 1.0 / diag  # (N, 3)

    # Step 2: Normalize to correlation form (vectorized)
    P_sym = 0.5 * (P_batch + np.swapaxes(P_batch, 1, 2))  # symmetrize
    P_norm = P_sym * d_inv[:, :, None] * d_inv[:, None, :]  # (N, 3, 3)

    # Step 3: Pre-regularize singular matrices
    dets = (
        P_norm[:, 0, 0]
        * (P_norm[:, 1, 1] * P_norm[:, 2, 2] - P_norm[:, 1, 2] * P_norm[:, 2, 1])
        - P_norm[:, 0, 1]
        * (P_norm[:, 1, 0] * P_norm[:, 2, 2] - P_norm[:, 1, 2] * P_norm[:, 2, 0])
        + P_norm[:, 0, 2]
        * (P_norm[:, 1, 0] * P_norm[:, 2, 1] - P_norm[:, 1, 1] * P_norm[:, 2, 0])
    )  # (N,)
    singular = np.abs(dets) < 1e-10
    if np.any(singular):
        P_norm[singular, 0, 0] += 1e-3
        P_norm[singular, 1, 1] += 1e-3
        P_norm[singular, 2, 2] += 1e-3

    # Step 4: Batch analytical inversion
    C_norm = _batch_invert_3x3(P_norm)

    # Step 5: Handle inversion failures
    bad = ~np.all(np.isfinite(C_norm), axis=(1, 2))
    if np.any(bad):
        C_norm[bad] = 0.0
        for i in range(3):
            C_norm[bad, i, i] = 1.0 / np.maximum(P_batch[bad, i, i], 1e-30)
        # For bad matrices, d_inv should be identity-like
        d_inv[bad] = 1.0

    # Step 6: Symmetrize
    C_sym = 0.5 * (C_norm + np.swapaxes(C_norm, 1, 2))

    # Step 7: Batch eigenvalue check and regularization
    # Use np.linalg.eigvalsh on the entire batch at once
    eigvals = np.linalg.eigvalsh(C_sym)  # (N, 3), sorted ascending
    min_eigs = eigvals[:, 0]  # (N,)
    needs_reg = (min_eigs < min_eigenval_threshold) & (~bad)
    if np.any(needs_reg):
        shifts = min_eigenval_threshold - min_eigs[needs_reg]
        C_sym[needs_reg, 0, 0] += shifts
        C_sym[needs_reg, 1, 1] += shifts
        C_sym[needs_reg, 2, 2] += shifts

    # Step 8: Transform back to original parameter space
    C = C_sym * d_inv[:, :, None] * d_inv[:, None, :]  # (N, 3, 3)

    return C


[docs] def inverse3(A, regularize=False, min_eigenval_threshold=1e-12): """ Compute the inverse of a series of 3x3 matrices. When ``regularize=True``, uses diagonal preconditioning: the precision matrix is normalized to a correlation-like matrix (1s on diagonal) before inversion, then transformed back. This reduces condition numbers from O(max_diag/min_diag) to O(1/(1-rho_max)), making the inversion numerically stable even when parameters have very different scales (e.g. scale ~ 10^5 vs R(V) precision ~ 30). Parameters ---------- A : `~numpy.ndarray` of shape `(..., 3, 3)` Array of 3x3 matrices. regularize : bool, optional Whether to apply diagonal preconditioning and regularization to ensure positive semi-definiteness. Default: False. min_eigenval_threshold : float, optional Minimum acceptable eigenvalue for OUTPUT matrices in the normalized space. Default: 1e-12. Returns ------- A_inv : `~numpy.ndarray` of shape `(..., 3, 3)` Inverse matrices, guaranteed to be positive semi-definite if regularize=True. """ if not regularize: if len(A.shape) == 2: return _invert_3x3_analytical(A) else: return _batch_invert_3x3(A) if len(A.shape) == 2: return _invert_3x3_preconditioned(A, min_eigenval_threshold) else: return _batch_invert_3x3_preconditioned(A, min_eigenval_threshold)
[docs] def isPSD(A): """ Check if `A` is a positive semidefinite matrix. A matrix is positive semidefinite if all its eigenvalues are non-negative. Parameters ---------- A : `~numpy.ndarray` of shape `(N, N)` Square matrix to test. Returns ------- is_psd : bool True if the matrix is positive semidefinite, False otherwise. """ # Check if matrix is symmetric (within numerical precision) if not np.allclose(A, A.T, rtol=1e-10, atol=1e-10): return False # Check eigenvalues are non-negative (eigvalsh is faster for symmetric matrices) eigenvals = np.linalg.eigvalsh(A) return np.all(eigenvals >= -1e-10) # Allow small numerical errors
[docs] def chisquare_logpdf(x, df, loc=0, scale=1): """ Compute log-PDF of a chi-square distribution. `chisquare_logpdf(x, df, loc, scale)` is equal to `chisquare_logpdf(y, df) - ln(scale)`, where `y = (x-loc)/scale`. NOTE: This function replicates `~scipy.stats.chi2.logpdf`. Parameters ---------- x : `~numpy.ndarray` of shape `(N)` or float Input values. df : float Degrees of freedom. Uses ``math.lgamma`` internally, so there is no overflow risk for large ``df`` (unlike ``math.gamma`` which overflows for ``df`` > ~340). loc : float, optional Offset of distribution. Default is 0. scale : float, optional Scaling of distribution. Default is 1. Returns ------- ans : `~numpy.ndarray` of shape `(N)` or float The natural log of the PDF values. """ if isinstance(x, list): x = np.asarray(x) y = (x - loc) / scale is_scalar = isinstance(y, (float, int)) if is_scalar: if y <= 0: return -np.inf else: keys = y <= 0 y = np.where(keys, 0.1, y) # temporary value, will be set to -inf below # Compute log-pdf ans = -(df / 2.0) * log(2) - lgamma(df / 2.0) ans = ans + (df / 2.0 - 1.0) * np.log(y) - y / 2.0 - log(scale) if not is_scalar: ans = np.where(keys, -np.inf, ans) return ans
[docs] def truncnorm_pdf(x, a, b, loc=0.0, scale=1.0): """ Compute PDF of a truncated normal distribution. The parent normal distribution has a mean of `loc` and standard deviation of `scale`. The distribution is cut off at `a` and `b`. NOTE: This function replicates `~scipy.stats.truncnorm.pdf`. Parameters ---------- x : `~numpy.ndarray` of shape `(N)` or float Input values. a : float Lower bound in standardized units. The actual lower cutoff in data space is ``scale * a + loc``. b : float Upper bound in standardized units. The actual upper cutoff in data space is ``scale * b + loc``. loc : float, optional Mean of normal distribution. Default is 0.0. scale : float, optional Standard deviation of normal distribution. Default is 1.0. Returns ------- ans : `~numpy.ndarray` of shape `(N)` or float The PDF values. """ _a = scale * a + loc _b = scale * b + loc xi = (x - loc) / scale phix = np.exp(-0.5 * xi**2) / np.sqrt(2.0 * np.pi) Phia = 0.5 * (1 + erf(a / np.sqrt(2))) Phib = 0.5 * (1 + erf(b / np.sqrt(2))) ans = phix / (scale * (Phib - Phia)) if not isinstance(x, (float, int)): keys = np.logical_or(x < _a, x > _b) ans[keys] = 0 else: if x < _a or x > _b: ans = 0 return ans
[docs] def truncnorm_logpdf(x, a, b, loc=0.0, scale=1.0): """ Compute log-PDF of a truncated normal distribution. The parent normal distribution has a mean of `loc` and standard deviation of `scale`. The distribution is cut off at `a` and `b`. NOTE: This function replicates `~scipy.stats.truncnorm.logpdf`. Parameters ---------- x : `~numpy.ndarray` of shape `(N)` or float Input values. a : float Lower bound in standardized units. The actual lower cutoff in data space is ``scale * a + loc``. b : float Upper bound in standardized units. The actual upper cutoff in data space is ``scale * b + loc``. loc : float, optional Mean of normal distribution. Default is 0.0. scale : float, optional Standard deviation of normal distribution. Default is 1.0. Returns ------- ans : `~numpy.ndarray` of shape `(N)` or float The natural log PDF values. """ _a = scale * a + loc _b = scale * b + loc xi = (x - loc) / scale lnphi = -np.log(np.sqrt(2 * np.pi)) - 0.5 * np.square(xi) lndenom = np.log(scale / 2.0) + np.log( np.maximum(erf(b / np.sqrt(2)) - erf(a / np.sqrt(2)), 1e-300) ) ans = np.subtract(lnphi, lndenom) if not isinstance(x, (float, int)): keys = np.logical_or(x < _a, x > _b) ans[keys] = -np.inf else: if x < _a or x > _b: ans = -np.inf return ans
[docs] def galactic_to_galactocentric_cyl(dists, ell, b, R_solar=8.2, Z_solar=0.025): """ Convert Galactic coordinates to Galactocentric cylindrical coordinates. Converts heliocentric Galactic coordinates (l, b, distance) to Galactocentric cylindrical coordinates (R, Z) using a simple rotation and translation. This is a fast NumPy-based replacement for astropy SkyCoord coordinate transformations. Parameters ---------- dists : array_like Heliocentric distances in kpc. ell : float or array_like Galactic longitude in degrees. b : float or array_like Galactic latitude in degrees. R_solar : float, optional Solar Galactocentric radius in kpc. Default is 8.2. Z_solar : float, optional Solar height above the Galactic midplane in kpc. Default is 0.025. Returns ------- R : ndarray Galactocentric cylindrical radius in kpc. Z : ndarray Height above/below the Galactic midplane in kpc. Notes ----- The conversion assumes: - The Sun is located at ``(x, y, z) = (R_solar, 0, Z_solar)`` in Galactocentric Cartesian coordinates. - Galactic longitude ``l=0`` points toward the Galactic center. - The Galactic midplane is at ``Z=0``. The Cartesian positions relative to the Sun are: .. math:: x = d \\cos(b) \\cos(l) y = d \\cos(b) \\sin(l) And the Galactocentric cylindrical coordinates are: .. math:: R = \\sqrt{(x - R_\\odot)^2 + y^2} Z = d \\sin(b) + Z_\\odot See Also -------- brutus.priors.galactic.logp_galactic_structure : Uses this for coordinate conversion instead of astropy SkyCoord. Examples -------- >>> import numpy as np >>> from brutus.utils.math import galactic_to_galactocentric_cyl >>> R, Z = galactic_to_galactocentric_cyl( ... dists=np.array([1.0, 5.0]), ell=90.0, b=0.0 ... ) """ ell_rad = np.deg2rad(ell) b_rad = np.deg2rad(b) cos_b = np.cos(b_rad) x = dists * cos_b * np.cos(ell_rad) y = dists * cos_b * np.sin(ell_rad) R = np.sqrt((x - R_solar) ** 2 + y**2) Z = dists * np.sin(b_rad) + Z_solar return R, Z