From cc839d6a68bf1f00192fb8d0263548b3d8557060 Mon Sep 17 00:00:00 2001 From: stephantul Date: Wed, 25 Feb 2026 07:58:37 +0100 Subject: [PATCH 1/4] fix: if layers == 0, layers were not initialized --- model2vec/train/classifier.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 52b1f46..d07c401 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -66,15 +66,17 @@ def classes(self) -> np.ndarray: def construct_head(self) -> nn.Sequential: """Constructs a simple classifier head.""" - if self.n_layers == 0: - return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim)) - modules = [ - nn.Linear(self.embed_dim, self.hidden_dim), - nn.ReLU(), - ] - for _ in range(self.n_layers - 1): - modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()]) - modules.extend([nn.Linear(self.hidden_dim, self.out_dim)]) + modules = [] + if self.n_layers > 0: + # If we have a hidden layer, we should first project to hidden_dim + modules = [ + nn.Linear(self.embed_dim, self.hidden_dim), + nn.ReLU(), + ] + for _ in range(self.n_layers - 1): + modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()]) + # We always have a layer mapping from hidden to out. + modules.append(nn.Linear(self.hidden_dim, self.out_dim)) for module in modules: if isinstance(module, nn.Linear): From d01391b000349814bd75fee0c894d8adaae74439 Mon Sep 17 00:00:00 2001 From: stephantul Date: Wed, 25 Feb 2026 08:01:44 +0100 Subject: [PATCH 2/4] typing --- model2vec/train/classifier.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index d07c401..ce87c8d 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -66,8 +66,10 @@ def classes(self) -> np.ndarray: def construct_head(self) -> nn.Sequential: """Constructs a simple classifier head.""" - modules = [] - if self.n_layers > 0: + modules: list[nn.Module] = [] + if self.n_layers == 0: + modules.append(nn.Linear(self.embed_dim, self.out_dim)) + else: # If we have a hidden layer, we should first project to hidden_dim modules = [ nn.Linear(self.embed_dim, self.hidden_dim), @@ -75,8 +77,8 @@ def construct_head(self) -> nn.Sequential: ] for _ in range(self.n_layers - 1): modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()]) - # We always have a layer mapping from hidden to out. - modules.append(nn.Linear(self.hidden_dim, self.out_dim)) + # We always have a layer mapping from hidden to out. + modules.append(nn.Linear(self.hidden_dim, self.out_dim)) for module in modules: if isinstance(module, nn.Linear): From 5eea426cd9d01e7ce2d6b5bd14ef337ff6ae521b Mon Sep 17 00:00:00 2001 From: stephantul Date: Wed, 25 Feb 2026 08:11:24 +0100 Subject: [PATCH 3/4] redo initialization --- model2vec/train/classifier.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index ce87c8d..1cc4beb 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -80,10 +80,16 @@ def construct_head(self) -> nn.Sequential: # We always have a layer mapping from hidden to out. modules.append(nn.Linear(self.hidden_dim, self.out_dim)) - for module in modules: - if isinstance(module, nn.Linear): - nn.init.kaiming_uniform_(module.weight) - nn.init.zeros_(module.bias) + linear_modules = [module for module in modules if isinstance(module, nn.Linear)] + if linear_modules: + *initial, last = linear_modules + for module in initial: + if isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, nonlinearity="relu") + nn.init.zeros_(module.bias) + # Final layer does not kaiming + nn.init.xavier_uniform_(last.weight) + nn.init.zeros_(last.bias) return nn.Sequential(*modules) From 6d76f31666d3fb2d525fd2feb3c9975cf9ae3e7d Mon Sep 17 00:00:00 2001 From: stephantul Date: Wed, 25 Feb 2026 08:17:57 +0100 Subject: [PATCH 4/4] remove check for linear --- model2vec/train/classifier.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 1cc4beb..8c49aef 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -84,9 +84,8 @@ def construct_head(self) -> nn.Sequential: if linear_modules: *initial, last = linear_modules for module in initial: - if isinstance(module, nn.Linear): - nn.init.kaiming_uniform_(module.weight, nonlinearity="relu") - nn.init.zeros_(module.bias) + nn.init.kaiming_uniform_(module.weight, nonlinearity="relu") + nn.init.zeros_(module.bias) # Final layer does not kaiming nn.init.xavier_uniform_(last.weight) nn.init.zeros_(last.bias)