Skip to content

Add TorchVision classification models: SqueezeNet, DenseNet, ShuffleN…#1550

Open
alinpahontu2912 wants to merge 3 commits intodotnet:mainfrom
alinpahontu2912:torchvision_models
Open

Add TorchVision classification models: SqueezeNet, DenseNet, ShuffleN…#1550
alinpahontu2912 wants to merge 3 commits intodotnet:mainfrom
alinpahontu2912:torchvision_models

Conversation

@alinpahontu2912
Copy link
Member

Fixes #588
Add 5 new model families (21 variants) ported from PyTorch torchvision:

  • SqueezeNet 1.0/1.1
  • DenseNet-121/161/169/201
  • ShuffleNet V2 x0.5/x1.0/x1.5/x2.0
  • EfficientNet B0-B7, EfficientNet V2 S/M/L
  • MNASNet 0.5/0.75/1.0/1.3

All models support pre-trained weight loading via weights_file/skipfc parameters with state_dict keys matching PyTorch exactly.

Tests added for all new model families.

TODO: The following torchvision classification models are not yet implemented:

  • RegNet (Y/X variants)
  • ConvNeXt (Tiny, Small, Base, Large)
  • Vision Transformer / ViT (B-16, B-32, L-16, L-32, H-14)
  • Swin Transformer (T, S, B)
  • Swin Transformer V2 (T, S, B)
  • MaxViT (T)

alinpahontu2912 and others added 3 commits February 27, 2026 15:04
…etV2, EfficientNet, MNASNet

Add 5 new model families (21 variants) ported from PyTorch torchvision:

- SqueezeNet 1.0/1.1
- DenseNet-121/161/169/201
- ShuffleNet V2 x0.5/x1.0/x1.5/x2.0
- EfficientNet B0-B7, EfficientNet V2 S/M/L
- MNASNet 0.5/0.75/1.0/1.3

All models support pre-trained weight loading via weights_file/skipfc
parameters with state_dict keys matching PyTorch exactly.

Tests added for all new model families.

TODO: The following torchvision classification models are not yet implemented:
- RegNet (Y/X variants)
- ConvNeXt (Tiny, Small, Base, Large)
- Vision Transformer / ViT (B-16, B-32, L-16, L-32, H-14)
- Swin Transformer (T, S, B)
- Swin Transformer V2 (T, S, B)
- MaxViT (T)

Closes dotnet#586

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Fix EfficientNetB0 and EfficientNetV2S named_children order to match
  field declaration order (features, avgpool, classifier)
- Fix DenseNet121 state_dict count from 242 to 727 to reflect proper
  registration of all dense layers via register_module

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
EfficientNetV2S (782 state dict entries) is the largest non-skipped
model and causes the test host process to crash from memory pressure
when run alongside all other model tests. Skip it following the same
pattern used for other large EfficientNet variants (B1-B7, V2M, V2L).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds additional TorchVision classification model families to TorchSharp, aligning module structure/state_dict keys with torchvision to enable loading pre-trained weights from exported PyTorch state_dicts.

Changes:

  • Introduces 5 new model families (SqueezeNet, DenseNet, ShuffleNetV2, EfficientNet V1/V2, MNASNet) with torchvision-style factory methods and optional weight loading (weights_file/skipfc).
  • Adds/extends unit tests to validate basic model structure (named children), state_dict sizes for select variants, and forward output shapes.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
test/TorchSharpTest/TestTorchVision.cs Adds tests covering the new model families (some heavy variants explicitly skipped for CI stability).
src/TorchVision/models/SqueezeNet.cs Implements SqueezeNet 1.0/1.1 modules + factory methods with optional weight loading.
src/TorchVision/models/DenseNet.cs Implements DenseNet-121/161/169/201 modules + factory methods with optional weight loading.
src/TorchVision/models/ShuffleNetV2.cs Implements ShuffleNet V2 x0.5/x1.0/x1.5/x2.0 modules + factory methods with optional weight loading.
src/TorchVision/models/EfficientNet.cs Implements EfficientNet B0–B7 and EfficientNetV2 S/M/L modules + factory methods with optional weight loading.
src/TorchVision/models/MNASNet.cs Implements MNASNet 0.5/0.75/1.0/1.3 modules + factory methods with optional weight loading.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +131 to +133
throw new ArgumentOutOfRangeException($"stride should be 1 or 2 instead of {stride}");
if (kernel_size != 3 && kernel_size != 5)
throw new ArgumentOutOfRangeException($"kernel_size should be 3 or 5 instead of {kernel_size}");
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ArgumentOutOfRangeException is being constructed with a single interpolated string. In .NET that overload treats the string as paramName, not the error message, which produces misleading exception details. Use an overload that specifies paramName (e.g., nameof(stride) / nameof(kernel_size)) and a separate message/value.

Suggested change
throw new ArgumentOutOfRangeException($"stride should be 1 or 2 instead of {stride}");
if (kernel_size != 3 && kernel_size != 5)
throw new ArgumentOutOfRangeException($"kernel_size should be 3 or 5 instead of {kernel_size}");
throw new ArgumentOutOfRangeException(nameof(stride), $"stride should be 1 or 2 instead of {stride}.");
if (kernel_size != 3 && kernel_size != 5)
throw new ArgumentOutOfRangeException(nameof(kernel_size), $"kernel_size should be 3 or 5 instead of {kernel_size}.");

Copilot uses AI. Check for mistakes.
private static Module<Tensor, Tensor> _stack(long in_ch, long out_ch, long kernel_size, long stride, long exp_factor, int repeats, double bn_momentum)
{
if (repeats < 1)
throw new ArgumentOutOfRangeException($"repeats should be >= 1, instead got {repeats}");
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as above: ArgumentOutOfRangeException is constructed with a single string, which is treated as the parameter name. Use an overload that passes paramName (e.g., nameof(repeats)) and a proper message (and ideally the actual value).

Suggested change
throw new ArgumentOutOfRangeException($"repeats should be >= 1, instead got {repeats}");
throw new ArgumentOutOfRangeException(nameof(repeats), repeats, $"repeats should be >= 1, instead got {repeats}");

Copilot uses AI. Check for mistakes.
private static int _round_to_multiple_of(double val, int divisor, double round_up_bias = 0.9)
{
if (round_up_bias <= 0.0 || round_up_bias >= 1.0)
throw new ArgumentOutOfRangeException($"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}");
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArgumentOutOfRangeException is instantiated with a single string which becomes the exception's paramName, not the message. Prefer new ArgumentOutOfRangeException(nameof(round_up_bias), round_up_bias, "...") (or similar) so callers get a correct parameter name and an informative message/value.

Suggested change
throw new ArgumentOutOfRangeException($"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}");
throw new ArgumentOutOfRangeException(
nameof(round_up_bias),
round_up_bias,
$"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}");

Copilot uses AI. Check for mistakes.
: base(nameof(MNASNet))
{
if (alpha <= 0.0)
throw new ArgumentOutOfRangeException($"alpha should be greater than 0.0 instead of {alpha}");
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArgumentOutOfRangeException is being created with a single interpolated string; that overload uses the string as paramName, so the exception message will be unhelpful. Use an overload that supplies nameof(alpha) and the offending value (and/or a separate message).

Suggested change
throw new ArgumentOutOfRangeException($"alpha should be greater than 0.0 instead of {alpha}");
throw new ArgumentOutOfRangeException(nameof(alpha), alpha, "alpha should be greater than 0.0.");

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TorchVision: Add models with pre-trained weights.

2 participants