Skip to content

Commit af63f23

Browse files
committed
feat: add parameter estimation adn tests
1 parent 8639781 commit af63f23

11 files changed

Lines changed: 844 additions & 0 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ dependencies = [
2323
"jaxlib>=0.4.35",
2424
"numpy>=1.26",
2525
"pandas>=2.2",
26+
"scipy>=1.12",
2627
]
2728

2829
[project.urls]
@@ -78,6 +79,7 @@ mypy_path = "src"
7879
module = [
7980
"jax.*",
8081
"jaxlib.*",
82+
"scipy.*",
8183
]
8284
ignore_missing_imports = true
8385

src/dynaris/estimation/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Parameter estimation: MLE, EM, and model diagnostics."""
2+
3+
from dynaris.estimation.diagnostics import acf, ljung_box, pacf, standardized_residuals
4+
from dynaris.estimation.em import EMResult, fit_em
5+
from dynaris.estimation.mle import MLEResult, fit_mle
6+
from dynaris.estimation.transforms import inverse_softplus, softplus
7+
8+
__all__ = [
9+
"EMResult",
10+
"MLEResult",
11+
"acf",
12+
"fit_em",
13+
"fit_mle",
14+
"inverse_softplus",
15+
"ljung_box",
16+
"pacf",
17+
"softplus",
18+
"standardized_residuals",
19+
]
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""Model diagnostics for fitted state-space models."""
2+
3+
from __future__ import annotations
4+
5+
import jax.numpy as jnp
6+
from jax import Array
7+
from scipy import stats
8+
9+
from dynaris.core.results import FilterResult
10+
from dynaris.core.state_space import StateSpaceModel
11+
12+
13+
def standardized_residuals(
14+
filter_result: FilterResult,
15+
model: StateSpaceModel,
16+
) -> Array:
17+
"""Compute standardized (one-step-ahead) prediction residuals.
18+
19+
e_t = (y_t - H @ x_{t|t-1}) / sqrt(H @ P_{t|t-1} @ H^T + R)
20+
21+
Returns:
22+
Standardized residuals, shape (T,) for univariate or (T, obs_dim).
23+
"""
24+
obs = filter_result.observations # (T, m)
25+
pred_states = filter_result.predicted_states # (T, n)
26+
pred_covs = filter_result.predicted_covariances # (T, n, n)
27+
28+
# Innovation: y_t - H @ x_{t|t-1}
29+
innovations = obs - pred_states @ model.H.T # (T, m)
30+
31+
# Innovation covariance: H @ P_{t|t-1} @ H^T + R
32+
# Shape: (T, m, m)
33+
innovation_covs = jnp.einsum(
34+
"ij,tjk,lk->til", model.H, pred_covs, model.H
35+
) + model.R[None, :, :]
36+
37+
# For univariate case, standardize directly
38+
# For multivariate, use diagonal elements
39+
std_devs = jnp.sqrt(
40+
jnp.diagonal(innovation_covs, axis1=-2, axis2=-1)
41+
) # (T, m)
42+
43+
std_resids = innovations / std_devs
44+
45+
# Squeeze if univariate
46+
if std_resids.shape[-1] == 1:
47+
return std_resids[:, 0]
48+
return std_resids
49+
50+
51+
def acf(x: Array, n_lags: int = 20) -> Array:
52+
"""Compute the sample autocorrelation function.
53+
54+
Args:
55+
x: 1D array of residuals, shape (T,).
56+
n_lags: Number of lags to compute.
57+
58+
Returns:
59+
Autocorrelations at lags 0, 1, ..., n_lags. Shape (n_lags + 1,).
60+
"""
61+
x = jnp.asarray(x).ravel()
62+
n = x.shape[0]
63+
x_centered = x - jnp.mean(x)
64+
var = jnp.sum(x_centered**2) / n
65+
66+
lags = jnp.arange(n_lags + 1)
67+
68+
def _acf_at_lag(lag: Array) -> Array:
69+
# For lag 0, return 1.0
70+
shifted = jnp.roll(x_centered, lag)
71+
# Zero out the rolled-in values
72+
mask = jnp.arange(n) >= lag
73+
cov = jnp.sum(x_centered * shifted * mask) / n
74+
return jnp.where(lag == 0, 1.0, cov / var)
75+
76+
return jnp.vectorize(_acf_at_lag)(lags) # type: ignore[no-any-return]
77+
78+
79+
def pacf(x: Array, n_lags: int = 20) -> Array:
80+
"""Compute the sample partial autocorrelation function via Durbin-Levinson.
81+
82+
Args:
83+
x: 1D array of residuals, shape (T,).
84+
n_lags: Number of lags to compute.
85+
86+
Returns:
87+
Partial autocorrelations at lags 0, 1, ..., n_lags.
88+
Shape (n_lags + 1,). PACF at lag 0 is 1.0.
89+
"""
90+
acf_vals = acf(x, n_lags)
91+
92+
result = [1.0] # lag 0
93+
94+
# Durbin-Levinson algorithm
95+
phi = float(acf_vals[1])
96+
result.append(phi)
97+
98+
phi_prev = [phi]
99+
100+
for k in range(2, n_lags + 1):
101+
# phi_k,k = (r(k) - sum_{j=1}^{k-1} phi_{k-1,j} * r(k-j))
102+
# / (1 - sum_{j=1}^{k-1} phi_{k-1,j} * r(j))
103+
numer = float(acf_vals[k])
104+
denom = 1.0
105+
for j in range(len(phi_prev)):
106+
numer -= phi_prev[j] * float(acf_vals[k - j - 1])
107+
denom -= phi_prev[j] * float(acf_vals[j + 1])
108+
109+
if abs(denom) < 1e-12:
110+
result.append(0.0)
111+
phi_prev = [0.0] * k
112+
continue
113+
114+
phi_kk = numer / denom
115+
result.append(phi_kk)
116+
117+
# Update phi coefficients
118+
new_phi = []
119+
for j in range(len(phi_prev)):
120+
new_phi.append(phi_prev[j] - phi_kk * phi_prev[-(j + 1)])
121+
new_phi.append(phi_kk)
122+
phi_prev = new_phi
123+
124+
return jnp.array(result)
125+
126+
127+
def ljung_box(
128+
residuals: Array, n_lags: int = 10
129+
) -> tuple[float, float]:
130+
"""Ljung-Box test for autocorrelation in residuals.
131+
132+
Tests H0: the residuals are independently distributed (no autocorrelation).
133+
134+
Args:
135+
residuals: 1D array of (standardized) residuals, shape (T,).
136+
n_lags: Number of lags to include in the test.
137+
138+
Returns:
139+
Tuple of (test_statistic, p_value).
140+
"""
141+
residuals = jnp.asarray(residuals).ravel()
142+
n = residuals.shape[0]
143+
acf_vals = acf(residuals, n_lags)
144+
145+
# Q = n(n+2) * sum_{k=1}^{h} r_k^2 / (n-k)
146+
q_stat = 0.0
147+
for k in range(1, n_lags + 1):
148+
rk = float(acf_vals[k])
149+
q_stat += rk**2 / (n - k)
150+
q_stat *= float(n * (n + 2))
151+
152+
# Under H0, Q ~ chi-squared(n_lags)
153+
p_value = float(1.0 - stats.chi2.cdf(q_stat, df=n_lags))
154+
155+
return q_stat, p_value

src/dynaris/estimation/em.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""EM algorithm for variance estimation in linear-Gaussian SSMs."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
7+
import jax.numpy as jnp
8+
from jax import Array
9+
10+
from dynaris.core.results import SmootherResult
11+
from dynaris.core.state_space import StateSpaceModel
12+
from dynaris.filters.kalman import kalman_filter
13+
from dynaris.smoothers.rts import rts_smooth
14+
15+
16+
@dataclass(frozen=True)
17+
class EMResult:
18+
"""Result of EM estimation.
19+
20+
Attributes:
21+
model: Fitted StateSpaceModel at convergence.
22+
log_likelihood: Final log-likelihood value.
23+
n_iterations: Number of EM iterations performed.
24+
converged: Whether the algorithm converged.
25+
log_likelihood_history: Log-likelihood at each iteration.
26+
"""
27+
28+
model: StateSpaceModel
29+
log_likelihood: float
30+
n_iterations: int
31+
converged: bool
32+
log_likelihood_history: list[float]
33+
34+
35+
def _e_step(
36+
model: StateSpaceModel, observations: Array
37+
) -> tuple[SmootherResult, float]:
38+
"""E-step: run Kalman filter + RTS smoother."""
39+
fr = kalman_filter(model, observations)
40+
sr = rts_smooth(model, fr)
41+
return sr, float(fr.log_likelihood)
42+
43+
44+
def _m_step(
45+
sr: SmootherResult, model: StateSpaceModel
46+
) -> StateSpaceModel:
47+
"""M-step: update Q and R from smoothed sufficient statistics.
48+
49+
For a general linear-Gaussian SSM:
50+
Q_new = (1/T) * sum_t [P_{t|T} + x_{t|T} x_{t|T}^T
51+
- (P_{t,t-1|T} + x_{t|T} x_{t-1|T}^T) F^T
52+
- F (P_{t,t-1|T} + x_{t|T} x_{t-1|T}^T)^T
53+
+ F (P_{t-1|T} + x_{t-1|T} x_{t-1|T}^T) F^T]
54+
55+
R_new = (1/T) * sum_t [(y_t - H x_{t|T})(y_t - H x_{t|T})^T
56+
+ H P_{t|T} H^T]
57+
58+
We use a simplified version that directly estimates the
59+
diagonal variances, which is standard for DLM applications.
60+
"""
61+
obs = sr.observations # (T, m)
62+
x_smooth = sr.smoothed_states # (T, n)
63+
p_smooth = sr.smoothed_covariances # (T, n, n)
64+
n_time = obs.shape[0]
65+
66+
# --- Estimate R (observation noise covariance) ---
67+
# residual_t = y_t - H @ x_{t|T}
68+
residuals = obs - (x_smooth @ model.H.T) # (T, m)
69+
# R = (1/T) * sum_t [r_t r_t^T + H P_{t|T} H^T]
70+
outer_sum = jnp.einsum("ti,tj->ij", residuals, residuals) # (m, m)
71+
hp_ht_sum = jnp.sum(model.H @ p_smooth @ model.H.T, axis=0) # sum over T -> (m, m)
72+
new_r = (outer_sum + hp_ht_sum) / n_time
73+
74+
# --- Estimate Q (state noise covariance) ---
75+
# Using: Q = (1/T) sum_t [P_{t|T} + (x_t - F x_{t-1})(x_t - F x_{t-1})^T
76+
# - F P_{t-1,t|T}^T - P_{t-1,t|T} F^T + F P_{t-1|T} F^T]
77+
# Simplified: approximate cross-covariance P_{t,t-1|T} via smoother gain
78+
# For practical DLM usage, we use:
79+
# state_resid_t = x_{t|T} - F @ x_{t-1|T}
80+
# Q ~ (1/(T-1)) sum_t [state_resid_t state_resid_t^T + P_{t|T} + F P_{t-1|T} F^T]
81+
# But a cleaner standard approach for the diagonal case:
82+
x_pred = (x_smooth[:-1] @ model.F.T) # F @ x_{t-1|T}, shape (T-1, n)
83+
state_resids = x_smooth[1:] - x_pred # (T-1, n)
84+
outer_q = jnp.einsum("ti,tj->ij", state_resids, state_resids) # (n, n)
85+
# Add smoothed covariance terms
86+
p_curr = jnp.sum(p_smooth[1:], axis=0) # sum P_{t|T} for t=1..T-1
87+
fp_ft = jnp.sum(
88+
model.F @ p_smooth[:-1] @ model.F.T, axis=0
89+
) # sum F P_{t-1|T} F^T
90+
new_q = (outer_q + p_curr + fp_ft) / (n_time - 1)
91+
# Ensure symmetry
92+
new_q = (new_q + new_q.T) / 2.0
93+
new_r = (new_r + new_r.T) / 2.0
94+
95+
return StateSpaceModel(
96+
transition_matrix=model.transition_matrix,
97+
observation_matrix=model.observation_matrix,
98+
state_noise_cov=new_q,
99+
obs_noise_cov=new_r,
100+
input_matrix=model.input_matrix,
101+
)
102+
103+
104+
def fit_em(
105+
observations: Array,
106+
initial_model: StateSpaceModel,
107+
max_iter: int = 100,
108+
tol: float = 1e-6,
109+
) -> EMResult:
110+
"""Fit a state-space model via the EM algorithm.
111+
112+
Iteratively updates Q (state noise) and R (observation noise)
113+
covariance matrices while keeping F, H, and B fixed.
114+
115+
Args:
116+
observations: Observation sequence, shape (T, obs_dim).
117+
initial_model: Starting model with initial variance guesses.
118+
max_iter: Maximum number of EM iterations.
119+
tol: Convergence tolerance on log-likelihood change.
120+
121+
Returns:
122+
EMResult with the fitted model and convergence details.
123+
"""
124+
observations = jnp.asarray(observations)
125+
model = initial_model
126+
ll_history: list[float] = []
127+
converged = False
128+
129+
for i in range(max_iter):
130+
sr, ll = _e_step(model, observations)
131+
ll_history.append(ll)
132+
133+
if i > 0 and abs(ll - ll_history[-2]) < tol:
134+
converged = True
135+
break
136+
137+
model = _m_step(sr, model)
138+
139+
return EMResult(
140+
model=model,
141+
log_likelihood=ll_history[-1] if ll_history else float("-inf"),
142+
n_iterations=len(ll_history),
143+
converged=converged,
144+
log_likelihood_history=ll_history,
145+
)

0 commit comments

Comments
 (0)