-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathio.cc
More file actions
314 lines (296 loc) · 12.1 KB
/
io.cc
File metadata and controls
314 lines (296 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
#include "dynet/io.h"
#include "dynet/tensor.h"
#include "dynet/except.h"
#include "dynet/str-util.h"
using namespace std;
using namespace dynet;
// Precision required not to lose accuracy when serializing float32 to text.
// We should probably use std::hexfloat, but it's not supported by some
// older incomplete implementations of C++11.
static const int FLOAT32_PRECISION = 8;
bool valid_key(const std::string & s) {
if (s.size() == 0) return true;
if (s == "/") return false;
auto it = std::find_if(s.begin(), s.end(),
[] (char ch) { return ch == ' ' || ch == '#';});
return it == s.end();
}
bool valid_pc_key(const std::string & s) {
if (s.size() == 0) return true;
if (!(startswith(s, "/"))) return false;
return valid_key(s);
}
bool grad_is_zero(const ParameterStorageBase & p){
return !p.has_grad();
}
void read_param_header(string line, string &type, string &name, Dim& dim,size_t& byte_count, bool& zero_grad){
// Read header
istringstream iss(line);
iss >> type >> name >> dim >> byte_count;
// Check whether gradient is 0
// Check for EOF (for backward compatibility)
string grad;
if (!iss.eof()){
iss >> grad;
if (grad == "ZERO_GRAD")
zero_grad = true;
}
}
TextFileSaver::TextFileSaver(const string & filename, bool append) :
datastream(filename, append ? ofstream::app : ofstream::out) {
if(!datastream)
DYNET_RUNTIME_ERR("Could not write model to " << filename);
}
void TextFileSaver::save(const ParameterCollection & model,
const string & key) {
if (!valid_pc_key(key))
DYNET_INVALID_ARG("Key should start with '/' and could not include ' ' or '#': " << key);
string key_ = key;
if (key_.back() != '/') key_ += "/";
const ParameterCollectionStorage & storage = model.get_storage();
if(key.size() == 0) {
for (auto & p : storage.params) save(*p, key);
for (auto & p : storage.lookup_params) save(*p, key);
} else {
size_t strip_size = model.get_fullname().size();
for (auto & p : storage.params)
save(*p, key_ + p->name.substr(strip_size));
for (auto & p : storage.lookup_params)
save(*p, key_ + p->name.substr(strip_size));
}
}
void TextFileSaver::save(const Parameter & param,
const string & key) {
if (!valid_key(key))
DYNET_INVALID_ARG("Key could not include ' ' or '#': " << key);
save(*param.p, key);
}
void TextFileSaver::save(const LookupParameter & param,
const string & key) {
if (!valid_key(key))
DYNET_INVALID_ARG("Key could not include ' ' or '#': " << key);
save(*param.p, key);
}
void TextFileSaver::save(const ParameterStorage & p,
const string & key) {
std::ostringstream buffer;
buffer.precision(FLOAT32_PRECISION);
buffer << dynet::as_vector(p.values) << endl;
bool zero_grad = grad_is_zero(p);
if(!zero_grad)
buffer << dynet::as_vector(p.g) << endl;
datastream << "#Parameter# " << (key.size() > 0 ? key : p.name) << ' '
<< p.dim << ' ' << buffer.str().size();
if(zero_grad)
datastream << " ZERO_GRAD";
else
datastream << " FULL_GRAD";
datastream << endl;
datastream.write(buffer.str().c_str(), buffer.str().size());
}
void TextFileSaver::save(const LookupParameterStorage & p,
const string & key) {
std::ostringstream buffer;
buffer.precision(FLOAT32_PRECISION);
buffer << dynet::as_vector(p.all_values) << endl;
bool zero_grad = grad_is_zero(p);
if(!zero_grad)
buffer << dynet::as_vector(p.all_grads) << endl;
datastream << "#LookupParameter# " << (key.size() > 0 ? key : p.name) << ' ' << p.all_dim << ' ' << buffer.str().size();
if(zero_grad)
datastream << " ZERO_GRAD";
else
datastream << " FULL_GRAD";
datastream << endl;
datastream.write(buffer.str().c_str(), buffer.str().size());
}
TextFileLoader::TextFileLoader(const string & filename) :
dataname(filename) { }
void TextFileLoader::populate(ParameterCollection & model, const string & key) {
ifstream datastream(dataname);
if(!datastream) DYNET_RUNTIME_ERR("Could not read model from " << dataname);
string line, type, name;
bool zero_grad = false;
Dim dim;
size_t byte_count = 0;
vector<float> values;
Tensor *value_t, *grad_t;
size_t param_id = 0, lookup_id = 0;
ParameterCollectionStorage & storage = model.get_storage();
string key_ = key;
if (key_.back() != '/') key_ += "/";
while(getline(datastream, line)) {
read_param_header(line, type, name, dim, byte_count, zero_grad);
// Skip ones that don't match
if(key.size() != 0 && name.substr(0, key_.size()) != key_) {
size_t offset = static_cast<size_t>(datastream.tellg()) + byte_count;
datastream.seekg(offset);
continue;
// Load a parameter
} else if(type == "#Parameter#") {
values.resize(dim.size());
if(param_id >= storage.params.size())
DYNET_RUNTIME_ERR("Too many parameters to load in populated model at " << name);
ParameterStorage & param = *storage.params[param_id++];
if(param.dim != dim)
DYNET_RUNTIME_ERR("Dimensions of parameter " << name << " looked up from file (" << dim <<
") do not match parameters to be populated (" << param.dim << ")");
value_t = ¶m.values;
grad_t = ¶m.g;
// Load a lookup parameter
} else if(type == "#LookupParameter#") {
values.resize(dim.size());
if(lookup_id >= storage.lookup_params.size())
DYNET_RUNTIME_ERR("Too many lookup parameters in populated model at " << name);
LookupParameterStorage & param = *storage.lookup_params[lookup_id++];
if(param.all_dim != dim)
DYNET_RUNTIME_ERR("Dimensions of lookup parameter " << name << " lookup up from file (" << dim <<
") do not match parameters to be populated (" << param.all_dim << ")");
value_t = ¶m.all_values;
grad_t = ¶m.all_grads;
} else {
DYNET_RUNTIME_ERR("Bad parameter specification in model: " << line);
}
{ getline(datastream, line); istringstream iss(line); iss >> values; }
TensorTools::set_elements(*value_t, values);
if(!zero_grad){
{ getline(datastream, line); istringstream iss(line); iss >> values; }
TensorTools::set_elements(*grad_t, values);
} else {
TensorTools::zero(*grad_t);
}
}
if(param_id != storage.params.size() || lookup_id != storage.lookup_params.size())
DYNET_RUNTIME_ERR("Number of parameter/lookup parameter objects loaded from file (" <<
param_id << '/' << lookup_id << ") did not match number to be populated (" <<
storage.params.size() << '/' << storage.lookup_params.size() << ')');
}
void TextFileLoader::populate(Parameter & param,
const string & key) {
if(key == "")
DYNET_INVALID_ARG("TextFileLoader.populate() requires non-empty key");
ifstream datastream(dataname);
if(!datastream) DYNET_RUNTIME_ERR("Could not read model from " << dataname);
string line, type, name;
bool zero_grad=false;
Dim dim;
size_t byte_count = 0;
while(getline(datastream, line)) {
read_param_header(line, type, name, dim, byte_count, zero_grad);
if(type == "#Parameter#" && name == key) {
if(param.p->dim != dim)
DYNET_RUNTIME_ERR("Attempted to populate parameter where arguments don't match (" << param.p->dim << " != " << dim << ")");
vector<float> values(dim.size());
{ getline(datastream, line); istringstream iss(line); iss >> values; }
TensorTools::set_elements(param.get_storage().values, values);
if(!zero_grad){
{ getline(datastream, line); istringstream iss(line); iss >> values; }
TensorTools::set_elements(param.get_storage().g, values);
} else {
TensorTools::zero(param.get_storage().g);
}
return;
} else {
size_t offset = static_cast<size_t>(datastream.tellg()) + byte_count;
datastream.seekg(offset);
}
}
DYNET_RUNTIME_ERR("Could not find key " << key << " in the model file");
}
void TextFileLoader::populate(LookupParameter & lookup_param,
const string & key) {
if(key == "")
DYNET_INVALID_ARG("TextFileLoader.populate() requires non-empty key");
ifstream datastream(dataname);
if(!datastream) DYNET_RUNTIME_ERR("Could not read model from " << dataname);
string line, type, name;
bool zero_grad=false;
Dim dim;
size_t byte_count = 0;
while(getline(datastream, line)) {
read_param_header(line, type, name, dim, byte_count, zero_grad);
if(type == "#LookupParameter#" && name == key) {
if(lookup_param.p->all_dim != dim)
DYNET_RUNTIME_ERR("Attempted to populate lookup parameter where arguments don't match (" << lookup_param.p->all_dim << " != " << dim << ")");
vector<float> values(dim.size());
{ getline(datastream, line); istringstream iss(line); iss >> values; }
TensorTools::set_elements(lookup_param.get_storage().all_values, values);
if(!zero_grad){
{ getline(datastream, line); istringstream iss(line); iss >> values; }
TensorTools::set_elements(lookup_param.get_storage().all_grads, values);
} else {
TensorTools::zero(lookup_param.get_storage().all_grads);
}
return;
} else {
size_t offset = static_cast<size_t>(datastream.tellg()) + byte_count;
datastream.seekg(offset);
}
}
DYNET_RUNTIME_ERR("Could not find key " << key << " in the model file");
}
Parameter TextFileLoader::load_param(ParameterCollection & model,
const string & key) {
if(key == "")
DYNET_INVALID_ARG("TextFileLoader.load_param() requires non-empty key");
ifstream datastream(dataname);
if(!datastream) DYNET_RUNTIME_ERR("Could not read model from " << dataname);
string line, type, name;
bool zero_grad=false;
Dim dim;
size_t byte_count = 0;
while(getline(datastream, line)) {
read_param_header(line, type, name, dim, byte_count, zero_grad);
if(type == "#Parameter#" && name == key) {
Parameter param = model.add_parameters(dim);
param.get_storage().name = name;
vector<float> values(dim.size());
{ getline(datastream, line); istringstream iss(line); iss >> values; }
TensorTools::set_elements(param.get_storage().values, values);
if(!zero_grad){
{ getline(datastream, line); istringstream iss(line); iss >> values; }
TensorTools::set_elements(param.get_storage().g, values);
} else {
TensorTools::zero(param.get_storage().g);
}
return param;
} else {
size_t offset = static_cast<size_t>(datastream.tellg()) + byte_count;
datastream.seekg(offset);
}
}
DYNET_RUNTIME_ERR("Could not find key " << key << " in the model file");
}
LookupParameter TextFileLoader::load_lookup_param(ParameterCollection & model,
const string & key) {
if(key == "")
DYNET_INVALID_ARG("TextFileLoader.load_lookup_param() requires non-empty key");
ifstream datastream(dataname);
if(!datastream) DYNET_RUNTIME_ERR("Could not read model from " << dataname);
string line, type, name;
bool zero_grad=false;
Dim dim;
size_t byte_count = 0;
while(getline(datastream, line)) {
read_param_header(line, type, name, dim, byte_count, zero_grad);
if(type == "#LookupParameter#" && name == key) {
vector<float> values(dim.size());
size_t size = dim[dim.nd-1]; dim.nd--;
LookupParameter lookup_param = model.add_lookup_parameters(size, dim);
lookup_param.get_storage().name = name;
{ getline(datastream, line); istringstream iss(line); iss >> values; }
TensorTools::set_elements(lookup_param.get_storage().all_values, values);
if(!zero_grad){
{ getline(datastream, line); istringstream iss(line); iss >> values; }
TensorTools::set_elements(lookup_param.get_storage().all_grads, values);
} else {
TensorTools::zero(lookup_param.get_storage().all_grads);
}
return lookup_param;
} else {
size_t offset = static_cast<size_t>(datastream.tellg()) + byte_count;
datastream.seekg(offset);
}
}
DYNET_RUNTIME_ERR("Could not find key " << key << " in the model file");
}