-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
49 lines (35 loc) · 1.04 KB
/
train.py
File metadata and controls
49 lines (35 loc) · 1.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
'''
Description: the train function for the CNN
Author: Yange Cao
Version: 1.0
Date: Novemberer 15th, 2018
'''
from network import *
from tqdm import tqdm
import pickle
def train(img_dim,
img_depth,
f1,
f2,
num_f1,
num_f2,
batch_size,
lr,
beta1,
beta2,
num_epochs,
save_path):
paras = init_paras(img_dim, img_depth, f1, f2, num_f1, num_f2)
train_data, test_data = get_data_cifar10()
cost = []
print("LR:" + str(lr) + ", Batch Size:" + str(batch_size))
for epoch in range(num_epochs):
batches = [train_data[k:k + batch_size] for k in range(0, train_data.shape[0], batch_size)]
t = tqdm(batches)
for x, batch in enumerate(t):
params, cost = adam(img_dim, img_depth, batch, paras, lr, beta1, beta2, cost)
t.set_description("Cost: {}".format(cost[-1]))
to_save = [params, cost]
with open(save_path, 'wb') as file:
pickle.dump(to_save, file)
return cost