Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/base/rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class RmsNorm : public Operator<RmsNorm> {
RmsNorm(const Tensor input, const Tensor weight, Tensor out)
: RmsNorm{input, weight, 1e-6f, out} {}

// TODO: Type of `eps` should be `std::optional<float>` instead of `float`.
virtual void operator()(const Tensor input, const Tensor weight, float eps,
Tensor out) const = 0;

Expand Down
39 changes: 20 additions & 19 deletions src/cuda/rms_norm/kernel.cuh
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
#ifndef INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_
#define INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_

#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include <cstddef>
#include <cstdint>
#include <cub/block/block_reduce.cuh>

#include "common/cuda/cast.h"
#include "common/cuda/kernel_commons.h"

namespace infini::ops {

namespace {

template <unsigned int block_size, typename Data, typename Compute>
__device__ __forceinline__ Compute SumSquared(const Data* data_ptr,
size_t count) {
Compute ss = 0;
template <unsigned int block_size, typename TData, typename TCompute>
__device__ __forceinline__ TCompute SumSquared(const TData* data_ptr,
size_t count) {
TCompute ss = 0;
for (size_t i = threadIdx.x; i < count; i += block_size) {
Compute val = Compute(data_ptr[i]);
ss += val * val;
TCompute value = Cast<TCompute>(data_ptr[i]);
ss += value * value;
}
using BlockReduce = cub::BlockReduce<Compute, block_size>;
using BlockReduce = cub::BlockReduce<TCompute, block_size>;
__shared__ typename BlockReduce::TempStorage temp_storage;
return BlockReduce(temp_storage).Sum(ss);
}

} // namespace

template <unsigned int block_size, typename Compute, typename Data,
typename Weight>
__global__ void RmsNormKernel(Data* __restrict__ y, int64_t stride_y_batch,
template <unsigned int block_size, typename TCompute, typename TData,
typename TWeight>
__global__ void RmsNormKernel(TData* __restrict__ y, int64_t stride_y_batch,
int64_t stride_y_nhead,
const Data* __restrict__ x,
const TData* __restrict__ x,
int64_t stride_x_batch, int64_t stride_x_nhead,
const Weight* __restrict__ w, size_t nhead,
const TWeight* __restrict__ w, size_t nhead,
size_t dim, float epsilon) {
size_t batch_idx = blockIdx.x / nhead;
size_t head_idx = blockIdx.x % nhead;
Expand All @@ -42,16 +42,17 @@ __global__ void RmsNormKernel(Data* __restrict__ y, int64_t stride_y_batch,
auto x_ptr = x + batch_idx * stride_x_batch + head_idx * stride_x_nhead;
auto w_ptr = w;

Compute ss = SumSquared<block_size, Data, Compute>(x_ptr, dim);
TCompute ss = SumSquared<block_size, TData, TCompute>(x_ptr, dim);

__shared__ Compute rms;
__shared__ TCompute rms;
if (threadIdx.x == 0) {
rms = Compute(rsqrtf(ss / Compute(dim) + epsilon));
rms = Cast<TCompute>(rsqrtf(ss / Cast<TCompute>(dim) + epsilon));
}
__syncthreads();

for (size_t i = threadIdx.x; i < dim; i += block_size) {
y_ptr[i] = Data(Compute(x_ptr[i]) * Compute(w_ptr[i]) * rms);
y_ptr[i] =
Cast<TData>(Cast<TCompute>(x_ptr[i]) * Cast<TCompute>(w_ptr[i]) * rms);
}
}

Expand Down
18 changes: 7 additions & 11 deletions src/cuda/rms_norm/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@

#include <cstdint>

// clang-format off
#include <cuda_runtime.h> // TODO: Remove this
// clang-format on

#include "base/rms_norm.h"
#include "common/cuda/kernel_commons.h"
#include "cuda/rms_norm/kernel.cuh"
Expand Down Expand Up @@ -45,13 +41,13 @@ class CudaRmsNorm : public RmsNorm {
[&](auto tag) {
using T = typename decltype(tag)::type;

#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \
RmsNormKernel<BLOCK_SIZE, float, T, T> \
<<<num_blocks, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<T*>(out.data()), stride_out_batch, \
stride_out_nhead, reinterpret_cast<const T*>(input.data()), \
stride_input_batch, stride_input_nhead, \
reinterpret_cast<const T*>(weight.data()), nhead_, dim_, eps_);
#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \
RmsNormKernel<BLOCK_SIZE, float, T, T> \
<<<num_blocks, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<T*>(out.data()), stride_out_batch, \
stride_out_nhead, reinterpret_cast<const T*>(input.data()), \
stride_input_batch, stride_input_nhead, \
reinterpret_cast<const T*>(weight.data()), nhead_, dim_, eps);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥把 eps_ 改成了 eps,但是其他的没变?这处更改是必要的嘛?


if (block_size == CUDA_BLOCK_SIZE_2048) {
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_2048)
Expand Down
31 changes: 31 additions & 0 deletions src/metax/rms_norm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef INFINI_OPS_METAX_RMS_NORM_KERNEL_H_
#define INFINI_OPS_METAX_RMS_NORM_KERNEL_H_

#include <utility>

// clang-format off
#include <mcr/mc_runtime.h>
// clang-format on

#include "cuda/rms_norm/kernel.h"

namespace infini::ops {

namespace rms_norm {

struct MetaxBackend {
using stream_t = mcStream_t;
};

} // namespace rms_norm

template <>
class Operator<RmsNorm, Device::Type::kMetax>
: public CudaRmsNorm<rms_norm::MetaxBackend> {
public:
using CudaRmsNorm<rms_norm::MetaxBackend>::CudaRmsNorm;
};

} // namespace infini::ops

#endif