|
| 1 | +"""EM algorithm for variance estimation in linear-Gaussian SSMs.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from dataclasses import dataclass |
| 6 | + |
| 7 | +import jax.numpy as jnp |
| 8 | +from jax import Array |
| 9 | + |
| 10 | +from dynaris.core.results import SmootherResult |
| 11 | +from dynaris.core.state_space import StateSpaceModel |
| 12 | +from dynaris.filters.kalman import kalman_filter |
| 13 | +from dynaris.smoothers.rts import rts_smooth |
| 14 | + |
| 15 | + |
| 16 | +@dataclass(frozen=True) |
| 17 | +class EMResult: |
| 18 | + """Result of EM estimation. |
| 19 | +
|
| 20 | + Attributes: |
| 21 | + model: Fitted StateSpaceModel at convergence. |
| 22 | + log_likelihood: Final log-likelihood value. |
| 23 | + n_iterations: Number of EM iterations performed. |
| 24 | + converged: Whether the algorithm converged. |
| 25 | + log_likelihood_history: Log-likelihood at each iteration. |
| 26 | + """ |
| 27 | + |
| 28 | + model: StateSpaceModel |
| 29 | + log_likelihood: float |
| 30 | + n_iterations: int |
| 31 | + converged: bool |
| 32 | + log_likelihood_history: list[float] |
| 33 | + |
| 34 | + |
| 35 | +def _e_step( |
| 36 | + model: StateSpaceModel, observations: Array |
| 37 | +) -> tuple[SmootherResult, float]: |
| 38 | + """E-step: run Kalman filter + RTS smoother.""" |
| 39 | + fr = kalman_filter(model, observations) |
| 40 | + sr = rts_smooth(model, fr) |
| 41 | + return sr, float(fr.log_likelihood) |
| 42 | + |
| 43 | + |
| 44 | +def _m_step( |
| 45 | + sr: SmootherResult, model: StateSpaceModel |
| 46 | +) -> StateSpaceModel: |
| 47 | + """M-step: update Q and R from smoothed sufficient statistics. |
| 48 | +
|
| 49 | + For a general linear-Gaussian SSM: |
| 50 | + Q_new = (1/T) * sum_t [P_{t|T} + x_{t|T} x_{t|T}^T |
| 51 | + - (P_{t,t-1|T} + x_{t|T} x_{t-1|T}^T) F^T |
| 52 | + - F (P_{t,t-1|T} + x_{t|T} x_{t-1|T}^T)^T |
| 53 | + + F (P_{t-1|T} + x_{t-1|T} x_{t-1|T}^T) F^T] |
| 54 | +
|
| 55 | + R_new = (1/T) * sum_t [(y_t - H x_{t|T})(y_t - H x_{t|T})^T |
| 56 | + + H P_{t|T} H^T] |
| 57 | +
|
| 58 | + We use a simplified version that directly estimates the |
| 59 | + diagonal variances, which is standard for DLM applications. |
| 60 | + """ |
| 61 | + obs = sr.observations # (T, m) |
| 62 | + x_smooth = sr.smoothed_states # (T, n) |
| 63 | + p_smooth = sr.smoothed_covariances # (T, n, n) |
| 64 | + n_time = obs.shape[0] |
| 65 | + |
| 66 | + # --- Estimate R (observation noise covariance) --- |
| 67 | + # residual_t = y_t - H @ x_{t|T} |
| 68 | + residuals = obs - (x_smooth @ model.H.T) # (T, m) |
| 69 | + # R = (1/T) * sum_t [r_t r_t^T + H P_{t|T} H^T] |
| 70 | + outer_sum = jnp.einsum("ti,tj->ij", residuals, residuals) # (m, m) |
| 71 | + hp_ht_sum = jnp.sum(model.H @ p_smooth @ model.H.T, axis=0) # sum over T -> (m, m) |
| 72 | + new_r = (outer_sum + hp_ht_sum) / n_time |
| 73 | + |
| 74 | + # --- Estimate Q (state noise covariance) --- |
| 75 | + # Using: Q = (1/T) sum_t [P_{t|T} + (x_t - F x_{t-1})(x_t - F x_{t-1})^T |
| 76 | + # - F P_{t-1,t|T}^T - P_{t-1,t|T} F^T + F P_{t-1|T} F^T] |
| 77 | + # Simplified: approximate cross-covariance P_{t,t-1|T} via smoother gain |
| 78 | + # For practical DLM usage, we use: |
| 79 | + # state_resid_t = x_{t|T} - F @ x_{t-1|T} |
| 80 | + # Q ~ (1/(T-1)) sum_t [state_resid_t state_resid_t^T + P_{t|T} + F P_{t-1|T} F^T] |
| 81 | + # But a cleaner standard approach for the diagonal case: |
| 82 | + x_pred = (x_smooth[:-1] @ model.F.T) # F @ x_{t-1|T}, shape (T-1, n) |
| 83 | + state_resids = x_smooth[1:] - x_pred # (T-1, n) |
| 84 | + outer_q = jnp.einsum("ti,tj->ij", state_resids, state_resids) # (n, n) |
| 85 | + # Add smoothed covariance terms |
| 86 | + p_curr = jnp.sum(p_smooth[1:], axis=0) # sum P_{t|T} for t=1..T-1 |
| 87 | + fp_ft = jnp.sum( |
| 88 | + model.F @ p_smooth[:-1] @ model.F.T, axis=0 |
| 89 | + ) # sum F P_{t-1|T} F^T |
| 90 | + new_q = (outer_q + p_curr + fp_ft) / (n_time - 1) |
| 91 | + # Ensure symmetry |
| 92 | + new_q = (new_q + new_q.T) / 2.0 |
| 93 | + new_r = (new_r + new_r.T) / 2.0 |
| 94 | + |
| 95 | + return StateSpaceModel( |
| 96 | + transition_matrix=model.transition_matrix, |
| 97 | + observation_matrix=model.observation_matrix, |
| 98 | + state_noise_cov=new_q, |
| 99 | + obs_noise_cov=new_r, |
| 100 | + input_matrix=model.input_matrix, |
| 101 | + ) |
| 102 | + |
| 103 | + |
| 104 | +def fit_em( |
| 105 | + observations: Array, |
| 106 | + initial_model: StateSpaceModel, |
| 107 | + max_iter: int = 100, |
| 108 | + tol: float = 1e-6, |
| 109 | +) -> EMResult: |
| 110 | + """Fit a state-space model via the EM algorithm. |
| 111 | +
|
| 112 | + Iteratively updates Q (state noise) and R (observation noise) |
| 113 | + covariance matrices while keeping F, H, and B fixed. |
| 114 | +
|
| 115 | + Args: |
| 116 | + observations: Observation sequence, shape (T, obs_dim). |
| 117 | + initial_model: Starting model with initial variance guesses. |
| 118 | + max_iter: Maximum number of EM iterations. |
| 119 | + tol: Convergence tolerance on log-likelihood change. |
| 120 | +
|
| 121 | + Returns: |
| 122 | + EMResult with the fitted model and convergence details. |
| 123 | + """ |
| 124 | + observations = jnp.asarray(observations) |
| 125 | + model = initial_model |
| 126 | + ll_history: list[float] = [] |
| 127 | + converged = False |
| 128 | + |
| 129 | + for i in range(max_iter): |
| 130 | + sr, ll = _e_step(model, observations) |
| 131 | + ll_history.append(ll) |
| 132 | + |
| 133 | + if i > 0 and abs(ll - ll_history[-2]) < tol: |
| 134 | + converged = True |
| 135 | + break |
| 136 | + |
| 137 | + model = _m_step(sr, model) |
| 138 | + |
| 139 | + return EMResult( |
| 140 | + model=model, |
| 141 | + log_likelihood=ll_history[-1] if ll_history else float("-inf"), |
| 142 | + n_iterations=len(ll_history), |
| 143 | + converged=converged, |
| 144 | + log_likelihood_history=ll_history, |
| 145 | + ) |
0 commit comments