From f901b5cc8c2f9e7638876c6f3ba868011f25cca9 Mon Sep 17 00:00:00 2001 From: gongchensu Date: Tue, 17 Mar 2026 02:42:44 +0000 Subject: [PATCH] feat(ops): add MetaX backend for `RmsNorm` - add MetaX `RmsNorm` operator specialization - make the shared CUDA-style rms_norm kernel compatible with MetaX - forward runtime `eps` when launching the kernel --- src/base/rms_norm.h | 1 + src/cuda/rms_norm/kernel.cuh | 39 ++++++++++++++++++------------------ src/cuda/rms_norm/kernel.h | 18 +++++++---------- src/metax/rms_norm/kernel.h | 31 ++++++++++++++++++++++++++++ 4 files changed, 59 insertions(+), 30 deletions(-) create mode 100644 src/metax/rms_norm/kernel.h diff --git a/src/base/rms_norm.h b/src/base/rms_norm.h index 3b40a1c..65f44b3 100644 --- a/src/base/rms_norm.h +++ b/src/base/rms_norm.h @@ -25,6 +25,7 @@ class RmsNorm : public Operator { RmsNorm(const Tensor input, const Tensor weight, Tensor out) : RmsNorm{input, weight, 1e-6f, out} {} + // TODO: Type of `eps` should be `std::optional` instead of `float`. virtual void operator()(const Tensor input, const Tensor weight, float eps, Tensor out) const = 0; diff --git a/src/cuda/rms_norm/kernel.cuh b/src/cuda/rms_norm/kernel.cuh index 98383f3..10228a6 100644 --- a/src/cuda/rms_norm/kernel.cuh +++ b/src/cuda/rms_norm/kernel.cuh @@ -1,39 +1,39 @@ #ifndef INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_ #define INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_ -#include -#include - #include #include #include +#include "common/cuda/cast.h" +#include "common/cuda/kernel_commons.h" + namespace infini::ops { namespace { -template -__device__ __forceinline__ Compute SumSquared(const Data* data_ptr, - size_t count) { - Compute ss = 0; +template +__device__ __forceinline__ TCompute SumSquared(const TData* data_ptr, + size_t count) { + TCompute ss = 0; for (size_t i = threadIdx.x; i < count; i += block_size) { - Compute val = Compute(data_ptr[i]); - ss += val * val; + TCompute value = Cast(data_ptr[i]); + ss += value * value; } - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; return BlockReduce(temp_storage).Sum(ss); } } // namespace -template -__global__ void RmsNormKernel(Data* __restrict__ y, int64_t stride_y_batch, +template +__global__ void RmsNormKernel(TData* __restrict__ y, int64_t stride_y_batch, int64_t stride_y_nhead, - const Data* __restrict__ x, + const TData* __restrict__ x, int64_t stride_x_batch, int64_t stride_x_nhead, - const Weight* __restrict__ w, size_t nhead, + const TWeight* __restrict__ w, size_t nhead, size_t dim, float epsilon) { size_t batch_idx = blockIdx.x / nhead; size_t head_idx = blockIdx.x % nhead; @@ -42,16 +42,17 @@ __global__ void RmsNormKernel(Data* __restrict__ y, int64_t stride_y_batch, auto x_ptr = x + batch_idx * stride_x_batch + head_idx * stride_x_nhead; auto w_ptr = w; - Compute ss = SumSquared(x_ptr, dim); + TCompute ss = SumSquared(x_ptr, dim); - __shared__ Compute rms; + __shared__ TCompute rms; if (threadIdx.x == 0) { - rms = Compute(rsqrtf(ss / Compute(dim) + epsilon)); + rms = Cast(rsqrtf(ss / Cast(dim) + epsilon)); } __syncthreads(); for (size_t i = threadIdx.x; i < dim; i += block_size) { - y_ptr[i] = Data(Compute(x_ptr[i]) * Compute(w_ptr[i]) * rms); + y_ptr[i] = + Cast(Cast(x_ptr[i]) * Cast(w_ptr[i]) * rms); } } diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index f450e91..dc28ee5 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -3,10 +3,6 @@ #include -// clang-format off -#include // TODO: Remove this -// clang-format on - #include "base/rms_norm.h" #include "common/cuda/kernel_commons.h" #include "cuda/rms_norm/kernel.cuh" @@ -45,13 +41,13 @@ class CudaRmsNorm : public RmsNorm { [&](auto tag) { using T = typename decltype(tag)::type; -#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \ - RmsNormKernel \ - <<>>( \ - reinterpret_cast(out.data()), stride_out_batch, \ - stride_out_nhead, reinterpret_cast(input.data()), \ - stride_input_batch, stride_input_nhead, \ - reinterpret_cast(weight.data()), nhead_, dim_, eps_); +#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \ + RmsNormKernel \ + <<>>( \ + reinterpret_cast(out.data()), stride_out_batch, \ + stride_out_nhead, reinterpret_cast(input.data()), \ + stride_input_batch, stride_input_nhead, \ + reinterpret_cast(weight.data()), nhead_, dim_, eps); if (block_size == CUDA_BLOCK_SIZE_2048) { LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_2048) diff --git a/src/metax/rms_norm/kernel.h b/src/metax/rms_norm/kernel.h new file mode 100644 index 0000000..b724552 --- /dev/null +++ b/src/metax/rms_norm/kernel.h @@ -0,0 +1,31 @@ +#ifndef INFINI_OPS_METAX_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_METAX_RMS_NORM_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/rms_norm/kernel.h" + +namespace infini::ops { + +namespace rms_norm { + +struct MetaxBackend { + using stream_t = mcStream_t; +}; + +} // namespace rms_norm + +template <> +class Operator + : public CudaRmsNorm { + public: + using CudaRmsNorm::CudaRmsNorm; +}; + +} // namespace infini::ops + +#endif