#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
3D dust map implementations.
This module provides classes for querying 3D dust maps, particularly the
Bayestar maps from Green et al. (2019).
"""
import astropy.coordinates as coordinates
import astropy.units as units
import h5py
import numpy as np
from .extinction import lb2pix
__all__ = ["DustMap", "Bayestar"]
[docs]
class DustMap:
"""
Base class for querying 3D dust maps.
This abstract base class defines the interface that all dust map
implementations should follow.
"""
[docs]
def __init__(self):
"""Initialize the dust map."""
pass
[docs]
def __call__(self, coords, **kwargs):
"""
Convenience method for querying the map.
This is an alias for the `query` method.
Parameters
----------
coords : astropy.coordinates.SkyCoord
Coordinates to query.
**kwargs
Additional keyword arguments passed to query.
Returns
-------
Query results as implemented by subclasses.
"""
return self.query(coords, **kwargs)
[docs]
def query(self, coords, **kwargs):
"""
Query the map at a set of coordinates.
Parameters
----------
coords : astropy.coordinates.SkyCoord
Coordinates to query.
**kwargs
Additional keyword arguments.
Returns
-------
Query results as implemented by subclasses.
Raises
------
NotImplementedError
This method must be implemented by subclasses.
"""
raise NotImplementedError("DustMap.query must be implemented by subclasses.")
[docs]
def query_gal(self, ell, b, d=None, **kwargs):
"""
Query the map using Galactic coordinates.
Parameters
----------
ell : float or astropy.units.Quantity
Galactic longitude in degrees.
b : float or astropy.units.Quantity
Galactic latitude in degrees.
d : float or astropy.units.Quantity, optional
Distance from the Solar System in kpc. Not used by HEALPix-based
maps but accepted for API compatibility.
**kwargs
Additional keyword arguments passed to query.
Returns
-------
Query results as implemented by subclasses.
"""
# Extract numeric values from astropy Quantities if needed
if isinstance(ell, units.Quantity):
ell = ell.to(units.deg).value
if isinstance(b, units.Quantity):
b = b.to(units.deg).value
# Pass (l, b) arrays directly to query, avoiding SkyCoord overhead
ell = np.atleast_1d(np.asarray(ell, dtype=float))
b = np.atleast_1d(np.asarray(b, dtype=float))
coords = np.column_stack([ell, b])
return self.query(coords, **kwargs)
[docs]
def query_equ(self, ra, dec, d=None, frame="icrs", **kwargs):
"""
Query the map using Equatorial coordinates.
Parameters
----------
ra : float or astropy.units.Quantity
Right ascension in degrees.
dec : float or astropy.units.Quantity
Declination in degrees.
d : float or astropy.units.Quantity, optional
Distance from the Solar System in kpc.
frame : str, optional
Coordinate frame. Options: 'icrs', 'fk4', 'fk5', 'fk4noeterms'.
Default is 'icrs'.
**kwargs
Additional keyword arguments passed to query.
Returns
-------
Query results as implemented by subclasses.
Raises
------
ValueError
If frame is not one of the supported coordinate frames.
"""
valid_frames = ["icrs", "fk4", "fk5", "fk4noeterms"]
if frame not in valid_frames:
raise ValueError(
f"Frame '{frame}' not supported. Must be one of {valid_frames}."
)
# Handle units
if not isinstance(ra, units.Quantity):
ra = ra * units.deg
if not isinstance(dec, units.Quantity):
dec = dec * units.deg
# Create coordinate object
if d is None:
coords = coordinates.SkyCoord(ra, dec, frame=frame)
else:
if not isinstance(d, units.Quantity):
d = d * units.kpc
coords = coordinates.SkyCoord(ra, dec, distance=d, frame=frame)
return self.query(coords, **kwargs)
[docs]
class Bayestar(DustMap):
"""
Query the Bayestar 3D dust maps from Green et al. (2019).
The Bayestar maps cover the Pan-STARRS 1 footprint (dec > -30°) over
approximately 3/4 of the sky, providing 3D extinction information.
Parameters
----------
dustfile : str, optional
Path to the Bayestar HDF5 data file. Default is 'bayestar2019_v1.h5'.
Attributes
----------
_distances : ndarray
Distance grid points (kpc).
_av_mean : ndarray
Mean A(V) extinction values.
_av_std : ndarray
Standard deviation of A(V) extinction values.
"""
[docs]
def __init__(self, dustfile="bayestar2019_v1.h5"):
"""
Initialize the Bayestar dust map.
Parameters
----------
dustfile : str, optional
Path to the Bayestar HDF5 data file.
"""
super().__init__()
# Open the HDF5 file
try:
# Try SWMR mode first (for concurrent access)
f = h5py.File(dustfile, "r", libver="latest", swmr=True)
except (OSError, ValueError):
# Fall back to regular mode
f = h5py.File(dustfile, "r")
try:
# Load pixel information
self._pixel_info = f["pixel_info"][:]
self._n_pix = self._pixel_info.size
# Load extinction data
self._distances = f["dists"][:]
self._av_mean = f["av_mean"][:]
self._av_std = f["av_std"][:]
self._n_distances = len(self._distances)
# Prepare HEALPix index lookup structures
self._prepare_index_structures()
finally:
f.close()
def _prepare_index_structures(self):
"""Prepare optimized lookup structures for HEALPix indices."""
# Sort pixels by nside and healpix_index for efficient searching
sort_idx = np.argsort(self._pixel_info, order=["nside", "healpix_index"])
self._nside_levels = np.unique(self._pixel_info["nside"])
self._hp_idx_sorted = []
self._data_idx = []
start_idx = 0
for nside in self._nside_levels:
# Find pixels at this nside level
end_idx = np.searchsorted(
self._pixel_info["nside"], nside, side="right", sorter=sort_idx
)
idx = sort_idx[start_idx:end_idx]
# Store sorted HEALPix indices and corresponding data indices
self._hp_idx_sorted.append(self._pixel_info["healpix_index"][idx])
self._data_idx.append(idx)
start_idx = end_idx
def _find_data_idx(self, gal_l, b):
"""
Find data indices corresponding to Galactic coordinates.
Parameters
----------
gal_l : array_like
Galactic longitude(s) in degrees.
b : array_like
Galactic latitude(s) in degrees.
Returns
-------
pix_idx : ndarray
Data indices for each coordinate. Invalid coordinates return -1.
"""
# Ensure arrays and get shape
l_arr = np.asarray(gal_l)
b_arr = np.asarray(b)
pix_idx = np.full(l_arr.shape, -1, dtype="i8")
# Search at each nside level (coarse to fine resolution)
for k, nside in enumerate(self._nside_levels):
# Convert coordinates to HEALPix pixel indices
ipix = lb2pix(nside, l_arr, b_arr, nest=True)
# Find insertion points in the sorted pixel list
idx = np.searchsorted(self._hp_idx_sorted[k], ipix, side="left")
# Handle scalar case
if np.isscalar(idx):
if (
idx < len(self._hp_idx_sorted[k])
and self._hp_idx_sorted[k][idx] == ipix
):
pix_idx[...] = self._data_idx[k][idx]
else:
# Check bounds for array case
in_bounds = idx < len(self._hp_idx_sorted[k])
if not np.any(in_bounds):
continue
# Check for exact matches
idx = np.where(in_bounds, idx, -1)
safe_idx = np.clip(idx, 0, None)
match_idx = in_bounds & (self._hp_idx_sorted[k][safe_idx] == ipix)
if np.any(match_idx):
valid_idx = idx[match_idx]
pix_idx[match_idx] = self._data_idx[k][valid_idx]
return pix_idx
[docs]
def get_query_size(self, coords):
"""
Estimate the total size of a query result.
Parameters
----------
coords : astropy.coordinates.SkyCoord
Coordinates that would be queried.
Returns
-------
int
Estimated total number of data points that would be returned.
"""
n_coords = np.prod(coords.shape, dtype=int)
return n_coords * self._n_distances
[docs]
def query(self, coords):
"""
Query extinction at the specified coordinates.
Parameters
----------
coords : astropy.coordinates.SkyCoord
Coordinates to query. Can be single coordinate or array.
Returns
-------
distances : ndarray
Distance grid points (kpc).
av_mean : ndarray
Mean A(V) extinction values along each line of sight.
av_std : ndarray
Standard deviation of A(V) extinction values.
Notes
-----
For coordinates outside the map coverage, NaN values are returned.
"""
try:
# Try to access as SkyCoord object - convert to Galactic if needed
if hasattr(coords, "galactic"):
gal_coords = coords.galactic
else:
gal_coords = coords
l_deg = gal_coords.l.deg
b_deg = gal_coords.b.deg
except AttributeError:
# Handle as array of coordinates [l, b] in degrees
coords_arr = np.atleast_2d(coords)
l_deg = coords_arr[:, 0]
b_deg = coords_arr[:, 1]
# Find corresponding data indices
pix_idx = self._find_data_idx(l_deg, b_deg)
# Extract extinction data
in_bounds = pix_idx != -1
safe_idx = np.clip(pix_idx, 0, None)
av_mean = self._av_mean[safe_idx].copy()
av_std = self._av_std[safe_idx].copy()
# Set out-of-bounds values to NaN
av_mean[~in_bounds] = np.nan
av_std[~in_bounds] = np.nan
# Handle scalar case - check if input was scalar
scalar_input = (hasattr(coords, "isscalar") and coords.isscalar) or (
not hasattr(coords, "__len__") and np.isscalar(l_deg)
)
if scalar_input and av_mean.shape[0] == 1:
av_mean = av_mean[0]
av_std = av_std[0]
return self._distances, av_mean, av_std