Skip to content

Fix Stormcast example#1490

Merged
pzharrington merged 8 commits intoNVIDIA:mainfrom
albertocarpentieri:acarpentieri/stormcast
Mar 20, 2026
Merged

Fix Stormcast example#1490
pzharrington merged 8 commits intoNVIDIA:mainfrom
albertocarpentieri:acarpentieri/stormcast

Conversation

@albertocarpentieri
Copy link
Copy Markdown
Contributor

Fix Stormcast example

Description

  • Added a safety check in examples/weather/stormcast/utils/parallel.py to skip nested_scatter when use_shard_tensor=False, returning tensors unchanged instead of sharding unnecessarily.
  • Refactored UNet wiring in examples/weather/stormcast/utils/nn.py and examples/weather/stormcast/utils/trainer.py to rely on model.hyperparameters (rather than duplicated top-level fields) and to pass use_apex_gn explicitly.
  • Updated StormCast configs to align with the new hyperparameter layout
  • Expanded examples/weather/stormcast/test_training.py coverage to parametrize num_invariant_channels and ensure diffusion condition lists only include "invariant" when invariants are actually provided

Checklist

root added 2 commits March 11, 2026 09:13
Signed-off-by: root <root@pool0-01605.cm.cluster>
… into hyperparameters model section

Signed-off-by: root <root@pool0-01605.cm.cluster>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 11, 2026

Greptile Summary

This PR fixes several issues in the StormCast training example. The changes consolidate model hyperparameter wiring to use model.hyperparameters (removing duplicated top-level fields like spatial_pos_embed, channel_mult, attn_resolutions), fix a crash in ParallelHelper.distribute_tensor when use_shard_tensor=False, explicitly thread use_apex_gn through the UNet construction path, align configs with the new layout, and improve test coverage by parametrizing num_invariant_channels.

Key changes:

  • utils/parallel.py: Fixes a bug where distribute_tensor called nested_scatter unconditionally — now correctly returns the tensor unchanged when use_shard_tensor=False.
  • utils/nn.py: Removes spatial_embedding, channel_mult, and attn_resolutions as explicit parameters (these are now passed via **model_kwargs from the hyperparameters config block) and adds use_apex_gn as an explicit argument.
  • utils/trainer.py: Wires use_apex_gn into get_preconditioned_unet and passes hyperparameters through cleanly. Adds a log when invariant conditions are configured but the dataset provides none — though this uses .info instead of .warning, which may cause the misconfiguration to go unnoticed.
  • test_training.py: Parametrizes num_invariant_channels and correctly gates "invariant" on the condition list based on whether invariants are actually provided.
  • Config files: diffusion.yaml moves spatial_pos_embed: True into hyperparameters.additive_pos_embed: True; stormcast.yaml removes the now-obsolete top-level spatial_pos_embed field.

Important Files Changed

Filename Overview
examples/weather/stormcast/utils/parallel.py Bug fix: distribute_tensor now correctly skips nested_scatter when use_shard_tensor=False, returning the tensor unchanged instead of attempting to shard it unnecessarily.
examples/weather/stormcast/utils/nn.py Refactored get_preconditioned_unet to remove hardcoded spatial_embedding, channel_mult, and attn_resolutions defaults (now passed via **model_kwargs from hyperparameters config) and added explicit use_apex_gn parameter. Changes look correct.
examples/weather/stormcast/utils/trainer.py Updated UNet wiring to use model_cfg.hyperparameters and pass use_apex_gn explicitly. Added a log when invariant conditions are configured but the dataset provides no invariants — however this uses .info instead of .warning, risking the misconfiguration being missed. The condition_list is also not cleaned up after the message, causing a subsequent "Model conditions" log to still list "invariant" as active.
examples/weather/stormcast/test_training.py Added num_invariant_channels parametrization to test_model_types, correctly conditionally appending "invariant" to diffusion_conditions only when invariants are present. Condition lists are now set to conservative defaults per model type with invariant added separately.
examples/weather/stormcast/datasets/mock.py Added num_invariant_channels parameter and implemented get_invariants() returning a deterministic random array when channels > 0, or None otherwise. Implementation is clean and consistent with the dataset interface.

Last reviewed commit: 558a869

Comment thread examples/weather/stormcast/utils/trainer.py
root and others added 2 commits March 12, 2026 02:05
Signed-off-by: root <root@pool0-01102.cm.cluster>
Signed-off-by: Alberto Carpentieri <acarpentieri@cw-dfw-cs-001-vscode-01.cm.cluster>
Comment thread examples/weather/stormcast/utils/trainer.py Outdated
root and others added 4 commits March 18, 2026 03:18
…y and updated readme

Signed-off-by: root <root@pool0-01763.cm.cluster>
Signed-off-by: root <root@pool0-01780.cm.cluster>
Signed-off-by: root <root@pool0-01780.cm.cluster>
Copy link
Copy Markdown
Collaborator

@pzharrington pzharrington left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@pzharrington
Copy link
Copy Markdown
Collaborator

/bossom-ci

@pzharrington pzharrington enabled auto-merge March 20, 2026 03:46
@pzharrington
Copy link
Copy Markdown
Collaborator

/blossom-ci

@pzharrington pzharrington added this pull request to the merge queue Mar 20, 2026
Merged via the queue into NVIDIA:main with commit 7e87e7b Mar 20, 2026
4 checks passed
nbren12 pushed a commit to nbren12/modulus that referenced this pull request Mar 24, 2026
* add check in parallel to avoid sharding if not required

Signed-off-by: root <root@pool0-01605.cm.cluster>

* add gn apex support, invariants testing and move hyperparams for unet into hyperparameters model section

Signed-off-by: root <root@pool0-01605.cm.cluster>

* add condition to parallel helper

Signed-off-by: root <root@pool0-01102.cm.cluster>

* add training sigma bin loss

Signed-off-by: Alberto Carpentieri <acarpentieri@cw-dfw-cs-001-vscode-01.cm.cluster>

* modify sigma bin updating strategy and move regression loss to loss.py and updated readme

Signed-off-by: root <root@pool0-01763.cm.cluster>

* fix merge

Signed-off-by: root <root@pool0-01780.cm.cluster>

* move SigmaBinTracker in loss.py to keep trainer cleaner

Signed-off-by: root <root@pool0-01780.cm.cluster>

---------

Signed-off-by: root <root@pool0-01605.cm.cluster>
Signed-off-by: root <root@pool0-01102.cm.cluster>
Signed-off-by: Alberto Carpentieri <acarpentieri@cw-dfw-cs-001-vscode-01.cm.cluster>
Signed-off-by: root <root@pool0-01763.cm.cluster>
Signed-off-by: root <root@pool0-01780.cm.cluster>
Co-authored-by: root <root@pool0-01605.cm.cluster>
Co-authored-by: root <root@pool0-01102.cm.cluster>
Co-authored-by: Alberto Carpentieri <acarpentieri@cw-dfw-cs-001-vscode-01.cm.cluster>
Co-authored-by: root <root@pool0-01763.cm.cluster>
Co-authored-by: root <root@pool0-01780.cm.cluster>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants