diff --git a/examples/plot_examples.py b/examples/plot_examples.py index 6bd16b7d..3f5b0446 100644 --- a/examples/plot_examples.py +++ b/examples/plot_examples.py @@ -36,6 +36,7 @@ upper_bc=1, wind_tol=0.5, engine="scipy", + parallel=False, ) # Plot a horizontal cross section diff --git a/pydda/__init__.py b/pydda/__init__.py index 1b6ca089..d96cd4c3 100644 --- a/pydda/__init__.py +++ b/pydda/__init__.py @@ -12,7 +12,7 @@ from . import constraints from . import io -__version__ = "2.3.0" +__version__ = "2.4.0" print("Welcome to PyDDA %s" % __version__) print("If you are using PyDDA in your publications, please cite:") diff --git a/pydda/cost_functions/_cost_functions_jax.py b/pydda/cost_functions/_cost_functions_jax.py index f8f3cc3c..ec1cef57 100644 --- a/pydda/cost_functions/_cost_functions_jax.py +++ b/pydda/cost_functions/_cost_functions_jax.py @@ -548,7 +548,7 @@ def calculate_background_gradient(u, v, w, weights, u_back, v_back, Cb=0.01): calculate_background_cost, u, v, w, weights, u_back, v_back, Cb ) u_grad, v_grad, w_grad, _, _, _, _ = fun_vjp(1.0) - y = np.stack([u_grad, v_grad, w_grad], axis=0) + y = jnp.stack([u_grad, v_grad, w_grad], axis=0) return y.flatten().copy() diff --git a/pydda/cost_functions/_cost_functions_numpy.py b/pydda/cost_functions/_cost_functions_numpy.py index 754f9316..fc578511 100644 --- a/pydda/cost_functions/_cost_functions_numpy.py +++ b/pydda/cost_functions/_cost_functions_numpy.py @@ -8,7 +8,7 @@ def calculate_radial_vel_cost_function( - vrs, azs, els, u, v, w, wts, rmsVr, weights, coeff=1.0 + vrs, azs, els, u, v, w, wts, rmsVr, weights, coeff=1.0, parallel=False ): """ Calculates the cost function due to difference of the wind field from @@ -54,22 +54,44 @@ def calculate_radial_vel_cost_function( Technol., 26, 2089–2106, https://doi.org/10.1175/2009JTECHA1256.1 """ - J_o = 0 lambda_o = coeff / (rmsVr * rmsVr) + if parallel: + vrs_arr = np.stack(vrs) + els_arr = np.stack(els) + azs_arr = np.stack(azs) + wts_arr = np.stack(wts) + v_ar = ( + np.cos(els_arr) * np.sin(azs_arr) * u[np.newaxis] + + np.cos(els_arr) * np.cos(azs_arr) * v[np.newaxis] + + np.sin(els_arr) * (w[np.newaxis] - np.abs(wts_arr)) + ) + return lambda_o * np.sum(np.square(vrs_arr - v_ar) * weights) + + J_o = 0 for i in range(len(vrs)): v_ar = ( np.cos(els[i]) * np.sin(azs[i]) * u + np.cos(els[i]) * np.cos(azs[i]) * v + np.sin(els[i]) * (w - np.abs(wts[i])) ) - the_weight = weights[i] - J_o += lambda_o * np.sum(np.square(vrs[i] - v_ar) * the_weight) + J_o += lambda_o * np.sum(np.square(vrs[i] - v_ar) * weights[i]) return J_o def calculate_grad_radial_vel( - vrs, els, azs, u, v, w, wts, weights, rmsVr, coeff=1.0, upper_bc=True + vrs, + els, + azs, + u, + v, + w, + wts, + weights, + rmsVr, + coeff=1.0, + upper_bc=True, + parallel=False, ): """ Calculates the gradient of the cost function due to difference of wind @@ -112,29 +134,45 @@ def calculate_grad_radial_vel( # Use zero for all masked values since we don't want to add them into # the cost function - p_x1 = np.zeros(vrs[0].shape) - p_y1 = np.zeros(vrs[0].shape) - p_z1 = np.zeros(vrs[0].shape) lambda_o = coeff / (rmsVr * rmsVr) - for i in range(len(vrs)): + if parallel: + vrs_arr = np.stack(vrs) + els_arr = np.stack(els) + azs_arr = np.stack(azs) + wts_arr = np.stack(wts) v_ar = ( - np.cos(els[i]) * np.sin(azs[i]) * u - + np.cos(els[i]) * np.cos(azs[i]) * v - + np.sin(els[i]) * (w - np.abs(wts[i])) + np.cos(els_arr) * np.sin(azs_arr) * u[np.newaxis] + + np.cos(els_arr) * np.cos(azs_arr) * v[np.newaxis] + + np.sin(els_arr) * (w[np.newaxis] - np.abs(wts_arr)) ) + residual = 2 * (v_ar - vrs_arr) * lambda_o + p_x1 = np.sum(residual * np.cos(els_arr) * np.sin(azs_arr) * weights, axis=0) + p_y1 = np.sum(residual * np.cos(els_arr) * np.cos(azs_arr) * weights, axis=0) + p_z1 = np.sum(residual * np.sin(els_arr) * weights, axis=0) + else: + p_x1 = np.zeros(vrs[0].shape) + p_y1 = np.zeros(vrs[0].shape) + p_z1 = np.zeros(vrs[0].shape) + + for i in range(len(vrs)): + v_ar = ( + np.cos(els[i]) * np.sin(azs[i]) * u + + np.cos(els[i]) * np.cos(azs[i]) * v + + np.sin(els[i]) * (w - np.abs(wts[i])) + ) - x_grad = ( - 2 * (v_ar - vrs[i]) * np.cos(els[i]) * np.sin(azs[i]) * weights[i] - ) * lambda_o - y_grad = ( - 2 * (v_ar - vrs[i]) * np.cos(els[i]) * np.cos(azs[i]) * weights[i] - ) * lambda_o - z_grad = (2 * (v_ar - vrs[i]) * np.sin(els[i]) * weights[i]) * lambda_o - - p_x1 += x_grad - p_y1 += y_grad - p_z1 += z_grad + x_grad = ( + 2 * (v_ar - vrs[i]) * np.cos(els[i]) * np.sin(azs[i]) * weights[i] + ) * lambda_o + y_grad = ( + 2 * (v_ar - vrs[i]) * np.cos(els[i]) * np.cos(azs[i]) * weights[i] + ) * lambda_o + z_grad = (2 * (v_ar - vrs[i]) * np.sin(els[i]) * weights[i]) * lambda_o + + p_x1 += x_grad + p_y1 += y_grad + p_z1 += z_grad # Impermeability condition p_z1[0, :, :] = 0 diff --git a/pydda/cost_functions/cost_functions.py b/pydda/cost_functions/cost_functions.py index 1ab64ff1..198e09b7 100644 --- a/pydda/cost_functions/cost_functions.py +++ b/pydda/cost_functions/cost_functions.py @@ -1,4 +1,5 @@ import numpy as np +from concurrent.futures import ThreadPoolExecutor # Adding jax import statements try: @@ -183,6 +184,7 @@ def J_function(winds, parameters): rmsVr=parameters.rmsVr, weights=parameters.weights, coeff=parameters.Co, + parallel=parameters.parallel, ) # print("apples Jvel", Jvel) @@ -510,96 +512,208 @@ def grad_J(winds, parameters): parameters.grid_shape[2], ), ) - grad = _cost_functions_numpy.calculate_grad_radial_vel( - parameters.vrs, - parameters.els, - parameters.azs, - winds[0], - winds[1], - winds[2], - parameters.wts, - parameters.weights, - parameters.rmsVr, - coeff=parameters.Co, - upper_bc=parameters.upper_bc, - ) - - if parameters.Cm > 0: - grad += _cost_functions_numpy.calculate_mass_continuity_gradient( - winds[0], - winds[1], - winds[2], - parameters.z, - parameters.dx, - parameters.dy, - parameters.dz, - coeff=parameters.Cm, - upper_bc=parameters.upper_bc, - ) - - if parameters.Cx > 0 or parameters.Cy > 0 or parameters.Cz > 0: - grad += _cost_functions_numpy.calculate_smoothness_gradient( - winds[0], - winds[1], - winds[2], - parameters.dx, - parameters.dy, - parameters.dz, - Cx=parameters.Cx, - Cy=parameters.Cy, - Cz=parameters.Cz, - upper_bc=parameters.upper_bc, - ) - - if parameters.Cb > 0: - grad += _cost_functions_numpy.calculate_background_gradient( - winds[0], - winds[1], - winds[2], - parameters.bg_weights, - parameters.u_back, - parameters.v_back, - parameters.Cb, - ) - - if parameters.Cv > 0: - grad += _cost_functions_numpy.calculate_vertical_vorticity_gradient( + if parameters.parallel: + futures = [] + with ThreadPoolExecutor() as pool: + futures.append( + pool.submit( + _cost_functions_numpy.calculate_grad_radial_vel, + parameters.vrs, + parameters.els, + parameters.azs, + winds[0], + winds[1], + winds[2], + parameters.wts, + parameters.weights, + parameters.rmsVr, + parameters.Co, + parameters.upper_bc, + True, + ) + ) + if parameters.Cm > 0: + futures.append( + pool.submit( + _cost_functions_numpy.calculate_mass_continuity_gradient, + winds[0], + winds[1], + winds[2], + parameters.z, + parameters.dx, + parameters.dy, + parameters.dz, + parameters.Cm, + 1, + parameters.upper_bc, + ) + ) + if parameters.Cx > 0 or parameters.Cy > 0 or parameters.Cz > 0: + futures.append( + pool.submit( + _cost_functions_numpy.calculate_smoothness_gradient, + winds[0], + winds[1], + winds[2], + parameters.dx, + parameters.dy, + parameters.dz, + parameters.Cx, + parameters.Cy, + parameters.Cz, + parameters.upper_bc, + ) + ) + if parameters.Cb > 0: + futures.append( + pool.submit( + _cost_functions_numpy.calculate_background_gradient, + winds[0], + winds[1], + winds[2], + parameters.bg_weights, + parameters.u_back, + parameters.v_back, + parameters.Cb, + ) + ) + if parameters.Cv > 0: + futures.append( + pool.submit( + _cost_functions_numpy.calculate_vertical_vorticity_gradient, + winds[0], + winds[1], + winds[2], + parameters.dx, + parameters.dy, + parameters.dz, + parameters.Ut, + parameters.Vt, + parameters.Cv, + parameters.upper_bc, + ) + ) + if parameters.Cmod > 0: + futures.append( + pool.submit( + _cost_functions_numpy.calculate_model_gradient, + winds[0], + winds[1], + winds[2], + parameters.model_weights, + parameters.u_model, + parameters.v_model, + parameters.w_model, + parameters.Cmod, + ) + ) + if parameters.Cpoint > 0: + futures.append( + pool.submit( + _cost_functions_numpy.calculate_point_gradient, + winds[0], + winds[1], + parameters.x, + parameters.y, + parameters.z, + parameters.point_list, + parameters.Cpoint, + parameters.roi, + ) + ) + grad = sum(f.result() for f in futures) + else: + grad = _cost_functions_numpy.calculate_grad_radial_vel( + parameters.vrs, + parameters.els, + parameters.azs, winds[0], winds[1], winds[2], - parameters.dx, - parameters.dy, - parameters.dz, - parameters.Ut, - parameters.Vt, - coeff=parameters.Cv, + parameters.wts, + parameters.weights, + parameters.rmsVr, + coeff=parameters.Co, upper_bc=parameters.upper_bc, ) - if parameters.Cmod > 0: - grad += _cost_functions_numpy.calculate_model_gradient( - winds[0], - winds[1], - winds[2], - parameters.model_weights, - parameters.u_model, - parameters.v_model, - parameters.w_model, - coeff=parameters.Cmod, - ) + if parameters.Cm > 0: + grad += _cost_functions_numpy.calculate_mass_continuity_gradient( + winds[0], + winds[1], + winds[2], + parameters.z, + parameters.dx, + parameters.dy, + parameters.dz, + coeff=parameters.Cm, + upper_bc=parameters.upper_bc, + ) + + if parameters.Cx > 0 or parameters.Cy > 0 or parameters.Cz > 0: + grad += _cost_functions_numpy.calculate_smoothness_gradient( + winds[0], + winds[1], + winds[2], + parameters.dx, + parameters.dy, + parameters.dz, + Cx=parameters.Cx, + Cy=parameters.Cy, + Cz=parameters.Cz, + upper_bc=parameters.upper_bc, + ) + + if parameters.Cb > 0: + grad += _cost_functions_numpy.calculate_background_gradient( + winds[0], + winds[1], + winds[2], + parameters.bg_weights, + parameters.u_back, + parameters.v_back, + parameters.Cb, + ) + + if parameters.Cv > 0: + grad += _cost_functions_numpy.calculate_vertical_vorticity_gradient( + winds[0], + winds[1], + winds[2], + parameters.dx, + parameters.dy, + parameters.dz, + parameters.Ut, + parameters.Vt, + coeff=parameters.Cv, + upper_bc=parameters.upper_bc, + ) + + if parameters.Cmod > 0: + grad += _cost_functions_numpy.calculate_model_gradient( + winds[0], + winds[1], + winds[2], + parameters.model_weights, + parameters.u_model, + parameters.v_model, + parameters.w_model, + coeff=parameters.Cmod, + ) + + if parameters.Cpoint > 0: + grad += _cost_functions_numpy.calculate_point_gradient( + winds[0], + winds[1], + parameters.x, + parameters.y, + parameters.z, + parameters.point_list, + Cp=parameters.Cpoint, + roi=parameters.roi, + upper_bc=parameters.upper_bc, + ) - if parameters.Cpoint > 0: - grad += _cost_functions_numpy.calculate_point_gradient( - winds[0], - winds[1], - parameters.x, - parameters.y, - parameters.z, - parameters.point_list, - Cp=parameters.Cpoint, - roi=parameters.roi, - upper_bc=parameters.upper_bc, - ) # Let's see if we need to enforce strong boundary conditions if parameters.const_boundary_cond is True: grad = np.reshape( diff --git a/pydda/retrieval/wind_retrieve.py b/pydda/retrieval/wind_retrieve.py index 5fb50766..5a5a0a2f 100644 --- a/pydda/retrieval/wind_retrieve.py +++ b/pydda/retrieval/wind_retrieve.py @@ -183,6 +183,7 @@ def __init__(self): self.gtol = 1e-2 self.Jveltol = 100.0 self.const_boundary_cond = False + self.parallel = False def _get_dd_wind_field_scipy( @@ -231,6 +232,7 @@ def _get_dd_wind_field_scipy( tolerance=1e-8, const_boundary_cond=False, max_wind_mag=100.0, + parallel=True, ): global _wcurrmax global _wprevmax @@ -500,6 +502,15 @@ def _get_dd_wind_field_scipy( parameters.bg_weights[~np.isfinite(parameters.bg_weights)] = 0 parameters.weights[parameters.weights > 0] = 1 parameters.bg_weights[parameters.bg_weights > 0] = 1 + + # Zero out bg_weights at height levels where the interpolated background + # is NaN (i.e. outside the sounding's vertical range). Also replace NaN + # in u_back/v_back with 0 so those levels don't corrupt cost function + # arithmetic even though they carry zero weight. + nan_bg_levels = ~np.isfinite(parameters.u_back) | ~np.isfinite(parameters.v_back) + parameters.bg_weights[nan_bg_levels] = 0 + parameters.u_back = np.nan_to_num(parameters.u_back) + parameters.v_back = np.nan_to_num(parameters.v_back) sum_Vr = np.nansum(np.square(parameters.vrs * parameters.weights)) parameters.rmsVr = np.sqrt(np.nansum(sum_Vr) / np.nansum(parameters.weights)) @@ -559,6 +570,7 @@ def _get_dd_wind_field_scipy( parameters.upper_bc = upper_bc parameters.points = points parameters.point_list = points + parameters.parallel = parallel _wprevmax = np.zeros(parameters.grid_shape) _wcurrmax = np.zeros(parameters.grid_shape) iterations = 0 @@ -622,14 +634,21 @@ def loss_and_gradient(x): {"winds": max_wind_mag * jnp.ones(winds.shape)}, ) winds = jnp.array(winds) + # JIT-compile the cost function explicitly so the compilation + # delay is isolated and visible before the solver loop starts. + loss_and_gradient = jax.jit(loss_and_gradient) + print("Compiling JAX cost functions...") + loss_and_gradient({"winds": winds}) + print("Compilation complete.") solver = jaxopt.LBFGSB( loss_and_gradient, True, has_aux=False, maxiter=max_iterations, tol=tolerance, - jit=True, + jit=False, implicit_diff=False, + verbose=True, ) winds = {"winds": winds} winds, state = solver.run(winds, bounds=bounds) @@ -1430,6 +1449,11 @@ def get_dd_wind_field( Tolerance for :math:`L_{2}` norm of gradient before stopping. max_wind_mag: float Constrain the optimization to have :math:`|u|`, :math:`|v|`, and :math:`|w| < x` m/s. + parallel: bool + If True, enables parallelized cost and gradient computations for the scipy engine. + This vectorizes the radar loop in the radial velocity cost/gradient functions and + computes independent constraint gradients concurrently using a thread pool. + Default is False. Returns ======= diff --git a/pydda/tests/test_cost_functions.py b/pydda/tests/test_cost_functions.py index 8c6c6ed4..6614c216 100644 --- a/pydda/tests/test_cost_functions.py +++ b/pydda/tests/test_cost_functions.py @@ -22,36 +22,65 @@ JAX_AVAILABLE = False -def test_calculate_rad_velocity_cost(): +def _make_radvel_inputs(): + """Build shared test inputs for radial velocity cost/gradient tests.""" Grid = pyart.testing.make_empty_grid( (20, 20, 20), ((0, 10000), (-10000, 10000), (-10000, 10000)) ) - - # a zero field fdata3 = np.zeros((20, 20, 20)) Grid.fields["zero_field"] = {"data": fdata3, "units": "m/s"} Grid = pydda.io.read_from_pyart_grid(Grid) - vel_field = "zero_field" - vrs = [np.ma.array(Grid[vel_field].values)] + vrs = [np.ma.array(Grid["zero_field"].values)] azs = [Grid["AZ"].values] els = [Grid["EL"].values] + wts = [np.ma.zeros((20, 20, 20))] + weights = np.ones((1, 20, 20, 20)) + return vrs, azs, els, wts, weights + + +def test_calculate_rad_velocity_cost(): + vrs, azs, els, wts, weights = _make_radvel_inputs() u = np.zeros((20, 20, 20)) v = np.zeros((20, 20, 20)) w = np.zeros((20, 20, 20)) rmsVr = 1.0 - wts = [np.ma.zeros((20, 20, 20))] - weights = [np.ones((20, 20, 20))] cost = pydda.cost_functions.calculate_radial_vel_cost_function( vrs, azs, els, u, v, w, wts, rmsVr, weights ) grad = pydda.cost_functions.calculate_grad_radial_vel( - vrs, azs, els, u, v, w, wts, weights, rmsVr + vrs, els, azs, u, v, w, wts, weights, rmsVr ) assert cost == 0 assert np.all(grad == 0) +def test_calculate_rad_velocity_cost_parallel(): + """Vectorized (parallel=True) radial velocity cost and gradient match serial results.""" + vrs, azs, els, wts, weights = _make_radvel_inputs() + rng = np.random.default_rng(42) + u = rng.random((20, 20, 20)) + v = rng.random((20, 20, 20)) + w = rng.random((20, 20, 20)) + rmsVr = 1.0 + + serial_cost = pydda.cost_functions.calculate_radial_vel_cost_function( + vrs, azs, els, u, v, w, wts, rmsVr, weights, parallel=False + ) + parallel_cost = pydda.cost_functions.calculate_radial_vel_cost_function( + vrs, azs, els, u, v, w, wts, rmsVr, weights, parallel=True + ) + np.testing.assert_allclose(parallel_cost, serial_cost) + + serial_grad = pydda.cost_functions.calculate_grad_radial_vel( + vrs, els, azs, u, v, w, wts, weights, rmsVr, parallel=False + ) + parallel_grad = pydda.cost_functions.calculate_grad_radial_vel( + vrs, els, azs, u, v, w, wts, weights, rmsVr, parallel=True + ) + np.testing.assert_allclose(parallel_grad, serial_grad) + + @pytest.mark.skipif(not JAX_AVAILABLE, reason="Jax not installed") def test_calculate_rad_velocity_cost_jax(): """Test with a zero velocity field radar""" diff --git a/pydda/tests/test_retrieval.py b/pydda/tests/test_retrieval.py index f537bb6c..f50b44ca 100644 --- a/pydda/tests/test_retrieval.py +++ b/pydda/tests/test_retrieval.py @@ -189,6 +189,59 @@ def test_twpice_case(): assert w_max > 5 +def test_twpice_case_parallel(): + """TWP-ICE case with parallel=True should produce physically consistent results + and match the serial retrieval.""" + Grid0 = pydda.io.read_grid(pydda.tests.EXAMPLE_RADAR0) + Grid1 = pydda.io.read_grid(pydda.tests.EXAMPLE_RADAR1) + sounding = pyart.io.read_arm_sonde(pydda.tests.SOUNDING_PATH) + + Grid0 = pydda.initialization.make_wind_field_from_profile( + Grid0, sounding[1], vel_field="corrected_velocity" + ) + + common_kwargs = dict( + Co=100, + Cm=1500.0, + max_iterations=20, + Cz=0, + Cmod=0.0, + vel_name="corrected_velocity", + wind_tol=0.1, + refl_field="reflectivity", + frz=5000.0, + engine="scipy", + mask_outside_opt=True, + upper_bc=1, + ) + + Grids_serial, _ = pydda.retrieval.get_dd_wind_field( + [deepcopy(Grid0), deepcopy(Grid1)], **common_kwargs, parallel=False + ) + Grids_parallel, _ = pydda.retrieval.get_dd_wind_field( + [deepcopy(Grid0), deepcopy(Grid1)], **common_kwargs, parallel=True + ) + + # Physical sanity: mean flow to the southeast, updrafts present + u_mean = np.nanmean(Grids_parallel[0]["u"].values) + v_mean = np.nanmean(Grids_parallel[0]["v"].values) + w_max = np.nanmax(Grids_parallel[0]["w"].values) + assert u_mean > 0 + assert v_mean < 0 + assert w_max > 5 + + # Numerical equivalence with serial + np.testing.assert_allclose( + Grids_parallel[0]["u"].values, Grids_serial[0]["u"].values, rtol=1e-5 + ) + np.testing.assert_allclose( + Grids_parallel[0]["v"].values, Grids_serial[0]["v"].values, rtol=1e-5 + ) + np.testing.assert_allclose( + Grids_parallel[0]["w"].values, Grids_serial[0]["w"].values, rtol=1e-5 + ) + + def test_smoothing(): """A field of random numbers from 0 to 1 should smooth out to near 0.5""" diff --git a/setup.py b/setup.py index afa6d956..0252d678 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ LICENSE = "BSD" PLATFORMS = "Linux, Windows, OSX" MAJOR = 2 -MINOR = 3 +MINOR = 4 MICRO = 0 # SCRIPTS = glob.glob('scripts/*')