-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathschf.py
More file actions
41 lines (31 loc) · 1.39 KB
/
schf.py
File metadata and controls
41 lines (31 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import schnetpack as spk
from schnetpack.datasets import QM9
qm9tut = './qm9tut'
if not os.path.exists('qm9tut'):
os.makedirs(qm9tut)
qm9data = QM9('./qm9.db', load_only=[QM9.U0], remove_uncharacterized=True) # already pre-downloaded
#print(qm9data.get_properties(0))
#print(len(qm9data))
train, val, test = spk.train_test_split(
data=qm9data,
num_train=100,
num_val=10,
split_file=os.path.join(qm9tut, "split.npz"),
)
#print(qm9data[0]) # probably this is how to feed the data to model THIS IS A DICTIONARY!!
fulldata = spk.AtomsLoader([qm9data[0]], batch_size=1, shuffle=False)
atomrefs = qm9data.get_atomref(QM9.U0) # tensor of atom energy
means, stddevs = fulldata.get_statistics(
QM9.U0, divide_by_atoms=True, single_atom_ref=atomrefs
)
sch_feat = spk.representation.SchNet(
n_atom_basis=30, n_filters=30, n_gaussians=20, n_interactions=5,
cutoff=4., cutoff_network=spk.nn.cutoff.CosineCutoff, return_intermediate = True
)
# the initialization doesn't even work :
#wacsf = spk.representation.SymmetryFunctions(n_radial=22, n_angular=5, zetas={1}, cutoff=spk.nn.cutoff.CosineCutoff, cutoff_radius=5.0, centered=False, crossterms=False, elements=frozenset({1, 6, 7, 8, 9}), sharez=True, trainz=False, initz='weighted', len_embedding=5, pairwise_elements=False)
for batch in fulldata:
f = sch_feat.forward(batch)
#print(batch)
print(f)