Source code for brutus.core.neural_nets

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Neural network utilities for fast SED prediction.

This module contains classes for neural network-based bolometric correction
predictions, enabling fast computation of stellar spectral energy distributions
from fundamental stellar parameters (Teff, log g, [Fe/H], [α/Fe], Av, Rv).

The neural networks are trained on synthetic stellar spectra and provide rapid
computation of bolometric corrections across multiple photometric bands,
significantly speeding up stellar parameter inference compared to full spectral
synthesis.

Classes
-------
FastNN : Base neural network class
    Provides core neural network functionality for bolometric correction
    prediction using pre-trained weights and biases.

FastNNPredictor : SED prediction class
    Extends FastNN to generate complete spectral energy distributions
    for multiple filters with distance modulus and reddening corrections.

Examples
--------
Basic usage for SED prediction:

>>> from brutus.core.neural_nets import FastNNPredictor
>>> from brutus.data.filters import FILTERS
>>>
>>> # Initialize predictor for specific filters
>>> predictor = FastNNPredictor(filters=['g', 'r', 'i', 'z'])
>>>
>>> # Predict SED for a solar-type star at 1 kpc
>>> sed = predictor.sed(logt=3.76, logg=4.44, feh_surf=0.0,
...                     logl=0.0, afe=0.0, av=0.1, rv=3.1, dist=1000)
>>> print(f"Predicted magnitudes: {sed}")

Notes
-----
The neural networks require pre-trained model files containing weights, biases,
and input scaling parameters. Default files are stored in the data directory
and are automatically downloaded when needed.
"""

import sys

import h5py
import numpy as np

# Import filter definitions from parent module
from ..data.filters import FILTERS
from ..data.loader import find_nn_file

__all__ = ["FastNN", "FastNNPredictor"]


[docs] class FastNN: """ Object that wraps the underlying neural networks used to interpolate. between grid points on the bolometric correction tables. This class provides the core neural network functionality for predicting bolometric corrections from stellar parameters. It loads pre-trained neural network weights and biases, and provides methods for encoding input parameters and evaluating the network. Parameters ---------- filters : list of str, optional The names of filters that photometry should be computed for. If not provided, all available filters will be used. Filter names should match those defined in `brutus.data.filters.FILTERS`. nnfile : str, optional Path to the neural network file containing pre-trained weights and biases. Default is `'brutus/data/DATAFILES/nnMIST_BC.h5'` which will be downloaded automatically if not present. verbose : bool, optional Whether to print initialization progress messages to stderr. Default is `True`. Attributes ---------- w1, w2, w3 : numpy.ndarray Neural network weight matrices for each layer. b1, b2, b3 : numpy.ndarray Neural network bias vectors for each layer. xmin, xmax : numpy.ndarray Minimum and maximum values for input parameter scaling. xspan : numpy.ndarray Range of input parameters (xmax - xmin). Notes ----- The neural network architecture is a 3-layer feedforward network with sigmoid activation functions. Input parameters are scaled to [0,1] range before evaluation. Expected input parameters (in order): - log10(Teff) : Effective temperature in Kelvin - log g : Surface gravity in cgs units - [Fe/H] : Surface metallicity (log scale) - [α/Fe] : Alpha enhancement (log scale) - Av : V-band extinction in magnitudes - Rv : Reddening parameter R(V) = A(V)/E(B-V) """
[docs] def __init__(self, filters=None, nnfile=None, verbose=True): # Initialize values. if filters is None: filters = np.array(FILTERS) if nnfile is None: nnfile = find_nn_file() # Read in NN data. if verbose: sys.stderr.write("Initializing FastNN predictor...") self._load_NN(filters, nnfile) if verbose: sys.stderr.write("done!\n")
def _load_NN(self, filters, nnfile): """ Load neural network weights and biases from HDF5 file. This method reads the pre-trained neural network parameters for each specified filter from the HDF5 file. Each filter has its own trained network with identical architecture but different weights. Parameters ---------- filters : array-like List of filter names to load networks for. nnfile : str, optional Path to HDF5 file containing neural network data. Raises ------ ValueError If neural networks have inconsistent input parameter ranges across different filters. Notes ----- The HDF5 file is expected to have the following structure: - /{filter}/w1, w2, w3 : weight matrices for layers 1, 2, 3 - /{filter}/b1, b2, b3 : bias vectors for layers 1, 2, 3 - /{filter}/xmin, xmax : input parameter scaling bounds """ with h5py.File(nnfile, "r") as f: # Store weights and bias for each layer and filter self.w1 = np.array([f[fltr]["w1"] for fltr in filters]) self.b1 = np.array([f[fltr]["b1"] for fltr in filters]) self.w2 = np.array([f[fltr]["w2"] for fltr in filters]) self.b2 = np.array([f[fltr]["b2"] for fltr in filters]) self.w3 = np.array([f[fltr]["w3"] for fltr in filters]) self.b3 = np.array([f[fltr]["b3"] for fltr in filters]) # Load input parameter scaling bounds xmin = np.array([f[fltr]["xmin"] for fltr in filters]) xmax = np.array([f[fltr]["xmax"] for fltr in filters]) # Verify all networks have consistent parameter ranges if len(np.unique(xmin)) == 6 and len(np.unique(xmax)) == 6: self.xmin = xmin[0] self.xmax = xmax[0] self.xspan = self.xmax - self.xmin else: raise ValueError( "Some of the neural networks have different " "`xmin` and `xmax` ranges for parameters." )
[docs] def encode(self, x): """ Rescale input parameters to [0,1] range for neural network evaluation. The neural networks are trained on scaled inputs where each parameter is normalized to the range [0,1] based on the training data bounds. Parameters ---------- x : numpy.ndarray of shape (Ninput,) or (Ninput, Nsamples) Input stellar parameters. Expected parameters are: [log10(Teff), log g, [Fe/H], [alpha/Fe], Av, Rv] Returns ------- xp : numpy.ndarray of shape (Ninput, 1) or (Ninput, Nsamples) Scaled input parameters ready for neural network evaluation. Notes ----- The scaling is applied as: x_scaled = (x - xmin) / (xmax - xmin) where xmin and xmax are the bounds from the training data. """ # Dispatch on input dimensionality explicitly. The previous # implementation tried the 1D normalization and fell back to the 2D # form only when broadcasting raised a ValueError; that dispatch is # unsound because a (6, 6) batch -- exactly 6 samples for 6 parameters # -- broadcasts against xmin[None, :] (shape (1, 6)) WITHOUT error and # silently normalized along the wrong axis, producing corrupted # photometry for the N == 6 case only. x = np.asarray(x) if x.ndim == 1: # 1D input of shape (Ninput,) -> (Ninput, 1) xp = (x - self.xmin) / self.xspan return xp[:, None] # 2D input of shape (Ninput, Nsamples) -> (Ninput, Nsamples) xp = (x - self.xmin[:, None]) / self.xspan[:, None] return xp
[docs] def sigmoid(self, a): """ Apply sigmoid activation function. Computes the logistic sigmoid function: f(a) = 1 / (1 + exp(-a)) Parameters ---------- a : numpy.ndarray Input array to apply sigmoid transformation to. Returns ------- a_t : numpy.ndarray Output after applying sigmoid activation, same shape as input. Notes ----- The sigmoid function maps any real number to the range (0, 1), providing smooth activation for the neural network hidden layers. """ return 1.0 / (1.0 + np.exp(-a))
[docs] def nneval(self, x): """ Evaluate the neural network for given input parameters. Performs forward propagation through the 3-layer neural network to predict bolometric corrections for all filters. Parameters ---------- x : numpy.ndarray of shape (Ninput,) Stellar parameters: [log10(Teff), log g, [Fe/H], [alpha/Fe], Av, Rv] Returns ------- y : numpy.ndarray Predicted bolometric corrections for each filter. Notes ----- The network architecture is: - Input layer: 6 parameters - Hidden layer 1: with sigmoid activation - Hidden layer 2: with sigmoid activation - Output layer: linear activation (bolometric corrections) """ # Forward propagation through the network a1 = self.sigmoid(np.matmul(self.w1, self.encode(x)) + self.b1) a2 = self.sigmoid(np.matmul(self.w2, a1) + self.b2) y = np.matmul(self.w3, a2) + self.b3 return np.squeeze(y)
[docs] class FastNNPredictor(FastNN): """ Object that generates SED predictions for a provided set of filters using neural networks. This class extends FastNN to provide a complete interface for stellar SED prediction, including automatic distance modulus calculation and conversion from bolometric corrections to apparent magnitudes. Parameters ---------- filters : list of str, optional The names of filters that photometry should be computed for. If not provided, all available filters will be used. Must be a subset of filters available in the neural network file. nnfile : str, optional Path to the neural network file containing pre-trained weights. Default is `'brutus/data/DATAFILES/nnMIST_BC.h5'` which contains networks trained on MIST isochrones with C3K synthetic spectra. verbose : bool, optional Whether to print initialization progress messages. Default is `True`. Attributes ---------- filters : numpy.ndarray Array of filter names for which predictions are made. NFILT : int Number of filters for which predictions are made. Examples -------- Predict SED for a solar analog: >>> predictor = FastNNPredictor(filters=['g', 'r', 'i']) >>> sed = predictor.sed(logt=3.76, logg=4.44, feh_surf=0.0, ... logl=0.0, dist=1000.) >>> print(f"g-r color: {sed[0] - sed[1]:.3f}") Predict for a red giant with extinction: >>> sed = predictor.sed(logt=3.60, logg=2.5, feh_surf=-0.5, ... logl=1.5, av=0.5, rv=3.1, dist=2000.) Notes ----- The neural networks provide bolometric corrections which are combined with luminosity and distance to produce apparent magnitudes: m = -2.5 * log10(L/L_sun) + 4.74 - BC + distance_modulus where BC is the bolometric correction predicted by the neural network. """
[docs] def __init__(self, filters=None, nnfile=None, verbose=True): # Initialize filter selection if filters is None: filters = np.array(FILTERS) self.filters = filters self.NFILT = len(filters) # Initialize parent class with neural network # (file discovery is handled by FastNN.__init__) super().__init__(filters=filters, nnfile=nnfile, verbose=verbose)
[docs] def sed( self, logt=3.8, logg=4.4, feh_surf=0.0, logl=0.0, afe=0.0, av=0.0, rv=3.3, dist=1000.0, filt_idxs=slice(None), ): """ Generate SED predictions for specified stellar parameters. Returns predicted apparent magnitudes in the specified filters for a star with the given physical parameters, distance, and extinction. Uses neural network bolometric corrections combined with standard photometric transformations. Parameters ---------- logt : float, optional Base-10 logarithm of effective temperature in Kelvin. Typical range: [3.3, 4.5] corresponding to ~2000-30000K. Default is 3.8 (6300K, solar-type). logg : float, optional Base-10 logarithm of surface gravity in cgs units (cm/s^2). Typical range: [0, 5] from supergiants to white dwarfs. Default is 4.4 (solar value). feh_surf : float, optional Surface metallicity [Fe/H] in logarithmic units relative to solar. Typical range: [-2.5, 0.5]. Default is 0.0 (solar). logl : float, optional Base-10 logarithm of luminosity in solar luminosities. Typical range: [-4, 6] from low-mass MS to supergiants. Default is 0.0 (solar luminosity). afe : float, optional Alpha element enhancement [alpha/Fe] in logarithmic units relative to solar abundance ratios. Typical range: [-0.2, 0.8]. Default is 0.0 (solar ratios). av : float, optional V-band extinction in magnitudes. Must be non-negative. Typical range: [0, 6] mag. Default is 0.0 (no extinction). rv : float, optional Reddening parameter R(V) = A(V)/E(B-V), describing the extinction curve shape. Typical range: [1, 8]. Default is 3.3 (Milky Way average). dist : float, optional Distance to the star in parsecs. Must be positive. Default is 1000 pc. filt_idxs : slice or array-like, optional Indices or slice object specifying which subset of filters to return predictions for. Default is slice(None) (all filters). Returns ------- sed : numpy.ndarray of shape (Nfilt_subset,) Predicted apparent magnitudes in the specified filter subset. Magnitudes are in the AB system and include distance modulus and extinction corrections. Notes ----- The computation follows these steps: 1. Compute distance modulus: mu = 5 * log10(dist) - 5 2. Evaluate neural network for bolometric corrections: BC = NN(params) 3. Convert to apparent magnitudes: m = -2.5 * logl + 4.74 - BC + mu If any input parameters are outside the neural network training bounds, NaN values are returned for safety. Examples -------- Solar analog at various distances: >>> predictor = FastNNPredictor(['V', 'K']) >>> for d in [100, 1000, 10000]: # pc ... sed = predictor.sed(dist=d) ... print(f"{d:5d} pc: V={sed[0]:.2f}, K={sed[1]:.2f}") Effect of extinction: >>> sed_clean = predictor.sed(av=0.0) >>> sed_dusty = predictor.sed(av=1.0, rv=3.1) >>> extinction = sed_dusty - sed_clean >>> print(f"V-band extinction: {extinction[0]:.2f} mag") """ # Compute distance modulus mu = 5.0 * np.log10(dist) - 5.0 # Prepare input parameters for neural network x = np.array([10.0**logt, logg, feh_surf, afe, av, rv]) # Check if parameters are within neural network bounds if np.all(np.isfinite(x)) and np.all((x >= self.xmin) & (x <= self.xmax)): # Parameters are valid - compute bolometric corrections BC = self.nneval(x) # Convert to apparent magnitudes # m = M_bol + BC + distance_modulus # where M_bol = -2.5*log10(L/L_sun) + M_bol_sun # and M_bol_sun = 4.74 m = -2.5 * logl + 4.74 - BC + mu else: # Parameters are out of bounds - return NaN values m = np.full(self.NFILT, np.nan) # Return specified subset of filters return np.atleast_1d(m)[filt_idxs]
[docs] def sed_batch( self, logt, logg, feh_surf, logl, afe, av, rv, dist, filt_idxs=slice(None), ): """ Generate SED predictions for an array of stellar parameters. Vectorized version of :meth:`sed` that evaluates the neural network once for all stars, providing significant speedup over looping. Parameters ---------- logt : numpy.ndarray of shape (N,) Base-10 logarithm of effective temperature for each star. logg : numpy.ndarray of shape (N,) Base-10 logarithm of surface gravity for each star. feh_surf : numpy.ndarray of shape (N,) Surface metallicity [Fe/H] for each star. logl : numpy.ndarray of shape (N,) Base-10 logarithm of luminosity for each star. afe : numpy.ndarray of shape (N,) Alpha element enhancement [alpha/Fe] for each star. av : float V-band extinction (same for all stars). rv : float Reddening parameter R(V) (same for all stars). dist : float Distance in parsecs (same for all stars). filt_idxs : slice or array-like, optional Filter subset to return. Default is all filters. Returns ------- seds : numpy.ndarray of shape (N, Nfilt_subset) Predicted apparent magnitudes for each star and filter. Stars with out-of-bounds parameters get NaN values. """ logt = np.asarray(logt) logg = np.asarray(logg) feh_surf = np.asarray(feh_surf) logl = np.asarray(logl) afe = np.asarray(afe) N = len(logt) # Distance modulus (constant for all stars) mu = 5.0 * np.log10(dist) - 5.0 # Build input array: shape (6, N) x = np.array([10.0**logt, logg, feh_surf, afe, np.full(N, av), np.full(N, rv)]) # Determine which stars have valid (in-bounds, finite) parameters valid = ( np.all(np.isfinite(x), axis=0) & np.all(x >= self.xmin[:, None], axis=0) & np.all(x <= self.xmax[:, None], axis=0) ) # shape (N,) # Initialize output with NaN seds = np.full((N, self.NFILT), np.nan) n_valid = np.sum(valid) if n_valid > 0: # Evaluate NN for all valid stars at once BC = self.nneval(x[:, valid]) # shape (NFILT, n_valid) # Convert to apparent magnitudes seds[valid] = -2.5 * logl[valid, None] + 4.74 - BC.T + mu return seds[:, filt_idxs]