from __future__ import annotations
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING
import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import optimistix as optx
from numpyro import handlers
from numpyro.distributions import Gamma
from numpyro.distributions.util import validate_sample
from numpyro.infer.util import constrain_fn
from chronpy.psd.models import Component, Model
if TYPE_CHECKING:
NDArray = np.ndarray
[docs]
class Lightcurve:
def __init__(
self,
t: NDArray,
counts: NDArray,
tbins: NDArray = None,
exposure: NDArray = None,
):
t = np.array(t, dtype=float, order='C', ndmin=1)
dt = t[1:] - t[:-1]
if np.any(dt < 0):
raise ValueError('t must be sorted in ascending order')
self._t = t
self._evenly_sampled = np.allclose(dt, dt[0])
if self._evenly_sampled:
self._dt = np.mean(dt)
else:
self._dt = None
counts = np.array(counts, dtype=float, order='C', ndmin=1)
if counts.size != t.size:
raise ValueError('counts must have the same size as t')
if np.any(counts < 0):
raise ValueError('counts must be non-negative')
if not np.allclose(np.round(counts), counts):
raise ValueError('counts must be integers')
self._counts = counts
if tbins is not None:
tbins = np.array(tbins, dtype=float, order='C', ndmin=1)
if tbins.size != t.size + 1:
raise ValueError('tbins must have size t.size + 1')
if tbins[0] > t[0] or tbins[-1] < t[-1]:
raise ValueError('t must be within the range of tbins')
if np.any(tbins[1:] - tbins[:-1] < 0):
raise ValueError('tbins must be sorted in ascending order')
elif self._evenly_sampled:
tbins = np.hstack([t[0] - 0.5 * self._dt, t + 0.5 * self._dt])
self._tbins = tbins
if exposure is not None:
exposure = np.array(exposure, dtype=float, order='C', ndmin=1)
if exposure.size != t.size:
raise ValueError('exposure must have the same size as t')
if np.any(exposure <= 0):
raise ValueError('exposure must be positive')
elif self._evenly_sampled:
exposure = np.full(self._t.size, self._dt)
elif self._tbins is not None:
exposure = np.diff(self._tbins)
else:
raise ValueError(
'exposure or tbins must be provided if t is not evenly sampled'
)
self._exposure = exposure
@property
def t(self) -> NDArray:
return self._t
@property
def counts(self) -> NDArray:
return self._counts
@property
def rate(self) -> NDArray:
return self._counts / self._exposure
@property
def tbins(self) -> NDArray | None:
return self._tbins
@property
def exposure(self) -> NDArray:
return self._exposure
@property
def dt(self) -> float | None:
return self._dt
@property
def evenly_sampled(self) -> bool:
return self._evenly_sampled
[docs]
class PSD(metaclass=ABCMeta):
def __init__(
self,
freq_bins: NDArray,
power: NDArray,
dof: NDArray,
dt: float,
df: float,
):
self._freq_bins = np.array(freq_bins, dtype=float, order='C', ndmin=1)
self._power = np.array(power, dtype=float, order='C', ndmin=1)
self._dof = np.array(dof, dtype=float, order='C', ndmin=1)
if not (
self._freq_bins.size - 1 == self._power.size == self._dof.size
):
raise ValueError('freq_bins, power, and dof are not matched')
if np.any(self._power < 0):
raise ValueError('power must be non-negative')
if np.any(self._dof <= 0):
raise ValueError('dof must be positive')
if np.any(self._freq_bins < 0):
raise ValueError('freq_bins must be non-negative')
if np.any(np.diff(self._freq_bins) <= 0):
raise ValueError('freq_bins must be sorted in ascending order')
self._bins_width = np.diff(self._freq_bins)
self._density = self._power / self._bins_width
self._freq = 0.5 * (self._freq_bins[:-1] + self._freq_bins[1:])
self._df = float(df)
self._dt = float(dt)
self._perr = None
self._derr = self.perr / self._bins_width
[docs]
@classmethod
@abstractmethod
def from_lc(cls, lc: Lightcurve | list[Lightcurve], norm: str = 'leahy'):
pass
@property
def freq(self) -> NDArray:
return self._freq
@property
def power(self) -> NDArray:
return self._power
@property
@abstractmethod
def perr(self) -> NDArray:
pass
@property
def density(self) -> NDArray:
return self._density
@property
def derr(self) -> NDArray:
return self._derr
@property
def freq_bins(self) -> NDArray:
return self._freq_bins
@property
def bins_width(self) -> NDArray:
return self._bins_width
@property
def df(self) -> float:
return self._df
@property
def dt(self) -> float:
return self._dt
@property
def dof(self) -> NDArray:
return self._dof
[docs]
class Periodogram(PSD):
[docs]
@classmethod
def from_lc(cls, lc: Lightcurve | list[Lightcurve], norm: str = 'leahy'):
if isinstance(lc, Lightcurve):
lc = [lc]
elif not all(isinstance(i, Lightcurve) for i in lc):
raise ValueError(
'lc must be a Lightcurve or a list of Lightcurves'
)
norm = str(norm).lower()
t = lc[0].t
for i in lc:
if not i.evenly_sampled:
raise ValueError('Lightcurves must be evenly sampled')
if not np.allclose(i.t, t):
raise ValueError('Lightcurves must have the same times')
exposure = np.vstack([i.exposure for i in lc])
rate = np.array([i.rate for i in lc])
rate_total = np.sum(rate, axis=0)
dt = float(lc[0].dt)
df = 1.0 / (rate.shape[1] * dt)
freq = np.fft.rfftfreq(rate_total.size, dt)[1:] # exclude 0 frequency
mod = np.abs(np.fft.rfft(rate_total))[1:] # exclude 0 frequency
if norm == 'leahy':
# this norm makes the mean PSD of wn being equal to 2,
# the dimension of power is frequency, and the PSD is dimensionless
power = mod * mod * 2.0 / np.sum(rate / exposure) * df
else:
raise ValueError('norm must be "leahy"')
# exclude the Nyquist frequency, whose dof of power is 1
if freq[-1] >= 0.5 / dt:
freq = freq[:-1]
power = power[:-1]
freq_bins = np.hstack([freq[0] - 0.5 * df, freq + 0.5 * df])
return cls(freq_bins, power, np.full(freq.size, 2), dt, df)
[docs]
def rebin_log(self, f: float = 0.01) -> PSD:
bins = self.freq_bins
idx = [0]
next_edge = 0.0
for i in range(1, len(bins)):
if bins[i] >= next_edge:
idx.append(i)
df = bins[idx[-1]] - bins[idx[-2]]
next_edge = bins[idx[-1]] + df * (1.0 + f)
if idx[-1] != len(bins) - 1:
idx[-1] = len(bins) - 1
freq_bins = np.array([bins[i] for i in idx])
power = np.add.reduceat(self.power, idx[:-1])
dof = np.add.reduceat(self.dof, idx[:-1])
return type(self)(freq_bins, power, dof, self.dt, self.df)
[docs]
def rebin_significance(self, s: float = 3.0) -> PSD:
assert s >= 1, 's must be greater than 1'
power = self.power
dof = self.dof
n = len(power)
idx = np.empty(n, np.int64)
idx[0] = 0
ng = 1
imax = n - 1
p_group = 0.0
dof_group = 0.0
for i, (pi, vi) in enumerate(zip(power, dof, strict=False)):
p_group += pi
dof_group += vi
x = p_group * (1 - s / np.sqrt(0.5 * dof_group))
if i == imax:
if x < 0 and ng > 1:
# if the last group is not significant,
# then combine the last two groups to ensure all
# groups meet the count requirement
ng -= 1
break
if x >= 0:
idx[ng] = i + 1
ng += 1
p_group = 0.0
dof_group = 0.0
idx = idx[:ng]
dof = np.add.reduceat(dof, idx)
power = np.add.reduceat(power, idx)
bins = self.freq_bins
freq_bins = bins[np.append(idx, len(bins) - 1)]
return type(self)(freq_bins, power, dof, self.dt, self.df)
@property
def perr(self) -> NDArray:
if self._perr is None:
self._perr = np.sqrt(0.5 * self._dof) * self._power
return self._perr
[docs]
class PowerDist(Gamma):
"""The probability density function of the power spectrum.
We know that if X ~ Chi_v^2, then cX ~ Gamma(alpha=v/2, scale=c/2).
For the power spectrum, we know that vI/S ~ Chi_v^2, where I is the
observed power, v is the dof of I, and S is the true power.
Then the underlying distribution of I is Gamma(alpha=v/2, scale=2S/v).
If I is obtained by summing over the power of m adjacent frequency bins,
I = sum_{m} I_m, then there is no analytical expression for the
distribution of I, since the Gamma dist's scale of each I_m is different.
If we assume that the expected powers of I_m are the same, then the
distribution of I is Gamma(alpha=sum_{m} v_m/2, scale=2S/(sum_{m} v_m)).
However, this is usually not the case for the observed powers of I_m,
and this assumption may underestimate the variance of I, thus leading to
a deviance that is systematically larger than the fit dof, i.e.,
number of data points minus number of parameters.
Averaging the powers of l different power spectrum is valid, and the
distribution of the average power I_j is
Gamma(alpha=l*v_j/2, scale=2S/(l*v_j)).
"""
def __init__(self, dof, s):
half_dof = 0.5 * dof
super().__init__(concentration=half_dof, rate=half_dof / s)
@validate_sample
def log_prob(self, value):
n = self.concentration
rate = self.rate
gof = n * (jnp.log(n / value) - 1.0)
return n * jnp.log(rate) - rate * value - gof
[docs]
class Fit:
def __init__(self, psd: PSD, model: Model):
if not isinstance(psd, PSD):
raise ValueError('psd must be an instance of PSD')
if isinstance(model, Component):
model = Model(model)
elif not isinstance(model, Model):
raise ValueError('model must be an instance of Model or Component')
self.psd = psd
self.model = model
self._loss = None
self._transform = None
self._ndata = len(psd.freq)
self._numpyro_model = None
@property
def ndata(self):
return self._ndata
@property
def nparam(self):
return len(self.params_names)
@property
def numpyro_model(self):
if self._numpyro_model is None:
def _():
params = {
p: numpyro.sample(p, d)
for p, d in self.model.prior.items()
}
power_model = jax.jit(self.model.power)(
params, self.psd.freq_bins
)
numpyro.deterministic('S', power_model)
pdist = PowerDist(self.psd.dof, power_model)
I_data = numpyro.primitives.mutable('I_data', self.psd.power)
with numpyro.plate('freq', len(self.psd.freq)):
numpyro.sample(name='I_obs', fn=pdist, obs=I_data)
numpyro.deterministic(
name='loglike', value=pdist.log_prob(I_data)
)
self._numpyro_model = _
return self._numpyro_model
[docs]
def run_nuts(
self,
num_warmup=2000,
num_samples=2000,
num_chains=4,
init=None,
chain_method='parallel',
progress=True,
seed=42,
):
if init is not None:
init = self.model.default | dict(init)
else:
init = self.model.default
mcmc = numpyro.infer.MCMC(
numpyro.infer.NUTS(
model=self.numpyro_model,
init_strategy=numpyro.infer.util.init_to_value(values=init),
),
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
chain_method=chain_method,
progress_bar=progress,
)
mcmc.run(
jax.random.PRNGKey(seed), extra_fields=('energy', 'num_steps')
)
return az.from_numpyro(mcmc)
@property
def loss(self):
if self._loss is None:
params_names = self.params_names
def get_loglike(unconstr_arr):
sites = constrain_fn(
model=self.numpyro_model,
model_args=(),
model_kwargs={},
params=dict(zip(params_names, unconstr_arr, strict=False)),
return_deterministic=True,
)
return sites['loglike']
residual = jax.jit(lambda x: jnp.sqrt(-2.0 * get_loglike(x)))
deviance = jax.jit(lambda x: jnp.sum(-2.0 * get_loglike(x)))
self._loss = {'residual': residual, 'deviance': deviance}
return self._loss
@property
def params_names(self):
return list(self.model.prior.keys())
@property
def transform(self):
from numpyro.distributions.transforms import biject_to
if self._transform is None:
self._transform = {
k: biject_to(v.support) for k, v in self.model.prior.items()
}
return self._transform
[docs]
def mle_lm(self, init=None):
if init is not None:
init = self.model.default | dict(init)
else:
init = self.model.default
lm_solver = optx.LevenbergMarquardt(rtol=0.0, atol=1e-8)
residual = jax.jit(lambda x, aux: self.loss['residual'](x))
def lm(init):
res = optx.least_squares(
fn=residual,
solver=lm_solver,
y0=init,
max_steps=4096,
throw=True,
)
grad_norm = jnp.linalg.norm(res.state.f_info.compute_grad())
deviance = jnp.square(res.state.f_info.residual).sum()
return res.value, deviance, grad_norm
_lm = jax.jit(lm)
t = self.transform
popt, f, g = _lm(
jnp.array([t[k].inv(init[k]) for k in self.params_names])
)
params = {
k: t[k](v)
for k, v in dict(
zip(self.params_names, popt, strict=False)
).items()
}
return params, f, g
[docs]
def mle(self, init=None):
if init is not None:
init = self.model.default | dict(init)
else:
init = self.model.default
@jax.jit
def deviance(x, _):
return self.loss['deviance'](x)
bfgs = optx.BFGS(rtol=0.0, atol=1e-8)
def mle(init):
res = optx.minimise(
fn=deviance,
solver=bfgs,
y0=init,
max_steps=4096,
throw=True,
)
grad_norm = jnp.linalg.norm(res.state.f_info.grad)
return res.value, res.state.f_info.f, grad_norm
_mle = jax.jit(mle)
t = self.transform
popt, f, g = _mle(
jnp.array([t[k].inv(init[k]) for k in self.params_names])
)
return (
{
k: t[k](v)
for k, v in dict(
zip(self.params_names, popt, strict=False)
).items()
},
f,
g,
)
[docs]
def simulate(self, params, n=1, seed=42):
n = int(n)
params = dict(params)
params = np.array([params[k] for k in self.params_names], float)
if params.ndim == 2 and n != 1:
raise ValueError('params must be 1D if n > 1')
params = dict(zip(self.params_names, params, strict=False))
rng = np.random.default_rng(seed)
power = self.model.power(params, self.psd.freq_bins)
dof = self.psd.dof
sample_shape = (n,) + power.shape if n != 1 else power.shape
sim_data = power * rng.chisquare(dof, size=sample_shape) / dof
return sim_data
[docs]
def batch_fit(
self,
data,
init=None,
parallel: bool = True,
n_parallel: int | None = None,
progress: bool = True,
update_rate: int = 50,
run_str: str = 'Fitting',
seed=42,
):
rng_key = jax.random.PRNGKey(seed)
if init is None:
rng_key = jax.random.split(rng_key, num=self.nparam)
rng_key = dict(zip(self.params_names, rng_key, strict=False))
init = {
k: v.sample(rng_key[k], (len(data),))
for k, v in self.model.prior.items()
}
else:
init = dict(init)
assert set(self.params_names).issubset(init), (
'init must contain all params'
)
return init_bacth_fit(self)(
init, data, parallel, n_parallel, progress, update_rate, run_str
)
[docs]
def init_bacth_fit(fit):
lm_solver = optx.LevenbergMarquardt(rtol=0.0, atol=1e-6)
numpyro_model = fit.numpyro_model
params_names = fit.params_names
transform = fit.transform
def get_sites_(unconstr_arr):
sites = constrain_fn(
model=numpyro_model,
model_args=(),
model_kwargs={},
params=dict(zip(params_names, unconstr_arr, strict=False)),
return_deterministic=True,
)
params = {k: sites[k] for k in params_names}
models = sites['S']
loglike = sites['loglike']
return {'params': params, 'models': models, 'loglike': loglike}
@jax.jit
def fit_once(i: int, args: tuple) -> tuple:
"""Loop core, fit simulation data once."""
result, init = args
# substitute observation data with simulation data
new_data = {'I_data': result['data'][i]}
get_sites = jax.jit(handlers.substitute(fn=get_sites_, data=new_data))
residual = lambda p: -2.0 * get_sites(p)['loglike']
# fit simulation data
res = optx.least_squares(
fn=lambda p, _: residual(p),
solver=lm_solver,
y0=init[i],
max_steps=1024,
throw=False,
)
fitted_params = res.value
grad_norm = jnp.linalg.norm(res.state.f_info.compute_grad())
sites = get_sites(fitted_params)
# update best fit params to result
result['params'] = jax.tree.map(
lambda x, y: x.at[i].set(y),
result['params'],
sites['params'],
)
# update the best fit model to result
result['models'] = result['models'].at[i].set(sites['models'])
# update the deviance information to result
dev = {
'total': -2.0 * sites['loglike'].sum(),
'point': -2.0 * sites['loglike'],
}
res_dev = result['deviance']
res_dev['total'] = res_dev['total'].at[i].set(dev['total'])
res_dev['point'] = res_dev['point'].at[i].set(dev['point'])
valid = jnp.bitwise_not(
jnp.isnan(dev['total'])
| jnp.isnan(grad_norm)
| jnp.greater(grad_norm, 1e-3)
)
result['valid'] = result['valid'].at[i].set(valid)
return result, init
def sequence_fit(
result: dict,
init: NDArray,
run_str: str,
progress: bool,
update_rate: int,
):
"""Fit simulation data in sequence."""
from elisa.util.misc import progress_bar_factory
from jax import lax
n = len(result['valid'])
if progress:
pbar_factory = progress_bar_factory(
n, 1, run_str=run_str, update_rate=update_rate
)
fn = pbar_factory(fit_once)
else:
fn = fit_once
fit_jit = jax.jit(lambda *args: lax.fori_loop(0, n, fn, args)[0])
result = fit_jit(result, init)
return result
def parallel_fit(
result: dict,
init: NDArray,
run_str: str,
progress: bool,
update_rate: int,
n_parallel: int,
) -> dict:
"""Fit simulation data in parallel."""
from elisa.util.misc import progress_bar_factory
from jax import lax
n = len(result['valid'])
n_parallel = int(n_parallel)
batch = n // n_parallel
if progress:
pbar_factory = progress_bar_factory(
n, n_parallel, run_str=run_str, update_rate=update_rate
)
fn = pbar_factory(fit_once)
else:
fn = fit_once
fit_pmap = jax.pmap(lambda *args: lax.fori_loop(0, batch, fn, args)[0])
reshape = lambda x: x.reshape((n_parallel, -1) + x.shape[1:])
result = fit_pmap(
jax.tree.map(reshape, result),
reshape(init),
)
return jax.tree.map(jnp.concatenate, result)
def run(
init_params: dict,
data: NDArray,
parallel: bool = True,
n_parallel: int | None = None,
progress: bool = True,
update_rate: int = 50,
run_str: str = 'Fitting',
) -> dict:
"""Simulate data and then fit the simulation data.
Parameters
----------
init_params : dict
The initial parameters values in unconstrained space.
data : dict
The model values corresponding to `free_params`.
parallel : bool, optional
Whether to fit in parallel, by default True.
n_parallel : int, optional
The number of parallel processes when `parallel` is ``True``.
Defaults to ``jax.local_device_count()``.
progress : bool, optional
Whether to show progress bar, by default True.
update_rate : int, optional
The update rate of the progress bar, by default 50.
run_str : str, optional
The string to ahead progress bar during the run when `progress` is
True. The default is 'Fitting'.
Returns
-------
result : dict
The simulation and fitting result.
"""
from elisa.util.config import get_parallel_number
init_params = jax.tree.map(jnp.array, init_params)
n = len(data)
n_parallel = get_parallel_number(n_parallel)
if n % n_parallel != 0:
raise ValueError(
f'n ({n}) must be a multiple of n_parallel ({n_parallel})'
)
assert set(init_params) == set(params_names)
assert n > 0
# check if all params shapes are the same
shapes = list(jax.tree.map(jnp.shape, init_params).values())
assert all(i == shapes[0] for i in shapes)
# get initial parameters arrays in unconstrained space,
t = transform
init = jnp.array([t[k].inv(init_params[k]) for k in params_names]).T
assert init.ndim <= 2
if init.ndim == 2:
assert init.shape[0] == n
if init.ndim == 1:
init = jnp.full((n, len(init)), init)
# fit result container
result = {
'data': data,
'params': {k: jnp.empty(n) for k in params_names},
'models': jnp.empty((n, fit.ndata)),
'deviance': {
'total': jnp.empty(n),
'point': jnp.empty((n, fit.ndata)),
},
'valid': jnp.full(n, True, bool),
}
# fit simulation data
if parallel:
res = parallel_fit(
result,
init,
run_str,
progress,
update_rate,
n_parallel,
)
else:
res = sequence_fit(result, init, run_str, progress, update_rate)
return res
return run