-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathexec.h
More file actions
90 lines (81 loc) · 3.82 KB
/
exec.h
File metadata and controls
90 lines (81 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#ifndef DYNET_EXEC_H
#define DYNET_EXEC_H
#include "dynet/dynet.h"
namespace dynet {
class ExecutionEngine {
public:
virtual ~ExecutionEngine();
virtual void invalidate() = 0;
virtual void invalidate(unsigned) = 0;
virtual const Tensor& forward() = 0;
virtual const Tensor& forward(VariableIndex i) = 0;
virtual std::vector<const Tensor*> forward(std::vector<VariableIndex> is); // forward on multiple nodes
virtual const Tensor& incremental_forward() = 0; // if you want to add nodes and evaluate just the new parts
virtual const Tensor& incremental_forward(VariableIndex i) = 0;
virtual const Tensor& get_value(VariableIndex i) = 0;
virtual const Tensor& get_gradient(VariableIndex i) = 0;
virtual void backward(bool full = false) = 0;
virtual void backward(VariableIndex i, bool full = false) = 0;
protected:
explicit ExecutionEngine(const ComputationGraph& cg) : cg(cg), backward_computed(0) {}
const ComputationGraph& cg;
VariableIndex backward_computed;
};
class SimpleExecutionEngine : public ExecutionEngine {
public:
explicit SimpleExecutionEngine(const ComputationGraph& cg) : ExecutionEngine(cg), num_nodes_evaluated(0) {}
void invalidate() override;
void invalidate(unsigned i) override;
const Tensor& forward() override;
const Tensor& forward(VariableIndex i) override;
const Tensor& incremental_forward() override; // if you want to add nodes and evaluate just the new parts
const Tensor& incremental_forward(VariableIndex i) override;
const Tensor& get_value(VariableIndex i) override;
const Tensor& get_gradient(VariableIndex i) override;
void backward(bool full = false) override;
void backward(VariableIndex i, bool full = false) override;
private:
std::vector<Tensor> nfxs;
std::vector<Tensor> ndEdfs;
VariableIndex num_nodes_evaluated;
};
struct BatchInfo {
public:
BatchInfo() : pseudo_node(nullptr) { }
Tensor nfx; // The forward tensor, may be null if singleton batch
Node* pseudo_node; // The pseudo node used for calculation, also may be null if not needed
std::vector<VariableIndex> ids; // IDs of the batch components
std::vector<int> concat; // 0=no need to concat, 1=need to concat, 2=need to concat + already contiguous in space
std::vector<const Tensor*> arg_nfxs; // Concatenated arguments
};
class BatchedExecutionEngine : public ExecutionEngine {
public:
explicit BatchedExecutionEngine(const ComputationGraph& cg) : ExecutionEngine(cg), num_nodes_evaluated(0), num_batches_evaluated(0) { }
~BatchedExecutionEngine() { garbage_collect(); }
void invalidate() override;
void invalidate(unsigned i) override;
const Tensor& forward() override;
const Tensor& forward(VariableIndex i) override;
const Tensor& incremental_forward() override; // if you want to add nodes and evaluate just the new parts
const Tensor& incremental_forward(VariableIndex i) override;
const Tensor& get_value(VariableIndex i) override;
const Tensor& get_gradient(VariableIndex i) override;
void backward(bool full = false) override;
void backward(VariableIndex i, bool full = false) override;
void garbage_collect();
private:
const Tensor& incremental_forward_no_update(VariableIndex i, int autobatch_strategy);
void combine_tensors(std::vector<VariableIndex> batch_ids, int aid, Tensor &tout);
void accumulate_tensors(const Tensor& my_ndEdf, std::vector<VariableIndex> batch_ids, int aid);
const Tensor& get_nfx(VariableIndex i);
std::vector<Tensor> nfx_cache;
std::vector<Tensor> ndEdfs;
VariableIndex num_nodes_evaluated, num_batches_evaluated;
// Information about the batched computation graph
std::vector<VariableIndex> node2batch; // length: number of nodes
std::vector<size_t> node2offset, node2size; // length: number of nodes
std::vector<BatchInfo> batches; // length: number of batches
SigMap sigmap;
};
} // namespace dynet
#endif