diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 52b1f46..8c49aef 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -66,20 +66,29 @@ def classes(self) -> np.ndarray: def construct_head(self) -> nn.Sequential: """Constructs a simple classifier head.""" + modules: list[nn.Module] = [] if self.n_layers == 0: - return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim)) - modules = [ - nn.Linear(self.embed_dim, self.hidden_dim), - nn.ReLU(), - ] - for _ in range(self.n_layers - 1): - modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()]) - modules.extend([nn.Linear(self.hidden_dim, self.out_dim)]) - - for module in modules: - if isinstance(module, nn.Linear): - nn.init.kaiming_uniform_(module.weight) + modules.append(nn.Linear(self.embed_dim, self.out_dim)) + else: + # If we have a hidden layer, we should first project to hidden_dim + modules = [ + nn.Linear(self.embed_dim, self.hidden_dim), + nn.ReLU(), + ] + for _ in range(self.n_layers - 1): + modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()]) + # We always have a layer mapping from hidden to out. + modules.append(nn.Linear(self.hidden_dim, self.out_dim)) + + linear_modules = [module for module in modules if isinstance(module, nn.Linear)] + if linear_modules: + *initial, last = linear_modules + for module in initial: + nn.init.kaiming_uniform_(module.weight, nonlinearity="relu") nn.init.zeros_(module.bias) + # Final layer does not kaiming + nn.init.xavier_uniform_(last.weight) + nn.init.zeros_(last.bias) return nn.Sequential(*modules)