Source code for brutus.core.individual

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Individual stellar modeling and synthetic photometry generation.

This module provides classes for modeling individual stars using MIST
(MESA Isochrones and Stellar Tracks) evolutionary tracks and generating
synthetic photometry with neural network-based bolometric corrections.

The module follows a clean separation of concerns:
- EEPTracks: Stellar parameter predictions for individual stars
- StarEvolTrack: SED/photometry generation for individual stars

This design mirrors the stellar population modeling pattern:
- Isochrone: Stellar parameter predictions for populations
- StellarPop: SED/photometry generation for populations

Classes
-------
EEPTracks : Individual stellar parameter predictions
    Interpolates MIST evolutionary tracks to predict stellar parameters
    (age, luminosity, temperature, etc.) as a function of initial mass,
    evolutionary phase (EEP), metallicity, and alpha enhancement.

StarEvolTrack : Individual stellar photometry synthesis
    Generates synthetic photometry for individual stars using neural
    network bolometric corrections, with support for binary companions,
    dust extinction, and observational effects.

Examples
--------
Basic individual star modeling:

>>> from brutus.core.individual import EEPTracks, StarEvolTrack
>>>
>>> # Create tracks for parameter predictions
>>> tracks = EEPTracks()
>>> params = tracks.get_predictions([1.0, 350, 0.0, 0.0])  # mini, eep, feh, afe
>>>
>>> # Create star track for photometry
>>> star_track = StarEvolTrack(tracks=tracks)
>>> sed, params, params2 = star_track.get_seds(
>>>     mini=1.0, eep=350, feh=0.0, afe=0.0,
>>>     av=0.1, dist=1000.0
>>> )

Advanced usage with binary stars:

>>> # Model binary system with mass ratio 0.7
>>> sed, params, params2 = star_track.get_seds(
>>>     mini=1.2, eep=400, feh=-0.2, afe=0.1,
>>>     smf=0.7, av=0.15, dist=1500.0
>>> )

Notes
-----
This design provides several advantages:

1. **Separation of Concerns**: Parameter prediction vs. photometry synthesis
2. **Consistency**: Mirrors the stellar population modeling pattern
3. **Flexibility**: Can use different SED generators with same tracks
4. **Generality**: EEPTracks name allows for future non-MIST implementations
5. **Maintainability**: Cleaner, more focused class responsibilities

The StarEvolTrack class uses dependency injection, accepting an EEPTracks
instance rather than inheriting from it. This makes the code more modular
and allows for different track implementations.

This implementation is based on the MIST stellar evolution framework
(Choi et al. 2016; Dotter 2016).

References
----------
- Choi et al. 2016, "MESA Isochrones and Stellar Tracks (MIST) 0. Methods
  for the Construction of Stellar Isochrones", ApJ, 823, 102
- Dotter 2016, "MESA Isochrones and Stellar Tracks (MIST) I. Solar-scaled
  Models", ApJS, 222, 8
"""

import os
import pickle
import sys
import warnings
from pathlib import Path

import h5py
import numpy as np
from scipy.interpolate import RegularGridInterpolator
from scipy.optimize import minimize

# Import filter definitions
from ..data.filters import FILTERS

# Import utilities
from ..utils.photometry import add_mag

# Import neural network predictor
from .neural_nets import FastNNPredictor

__all__ = ["EEPTracks", "StarEvolTrack"]


# Rename parameters from MIST HDF5 file for easier use as keyword arguments
rename = {
    "mini": "initial_mass",  # input parameters
    "eep": "EEP",
    "feh": "initial_[Fe/H]",
    "afe": "initial_[a/Fe]",
    "mass": "star_mass",  # outputs
    "feh_surf": "[Fe/H]",
    "afe_surf": "[a/Fe]",
    "loga": "log_age",
    "logt": "log_Teff",
    "logg": "log_g",
    "logl": "log_L",
    "logr": "log_R",
}


[docs] class EEPTracks: """ Stellar parameter predictions for individual stars using evolutionary tracks. This class provides interpolation of stellar parameters along evolutionary tracks as a function of initial mass, equivalent evolutionary point (EEP), metallicity, and alpha enhancement. It focuses solely on stellar parameter prediction without photometry generation. The class name "EEPTracks" is intentionally general to allow for future implementations beyond MIST, while the current implementation uses MIST (MESA Isochrones and Stellar Tracks) evolutionary models. For photometry generation, use StarEvolTrack with this class. Parameters ---------- mistfile : str, optional Path to the HDF5 file containing the evolutionary tracks. If not provided, defaults to the standard MIST v1.2 EEP tracks file. predictions : iterable of str, optional The names of stellar parameters to predict. Default is: `["loga", "logl", "logt", "logg", "feh_surf", "afe_surf"]`. ageweight : bool, optional Whether to compute age weights d(age)/d(EEP) for age priors. Default is `True`. verbose : bool, optional Whether to output progress messages during initialization. Default is `True`. use_cache : bool, optional Whether to use pickle caching to speed up loading. If True, will save processed EEPTracks to a .pkl file for faster subsequent loads, and load from cache if available and newer than the original file. Default is `True`. Attributes ---------- labels : list of str Input parameter names: ['mini', 'eep', 'feh', 'afe'] predictions : list of str Output parameter names as specified in initialization ndim, npred : int Number of input dimensions and predicted parameters interpolator : scipy.interpolate.RegularGridInterpolator The main interpolation object for stellar parameter prediction gridpoints : dict Unique grid points for each input parameter Examples -------- Predict parameters for a solar-mass star on the main sequence: >>> tracks = EEPTracks() >>> params = tracks.get_predictions([1.0, 350, 0.0, 0.0]) # mini, eep, feh, afe >>> log_age = params[0] # log10(age in years) >>> log_teff = params[2] # log10(effective temperature) >>> print(f"Age: {10**log_age:.1e} yr, Teff: {10**log_teff:.0f} K") Batch prediction for multiple stars: >>> import numpy as np >>> labels = np.array([[0.8, 350, -0.5, 0.2], # Metal-poor dwarf ... [1.2, 454, 0.0, 0.0], # Solar at turnoff ... [2.0, 500, 0.3, 0.0]]) # Massive metal-rich >>> preds = tracks.get_predictions(labels) >>> ages = 10**preds[:, 0] # Convert to linear ages """
[docs] def __init__( self, mistfile=None, predictions=["loga", "logl", "logt", "logg", "feh_surf", "afe_surf"], ageweight=True, verbose=True, use_cache=True, ): # Define input parameter labels labels = ["mini", "eep", "feh", "afe"] # Initialize values self.labels = list(np.array(labels)) self.predictions = list(np.array(predictions)) self.ndim, self.npred = len(self.labels), len(self.predictions) self.null = np.zeros(self.npred) + np.nan # Set label references for fast indexing self.mini_idx = np.where(np.array(self.labels) == "mini")[0][0] self.eep_idx = np.where(np.array(self.labels) == "eep")[0][0] self.feh_idx = np.where(np.array(self.labels) == "feh")[0][0] self.logt_idx = np.where(np.array(self.predictions) == "logt")[0][0] self.logl_idx = np.where(np.array(self.predictions) == "logl")[0][0] self.logg_idx = np.where(np.array(self.predictions) == "logg")[0][0] # Set default file path if mistfile is None: package_root = Path(__file__).parent.parent.parent.parent mistfile = package_root / "data" / "DATAFILES" / "MIST_1.2_EEPtrk.h5" # If default path doesn't exist, try pooch cache directory # Convert to string for os.path.exists() - helps with mock compatibility if not os.path.exists(str(mistfile)): import pooch cache_dir = Path(pooch.os_cache("astro-brutus")) cache_path = cache_dir / "MIST_1.2_EEPtrk.h5" if os.path.exists(str(cache_path)): mistfile = cache_path self.mistfile = Path(mistfile) # Generate cache file path based on original file and configuration cache_key = ( f"{self.mistfile.stem}_ageweight{ageweight}_pred{''.join(predictions)}" ) cache_file = self.mistfile.parent / f"{cache_key}.pkl" # Try to load from cache first if use_cache and cache_file.exists(): try: # Check if cache is newer than original file cache_mtime = cache_file.stat().st_mtime orig_mtime = self.mistfile.stat().st_mtime if cache_mtime > orig_mtime: if verbose: sys.stderr.write( f"Loading cached EEPTracks from {cache_file}...\n" ) with open(cache_file, "rb") as f: cached_data = pickle.load(f) # Restore all cached attributes for attr, value in cached_data.items(): setattr(self, attr, value) if verbose: sys.stderr.write("Cached EEPTracks loaded successfully!\n") return else: if verbose: sys.stderr.write( "Cache is older than data file, regenerating...\n" ) except Exception as e: if verbose: sys.stderr.write( f"Cache loading failed ({e}), loading from original file...\n" ) # Load from original file if verbose: sys.stderr.write(f"Loading evolutionary tracks from {mistfile}...\n") # Load and process track data try: with h5py.File(self.mistfile, "r") as misth5: self._make_lib(misth5, verbose=verbose) self._lib_as_grid() # Construct age weights if requested self._ageidx = self.predictions.index("loga") if ageweight: self._add_age_weights(verbose=verbose) # Build interpolation grid self._build_interpolator() # Save to cache if enabled if use_cache: try: # Collect all relevant attributes for caching cache_data = {} cache_attrs = [ "labels", "predictions", "ndim", "npred", "null", "mini_idx", "eep_idx", "feh_idx", "logt_idx", "logl_idx", "logg_idx", "_ageidx", "mistfile", "grid_dims", "interpolator", ] # Add dynamic attributes that get created during processing for attr in dir(self): if not attr.startswith("__") and attr not in cache_attrs: if hasattr(self, attr): value = getattr(self, attr) # Only cache serializable objects if not callable(value): cache_attrs.append(attr) # Cache all attributes for attr in cache_attrs: if hasattr(self, attr): cache_data[attr] = getattr(self, attr) with open(cache_file, "wb") as f: pickle.dump(cache_data, f) if verbose: sys.stderr.write(f"EEPTracks cached to {cache_file}\n") except Exception as e: if verbose: sys.stderr.write(f"Warning: Failed to cache EEPTracks ({e})\n") except Exception as e: raise RuntimeError(f"Failed to initialize EEPTracks: {e}") if verbose: sys.stderr.write("done!\n")
def _make_lib(self, misth5, verbose=True): """ Convert HDF5 input to numpy arrays for labels and outputs. Reads evolutionary track data from MIST HDF5 file and organizes it into structured arrays for input parameters (mini, eep, feh, afe) and predicted outputs (loga, logl, logt, logg, etc.). Parameters ---------- misth5 : h5py.File Open HDF5 file containing MIST evolutionary track data. verbose : bool, optional Whether to print progress messages. Default is True. Notes ----- This method handles the case where alpha enhancement data ([α/Fe]) is not available in the file by setting it to zero. """ if verbose: sys.stderr.write(" Constructing track library...\n") # Extract input parameters (mini, eep, feh, afe) cols = [rename[p] for p in self.labels] self.libparams = np.concatenate( [np.array(misth5[z])[cols] for z in misth5["index"]] ) self.libparams.dtype.names = tuple(self.labels) # Handle alpha enhancement availability cols = [rename[p] for p in self.predictions] afe_col = rename["afe_surf"] afe_available = True afe_surf_idx = None try: # Test alpha enhancement column availability first_z = list(misth5["index"])[0] _ = misth5[first_z][afe_col] except (KeyError, ValueError): afe_available = False for i, pred in enumerate(self.predictions): if pred == "afe_surf": afe_surf_idx = i break if verbose: sys.stderr.write(" [alpha/Fe] column not found, will set to zero\n") # Read output parameters efficiently cols_to_read = [] read_to_pred_mapping = [] for pred_idx, col in enumerate(cols): if not afe_available and col == afe_col: continue else: cols_to_read.append(col) read_to_pred_mapping.append(pred_idx) if verbose: sys.stderr.write(f" Reading {len(cols_to_read)} parameter columns\n") output_data = [ np.concatenate([misth5[z][p] for z in misth5["index"]]) for p in cols_to_read ] # Create and fill output array self.output = np.empty((len(output_data[0]), len(self.predictions)), dtype="f8") for read_idx, pred_idx in enumerate(read_to_pred_mapping): self.output[:, pred_idx] = output_data[read_idx] # Handle missing alpha enhancement if not afe_available and afe_surf_idx is not None: self.output[:, afe_surf_idx] = 0.0 def _lib_as_grid(self): """ Convert library parameters to pixel indices for interpolation. Determines the unique grid points in each parameter dimension (mini, eep, feh, afe) and creates mappings from continuous parameter values to discrete grid indices for efficient interpolation. Notes ----- This method populates the following attributes: - gridpoints: dict of unique values for each parameter - binwidths: dict of grid spacings for each parameter - X: array of grid indices for each library point - mini_bound: minimum initial mass in the grid """ # Get unique grid points in each dimension self.gridpoints = {} self.binwidths = {} for p in self.labels: self.gridpoints[p] = np.unique(self.libparams[p]) self.binwidths[p] = np.diff(self.gridpoints[p]) # Digitize library parameters to grid indices X = np.array( [ np.digitize(self.libparams[p], bins=self.gridpoints[p], right=True) for p in self.labels ] ) self.X = X.T # Store minimum mass bound self.mini_bound = self.gridpoints["mini"].min() def _add_age_weights(self, verbose=True): """ Compute age gradient d(age)/d(EEP) for age priors. Calculates the derivative of age with respect to evolutionary point (EEP) for each track. This is used to properly weight age priors when sampling in EEP space, accounting for the non-uniform mapping between EEP and age. Parameters ---------- verbose : bool, optional Whether to print progress messages. Default is True. Notes ----- Uses pandas for vectorized computation if available, otherwise falls back to a slower loop-based implementation. The age weights are appended to the predictions array and "agewt" is added to the predictions list. The gradient is computed as d(age)/d(EEP) where age is in linear (not logarithmic) units, even though ages are stored as log(age). """ if verbose: sys.stderr.write(" Computing age weights...\n") # Use vectorized approach with pandas if available try: import pandas as pd age_ind = self._ageidx df_data = { "mini": self.libparams["mini"], "feh": self.libparams["feh"], "afe": self.libparams["afe"], "loga": self.output[:, age_ind], "index": np.arange(len(self.libparams)), } df = pd.DataFrame(df_data) ageweights = np.zeros(len(self.libparams)) for (m, z, a), group in df.groupby(["mini", "feh", "afe"]): indices = group["index"].values log_ages = group["loga"].values if len(log_ages) > 1: linear_ages = 10**log_ages age_gradients = np.gradient(linear_ages) ageweights[indices] = age_gradients except ImportError: # Fallback to original method if verbose: sys.stderr.write(" Using fallback method (pandas not available)\n") age_ind = self._ageidx ageweights = np.zeros(len(self.libparams)) for i, m in enumerate(self.gridpoints["mini"]): for j, z in enumerate(self.gridpoints["feh"]): for k, a in enumerate(self.gridpoints["afe"]): inds = ( (self.libparams["mini"] == m) & (self.libparams["feh"] == z) & (self.libparams["afe"] == a) ) try: agewts = np.gradient(10 ** self.output[inds, age_ind]) ageweights[inds] = agewts except (ValueError, IndexError): pass # Append to outputs self.output = np.hstack([self.output, ageweights[:, None]]) self.predictions += ["agewt"] def _build_interpolator(self): """ Build the RegularGridInterpolator for fast predictions. Creates a scipy RegularGridInterpolator object that enables fast multi-linear interpolation of stellar parameters across the 4-dimensional grid (mini, eep, feh, afe). Notes ----- Handles the special case where alpha enhancement has only one value by padding the grid dimension to enable interpolation. Uses linear interpolation with NaN fill values for out-of-bounds queries. The interpolator maps from (mini, eep, feh, afe) input coordinates to all predicted stellar parameters simultaneously. """ # Set up grid dimensions self.grid_dims = np.append( [len(self.gridpoints[p]) for p in self.labels], self.output.shape[-1] ) self.xgrid = tuple([self.gridpoints[lbl] for lbl in self.labels]) # Initialize output grid self.ygrid = np.zeros(self.grid_dims) + np.nan # Fill grid using optimized indexing if len(self.X) > 0: indices = tuple(self.X.T) self.ygrid[indices] = self.output # Handle singular alpha enhancement dimension if self.grid_dims[-2] == 1: afe_val = self.xgrid[-1][0] xgrid = list(self.xgrid) xgrid[-1] = np.array([afe_val - 1e-5, afe_val + 1e-5]) self.xgrid = tuple(xgrid) # Duplicate values in padded dimension self.grid_dims[-2] += 1 ygrid = np.empty(self.grid_dims) ygrid[:, :, :, 0, :] = self.ygrid[:, :, :, 0, :] ygrid[:, :, :, 1, :] = self.ygrid[:, :, :, 0, :] self.ygrid = ygrid # Initialize interpolator self.interpolator = RegularGridInterpolator( self.xgrid, self.ygrid, method="linear", bounds_error=False, fill_value=np.nan, )
[docs] def get_predictions(self, labels, apply_corr=True, corr_params=None): """ Generate stellar parameter predictions for given input parameters. Parameters ---------- labels : array-like of shape (4,) or (Nobj, 4) Input parameters [mini, eep, feh, afe] where: - mini: Initial mass in solar masses - eep: Equivalent evolutionary point - feh: Metallicity [Fe/H] in logarithmic solar units - afe: Alpha enhancement [alpha/Fe] in logarithmic solar units apply_corr : bool, optional Whether to apply empirical corrections. Default is True. corr_params : tuple, optional Correction parameters (dtdm, drdm, msto_smooth, feh_scale). Returns ------- preds : numpy.ndarray of shape (Npred,) or (Nobj, Npred) Predicted stellar parameters in the order specified by `self.predictions` attribute. See Also -------- get_corrections : Computes empirical corrections applied when apply_corr=True StarEvolTrack.get_seds : Uses these predictions to generate photometry Examples -------- Single star prediction: >>> tracks = EEPTracks() >>> params = tracks.get_predictions([1.0, 350, 0.0, 0.0]) >>> log_age, log_L, log_Teff, log_g = params[:4] Multiple star prediction: >>> import numpy as np >>> labels = np.array([[0.8, 350, -0.5, 0.2], [1.2, 454, 0.0, 0.0]]) >>> params = tracks.get_predictions(labels) """ labels = np.array(labels) ndim = labels.ndim # Perform interpolation if ndim == 1: preds = self.interpolator(labels)[0] elif ndim == 2: preds = self.interpolator(labels) else: raise ValueError("Input `labels` must be 1-D or 2-D array.") # Apply empirical corrections if requested if apply_corr: corrs = self.get_corrections(labels, corr_params=corr_params) if ndim == 1: dlogt, dlogr = corrs preds[self.logt_idx] += dlogt preds[self.logl_idx] += 2.0 * dlogr # L ∝ R^2 preds[self.logg_idx] -= 2.0 * dlogr # g ∝ M/R^2 elif ndim == 2: dlogt, dlogr = corrs.T preds[:, self.logt_idx] += dlogt preds[:, self.logl_idx] += 2.0 * dlogr preds[:, self.logg_idx] -= 2.0 * dlogr return preds
[docs] def get_corrections(self, labels, corr_params=None): r""" Compute empirical corrections to stellar parameters. Applies empirical corrections to effective temperature and radius based on stellar mass, evolutionary phase (EEP), and metallicity. These corrections account for systematic offsets between MIST models and observations, particularly for low-mass stars. Parameters ---------- labels : array-like of shape (4,) or (Nobj, 4) Input parameters [mini, eep, feh, afe] where: - mini: Initial mass in solar masses - eep: Equivalent evolutionary point - feh: Metallicity [Fe/H] - afe: Alpha enhancement [α/Fe] corr_params : tuple of float, optional Correction parameters (dtdm, drdm, msto_smooth, feh_scale) where: - dtdm: Temperature correction slope with mass - drdm: Radius correction slope with mass - msto_smooth: Smoothing scale for main sequence turnoff transition - feh_scale: Metallicity scaling factor Default is (0.09, -0.09, 30.0, 0.5). Returns ------- corrs : numpy.ndarray of shape (2,) or (Nobj, 2) Corrections to [log(Teff), log(R)]. These are added to the base predictions from the interpolator. See Also -------- get_predictions : Applies these corrections to stellar parameters Notes ----- Corrections are applied as: .. math:: \\Delta \\log T_{\\rm eff} = f_{\\rm EEP} \\cdot f_{\\rm [Fe/H]} \\cdot \\log(1 + \\Delta M \\cdot \\alpha_T) \\Delta \\log R = f_{\\rm EEP} \\cdot f_{\\rm [Fe/H]} \\cdot \\log(1 + \\Delta M \\cdot \\alpha_R) where :math:`\\Delta M = M_{\\rm ini} - 1.0` and the EEP factor smoothly transitions from 0 (pre-main sequence) to 1 (post-turnoff). Corrections are set to zero for stars with :math:`M_{\\rm ini} \\geq 1.0 M_\\odot`. """ labels = np.array(labels) ndim = labels.ndim # Extract parameters if ndim == 1: mini = labels[self.mini_idx] eep = labels[self.eep_idx] feh = labels[self.feh_idx] elif ndim == 2: mini = labels[:, self.mini_idx] eep = labels[:, self.eep_idx] feh = labels[:, self.feh_idx] else: raise ValueError("Input `labels` must be 1-D or 2-D array.") # Set correction parameters if corr_params is not None: dtdm, drdm, msto_smooth, feh_scale = corr_params else: dtdm, drdm, msto_smooth, feh_scale = 0.09, -0.09, 30.0, 0.5 # Compute corrections with safeguards mass_offset = mini - 1.0 eps = 1e-10 temp_arg = np.maximum(1.0 + mass_offset * dtdm, eps) radius_arg = np.maximum(1.0 + mass_offset * drdm, eps) dlogt = np.log10(temp_arg) dlogr = np.log10(radius_arg) # EEP and metallicity dependence ecorr = 1.0 - 1.0 / (1.0 + np.exp(-(eep - 454.0) / msto_smooth)) fcorr = np.exp(feh_scale * feh) dlogt *= ecorr * fcorr dlogr *= ecorr * fcorr # Zero corrections for solar mass and above if ndim == 1: if mini >= 1.0: dlogt, dlogr = 0.0, 0.0 elif ndim == 2: mask = mini >= 1.0 dlogt[mask] = 0.0 dlogr[mask] = 0.0 # Format output if ndim == 1: corrs = np.array([dlogt, dlogr]) elif ndim == 2: corrs = np.c_[dlogt, dlogr] return corrs
[docs] class StarEvolTrack: """ Synthetic photometry generation for individual stars. This class generates synthetic SEDs and photometry for individual stars using neural network-based bolometric corrections. It provides modeling of binary stars, dust extinction, and observational effects. This class mirrors StellarPop but for individual stars: - StarEvolTrack: Individual star photometry - StellarPop: Stellar population photometry The class uses dependency injection, accepting an EEPTracks instance for stellar parameter predictions rather than inheriting from it. Parameters ---------- tracks : EEPTracks EEPTracks instance for stellar parameter predictions. filters : list of str, optional Filter names for photometry computation. If None, uses all available. nnfile : str, optional Path to neural network file for bolometric corrections. verbose : bool, optional Whether to output progress messages. Default is True. Attributes ---------- tracks : EEPTracks The evolutionary track model for stellar parameter predictions filters : numpy.ndarray Array of filter names predictor : FastNNPredictor Neural network predictor for photometry See Also -------- EEPTracks : Stellar parameter predictions used by this class StarGrid : Alternative grid-based approach for photometry brutus.core.neural_nets.FastNNPredictor : Neural network used for SEDs Examples -------- Basic individual star photometry: >>> tracks = EEPTracks() >>> star_track = StarEvolTrack(tracks=tracks) >>> >>> # Generate SED for a solar-mass main sequence star >>> sed, params, params2 = star_track.get_seds( ... mini=1.0, eep=350, feh=0.0, afe=0.0, ... av=0.1, rv=3.1, dist=1000.0 ... ) Binary star modeling: >>> # Model binary with 70% mass ratio secondary >>> sed, params, params2 = star_track.get_seds( ... mini=1.2, eep=400, feh=-0.2, afe=0.1, ... smf=0.7, av=0.15, dist=1500.0 ... ) """
[docs] def __init__(self, tracks, filters=None, nnfile=None, verbose=True): # Store tracks reference self.tracks = tracks # Set up filters if filters is None: filters = np.array(FILTERS) self.filters = filters # Set default neural network file if nnfile is None: from ..data.loader import find_nn_file nnfile = find_nn_file() # Initialize neural network predictor try: self.predictor = FastNNPredictor( filters=filters, nnfile=nnfile, verbose=verbose ) except Exception as e: if verbose: sys.stderr.write( f"Warning: Neural network initialization failed: {e}\n" ) self.predictor = None
[docs] def get_seds( self, mini=1.0, eep=350, feh=0.0, afe=0.0, av=0.0, rv=3.3, smf=0.0, dist=1000.0, loga_max=10.15, eep2=None, mini_bound=0.08, eep_binary_max=480.0, apply_corr=True, corr_params=None, return_eep2=False, return_dict=True, combine_seds=True, tol=1e-2, **kwargs, ): r""" Generate synthetic SED for an individual star. Parameters ---------- mini : float, optional Initial stellar mass in solar masses. Default is 1.0. eep : float, optional Equivalent evolutionary point. Default is 350. feh : float, optional Metallicity [Fe/H]. Default is 0.0. afe : float, optional Alpha enhancement [α/Fe]. Default is 0.0. av : float, optional V-band extinction A(V). Default is 0.0. rv : float, optional Extinction law parameter R(V). Default is 3.3. smf : float, optional Secondary mass fraction for binary. Default is 0.0 (single star). dist : float, optional Distance in parsecs. Default is 1000.0. loga_max : float, optional Maximum log(age) for SED computation. Default is 10.15. eep2 : float, optional EEP of secondary component for binaries. mini_bound : float, optional Minimum mass for SED computation. Default is 0.08. eep_binary_max : float, optional Maximum EEP for binary modeling. Default is 480.0. apply_corr : bool, optional Apply empirical corrections. Default is True. corr_params : tuple, optional Correction parameters. return_eep2 : bool, optional Return secondary EEP. Default is False. return_dict : bool, optional Return parameters as dictionary. Default is True. combine_seds : bool, optional If True and binary star requested, combine primary and secondary SEDs by adding their magnitudes. If False, separate SEDs for primary and secondary will be returned. Default is True. tol : float, optional Convergence tolerance (in dex) on the log10(age) match when solving for the secondary's EEP in a binary. A companion whose age cannot be matched to within ``tol`` is treated as non-existent and a primary-only SED is returned (rather than discarding the model). Default is 1e-2. The previous default (1e-6) was far tighter than the age emulator's own precision and silently dropped most otherwise-valid binary models. Returns ------- sed : numpy.ndarray of shape (Nfilters,) or tuple Synthetic SED in magnitudes. params : dict or numpy.ndarray Primary component stellar parameters. params2 : dict or numpy.ndarray Secondary component parameters. eep2 : float, optional Secondary EEP (if return_eep2=True). See Also -------- EEPTracks.get_predictions : Stellar parameter predictions FastNNPredictor.sed : Neural network SED generation _get_eep_for_secondary : Binary companion EEP calculation Notes ----- Distance modulus is applied as: .. math:: m = M + 5 \\log_{10}(d/10\\,{\\rm pc}) where d is the distance in parsecs. Binary SEDs are combined using magnitude addition (flux summing): .. math:: m_{\\rm combined} = -2.5 \\log_{10}(10^{-0.4 m_1} + 10^{-0.4 m_2}) Examples -------- Single star: >>> sed, params, params2 = star_track.get_seds( ... mini=1.0, eep=350, feh=0.0, afe=0.0 ... ) Binary system: >>> sed, params, params2 = star_track.get_seds( ... mini=1.2, eep=400, feh=-0.2, afe=0.1, smf=0.7 ... ) """ if self.predictor is None: raise RuntimeError("Neural network predictor not available") # Grab input labels labels = {"mini": mini, "eep": eep, "feh": feh, "afe": afe} labels = np.array([labels[lbl] for lbl in self.tracks.labels]) # Generate primary component predictions try: params_arr = self.tracks.get_predictions( labels, apply_corr=apply_corr, corr_params=corr_params ) except Exception as e: raise RuntimeError(f"Failed to generate stellar parameters: {e}") # Convert to dictionary format params = dict(zip(self.tracks.predictions, params_arr)) sed = np.full(self.predictor.NFILT, np.nan) sed2 = np.full(self.predictor.NFILT, np.nan) # Initialize secondary parameters params_arr2 = np.full_like(params_arr, np.nan) params2 = dict(zip(self.tracks.predictions, params_arr2)) # Generate primary SED mini_min = max(getattr(self.tracks, "mini_bound", 0.08), mini_bound) loga = params["loga"] if loga <= loga_max: try: sed = self.predictor.sed( logl=params["logl"], logt=params["logt"], logg=params["logg"], feh_surf=params["feh_surf"], afe=params["afe_surf"], av=av, rv=rv, dist=dist, ) except Exception as e: warnings.warn( f"Primary SED generation failed for (mini={mini}, eep={eep}): {e}", RuntimeWarning, stacklevel=2, ) # Add binary companion if requested if smf > 0.0 and eep <= eep_binary_max and mini * smf >= mini_min: # Generate secondary parameters if eep2 is None: eep2 = self._get_eep_for_secondary( loga, mini, eep, feh, afe, smf, tol ) if not np.isfinite(eep2): # The companion's age could not be matched to the primary # within tolerance. Keep the (valid) primary SED rather than # poisoning it with a NaN secondary via add_mag, which would # turn the whole combined SED into NaN and silently discard # an otherwise-valid model during grid generation. warnings.warn( f"Secondary EEP did not converge for binary " f"(mini={mini}, smf={smf}); returning primary-only SED.", RuntimeWarning, stacklevel=2, ) else: labels2 = { "mini": mini * smf, "eep": eep2, "feh": feh, "afe": afe, } labels2 = np.array([labels2[lbl] for lbl in self.tracks.labels]) try: params_arr2 = self.tracks.get_predictions( labels2, apply_corr=apply_corr, corr_params=corr_params ) params2 = dict(zip(self.tracks.predictions, params_arr2)) # Generate secondary SED sed2 = self.predictor.sed( logl=params2["logl"], logt=params2["logt"], logg=params2["logg"], feh_surf=params2["feh_surf"], afe=params2["afe_surf"], av=av, rv=rv, dist=dist, ) if combine_seds: # Combine SEDs (magnitude addition) sed = add_mag(sed, sed2) except Exception as e: warnings.warn( f"Secondary SED generation failed for binary " f"(mini={mini}, smf={smf}, eep2={eep2}): {e}", RuntimeWarning, stacklevel=2, ) # Format output if not return_dict: params = params_arr params2 = params_arr2 if not combine_seds and smf > 0.0: # Return separate SEDs for primary and secondary as a list sed = [sed, sed2] if return_eep2: return sed, params, params2, eep2 else: return sed, params, params2
def _get_eep_for_secondary(self, loga, mini, eep, feh, afe, smf, tol): r""" Calculate EEP for secondary component that matches the age of the primary. This method solves the inverse problem: given a target age (from the primary), find the EEP that produces that age for the secondary star with mass mini*smf. Uses scipy.optimize.minimize with Nelder-Mead to find the best-fit EEP. Parameters ---------- loga : float Target log10(age in years) to match from primary star mini : float Primary star initial mass in solar masses eep : float Primary star EEP (used as initial guess for optimization) feh : float Metallicity [Fe/H] in logarithmic solar units afe : float Alpha enhancement [α/Fe] in logarithmic solar units smf : float Secondary mass fraction (secondary mass = mini * smf) tol : float Convergence tolerance (in dex) on the log10(age) match. The squared-age-difference loss is accepted when ``sqrt(loss) < tol`` (equivalently ``loss < tol**2``); otherwise NaN is returned. Returns ------- eep2 : float EEP for secondary star that produces the target age. Returns NaN if optimization fails or doesn't converge within tolerance. See Also -------- get_seds : Uses this method for binary star modeling EEPTracks.get_predictions : Called to evaluate ages at different EEPs Notes ----- The optimization minimizes the squared age difference: .. math:: L(EEP_2) = (\\log_{10} age(M_2, EEP_2) - \\log_{10} age_{\\rm target})^2 where :math:`M_2 = M_1 \\times smf` is the secondary mass. The alpha enhancement parameter is currently fixed to solar ([α/Fe] = 0.0) for secondary stars, regardless of the primary's value. This is intentional because the current MIST model grids do not include α-enhanced tracks. The `afe` parameter is accepted for API consistency but not used. """ # Get age index from tracks aidx = self.tracks.predictions.index("loga") # Define loss function: minimize difference between predicted and target age def loss(x): if isinstance(x, np.ndarray) and x.size == 1: x = x[0] # Get predicted age for secondary star with mass mini*smf at EEP x try: loga_pred = self.tracks.get_predictions([mini * smf, x, feh, 0.0])[aidx] return (loga_pred - loga) ** 2 except Exception: # Return large loss if prediction fails return 1e6 # Find best-fit EEP that minimizes age difference try: with warnings.catch_warnings(): warnings.simplefilter("ignore") # ignore bad values during optimization res = minimize(loss, eep, method="Nelder-Mead") # Check if solution meets tolerance. loss is the squared log-age # difference, so this accepts matches with |dloga| < tol (in dex). if res.fun < tol**2: eep2 = res.x[0] else: eep2 = np.nan except Exception as e: warnings.warn( f"EEP optimization failed for secondary star " f"(mini={mini}, smf={smf}): {e}", RuntimeWarning, stacklevel=2, ) eep2 = np.nan return eep2
[docs] class StarGrid: """ Grid-based stellar modeling and synthetic photometry generation. This class provides an interface for working with pre-computed stellar model grids, enabling parameter interpolation and SED generation similar to StarEvolTrack but using grid-based models rather than evolutionary tracks. The grid structure allows for irregular spacing in each dimension, with models indexed by their array location. Multi-linear interpolation is used to compute stellar parameters and photometry between grid points. Parameters ---------- models : numpy.ndarray of shape (Nmodel, Nfilt, Ncoef) or dict/h5py file Pre-computed model grid containing photometric coefficients. If dict or h5py file, should contain 'mag_coeffs' key. Each model contains 3 coefficients per filter: - Unreddened magnitude - Reddening vector for R_V = 0 - Change in reddening vector as function of R_V models_labels : structured numpy.ndarray of shape (Nmodel,) Labels for each model in the grid (e.g., mini, eep, feh, afe, smf). models_params : structured numpy.ndarray of shape (Nmodel,), optional Additional parameters for each model (e.g., loga, logl, logt, logg). If not provided, these won't be available in predictions. filters : list of str, optional Filter names for photometry. If None, uses all available filters from the models. verbose : bool, optional Whether to print progress messages. Default is True. Attributes ---------- nmodels : int Number of models in the grid nfilters : int Number of filters filters : numpy.ndarray Array of filter names labels : structured numpy.ndarray Grid labels (mini, eep, feh, etc.) params : structured numpy.ndarray or None Additional parameters if provided label_names : list Names of available labels param_names : list Names of available parameters Notes ----- Binary star support is currently limited. The `smf` parameter is accepted for API compatibility with StarEvolTrack, but full binary modeling requires a dedicated binary grid with pre-computed combined photometry. The current implementation returns empty placeholders for secondary parameters. Examples -------- Load a pre-computed grid and generate photometry: >>> from brutus.data import load_models >>> models, labels, label_mask = load_models('grid_mist_v9.h5') >>> grid = StarGrid(models, labels) >>> >>> # Get predictions for specific stellar parameters >>> predictions = grid.get_predictions(mini=1.0, eep=350, feh=0.0) >>> >>> # Generate SED with extinction >>> sed, params, params2 = grid.get_seds( ... mini=1.0, eep=350, feh=0.0, ... av=0.1, rv=3.3, dist=1000.0 ... ) """
[docs] def __init__( self, models, models_labels, models_params=None, filters=None, verbose=True ): """Initialize the StarGrid with model data.""" # Handle different input formats if isinstance(models, dict): # Dictionary input (e.g., from h5py file) if "mag_coeffs" in models: mag_coeffs = models["mag_coeffs"] else: raise ValueError("models dict must contain 'mag_coeffs' key") if "labels" in models and models_labels is None: models_labels = models["labels"] if "parameters" in models and models_params is None: models_params = models["parameters"] models = mag_coeffs # Store model data self.models = np.asarray(models) self.labels = np.asarray(models_labels) self.params = np.asarray(models_params) if models_params is not None else None # Get dimensions if self.models.ndim == 2: # Models are (Nmodel, Nfilt*3) - reshape to (Nmodel, Nfilt, 3) self.nmodels = self.models.shape[0] self.nfilters = self.models.shape[1] // 3 self.models = self.models.reshape(self.nmodels, self.nfilters, 3) elif self.models.ndim == 3: # Models are already (Nmodel, Nfilt, 3) self.nmodels, self.nfilters, ncoef = self.models.shape if ncoef != 3: raise ValueError(f"Expected 3 coefficients per filter, got {ncoef}") else: # Handle structured array from actual data files if models.dtype.names is not None: # Extract filter names and reshape filter_names = list(models.dtype.names) if filters is not None: # Filter to requested filters filter_names = [f for f in filter_names if f in filters] self.filters = np.array(filter_names) self.nfilters = len(self.filters) self.nmodels = len(models) # Extract coefficients for each filter model_array = np.zeros((self.nmodels, self.nfilters, 3)) for i, filt in enumerate(self.filters): model_array[:, i, :] = models[filt] self.models = model_array else: raise ValueError(f"Unexpected models shape: {self.models.shape}") # Set filter names if not already set if not hasattr(self, "filters"): if filters is not None: self.filters = np.array(filters) else: # Generate default filter names self.filters = np.array([f"filter_{i}" for i in range(self.nfilters)]) # Store label and parameter names if hasattr(self.labels, "dtype") and self.labels.dtype.names: self.label_names = list(self.labels.dtype.names) else: self.label_names = [] if ( self.params is not None and hasattr(self.params, "dtype") and self.params.dtype.names ): self.param_names = list(self.params.dtype.names) else: self.param_names = [] # Build lookup indices for efficient interpolation self._build_grid_indices() # Initialize KD-tree placeholder (built on first use) self.kdtree = None self._kdtree_labels = None # Track which labels are in KD-tree if verbose: print( f"Loaded StarGrid with {self.nmodels:,} models, " f"{self.nfilters} filters, {len(self.label_names)} labels" ) if self.param_names: print(f"Additional parameters: {', '.join(self.param_names)}")
def _build_grid_indices(self): """Build indices for efficient grid lookup and interpolation.""" # Create unique value arrays for each label dimension self.grid_axes = {} self.grid_shape = [] for label in self.label_names: if label in ["mini", "eep", "feh", "afe", "smf"]: unique_vals = np.unique(self.labels[label]) self.grid_axes[label] = unique_vals self.grid_shape.append(len(unique_vals)) # Create mapping from label values to grid indices self.label_to_idx = {} for label, values in self.grid_axes.items(): self.label_to_idx[label] = {val: idx for idx, val in enumerate(values)} def _build_kdtree(self, **kwargs): """ Build KD-tree for efficient nearest neighbor queries. Only built on first use to avoid overhead if only using via BruteForce. """ if self.kdtree is not None: return # Already built from scipy.spatial import cKDTree # Determine which labels to use for KD-tree active_labels = [] for label in ["mini", "eep", "feh", "afe", "smf"]: if ( label in self.label_names and label in kwargs and kwargs[label] is not None ): active_labels.append(label) if not active_labels: # Use all available labels active_labels = [ lbl for lbl in ["mini", "eep", "feh", "afe", "smf"] if lbl in self.label_names ] # Build normalized coordinates for KD-tree coords = [] for label in active_labels: vals = self.labels[label] # Normalize to [0, 1] for balanced distance metrics val_min, val_max = vals.min(), vals.max() if val_max > val_min: normalized = (vals - val_min) / (val_max - val_min) else: normalized = np.zeros_like(vals) coords.append(normalized) if coords: self.kdtree = cKDTree(np.column_stack(coords)) self._kdtree_labels = active_labels def _find_neighbors_multilinear(self, **kwargs): """ Find bracketing grid points and compute multi-linear interpolation weights. For each dimension, finds the two bracketing values and computes interpolation weights. This gives us 2^N neighbors for N dimensions. Parameters ---------- **kwargs : keyword arguments Stellar parameters (mini, eep, feh, afe, smf) Returns ------- indices : numpy.ndarray Indices of neighboring grid points weights : numpy.ndarray Interpolation weights for each neighbor Notes ----- Performs multi-linear interpolation by: 1. Finding bracketing grid points in each dimension 2. Computing linear interpolation weights 3. Generating all 2^N corner points for N dimensions 4. Weighting each corner by product of dimension weights Falls back to KD-tree method if grid structure is irregular or interpolation fails. """ import itertools # Get requested parameters req_params = {} for key in ["mini", "eep", "feh", "afe", "smf"]: if key in kwargs and kwargs[key] is not None: req_params[key] = kwargs[key] if not req_params: # No parameters specified, return first model with weight 1 return np.array([0]), np.array([1.0]) # For each parameter, find bracketing indices and weights bracket_info = {} for param, value in req_params.items(): if param in self.grid_axes: axis_values = self.grid_axes[param] # Find bracketing indices using searchsorted idx = np.searchsorted(axis_values, value) if idx == 0: # Before first point idx_low = idx_high = 0 weight_high = 1.0 elif idx >= len(axis_values): # After last point idx_low = idx_high = len(axis_values) - 1 weight_high = 1.0 else: # Between points idx_low = idx - 1 idx_high = idx # Linear interpolation weight val_low = axis_values[idx_low] val_high = axis_values[idx_high] if val_high > val_low: weight_high = (value - val_low) / (val_high - val_low) else: weight_high = 0.5 bracket_info[param] = { "indices": ( [idx_low, idx_high] if idx_low != idx_high else [idx_low] ), "weights": ( [1.0 - weight_high, weight_high] if idx_low != idx_high else [1.0] ), "values": ( axis_values[[idx_low, idx_high]] if idx_low != idx_high else axis_values[[idx_low]] ), } # Generate all combinations of bracket points param_names = list(bracket_info.keys()) index_combinations = itertools.product( *[bracket_info[p]["indices"] for p in param_names] ) weight_combinations = itertools.product( *[bracket_info[p]["weights"] for p in param_names] ) # Find actual grid indices for each combination indices = [] weights = [] for idx_combo, wt_combo in zip(index_combinations, weight_combinations): # Build selection criteria sel = np.ones(self.nmodels, dtype=bool) for param_name, param_idx in zip(param_names, idx_combo): param_val = bracket_info[param_name]["values"][ bracket_info[param_name]["indices"].index(param_idx) ] sel &= self.labels[param_name] == param_val # Handle other parameters not being interpolated for param in self.label_names: if param not in req_params and param in [ "mini", "eep", "feh", "afe", "smf", ]: # Use first available value for unspecified parameters if param in self.grid_axes: sel &= self.labels[param] == self.grid_axes[param][0] # Find matching grid point grid_idx = np.where(sel)[0] if len(grid_idx) > 0: indices.append(grid_idx[0]) # Weight is product of all dimension weights weights.append(np.prod(wt_combo)) if not indices: # Fallback to KD-tree nearest neighbor if multi-linear fails return self._find_neighbors_kdtree(**kwargs) # Normalize weights indices = np.array(indices) weights = np.array(weights) weights /= weights.sum() return indices, weights def _find_neighbors_kdtree(self, **kwargs): """ Find nearest neighbors using KD-tree (fallback method). Parameters ---------- **kwargs : keyword arguments Stellar parameters (mini, eep, feh, afe, smf) Returns ------- indices : numpy.ndarray Indices of neighboring grid points weights : numpy.ndarray Interpolation weights for each neighbor Notes ----- Uses inverse distance weighting with up to k=8 neighbors. Distances are computed in normalized parameter space where each dimension is scaled to [0, 1] for balanced metrics. The KD-tree is built lazily on first call and cached for subsequent queries. """ # Build KD-tree on first use self._build_kdtree(**kwargs) if self.kdtree is None: # KD-tree couldn't be built, use simple nearest neighbor distances = np.zeros(self.nmodels) for label in ["mini", "eep", "feh", "afe", "smf"]: if ( label in kwargs and kwargs[label] is not None and label in self.label_names ): label_vals = self.labels[label] val_range = label_vals.max() - label_vals.min() if val_range > 0: distances += ((label_vals - kwargs[label]) / val_range) ** 2 nearest_idx = np.argmin(distances) return np.array([nearest_idx]), np.array([1.0]) # Build query point query_point = [] for label in self._kdtree_labels: if label in kwargs and kwargs[label] is not None: val = kwargs[label] else: val = self.grid_axes[label][0] if label in self.grid_axes else 0 # Normalize vals = self.labels[label] val_min, val_max = vals.min(), vals.max() if val_max > val_min: normalized = (val - val_min) / (val_max - val_min) else: normalized = 0.0 query_point.append(normalized) # Query KD-tree for nearest neighbors k = min(8, self.nmodels) # Use up to 8 neighbors distances, indices = self.kdtree.query(query_point, k=k) # Convert distances to weights (inverse distance weighting) epsilon = 1e-10 weights = 1.0 / (distances + epsilon) weights /= weights.sum() return indices, weights
[docs] def get_predictions( self, mini=None, eep=None, feh=None, afe=None, smf=None, use_multilinear=True, **kwargs, ): """ Get stellar parameter predictions from the grid. Interpolates grid models to estimate stellar parameters at the requested input values using multi-linear interpolation. Parameters ---------- mini : float, optional Initial mass in solar masses eep : float, optional Equivalent evolutionary phase feh : float, optional Metallicity [Fe/H] afe : float, optional Alpha enhancement [α/Fe] smf : float, optional Secondary mass fraction for binaries use_multilinear : bool, optional Use multi-linear interpolation (True) or KD-tree nearest neighbor (False). Default is True. **kwargs : additional parameters Any additional selection criteria Returns ------- predictions : dict or numpy.ndarray Predicted stellar parameters. Returns dict with parameter names as keys if parameters are available, otherwise returns array. See Also -------- get_seds : Generate photometry along with parameter predictions _find_neighbors_multilinear : Multi-linear interpolation method _find_neighbors_kdtree : KD-tree nearest neighbor method Examples -------- >>> grid = StarGrid(models, labels, params) >>> preds = grid.get_predictions(mini=1.0, eep=350, feh=0.0) >>> print(f"log(age) = {preds['loga']:.2f}") >>> print(f"log(L) = {preds['logl']:.2f}") """ # Find neighboring grid points if use_multilinear: indices, weights = self._find_neighbors_multilinear( mini=mini, eep=eep, feh=feh, afe=afe, smf=smf, **kwargs ) else: indices, weights = self._find_neighbors_kdtree( mini=mini, eep=eep, feh=feh, afe=afe, smf=smf, **kwargs ) # Interpolate parameters if available if self.params is not None and self.param_names: predictions = {} # Add input labels to predictions for label in ["mini", "eep", "feh", "afe", "smf"]: value = locals()[label] if label in self.label_names and value is not None: predictions[label] = value # Interpolate each parameter for param in self.param_names: param_vals = self.params[param][indices] predictions[param] = np.sum(param_vals * weights) return predictions else: # Return weighted average of labels only predictions = {} for label in self.label_names: label_vals = self.labels[label][indices] predictions[label] = np.sum(label_vals * weights) return predictions
[docs] def get_seds( self, mini=None, eep=None, feh=None, afe=None, av=0.0, rv=3.3, smf=None, dist=1000.0, return_dict=True, return_flux=False, return_predictions=True, use_multilinear=True, **kwargs, ): r""" Generate synthetic SED from the grid. Interpolates grid models and applies extinction to generate synthetic photometry using multi-linear interpolation. Parameters ---------- mini : float, optional Initial mass in solar masses eep : float, optional Equivalent evolutionary phase feh : float, optional Metallicity [Fe/H] afe : float, optional Alpha enhancement [α/Fe] av : float, optional V-band extinction in magnitudes. Default is 0.0. rv : float, optional Reddening law parameter. Default is 3.3. smf : float, optional Secondary mass fraction for binaries. NOTE: Currently returns empty placeholder for params2. Full binary support requires a dedicated binary grid with pre-computed combined photometry. dist : float, optional Distance in parsecs. Default is 1000.0 (1 kpc), which corresponds to parallax = 1 mas for consistency with Gaia units. return_dict : bool, optional If True, return parameters as dict. Default is True. return_flux : bool, optional If True, return fluxes instead of magnitudes. Default is False. return_predictions : bool, optional If True, compute and return stellar parameters. Set to False to only get photometry (more efficient). Default is True. use_multilinear : bool, optional Use multi-linear interpolation (True) or KD-tree nearest neighbor (False). Default is True. **kwargs : additional parameters Passed to interpolation methods Returns ------- sed : numpy.ndarray of shape (Nfilt,) Synthetic photometry (magnitudes or fluxes) params : dict or numpy.ndarray or None Primary star parameters (None if return_predictions=False) params2 : dict or numpy.ndarray or None Secondary star parameters (empty placeholder - full binary implementation requires dedicated binary grid) See Also -------- get_predictions : Get stellar parameters without photometry StarEvolTrack.get_seds : Alternative track-based SED generation brutus.analysis.BruteForce : Fitting with StarGrid Notes ----- The SED is computed by interpolating magnitude coefficients and applying the extinction law: .. math:: m(\\lambda) = m_0(\\lambda) + A_V \\cdot [r_0(\\lambda) + R_V \\cdot dr(\\lambda)] where :math:`m_0` is the unreddened magnitude, :math:`r_0` and :math:`dr` are the reddening vector coefficients from the grid. Examples -------- >>> grid = StarGrid(models, labels) >>> sed, params, _ = grid.get_seds( ... mini=1.0, eep=350, feh=0.0, ... av=0.1, dist=500.0 ... ) >>> print(f"G magnitude: {sed[0]:.2f}") """ # Warning for binary support limitations if smf is not None and smf > 0: import warnings warnings.warn( "Binary star support in StarGrid is limited. Secondary parameters " "(params2) will be empty. For full binary modeling, use StarEvolTrack " "or wait for binary grid implementation.", UserWarning, stacklevel=2, ) # Find neighboring grid points using chosen interpolation method if use_multilinear: indices, weights = self._find_neighbors_multilinear( mini=mini, eep=eep, feh=feh, afe=afe, smf=smf, **kwargs ) else: indices, weights = self._find_neighbors_kdtree( mini=mini, eep=eep, feh=feh, afe=afe, smf=smf, **kwargs ) # Get weighted average of magnitude coefficients weighted_coeffs = np.zeros((self.nfilters, 3)) for idx, weight in zip(indices, weights): weighted_coeffs += self.models[idx] * weight # Apply extinction and distance modulus using sed_utils from .sed_utils import _get_seds # Get reddened magnitudes # _get_seds expects (Nmodels, Nbands, Ncoef) and arrays for av/rv weighted_coeffs_3d = weighted_coeffs[np.newaxis, :, :] av_array = np.array([av]) rv_array = np.array([rv]) seds_array, _, _ = _get_seds( weighted_coeffs_3d, av_array, rv_array, return_flux=False ) mags = seds_array[0] # Extract the single model result # Apply distance modulus (grid magnitudes stored at 1 kpc = 1000 pc reference) # Note: Grid files are generated at 1 kpc distance for consistency with # Gaia parallax measurements (1 mas = 1 kpc) if dist is not None and dist != 1000.0: # Grid reference is 1 kpc dist_mod = 5.0 * np.log10(dist / 1000.0) mags += dist_mod # Convert to flux if requested if return_flux: from ..utils.photometry import inv_magnitude # Use zero errors for now (could be improved with proper error propagation) sed, _ = inv_magnitude(mags, np.zeros_like(mags)) else: sed = mags # Get parameter predictions using same neighbors (avoid redundant computation) if return_predictions: # Interpolate parameters using already-found neighbors if self.params is not None and self.param_names: params = {} # Add input labels for label in ["mini", "eep", "feh", "afe", "smf"]: value = locals()[label] if label in self.label_names and value is not None: params[label] = value # Interpolate parameters for param in self.param_names: param_vals = self.params[param][indices] params[param] = np.sum(param_vals * weights) else: # Interpolate labels only params = {} for label in self.label_names: label_vals = self.labels[label][indices] params[label] = np.sum(label_vals * weights) else: params = None # Binary placeholder (full implementation requires dedicated binary grid) params2 = None # Format output for compatibility if not return_dict: if isinstance(params, dict): params = np.array(list(params.values())) if params else np.array([]) if params2 is None: params2 = np.array([]) elif params2 is None: params2 = {} return sed, params, params2
[docs] def __repr__(self): """Return string representation of the StarGrid object.""" rep = f"StarGrid(nmodels={self.nmodels:,}, nfilters={self.nfilters}" if self.label_names: rep += f", labels={self.label_names}" rep += ")" return rep
[docs] def __len__(self): """Return the number of models in the grid.""" return self.nmodels