Skip to content

Fixes in the JAX engine, parallel support for scipy engine, and background weight fixes#156

Merged
rcjackson merged 6 commits intoopenradar:mainfrom
rcjackson:parallel_scipy
Apr 8, 2026
Merged

Fixes in the JAX engine, parallel support for scipy engine, and background weight fixes#156
rcjackson merged 6 commits intoopenradar:mainfrom
rcjackson:parallel_scipy

Conversation

@rcjackson
Copy link
Copy Markdown
Collaborator

This PR has three additions:
* The JAX engine now does the JIT compilation of the cost function before the optimization loop. In addition, the output is now verbose, making it easier to use the JAX engine and diagnose issues in the retrieval.
* The SciPy engine now has the option to run the individual cost functions in parallel during calculation.
* The background weights were still set to 1, even in regions where the sounding data are unavailable. This has been fixed.

Robert Jackson and others added 6 commits April 8, 2026 09:40
…dients

Adds a `parallel=True` flag to `get_dd_wind_field` (scipy engine) that
enables two optimizations per cost function iteration:

1. Vectorized radar loop: stacks all radar arrays into (n_radars, nz, ny, nx)
   tensors and computes the radial velocity cost and gradient in a single
   NumPy broadcast instead of a sequential Python loop.

2. Threaded constraint gradients: submits each active constraint gradient
   (radial velocity, mass continuity, smoothness, background, vorticity,
   model, point) to a ThreadPoolExecutor concurrently, then sums the results.
   NumPy releases the GIL for array ops so threads provide real parallelism.

Unit tests extended to cover both serial and parallel paths:
- `test_calculate_rad_velocity_cost_parallel`: asserts numerical equivalence
  of cost and gradient for serial vs. vectorized implementations.
- `test_twpice_case_parallel`: end-to-end TWP-ICE retrieval with parallel=True,
  checking physical sanity and close numerical match with serial results.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
With jit=True, jaxopt compiled the cost function silently during the
first solver.run() call, causing a long unexplained pause. Fix by
manually jit-compiling loss_and_gradient with jax.jit and triggering
a warm-up call before the solver starts. This surfaces the delay with
clear print messages and passes jit=False to jaxopt since the function
is already compiled.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…round

When the sounding doesn't cover the full vertical extent of the grid,
interp1d (bounds_error=False) returns NaN for out-of-range levels.
Previously bg_weights was never zeroed at those levels, so the background
constraint was applied using NaN wind values, corrupting the cost function.

Fix by masking bg_weights wherever u_back or v_back is NaN after
interpolation, and replacing NaN values in u_back/v_back with 0 so
they don't propagate through cost function arithmetic.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 8, 2026

Codecov Report

❌ Patch coverage is 91.81818% with 9 lines in your changes missing coverage. Please review.
✅ Project coverage is 69.21%. Comparing base (7685760) to head (fad740f).
⚠️ Report is 8 commits behind head on main.

Files with missing lines Patch % Lines
pydda/cost_functions/cost_functions.py 74.19% 8 Missing ⚠️
setup.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #156      +/-   ##
==========================================
+ Coverage   68.85%   69.21%   +0.35%     
==========================================
  Files          32       32              
  Lines        5141     5217      +76     
==========================================
+ Hits         3540     3611      +71     
- Misses       1601     1606       +5     
Flag Coverage Δ
unittests 69.21% <91.81%> (+0.35%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@rcjackson rcjackson merged commit e28bda8 into openradar:main Apr 8, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant