Skip to content

[CUDA] Quantized GEMV#3180

Merged
zcbenz merged 1 commit intoml-explore:mainfrom
zcbenz:qmv
Mar 3, 2026
Merged

[CUDA] Quantized GEMV#3180
zcbenz merged 1 commit intoml-explore:mainfrom
zcbenz:qmv

Conversation

@zcbenz
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz commented Feb 27, 2026

Refs #2536.

Implements a qmv kernel using CUTLASS to do vectorized dequantization and fma, which works for all types of quants.

This kernel is fast for small problems for FP32xINT8, measured on A100:

M N K QMV (TFlop/s) CUBLAS (TFlop/s) QMV (GiB/s) CUBLAS (GiB/s) Speedup (x)
1 4096 4096 2.40 0.71 1351.7 1414.2 3.39
1 8192 8192 2.26 0.85 1272.7 1694.4 2.67
1 16384 16384 2.47 0.87 1389.7 1749.5 2.82

The memory bandwidth is somehow lower for FP16xINT8:

M N K QMV (TFlop/s) CUBLAS (TFlop/s) QMV (GiB/s) CUBLAS (GiB/s) Speedup (x)
1 4096 4096 2.10 1.47 1118.1 1466.7 1.43
1 8192 8192 1.95 1.49 1039.1 1494.6 1.31

Independent C++ source code for profiling the kernel

The memory bandwidth drops to half for FP8/FP4/INT4 quants unfortunately, which is likely because CUTLASS does not implement fast vectorized conversions for them. We can fix it by writing specializations of dequant_fma and I'll continue in followup PRs.

This PR also does some refactoring to dispatch quantized_mamtul to the fastest kernel depending on the problems size. For now we still prefer fp_qmv over qmv for FP8/FP4 quants but eventually I will merge fp_qmv into qmv.

Comment thread mlx/backend/cuda/quantized/qmm/qmm.cpp Outdated
@jagrit06
Copy link
Copy Markdown
Member

Just so I understand the comparison - in table 1, the cublas is doing FP32xFP32 and in table 2 cublas is doing FP16xFP16 ?

@zcbenz
Copy link
Copy Markdown
Collaborator Author

zcbenz commented Feb 27, 2026

Just so I understand the comparison - in table 1, the cublas is doing FP32xFP32 and in table 2 cublas is doing FP16xFP16 ?

Yeah cublas was measured with activation dtype.

@zcbenz
Copy link
Copy Markdown
Collaborator Author

zcbenz commented Mar 2, 2026

There was a mistake that the native dtype (__half) instead of cutlass dtype (cutlass::half_t) was passed to cutlass, fixing it makes FP16xINT8 a lot faster:

M N K QMV (TFlop/s) CUBLAS (TFlop/s) QMV (GiB/s) CUBLAS (GiB/s) Speedup (x)
1 4096 4096 2.64 1.42 1406.0 1420.9 1.86
1 8192 8192 2.37 1.48 1261.5 1477.7 1.61
1 16384 16384 2.55 1.69 1352.4 1691.0 1.51

The sub-byte quants are still slow though.

Copy link
Copy Markdown
Member

@jagrit06 jagrit06 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I think we should merge this to work for Affine quants and work up from there

The results are in the table give reasonable memory bandwidth numbers for the device, so we can merge and then improve

I will caution against comparing TFlops directly though, not all ops are created equal, especially when our operation has a healthy amount of int / bit manipulation compared to the CuBLAS comparison point - I'd recommend you continue to focus on that, especially on the sub-byte types

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just some minor comments.

Comment thread mlx/backend/cuda/quantized/qmm/fp_qmv.cu Outdated

cutlass::NumericArrayConverter<float, Q, N> converter;
cutlass::Array<float, N> w_dq = converter(w_vec);
#pragma unroll
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably like above w_dq = w_dq * scale + bias.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For float it is slower, I think it is because there is no vectorized instructions and there would be new registers used.

@zcbenz zcbenz merged commit 3c56543 into ml-explore:main Mar 3, 2026
16 checks passed
@zcbenz zcbenz deleted the qmv branch March 3, 2026 23:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants