diff --git a/src/dynaris/__init__.py b/src/dynaris/__init__.py index cc67eba..de930e1 100644 --- a/src/dynaris/__init__.py +++ b/src/dynaris/__init__.py @@ -18,7 +18,14 @@ Regression, Seasonal, ) -from dynaris.filters import ExtendedKalmanFilter, KalmanFilter, ekf_filter, kalman_filter +from dynaris.filters import ( + ExtendedKalmanFilter, + KalmanFilter, + UnscentedKalmanFilter, + ekf_filter, + kalman_filter, + ukf_filter, +) from dynaris.smoothers import RTSSmoother, rts_smooth __version__ = "0.1.0" @@ -41,8 +48,10 @@ "SmootherProtocol", "SmootherResult", "StateSpaceModel", + "UnscentedKalmanFilter", "__version__", "ekf_filter", "kalman_filter", "rts_smooth", + "ukf_filter", ] diff --git a/src/dynaris/filters/__init__.py b/src/dynaris/filters/__init__.py index e0be3e9..e078d12 100644 --- a/src/dynaris/filters/__init__.py +++ b/src/dynaris/filters/__init__.py @@ -2,10 +2,13 @@ from dynaris.filters.ekf import ExtendedKalmanFilter, ekf_filter from dynaris.filters.kalman import KalmanFilter, kalman_filter +from dynaris.filters.ukf import UnscentedKalmanFilter, ukf_filter __all__ = [ "ExtendedKalmanFilter", "KalmanFilter", + "UnscentedKalmanFilter", "ekf_filter", "kalman_filter", + "ukf_filter", ] diff --git a/src/dynaris/filters/ukf.py b/src/dynaris/filters/ukf.py new file mode 100644 index 0000000..1cb9aca --- /dev/null +++ b/src/dynaris/filters/ukf.py @@ -0,0 +1,355 @@ +"""Unscented Kalman Filter for nonlinear state-space models. + +Propagates sigma points through nonlinear transition and observation functions +to capture the posterior mean and covariance without linearization. Uses the +scaled unscented transform with configurable alpha, beta, kappa parameters. + +References: + Julier, S.J. and Uhlmann, J.K. (2004). "Unscented Filtering and + Nonlinear Estimation." Proceedings of the IEEE, 92(3), 401-422. +""" + +from __future__ import annotations + +from typing import NamedTuple + +import jax +import jax.numpy as jnp +from jax import Array + +from dynaris.core.nonlinear import NonlinearSSM +from dynaris.core.results import FilterResult +from dynaris.core.types import GaussianState + +# --------------------------------------------------------------------------- +# Sigma-point weights +# --------------------------------------------------------------------------- + + +class SigmaWeights(NamedTuple): + """Weights for the unscented transform.""" + + wm: Array # mean weights, shape (2n+1,) + wc: Array # covariance weights, shape (2n+1,) + lam: Array # scaling parameter lambda (scalar array) + + +def compute_weights( + n: int, + alpha: float = 1e-3, + beta: float = 2.0, + kappa: float = 0.0, +) -> SigmaWeights: + """Compute sigma-point weights for the scaled unscented transform. + + Args: + n: State dimension. + alpha: Spread of sigma points around the mean (typically 1e-4 to 1). + beta: Prior knowledge of distribution (2.0 is optimal for Gaussian). + kappa: Secondary scaling parameter (typically 0 or 3-n). + + Returns: + SigmaWeights with mean weights, covariance weights, and lambda. + """ + lam = alpha**2 * (n + kappa) - n + + wm = jnp.full(2 * n + 1, 1.0 / (2.0 * (n + lam))) + wm = wm.at[0].set(lam / (n + lam)) + + wc = jnp.full(2 * n + 1, 1.0 / (2.0 * (n + lam))) + wc = wc.at[0].set(lam / (n + lam) + (1.0 - alpha**2 + beta)) + + return SigmaWeights(wm=wm, wc=wc, lam=jnp.array(lam)) + + +# --------------------------------------------------------------------------- +# Sigma-point generation +# --------------------------------------------------------------------------- + + +def sigma_points(state: GaussianState, lam: Array) -> Array: + """Generate 2n+1 sigma points from a Gaussian state. + + Args: + state: Gaussian belief with mean (n,) and covariance (n, n). + lam: Scaling parameter lambda (scalar). + + Returns: + Sigma points, shape (2n+1, n). + """ + n = state.mean.shape[0] + scaled_cov = (n + lam) * state.cov + L = jnp.linalg.cholesky(scaled_cov) # (n, n) + + # Build sigma points: [mean, mean + L_i, mean - L_i] + offsets = jnp.concatenate([ + jnp.zeros((1, n)), + L, # rows of L as positive offsets + -L, # rows of L as negative offsets + ], axis=0) # (2n+1, n) + + return state.mean[None, :] + offsets + + +# --------------------------------------------------------------------------- +# Internal scan carry +# --------------------------------------------------------------------------- + + +class _ScanCarry(NamedTuple): + filtered: GaussianState + log_likelihood: Array # scalar + + +class _ScanOutput(NamedTuple): + predicted_mean: Array + predicted_cov: Array + filtered_mean: Array + filtered_cov: Array + + +# --------------------------------------------------------------------------- +# Pure-function predict and update steps +# --------------------------------------------------------------------------- + + +def predict( + state: GaussianState, + model: NonlinearSSM, + weights: SigmaWeights, +) -> GaussianState: + """UKF predict step (time update). + + Generates sigma points, propagates them through the transition function, + and recovers the predicted mean and covariance. + """ + pts = sigma_points(state, weights.lam) # (2n+1, n) + + # Propagate sigma points through transition function + pts_pred = jax.vmap(model.f)(pts) # (2n+1, n) + + # Recover predicted mean + mean = jnp.sum(weights.wm[:, None] * pts_pred, axis=0) # (n,) + + # Recover predicted covariance + diff = pts_pred - mean[None, :] # (2n+1, n) + cov = jnp.sum(weights.wc[:, None, None] * (diff[:, :, None] * diff[:, None, :]), axis=0) + cov = cov + model.Q + + return GaussianState(mean=mean, cov=cov) + + +def update( + predicted: GaussianState, + observation: Array, + model: NonlinearSSM, + weights: SigmaWeights, +) -> tuple[GaussianState, Array]: + """UKF update step (measurement update). + + Generates sigma points from the predicted state, propagates through the + observation function, and computes the Kalman gain. + + Returns the filtered state and the log-likelihood contribution. + Handles missing observations (NaN) by skipping the update. + """ + y = observation + pts = sigma_points(predicted, weights.lam) # (2n+1, n) + + # Propagate through observation function + pts_obs = jax.vmap(model.h)(pts) # (2n+1, m) + + # Predicted observation mean + y_pred = jnp.sum(weights.wm[:, None] * pts_obs, axis=0) # (m,) + + # Innovation covariance S = sum wc * (y_diff)(y_diff)' + R + y_diff = pts_obs - y_pred[None, :] # (2n+1, m) + S = jnp.sum(weights.wc[:, None, None] * (y_diff[:, :, None] * y_diff[:, None, :]), axis=0) + S = S + model.R # (m, m) + + # Cross-covariance P_xy = sum wc * (x_diff)(y_diff)' + x_diff = pts - predicted.mean[None, :] # (2n+1, n) + P_xy = jnp.sum(weights.wc[:, None, None] * (x_diff[:, :, None] * y_diff[:, None, :]), axis=0) + # (n, m) + + # Kalman gain K = P_xy @ S^{-1} + K = jnp.linalg.solve(S.T, P_xy.T).T # (n, m) + + # Innovation + e = y - y_pred # (m,) + + filtered_mean = predicted.mean + K @ e + filtered_cov = predicted.cov - K @ S @ K.T + + # Log-likelihood: log N(e; 0, S) + m = observation.shape[-1] + log_det = jnp.linalg.slogdet(S)[1] + mahal = e @ jnp.linalg.solve(S, e) + ll = -0.5 * (m * jnp.log(2.0 * jnp.pi) + log_det + mahal) + + # Handle missing observations + obs_valid = ~jnp.any(jnp.isnan(y)) + filtered_mean = jnp.where(obs_valid, filtered_mean, predicted.mean) + filtered_cov = jnp.where(obs_valid, filtered_cov, predicted.cov) + ll = jnp.where(obs_valid, ll, 0.0) + + filtered = GaussianState(mean=filtered_mean, cov=filtered_cov) + return filtered, ll + + +# --------------------------------------------------------------------------- +# Full forward pass via lax.scan +# --------------------------------------------------------------------------- + + +class UnscentedKalmanFilter: + """Unscented Kalman Filter for nonlinear state-space models. + + Uses the scaled unscented transform to propagate sigma points through + nonlinear functions, avoiding the need for Jacobian computation. + + Args: + alpha: Spread of sigma points (default 1e-3). + beta: Prior distribution parameter (default 2.0, optimal for Gaussian). + kappa: Secondary scaling parameter (default 0.0). + """ + + def __init__( + self, + alpha: float = 1e-3, + beta: float = 2.0, + kappa: float = 0.0, + ) -> None: + self.alpha = alpha + self.beta = beta + self.kappa = kappa + + def predict(self, state: GaussianState, model: NonlinearSSM) -> GaussianState: + """UKF predict step (time update).""" + w = compute_weights(model.state_dim, self.alpha, self.beta, self.kappa) + return predict(state, model, w) + + def update( + self, + predicted: GaussianState, + observation: Array, + model: NonlinearSSM, + ) -> GaussianState: + """UKF update step (measurement update).""" + w = compute_weights(model.state_dim, self.alpha, self.beta, self.kappa) + filtered, _ll = update(predicted, observation, model, w) + return filtered + + def scan( + self, + model: NonlinearSSM, + observations: Array, + initial_state: GaussianState | None = None, + ) -> FilterResult: + """Run full forward UKF via jax.lax.scan.""" + return _ukf_filter_impl( + model, observations, initial_state, + self.alpha, self.beta, self.kappa, + ) + + +def ukf_filter( + model: NonlinearSSM, + observations: Array, + initial_state: GaussianState | None = None, + *, + alpha: float = 1e-3, + beta: float = 2.0, + kappa: float = 0.0, +) -> FilterResult: + """Unscented Kalman Filter forward pass. + + Uses the scaled unscented transform with configurable parameters to + propagate sigma points through nonlinear transition and observation + functions. + + Args: + model: Nonlinear state-space model with callable f and h. + observations: Observation sequence, shape (T, obs_dim). + initial_state: Initial state belief. Defaults to diffuse prior. + alpha: Spread of sigma points around the mean (default 1e-3). + beta: Prior distribution parameter (default 2.0, optimal for Gaussian). + kappa: Secondary scaling parameter (default 0.0). + + Returns: + FilterResult with filtered/predicted states and log-likelihood. + + Example:: + + import jax.numpy as jnp + from dynaris.core.nonlinear import NonlinearSSM + from dynaris.filters.ukf import ukf_filter + + model = NonlinearSSM( + transition_fn=lambda x: x, + observation_fn=lambda x: x, + transition_cov=jnp.eye(1), + observation_cov=jnp.eye(1), + state_dim=1, obs_dim=1, + ) + result = ukf_filter(model, observations) + """ + return _ukf_filter_impl(model, observations, initial_state, alpha, beta, kappa) + + +def _ukf_filter_impl( + model: NonlinearSSM, + observations: Array, + initial_state: GaussianState | None, + alpha: float, + beta: float, + kappa: float, +) -> FilterResult: + """Internal implementation — weights computed before JIT boundary.""" + if initial_state is None: + initial_state = model.initial_state() + + weights = compute_weights(model.state_dim, alpha, beta, kappa) + return _ukf_scan(model, observations, initial_state, weights) + + +@jax.jit +def _ukf_scan( + model: NonlinearSSM, + observations: Array, + initial_state: GaussianState, + weights: SigmaWeights, +) -> FilterResult: + """JIT-compiled scan loop for UKF.""" + init_carry = _ScanCarry( + filtered=initial_state, + log_likelihood=jnp.array(0.0), + ) + + def _scan_step( + carry: _ScanCarry, obs: Array + ) -> tuple[_ScanCarry, _ScanOutput]: + predicted = predict(carry.filtered, model, weights) + filtered, ll = update(predicted, obs, model, weights) + new_carry = _ScanCarry( + filtered=filtered, + log_likelihood=carry.log_likelihood + ll, + ) + output = _ScanOutput( + predicted_mean=predicted.mean, + predicted_cov=predicted.cov, + filtered_mean=filtered.mean, + filtered_cov=filtered.cov, + ) + return new_carry, output + + final_carry, outputs = jax.lax.scan(_scan_step, init_carry, observations) + + return FilterResult( + filtered_states=outputs.filtered_mean, + filtered_covariances=outputs.filtered_cov, + predicted_states=outputs.predicted_mean, + predicted_covariances=outputs.predicted_cov, + log_likelihood=final_carry.log_likelihood, + observations=observations, + ) diff --git a/tests/test_filters/test_ukf.py b/tests/test_filters/test_ukf.py new file mode 100644 index 0000000..af668e6 --- /dev/null +++ b/tests/test_filters/test_ukf.py @@ -0,0 +1,408 @@ +"""Tests for the Unscented Kalman Filter.""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +from dynaris.core.nonlinear import NonlinearSSM +from dynaris.core.results import FilterResult +from dynaris.core.state_space import StateSpaceModel +from dynaris.core.types import GaussianState +from dynaris.datasets import load_nile_jax +from dynaris.filters.kalman import kalman_filter +from dynaris.filters.ukf import ( + UnscentedKalmanFilter, + compute_weights, + sigma_points, + ukf_filter, + predict, + update, +) + +NILE = load_nile_jax() + + +# --------------------------------------------------------------------------- +# Helper: linear model as NonlinearSSM (for comparison with Kalman) +# --------------------------------------------------------------------------- + + +def _linear_nonlinear_model( + sigma_level: float = 1.0, sigma_obs: float = 1.0 +) -> NonlinearSSM: + """Local-level model as a NonlinearSSM.""" + return NonlinearSSM( + transition_fn=lambda x: x, + observation_fn=lambda x: x, + transition_cov=jnp.array([[sigma_level**2]]), + observation_cov=jnp.array([[sigma_obs**2]]), + state_dim=1, + obs_dim=1, + ) + + +def _linear_ssm(sigma_level: float = 1.0, sigma_obs: float = 1.0) -> StateSpaceModel: + """Equivalent linear model for Kalman filter comparison.""" + return StateSpaceModel( + system_matrix=jnp.array([[1.0]]), + observation_matrix=jnp.array([[1.0]]), + evolution_cov=jnp.array([[sigma_level**2]]), + obs_cov=jnp.array([[sigma_obs**2]]), + ) + + +# --------------------------------------------------------------------------- +# Sigma-point and weight tests +# --------------------------------------------------------------------------- + + +def test_compute_weights_shape() -> None: + w = compute_weights(n=3) + assert w.wm.shape == (7,) + assert w.wc.shape == (7,) + + +def test_compute_weights_sum_with_large_alpha() -> None: + """With alpha=1, weights sum to 1.""" + w = compute_weights(n=3, alpha=1.0, kappa=0.0) + np.testing.assert_allclose(jnp.sum(w.wm), 1.0, atol=1e-6) + + +def test_compute_weights_custom_params() -> None: + w = compute_weights(n=2, alpha=0.5, beta=2.0, kappa=1.0) + assert w.wm.shape == (5,) + + +def test_sigma_points_shape() -> None: + state = GaussianState(mean=jnp.zeros(3), cov=jnp.eye(3)) + w = compute_weights(n=3) + pts = sigma_points(state, w.lam) + assert pts.shape == (7, 3) + + +def test_sigma_points_center_is_mean() -> None: + mean = jnp.array([1.0, 2.0]) + state = GaussianState(mean=mean, cov=jnp.eye(2) * 0.5) + w = compute_weights(n=2) + pts = sigma_points(state, w.lam) + np.testing.assert_allclose(pts[0], mean, atol=1e-6) + + +def test_sigma_points_symmetric() -> None: + state = GaussianState(mean=jnp.array([3.0]), cov=jnp.array([[2.0]])) + w = compute_weights(n=1) + pts = sigma_points(state, w.lam) + # Points 1 and 2 should be equidistant from the mean + np.testing.assert_allclose( + pts[1] - state.mean, -(pts[2] - state.mean), atol=1e-6 + ) + + +def test_sigma_points_weighted_mean_recovers_mean() -> None: + """Weighted mean of sigma points should recover the original mean.""" + mean = jnp.array([2.0, -1.0]) + cov = jnp.array([[1.0, 0.3], [0.3, 0.5]]) + state = GaussianState(mean=mean, cov=cov) + w = compute_weights(n=2, alpha=1.0) + pts = sigma_points(state, w.lam) + recovered = jnp.sum(w.wm[:, None] * pts, axis=0) + np.testing.assert_allclose(recovered, mean, atol=1e-5) + + +# --------------------------------------------------------------------------- +# Predict step tests +# --------------------------------------------------------------------------- + + +def test_predict_identity_transition() -> None: + model = _linear_nonlinear_model() + state = GaussianState(mean=jnp.array([5.0]), cov=jnp.array([[2.0]])) + w = compute_weights(n=1, alpha=1.0) + pred = predict(state, model, w) + # Identity transition: mean unchanged, cov = P + Q + np.testing.assert_allclose(pred.mean, [5.0], atol=1e-4) + np.testing.assert_allclose(pred.cov, [[3.0]], atol=1e-3) + + +def test_predict_nonlinear_transition() -> None: + def f(x: Array) -> Array: + return x + 0.1 * jnp.sin(x) + + model = NonlinearSSM( + transition_fn=f, + observation_fn=lambda x: x, + transition_cov=jnp.array([[0.5]]), + observation_cov=jnp.array([[1.0]]), + state_dim=1, + obs_dim=1, + ) + state = GaussianState(mean=jnp.array([1.0]), cov=jnp.array([[0.001]])) + w = compute_weights(n=1, alpha=1.0) + pred = predict(state, model, w) + expected = 1.0 + 0.1 * float(jnp.sin(1.0)) + np.testing.assert_allclose(pred.mean, [expected], atol=0.01) + assert jnp.all(jnp.isfinite(pred.cov)) + + +# --------------------------------------------------------------------------- +# Update step tests +# --------------------------------------------------------------------------- + + +def test_update_reduces_uncertainty() -> None: + model = _linear_nonlinear_model(sigma_level=1.0, sigma_obs=1.0) + predicted = GaussianState(mean=jnp.array([0.0]), cov=jnp.array([[10.0]])) + obs = jnp.array([5.0]) + w = compute_weights(n=1) + filtered, ll = update(predicted, obs, model, w) + assert float(filtered.cov[0, 0]) < 10.0 + assert float(filtered.mean[0]) > 0.0 + assert jnp.isfinite(ll) + + +def test_update_nan_skips() -> None: + model = _linear_nonlinear_model() + predicted = GaussianState(mean=jnp.array([3.0]), cov=jnp.array([[2.0]])) + obs = jnp.array([jnp.nan]) + w = compute_weights(n=1) + filtered, ll = update(predicted, obs, model, w) + np.testing.assert_allclose(filtered.mean, predicted.mean) + np.testing.assert_allclose(filtered.cov, predicted.cov) + assert float(ll) == 0.0 + + +# --------------------------------------------------------------------------- +# UKF matches Kalman on linear models +# --------------------------------------------------------------------------- + + +def test_ukf_matches_kalman_on_linear_model() -> None: + """When the model is linear, UKF should produce similar results to Kalman. + + The UKF with default alpha=1e-3 and a diffuse prior can diverge slightly + due to the extreme scaling. With alpha=1 and a tighter prior, the match + is close. + """ + sigma_level, sigma_obs = 40.0, 120.0 + nl_model = _linear_nonlinear_model(sigma_level, sigma_obs) + lin_model = _linear_ssm(sigma_level, sigma_obs) + + observations = NILE.reshape(-1, 1) + + # Use a tighter (non-diffuse) initial state for better UKF-Kalman agreement + init = GaussianState(mean=jnp.array([1000.0]), cov=jnp.eye(1) * 1e4) + + ukf_result = ukf_filter(nl_model, observations, initial_state=init, alpha=1.0) + kf_result = kalman_filter(lin_model, observations, initial_state=init) + + # After initial transient, filtered states should be very close + np.testing.assert_allclose( + ukf_result.filtered_states[10:], kf_result.filtered_states[10:], atol=0.5 + ) + np.testing.assert_allclose( + ukf_result.log_likelihood, kf_result.log_likelihood, atol=5.0 + ) + + +# --------------------------------------------------------------------------- +# Full filter scan tests +# --------------------------------------------------------------------------- + + +def test_ukf_filter_shapes() -> None: + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE.reshape(-1, 1) + result = ukf_filter(model, observations) + + assert isinstance(result, FilterResult) + assert result.filtered_states.shape == (100, 1) + assert result.filtered_covariances.shape == (100, 1, 1) + assert result.predicted_states.shape == (100, 1) + assert result.predicted_covariances.shape == (100, 1, 1) + assert result.log_likelihood.shape == () + + +def test_ukf_filter_finite() -> None: + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE.reshape(-1, 1) + result = ukf_filter(model, observations) + + assert jnp.all(jnp.isfinite(result.filtered_states)) + assert jnp.all(jnp.isfinite(result.filtered_covariances)) + assert jnp.isfinite(result.log_likelihood) + + +def test_ukf_filter_negative_log_likelihood() -> None: + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE.reshape(-1, 1) + result = ukf_filter(model, observations) + assert float(result.log_likelihood) < 0.0 + + +def test_ukf_filter_with_missing_obs() -> None: + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE.reshape(-1, 1) + observations = observations.at[10, 0].set(jnp.nan) + observations = observations.at[20, 0].set(jnp.nan) + + result = ukf_filter(model, observations) + assert jnp.all(jnp.isfinite(result.filtered_states)) + assert jnp.isfinite(result.log_likelihood) + np.testing.assert_allclose( + result.filtered_states[10], result.predicted_states[10], atol=1e-5 + ) + + +# --------------------------------------------------------------------------- +# Nonlinear model tests +# --------------------------------------------------------------------------- + + +def test_ukf_nonlinear_tracking() -> None: + """UKF should track a nonlinear state through noisy observations.""" + key = jax.random.PRNGKey(42) + k1, k2 = jax.random.split(key) + n_steps = 200 + + def f(x: Array) -> Array: + return 0.95 * x + 0.1 * jnp.sin(x) + + def h(x: Array) -> Array: + return x + + sigma_q, sigma_r = 0.5, 1.0 + state_noise = jax.random.normal(k1, (n_steps,)) * sigma_q + obs_noise = jax.random.normal(k2, (n_steps,)) * sigma_r + + state = jnp.array([5.0]) + states_list = [] + for t in range(n_steps): + state = f(state) + state_noise[t : t + 1] + states_list.append(state) + true_states = jnp.concatenate(states_list) + observations = (true_states + obs_noise).reshape(-1, 1) + + model = NonlinearSSM( + transition_fn=f, + observation_fn=h, + transition_cov=jnp.array([[sigma_q**2]]), + observation_cov=jnp.array([[sigma_r**2]]), + state_dim=1, + obs_dim=1, + ) + + init = GaussianState(mean=jnp.array([5.0]), cov=jnp.array([[1.0]])) + result = ukf_filter(model, observations, initial_state=init) + + filtered = result.filtered_states[:, 0] + correlation = jnp.corrcoef(jnp.stack([filtered, true_states]))[0, 1] + assert float(correlation) > 0.7, f"Correlation {correlation} too low" + assert jnp.all(jnp.isfinite(result.filtered_states)) + + +def test_ukf_2d_model() -> None: + """Test UKF with a 2D state, 2D observation model.""" + + def f(x: Array) -> Array: + return x * 0.99 + + def h(x: Array) -> Array: + return x + + model = NonlinearSSM( + transition_fn=f, + observation_fn=h, + transition_cov=jnp.eye(2) * 0.1, + observation_cov=jnp.eye(2) * 1.0, + state_dim=2, + obs_dim=2, + ) + + key = jax.random.PRNGKey(7) + observations = jax.random.normal(key, (50, 2)) + + init = GaussianState(mean=jnp.array([0.0, 0.0]), cov=jnp.eye(2) * 10.0) + result = ukf_filter(model, observations, initial_state=init, alpha=1.0) + + assert result.filtered_states.shape == (50, 2) + assert jnp.all(jnp.isfinite(result.filtered_states)) + assert jnp.isfinite(result.log_likelihood) + + +# --------------------------------------------------------------------------- +# Configurable parameters +# --------------------------------------------------------------------------- + + +def test_ukf_custom_alpha_beta_kappa() -> None: + """UKF should work with custom sigma-point parameters.""" + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE[:20].reshape(-1, 1) + + result = ukf_filter(model, observations, alpha=0.5, beta=2.0, kappa=1.0) + assert jnp.all(jnp.isfinite(result.filtered_states)) + assert jnp.isfinite(result.log_likelihood) + + +# --------------------------------------------------------------------------- +# JIT compatibility +# --------------------------------------------------------------------------- + + +def test_ukf_filter_jit() -> None: + """Verify ukf_filter works with JIT compilation.""" + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE[:20].reshape(-1, 1) + r1 = ukf_filter(model, observations) + r2 = ukf_filter(model, observations) + np.testing.assert_allclose(r1.log_likelihood, r2.log_likelihood, atol=1e-5) + + +def test_grad_through_ukf() -> None: + """Verify autodiff works through the UKF log-likelihood.""" + observations = NILE[:20].reshape(-1, 1) + + def neg_ll(log_sigma_level: Array, log_sigma_obs: Array) -> Array: + Q = jnp.exp(log_sigma_level) * jnp.eye(1) + R = jnp.exp(log_sigma_obs) * jnp.eye(1) + model = NonlinearSSM( + transition_fn=lambda x: x, + observation_fn=lambda x: x, + transition_cov=Q, + observation_cov=R, + state_dim=1, + obs_dim=1, + ) + result = ukf_filter(model, observations) + return -result.log_likelihood + + grad_fn = jax.grad(neg_ll, argnums=(0, 1)) + g1, g2 = grad_fn(jnp.log(jnp.array(1600.0)), jnp.log(jnp.array(15000.0))) + assert jnp.isfinite(g1) + assert jnp.isfinite(g2) + + +# --------------------------------------------------------------------------- +# Class interface +# --------------------------------------------------------------------------- + + +def test_ukf_class_scan() -> None: + ukf = UnscentedKalmanFilter() + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE[:10].reshape(-1, 1) + result = ukf.scan(model, observations) + assert isinstance(result, FilterResult) + assert result.filtered_states.shape == (10, 1) + + +def test_ukf_class_custom_params() -> None: + ukf = UnscentedKalmanFilter(alpha=0.5, beta=2.0, kappa=1.0) + model = _linear_nonlinear_model(sigma_level=40.0, sigma_obs=120.0) + observations = NILE[:10].reshape(-1, 1) + result = ukf.scan(model, observations) + assert isinstance(result, FilterResult) + assert jnp.isfinite(result.log_likelihood)