-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathcuda-matrix-multiply.h
More file actions
43 lines (35 loc) · 1.3 KB
/
cuda-matrix-multiply.h
File metadata and controls
43 lines (35 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#ifndef DYNET_CUDA_MATRIX_MULTIPLY_H__
#define DYNET_CUDA_MATRIX_MULTIPLY_H__
#ifdef __CUDACC__
#include "dynet/tensor.h"
#include "dynet/devices.h"
#include "dynet/cuda.h"
namespace dynet {
inline void CUDAMatrixMultiply(const Device_GPU & dev, const Tensor& l, const Tensor& r, Tensor& y, const float* acc_scalar) {
if(l.d.bd == 1 && r.d.bd == y.d.bd) {
// If the left side has one batch, multiply by columns
// [x, z, b] = [x, y] * [y, z, b]
// -> [x, z*b] = [x, y], [y, z*b]
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
y.d.rows(), y.d.cols() * y.d.batch_elems(), l.d.cols(),
kSCALAR_ONE,
l.v, l.d.rows(),
r.v, r.d.rows(),
acc_scalar, y.v, y.d.rows()));
} else {
// Otherwise, loop over the batches
DYNET_ASSERT(r.d.bd != 1 || r.d.bd != l.d.bd,
"Number of batch elements in matrix multiply must match, but got: " << r.d.bd << ", " << l.d.bd);
for(unsigned b = 0; b < y.d.bd; ++b) {
CUBLAS_CHECK(cublasSgemm(dev.cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
y.d.rows(), y.d.cols(), l.d.cols(),
kSCALAR_ONE,
l.batch_ptr(b), l.d.rows(),
r.batch_ptr(b), r.d.rows(),
acc_scalar, y.batch_ptr(b), y.d.rows()));
}
}
}
}
#endif
#endif