-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathcfsm-builder.cc
More file actions
222 lines (198 loc) · 7.67 KB
/
cfsm-builder.cc
File metadata and controls
222 lines (198 loc) · 7.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#include "dynet/cfsm-builder.h"
#include "dynet/except.h"
#include "dynet/param-init.h"
#include <fstream>
#include <iostream>
using namespace std;
namespace dynet {
inline bool is_ws(char x) { return (x == ' ' || x == '\t'); }
inline bool not_ws(char x) { return (x != ' ' && x != '\t'); }
SoftmaxBuilder::~SoftmaxBuilder() {}
StandardSoftmaxBuilder::StandardSoftmaxBuilder() {}
StandardSoftmaxBuilder::StandardSoftmaxBuilder(unsigned rep_dim, unsigned vocab_size, ParameterCollection& model) {
local_model = model.add_subcollection("standard-softmax-builder");
p_w = local_model.add_parameters({vocab_size, rep_dim});
p_b = local_model.add_parameters({vocab_size}, ParameterInitConst(0.f));
}
void StandardSoftmaxBuilder::new_graph(ComputationGraph& cg) {
pcg = &cg;
w = parameter(cg, p_w);
b = parameter(cg, p_b);
}
Expression StandardSoftmaxBuilder::neg_log_softmax(const Expression& rep, unsigned wordidx) {
return pickneglogsoftmax(affine_transform({b, w, rep}), wordidx);
}
unsigned StandardSoftmaxBuilder::sample(const Expression& rep) {
Expression dist_expr = softmax(affine_transform({b, w, rep}));
vector<float> dist = as_vector(pcg->incremental_forward(dist_expr));
unsigned c = 0;
double p = rand01();
for (; c < dist.size(); ++c) {
p -= dist[c];
if (p < 0.0) { break; }
}
if (c == dist.size()) {
--c;
}
return c;
}
Expression StandardSoftmaxBuilder::full_log_distribution(const Expression& rep) {
return log(softmax(affine_transform({b, w, rep})));
}
ClassFactoredSoftmaxBuilder::ClassFactoredSoftmaxBuilder() {}
ClassFactoredSoftmaxBuilder::ClassFactoredSoftmaxBuilder(unsigned rep_dim,
const std::string& cluster_file,
Dict& word_dict,
ParameterCollection& model) {
read_cluster_file(cluster_file, word_dict);
const unsigned num_clusters = cdict.size();
local_model = model.add_subcollection("class-factored-softmax-builder");
p_r2c = local_model.add_parameters({num_clusters, rep_dim});
p_cbias = local_model.add_parameters({num_clusters}, ParameterInitConst(0.f));
p_rc2ws.resize(num_clusters);
p_rcwbiases.resize(num_clusters);
for (unsigned i = 0; i < num_clusters; ++i) {
auto& words = cidx2words[i]; // vector of word ids
const unsigned num_words_in_cluster = words.size();
if (num_words_in_cluster > 1) {
// for singleton clusters, we don't need these parameters, so
// we don't create them
p_rc2ws[i] = local_model.add_parameters({num_words_in_cluster, rep_dim});
p_rcwbiases[i] = local_model.add_parameters({num_words_in_cluster}, ParameterInitConst(0.f));
}
}
}
void ClassFactoredSoftmaxBuilder::new_graph(ComputationGraph& cg) {
pcg = &cg;
const unsigned num_clusters = cdict.size();
r2c = parameter(cg, p_r2c);
cbias = parameter(cg, p_cbias);
rc2ws.clear();
rc2biases.clear();
rc2ws.resize(num_clusters);
rc2biases.resize(num_clusters);
}
Expression ClassFactoredSoftmaxBuilder::neg_log_softmax(const Expression& rep, unsigned wordidx) {
// TODO check that new_graph has been called
int clusteridx = widx2cidx[wordidx];
DYNET_ARG_CHECK(clusteridx >= 0,
"Word ID " << wordidx << " missing from clusters in ClassFactoredSoftmaxBuilder::neg_log_softmax");
Expression cscores = affine_transform({cbias, r2c, rep});
Expression cnlp = pickneglogsoftmax(cscores, clusteridx);
if (singleton_cluster[clusteridx]) return cnlp;
// if there is only one word in the cluster, just return -log p(class | rep)
// otherwise predict word too
unsigned wordrow = widx2cwidx[wordidx];
Expression& cwbias = get_rc2wbias(clusteridx);
Expression& r2cw = get_rc2w(clusteridx);
Expression wscores = affine_transform({cwbias, r2cw, rep});
Expression wnlp = pickneglogsoftmax(wscores, wordrow);
return cnlp + wnlp;
}
unsigned ClassFactoredSoftmaxBuilder::sample(const Expression& rep) {
// TODO check that new_graph has been called
Expression cscores = affine_transform({cbias, r2c, rep});
Expression cdist_expr = softmax(cscores);
auto cdist = as_vector(pcg->incremental_forward(cdist_expr));
unsigned c = 0;
double p = rand01();
for (; c < cdist.size(); ++c) {
p -= cdist[c];
if (p < 0.0) { break; }
}
if (c == cdist.size()) --c;
unsigned w = 0;
if (!singleton_cluster[c]) {
Expression& cwbias = get_rc2wbias(c);
Expression& r2cw = get_rc2w(c);
Expression wscores = affine_transform({cwbias, r2cw, rep});
Expression wdist_expr = softmax(wscores);
auto wdist = as_vector(pcg->incremental_forward(wdist_expr));
p = rand01();
for (; w < wdist.size(); ++w) {
p -= wdist[w];
if (p < 0.0) { break; }
}
if (w == wdist.size()) --w;
}
return cidx2words[c][w];
}
Expression ClassFactoredSoftmaxBuilder::full_log_distribution(const Expression& rep) {
vector<Expression> full_dist(widx2cidx.size());
Expression cscores = log(softmax(affine_transform({cbias, r2c, rep})));
for (unsigned i = 0; i < widx2cidx.size(); ++i) {
if (widx2cidx[i] == -1) {
// XXX: Should be -inf
full_dist[i] = input(*pcg, -10000);
}
}
for (unsigned c = 0; c < p_rc2ws.size(); ++c) {
Expression cscore = pick(cscores, c);
if (singleton_cluster[c]) {
for (unsigned i = 0; i < cidx2words[c].size(); ++i) {
unsigned w = cidx2words[c][i];
full_dist[w] = cscore;
}
}
else {
Expression& cwbias = get_rc2wbias(c);
Expression& r2cw = get_rc2w(c);
Expression wscores = affine_transform({cwbias, r2cw, rep});
Expression wdist = softmax(wscores);
for (unsigned i = 0; i < cidx2words[c].size(); ++i) {
unsigned w = cidx2words[c][i];
full_dist[w] = pick(wdist, i) + cscore;
}
}
}
return log(softmax(concatenate(full_dist)));
}
void ClassFactoredSoftmaxBuilder::read_cluster_file(const std::string& cluster_file, Dict& word_dict) {
cerr << "Reading clusters from " << cluster_file << " ...\n";
ifstream in(cluster_file);
if(!in)
DYNET_INVALID_ARG("Could not find cluster file " << cluster_file << " in ClassFactoredSoftmax");
int wc = 0;
string line;
while(getline(in, line)) {
++wc;
const unsigned len = line.size();
unsigned startc = 0;
while (is_ws(line[startc]) && startc < len) { ++startc; }
unsigned endc = startc;
while (not_ws(line[endc]) && endc < len) { ++endc; }
unsigned startw = endc;
while (is_ws(line[startw]) && startw < len) { ++startw; }
unsigned endw = startw;
while (not_ws(line[endw]) && endw < len) { ++endw; }
if(endc <= startc || startw <= endc || endw <= startw)
DYNET_INVALID_ARG("Invalid format in cluster file " << cluster_file << " in ClassFactoredSoftmax");
unsigned c = cdict.convert(line.substr(startc, endc - startc));
unsigned word = word_dict.convert(line.substr(startw, endw - startw));
if (word >= widx2cidx.size()) {
widx2cidx.resize(word + 1, -1);
widx2cwidx.resize(word + 1);
}
widx2cidx[word] = c;
if (c >= cidx2words.size()) cidx2words.resize(c + 1);
auto& clusterwords = cidx2words[c];
widx2cwidx[word] = clusterwords.size();
clusterwords.push_back(word);
}
singleton_cluster.resize(cidx2words.size());
int scs = 0;
for (unsigned i = 0; i < cidx2words.size(); ++i) {
bool sc = cidx2words[i].size() <= 1;
if (sc) scs++;
singleton_cluster[i] = sc;
}
cerr << "Read " << wc << " words in " << cdict.size() << " clusters (" << scs << " singleton clusters)\n";
}
void ClassFactoredSoftmaxBuilder::initialize_expressions() {
for (unsigned c = 0; c < p_rc2ws.size(); ++c) {
//get_rc2w(_bias) creates the expression at c if the expression does not already exist.
get_rc2w(c);
get_rc2wbias(c);
}
}
} // namespace dynet