Source code for brutus.data.loader

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Data loading utilities for brutus.

This module contains functions for loading stellar evolution models,
photometric offsets, and other data files into memory for use in
stellar fitting and analysis.
"""

import os
import sys
from pathlib import Path

import h5py
import numpy as np

# Import filter definitions
from .filters import FILTERS

__all__ = ["find_nn_file", "load_models", "load_offsets"]


[docs] def find_nn_file(possible_names=("nn_c3k.h5", "nnMIST_BC.h5")): """ Find the neural network model file in standard locations. Searches for the neural network HDF5 file used for bolometric correction predictions, checking the local ``data/DATAFILES`` directory first, then the pooch cache directory. Parameters ---------- possible_names : tuple of str, optional Filenames to search for, tried in order. Default is ``("nn_c3k.h5", "nnMIST_BC.h5")``. The first file found is returned. Returns ------- path : pathlib.Path Absolute path to the neural network file. Raises ------ FileNotFoundError If none of the candidate files can be found in any searched location. """ # Package root: src/brutus/data/loader.py -> four levels up to repo root package_root = Path(__file__).parent.parent.parent.parent for nn_name in possible_names: # Check local data directory first local_path = package_root / "data" / "DATAFILES" / nn_name if os.path.exists(str(local_path)): return local_path # If not found locally, try pooch cache directory import pooch cache_dir = Path(pooch.os_cache("astro-brutus")) cache_path = cache_dir / nn_name if os.path.exists(str(cache_path)): return cache_path raise FileNotFoundError( f"Could not find neural network file. " f"Searched for {possible_names} in " f"{package_root / 'data' / 'DATAFILES'} and pooch cache. " f"Run brutus.data.fetch_nns() to download the file." )
[docs] def load_models( filepath, filters=None, labels=None, include_ms=True, include_postms=True, include_binaries=False, verbose=True, ): """ Loads pre-computed stellar model grids with photometric coefficients for multiple filters and stellar parameters. Models can be filtered by evolutionary phase and binary status. Parameters ---------- filepath : str The filepath of the stellar model file (typically .h5 format). filters : iterable of str with length `Nfilt`, optional List of filters that will be loaded. If not provided, will default to all available filters. See the internally-defined `FILTERS` variable for more details on filter names. Any filters that are not available will be skipped over. labels : iterable of str with length `Nlabel`, optional List of labels associated with the set of imported stellar models. Any labels that are not available will be skipped over. The default set is `['mini', 'feh', 'eep', 'smf', 'loga', 'logl', 'logt', 'logg', 'Mr', 'agewt']`. include_ms : bool, optional Whether to include objects on the Main Sequence. Applied as a cut on `eep <= 454` when `'eep'` is included. Default is `True`. include_postms : bool, optional Whether to include objects evolved off the Main Sequence. Applied as a cut on `eep > 454` when `'eep'` is included. Default is `True`. include_binaries : bool, optional Whether to include unresolved binaries. Applied as a cut on secondary mass fraction (`'smf'`) when it has been included. Default is `False`. If set to `False`, `'smf'` is not returned as a label. verbose : bool, optional Whether to print progress messages. Default is `True`. Returns ------- models : `~numpy.ndarray` of shape `(Nmodel, Nfilt, Ncoef)` Array of models comprised of coefficients in each band used to describe the photometry as a function of reddening, parameterized in terms of A_V. Each model contains coefficients for: - Unreddened magnitude - Reddening vector for R_V = 0 - Change in reddening vector as function of R_V labels : structured `~numpy.ndarray` with dimensions `(Nmodel, Nlabel)` A structured array with the labels corresponding to each model. Contains stellar parameters like initial mass, metallicity, age, etc. label_mask : structured `~numpy.ndarray` with dimensions `(1, Nlabel)` A structured array that masks ancillary labels associated with predictions (rather than those used to compute the model grid). Raises ------ ValueError If neither main sequence nor post-main sequence models are included. Notes ----- The `label_mask` return value is a boolean structured array indicating which labels are ancillary (derived from the grid) vs. those used to generate the grid. For example, if luminosity is predicted from mass/age/metallicity, this mask would be False for luminosity. Used internally by StarGrid to determine which parameters to marginalize over during fitting. Examples -------- Basic usage with default settings: >>> from brutus.data import load_models, fetch_grids >>> fetch_grids() # Download data (first time only) >>> models, labels, label_mask = load_models('grid_mist_v9.h5') >>> print(f"Loaded {len(labels)} stellar models") >>> print(f"Available labels: {labels.dtype.names}") Loading specific filters only: >>> models, labels, mask = load_models( ... 'grid_mist_v9.h5', ... filters=['g', 'r', 'i', 'z', 'y'] ... ) Loading only main sequence stars: >>> models, labels, mask = load_models( ... 'grid_mist_v9.h5', ... include_ms=True, ... include_postms=False ... ) Using with StarGrid for fitting: >>> from brutus.core import StarGrid >>> from brutus.analysis import BruteForce >>> models, labels, mask = load_models('grid_mist_v9.h5') >>> grid = StarGrid(models, labels, mask) >>> fitter = BruteForce(grid) """ # Initialize values. if filters is None: filters = FILTERS if labels is None: labels = [ "mini", "feh", "eep", "smf", "loga", "logl", "logt", "logg", "Mr", "agewt", ] # Read in models. try: f = h5py.File(filepath, "r", libver="latest", swmr=True) except (OSError, ValueError): f = h5py.File(filepath, "r") with f: mag_coeffs_dataset = f["mag_coeffs"] # Find which requested filters actually exist in the file available_filters = list(mag_coeffs_dataset.dtype.names) valid_filters = [filt for filt in filters if filt in available_filters] if verbose: sys.stderr.write( f"Reading entire dataset ({len(available_filters)} filters) once...\n" ) # Read the ENTIRE dataset once into memory (this is the key optimization!) mag_coeffs = mag_coeffs_dataset[:] if verbose: sys.stderr.write( f"Extracting {len(valid_filters)} requested filters from memory...\n" ) # Pre-allocate array for only the valid filters models = np.zeros((len(mag_coeffs), len(valid_filters), 3), dtype="float32") # Extract each valid filter from the in-memory data (no more H5 I/O!) for i, filt in enumerate(valid_filters): try: models[:, i] = mag_coeffs[filt] # Extract from memory, not H5! except KeyError: pass # Update filters list to only include the ones we actually loaded filters = valid_filters # Read in labels. combined_labels = np.full( len(models), np.nan, dtype=np.dtype([(n, np.float64) for n in labels]) ) label_mask = np.zeros(1, dtype=np.dtype([(n, np.bool_) for n in labels])) try: # Grab "labels" (inputs). flabels = f["labels"][:] for n in flabels.dtype.names: if n in labels: combined_labels[n] = flabels[n] label_mask[n] = True except KeyError: pass try: # Grab "parameters" (predictions from labels). fparams = f["parameters"][:] for n in fparams.dtype.names: if n in labels: combined_labels[n] = fparams[n] except KeyError: pass # Remove extraneous/undefined labels. labels2 = [l for i, l in zip(combined_labels[0], labels) if ~np.isnan(i)] # Apply cuts. sel = np.ones(len(combined_labels), dtype="bool") if not include_ms and not include_postms: raise ValueError( "If you don't include the Main Sequence and " "Post-Main Sequence models you have nothing left!" ) elif include_postms and not include_ms: try: sel = combined_labels["eep"] > 454.0 except KeyError: pass elif include_ms and not include_postms: try: sel = combined_labels["eep"] <= 454.0 except KeyError: pass # else: include_ms and include_postms — sel stays all-True if not include_binaries and "smf" in labels2: try: sel &= combined_labels["smf"] == 0.0 labels2 = [x for x in labels2 if x != "smf"] except KeyError: pass # Compile results. combined_labels = combined_labels[labels2] label_mask = label_mask[labels2] return models[sel], combined_labels[sel], label_mask
[docs] def load_offsets(filepath, filters=None, verbose=True): """ Loads multiplicative photometric offsets used to calibrate systematic differences between observed and synthetic photometry. Parameters ---------- filepath : str The filepath of the photometric offsets file (typically .txt format). filters : iterable of str with length `Nfilt`, optional List of filters that will be loaded. If not provided, will default to all available filters. See the internally-defined `FILTERS` variable for more details on filter names. Any filters that are not available will be skipped over. verbose : bool, optional Whether to print a summary of the offsets. Default is `True`. Returns ------- offsets : `~numpy.ndarray` of shape `(Nfilt)` Array of constants that will be *multiplied* to the *data* to account for offsets (i.e. multiplicative flux offsets). Values are typically close to 1.0, with deviations indicating systematic differences. Notes ----- The offset file should contain two columns: filter names and offset values. Filters not found in the file will be assigned an offset of 1.0 (no correction). Examples -------- >>> from brutus.data import load_offsets >>> offsets = load_offsets('./data/DATAFILES/offsets_mist_v9.txt') >>> print(f"Loaded offsets for {len(offsets)} filters") >>> # Load specific filters >>> gri_offsets = load_offsets('./data/DATAFILES/offsets_mist_v9.txt', ... filters=['g', 'r', 'i']) >>> # Check which filters have significant offsets >>> significant = np.abs(offsets - 1.0) > 0.01 >>> print(f"Filters with >1% offsets: {np.sum(significant)}") """ # Initialize values. if filters is None: filters = FILTERS Nfilters = len(filters) # Read in offsets. numpy.loadtxt may return a 2D array (rows x cols) # where transposing gives columns, or tests may mock it to return a # tuple of (filts, vals). Handle both cases robustly. _tmp = np.loadtxt(filepath, dtype="str") if isinstance(_tmp, tuple): filts, vals = _tmp else: # np.loadtxt collapses a single data row to shape (2,); atleast_2d # promotes it back to (Nrows, 2) so a single-filter offsets file does # not crash (the old arr.T unpack yielded 0-d scalars and a later # "nonzero on 0d arrays" ValueError). arr = np.atleast_2d(np.asarray(_tmp)) filts, vals = arr[:, 0], arr[:, 1] vals = vals.astype(float) # Fill in offsets where appropriate. offsets = np.full(Nfilters, np.nan) for i, filt in enumerate(filters): filt_idx = np.where(filts == filt)[0] # get filter location if len(filt_idx) == 1: offsets[i] = vals[filt_idx[0]] # insert offset elif len(filt_idx) == 0: offsets[i] = 1.0 # assume no offset if not calibrated else: raise ValueError( "Something went wrong when extracting " "offsets for filter {}.".format(filt) ) if verbose: for filt, zp in zip(filters, offsets): sys.stderr.write("{0} ({1:3.2}%)\n".format(filt, 100 * (zp - 1.0))) return offsets