Add TorchVision classification models: SqueezeNet, DenseNet, ShuffleN…#1550
Add TorchVision classification models: SqueezeNet, DenseNet, ShuffleN…#1550alinpahontu2912 wants to merge 3 commits intodotnet:mainfrom
Conversation
…etV2, EfficientNet, MNASNet Add 5 new model families (21 variants) ported from PyTorch torchvision: - SqueezeNet 1.0/1.1 - DenseNet-121/161/169/201 - ShuffleNet V2 x0.5/x1.0/x1.5/x2.0 - EfficientNet B0-B7, EfficientNet V2 S/M/L - MNASNet 0.5/0.75/1.0/1.3 All models support pre-trained weight loading via weights_file/skipfc parameters with state_dict keys matching PyTorch exactly. Tests added for all new model families. TODO: The following torchvision classification models are not yet implemented: - RegNet (Y/X variants) - ConvNeXt (Tiny, Small, Base, Large) - Vision Transformer / ViT (B-16, B-32, L-16, L-32, H-14) - Swin Transformer (T, S, B) - Swin Transformer V2 (T, S, B) - MaxViT (T) Closes dotnet#586 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Fix EfficientNetB0 and EfficientNetV2S named_children order to match field declaration order (features, avgpool, classifier) - Fix DenseNet121 state_dict count from 242 to 727 to reflect proper registration of all dense layers via register_module Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
EfficientNetV2S (782 state dict entries) is the largest non-skipped model and causes the test host process to crash from memory pressure when run alongside all other model tests. Skip it following the same pattern used for other large EfficientNet variants (B1-B7, V2M, V2L). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Adds additional TorchVision classification model families to TorchSharp, aligning module structure/state_dict keys with torchvision to enable loading pre-trained weights from exported PyTorch state_dicts.
Changes:
- Introduces 5 new model families (SqueezeNet, DenseNet, ShuffleNetV2, EfficientNet V1/V2, MNASNet) with torchvision-style factory methods and optional weight loading (
weights_file/skipfc). - Adds/extends unit tests to validate basic model structure (named children), state_dict sizes for select variants, and forward output shapes.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| test/TorchSharpTest/TestTorchVision.cs | Adds tests covering the new model families (some heavy variants explicitly skipped for CI stability). |
| src/TorchVision/models/SqueezeNet.cs | Implements SqueezeNet 1.0/1.1 modules + factory methods with optional weight loading. |
| src/TorchVision/models/DenseNet.cs | Implements DenseNet-121/161/169/201 modules + factory methods with optional weight loading. |
| src/TorchVision/models/ShuffleNetV2.cs | Implements ShuffleNet V2 x0.5/x1.0/x1.5/x2.0 modules + factory methods with optional weight loading. |
| src/TorchVision/models/EfficientNet.cs | Implements EfficientNet B0–B7 and EfficientNetV2 S/M/L modules + factory methods with optional weight loading. |
| src/TorchVision/models/MNASNet.cs | Implements MNASNet 0.5/0.75/1.0/1.3 modules + factory methods with optional weight loading. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| throw new ArgumentOutOfRangeException($"stride should be 1 or 2 instead of {stride}"); | ||
| if (kernel_size != 3 && kernel_size != 5) | ||
| throw new ArgumentOutOfRangeException($"kernel_size should be 3 or 5 instead of {kernel_size}"); |
There was a problem hiding this comment.
The ArgumentOutOfRangeException is being constructed with a single interpolated string. In .NET that overload treats the string as paramName, not the error message, which produces misleading exception details. Use an overload that specifies paramName (e.g., nameof(stride) / nameof(kernel_size)) and a separate message/value.
| throw new ArgumentOutOfRangeException($"stride should be 1 or 2 instead of {stride}"); | |
| if (kernel_size != 3 && kernel_size != 5) | |
| throw new ArgumentOutOfRangeException($"kernel_size should be 3 or 5 instead of {kernel_size}"); | |
| throw new ArgumentOutOfRangeException(nameof(stride), $"stride should be 1 or 2 instead of {stride}."); | |
| if (kernel_size != 3 && kernel_size != 5) | |
| throw new ArgumentOutOfRangeException(nameof(kernel_size), $"kernel_size should be 3 or 5 instead of {kernel_size}."); |
| private static Module<Tensor, Tensor> _stack(long in_ch, long out_ch, long kernel_size, long stride, long exp_factor, int repeats, double bn_momentum) | ||
| { | ||
| if (repeats < 1) | ||
| throw new ArgumentOutOfRangeException($"repeats should be >= 1, instead got {repeats}"); |
There was a problem hiding this comment.
Same issue as above: ArgumentOutOfRangeException is constructed with a single string, which is treated as the parameter name. Use an overload that passes paramName (e.g., nameof(repeats)) and a proper message (and ideally the actual value).
| throw new ArgumentOutOfRangeException($"repeats should be >= 1, instead got {repeats}"); | |
| throw new ArgumentOutOfRangeException(nameof(repeats), repeats, $"repeats should be >= 1, instead got {repeats}"); |
| private static int _round_to_multiple_of(double val, int divisor, double round_up_bias = 0.9) | ||
| { | ||
| if (round_up_bias <= 0.0 || round_up_bias >= 1.0) | ||
| throw new ArgumentOutOfRangeException($"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}"); |
There was a problem hiding this comment.
ArgumentOutOfRangeException is instantiated with a single string which becomes the exception's paramName, not the message. Prefer new ArgumentOutOfRangeException(nameof(round_up_bias), round_up_bias, "...") (or similar) so callers get a correct parameter name and an informative message/value.
| throw new ArgumentOutOfRangeException($"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}"); | |
| throw new ArgumentOutOfRangeException( | |
| nameof(round_up_bias), | |
| round_up_bias, | |
| $"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}"); |
| : base(nameof(MNASNet)) | ||
| { | ||
| if (alpha <= 0.0) | ||
| throw new ArgumentOutOfRangeException($"alpha should be greater than 0.0 instead of {alpha}"); |
There was a problem hiding this comment.
ArgumentOutOfRangeException is being created with a single interpolated string; that overload uses the string as paramName, so the exception message will be unhelpful. Use an overload that supplies nameof(alpha) and the offending value (and/or a separate message).
| throw new ArgumentOutOfRangeException($"alpha should be greater than 0.0 instead of {alpha}"); | |
| throw new ArgumentOutOfRangeException(nameof(alpha), alpha, "alpha should be greater than 0.0."); |
Fixes #588
Add 5 new model families (21 variants) ported from PyTorch torchvision:
All models support pre-trained weight loading via weights_file/skipfc parameters with state_dict keys matching PyTorch exactly.
Tests added for all new model families.
TODO: The following torchvision classification models are not yet implemented: