Skip to content

Commit 175ae02

Browse files
committed
Changes will need to be merged into a middle branch that is up to date with main
1 parent ca2ae1a commit 175ae02

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

bioneuralnet/downstream_task/dpmon.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class DPMON:
6969
lr (float): Learning rate for the optimizer.
7070
weight_decay (float): L2 weight decay (regularization) coefficient.
7171
tune (bool): If True, perform hyperparameter tuning before final training.
72+
tune_trails (int): Number of trials to perform if tune=True.
7273
gpu (bool): If True, use GPU if available.
7374
cv (bool): If True, use K-fold cross-validation; otherwise use repeated train/test splits.
7475
cuda (int): CUDA device index to use when gpu=True.
@@ -98,6 +99,7 @@ def __init__(
9899
lr: float = 1e-1,
99100
weight_decay: float = 1e-4,
100101
tune: bool = False,
102+
tune_trails: int = 10,
101103
gpu: bool = False,
102104
cv: bool = False,
103105
cuda: int = 0,
@@ -156,6 +158,7 @@ def __init__(
156158
self.lr = lr
157159
self.weight_decay = weight_decay
158160
self.tune = tune
161+
self.tune_trails = tune_trails
159162
self.gpu = gpu
160163
self.cuda = cuda
161164
self.seed = seed
@@ -205,6 +208,7 @@ def run(self) -> Tuple[pd.DataFrame, object, torch.Tensor | None]:
205208
"gpu": self.gpu,
206209
"cuda": self.cuda,
207210
"tune": self.tune,
211+
"tune_trials": self.tune_trails,
208212
"seed": self.seed,
209213
"seed_trials": self.seed_trials,
210214
"cv": self.cv,
@@ -892,8 +896,8 @@ def short_dirname_creator(trial):
892896

893897
gpu_per_trial = 0.05 if use_gpu else 0.0
894898

895-
num_samples = 50
896-
max_retries = 5
899+
num_samples = dpmon_params['tune_trials']
900+
max_retries = 4
897901

898902
seed_trials = dpmon_params.get("seed_trials", False)
899903

@@ -916,6 +920,7 @@ def short_dirname_creator(trial):
916920
config=pipeline_configs,
917921
num_samples=num_samples,
918922
verbose=0,
923+
log_to_file=True,
919924
scheduler=scheduler,
920925
stop=stopper,
921926
name="tune_dp",

0 commit comments

Comments
 (0)