[CUDA] Quantized GEMV#3180
Conversation
|
Just so I understand the comparison - in table 1, the cublas is doing FP32xFP32 and in table 2 cublas is doing FP16xFP16 ? |
Yeah cublas was measured with activation dtype. |
|
There was a mistake that the native dtype (
The sub-byte quants are still slow though. |
jagrit06
left a comment
There was a problem hiding this comment.
Alright, I think we should merge this to work for Affine quants and work up from there
The results are in the table give reasonable memory bandwidth numbers for the device, so we can merge and then improve
I will caution against comparing TFlops directly though, not all ops are created equal, especially when our operation has a healthy amount of int / bit manipulation compared to the CuBLAS comparison point - I'd recommend you continue to focus on that, especially on the sub-byte types
angeloskath
left a comment
There was a problem hiding this comment.
Looks great! Just some minor comments.
|
|
||
| cutlass::NumericArrayConverter<float, Q, N> converter; | ||
| cutlass::Array<float, N> w_dq = converter(w_vec); | ||
| #pragma unroll |
There was a problem hiding this comment.
Probably like above w_dq = w_dq * scale + bias.
There was a problem hiding this comment.
For float it is slower, I think it is because there is no vectorized instructions and there would be new registers used.
Refs #2536.
Implements a qmv kernel using CUTLASS to do vectorized dequantization and fma, which works for all types of quants.
This kernel is fast for small problems for FP32xINT8, measured on A100:
The memory bandwidth is somehow lower for FP16xINT8:
Independent C++ source code for profiling the kernel
The memory bandwidth drops to half for FP8/FP4/INT4 quants unfortunately, which is likely because CUTLASS does not implement fast vectorized conversions for them. We can fix it by writing specializations of
dequant_fmaand I'll continue in followup PRs.This PR also does some refactoring to dispatch
quantized_mamtulto the fastest kernel depending on the problems size. For now we still preferfp_qmvoverqmvfor FP8/FP4 quants but eventually I will mergefp_qmvintoqmv.