Skip to content

Feat/dev datatype#22

Open
bitzyz wants to merge 2 commits intofeat/dev-infrafrom
feat/dev-datatype
Open

Feat/dev datatype#22
bitzyz wants to merge 2 commits intofeat/dev-infrafrom
feat/dev-datatype

Conversation

@bitzyz
Copy link

@bitzyz bitzyz commented Mar 16, 2026

  1. ✅Moved device-specific cast.h/common.h files from common/ to respective device folders (cuda/, cpu/, cambricon/)
  2. ✅Introduced ValueTag and TypeTag templates for compile-time type-safe runtime dispatch
  3. ✅Enhanced DataType struct with device-aware type mappings
  4. ✅Updated all kernel files to use the new dispatch mechanism

@bitzyz bitzyz self-assigned this Mar 16, 2026
@bitzyz bitzyz force-pushed the feat/dev-datatype branch 2 times, most recently from d479f8c to 7cba646 Compare March 16, 2026 02:28
@bitzyz
Copy link
Author

bitzyz commented Mar 16, 2026

image

@bitzyz
Copy link
Author

bitzyz commented Mar 16, 2026

image

@bitzyz
Copy link
Author

bitzyz commented Mar 16, 2026

寒武纪平台编译正常
image

@bitzyz bitzyz requested a review from voltjia March 16, 2026 02:45
@bitzyz
Copy link
Author

bitzyz commented Mar 16, 2026

沐曦平台:
image

@bitzyz bitzyz force-pushed the feat/dev-datatype branch from 7cba646 to db13a72 Compare March 16, 2026 03:15
@bitzyz bitzyz force-pushed the feat/dev-datatype branch from db13a72 to 1c72cfe Compare March 16, 2026 03:26
@bitzyz
Copy link
Author

bitzyz commented Mar 16, 2026

天数平台:
img_v3_02vr_4c1517ed-6d42-45f8-a08d-0f380d8d416g


#include "base/gemm.h"
#include "cambricon/common.h"
#include "../common.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image


#include <utility>

#include "../cast.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image


#include <cmath>

#include "../cast.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image


#include <utility>

#include "../cast.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image


#include <cmath>

#include "../cast.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image

: false));

if (!handled) {
// TODO(lzm): change to logging.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些 TODO 咋给去掉了。

};

template <DataType dtype>
// Forward declaration.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几处新加的注释可以去掉。如果决定不去掉需要用 Markdown。

using TypeMapTypeDevice = typename TypeMap<dtype, device>::type;

template <typename T>
template <typename T, Device::Type D>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image

按照 style guide,这个地方的 D 应该改成 d,因为它应该遵循变量命名,而不是类型命名。下同。


def _torch_rms_norm(input, weight, *, eps=1e-6, out=None):
return torch.nn.functional.rms_norm(input, input.shape[-1:], weight=weight, eps=eps)

No newline at end of file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方是不是多加了空格,有过 ruff format && ruff check 嘛?

| `-DWITH_NVIDIA=[ON\|OFF]` | Compile the NVIDIA implementation | n
| `-DWITH_METAX=[ON\|OFF]` | Compile the MetaX implementation | n
| `-DGENERATE_PYTHON_BINDINGS=[ON\|OFF]` | Generate Python bindings | n
| Option | Functionality | Default
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最右列要不再右移一格,这样离最近的多一个空格。

void operator()(const Tensor input, const Tensor other,
Tensor out) const override {
DispatchFunc<AllTypes>(
DispatchFunc<Backend::device_value, AllTypes>(
Copy link

@Ziminli Ziminli Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

啊,这样能编过的么?所以这里的意图是准备只在这个平台分发这些类型吗?按理来说之前其他平台倒是都不需要这样写,毕竟因为只会编译自己所在的平台目录下的代码嘛。目前如果想做多类型混合分发的话,应该是类似这样:

DispatchFunc<FloatTypes, List<Device::Type::kCpu, Device::Type::kNvidia>>(
      {static_cast<int64_t>(DataType::kFloat32),
       static_cast<int64_t>(Device::Type::kNvidia)},
      0,
      [](auto list_tag) {
        constexpr DataType DT = static_cast<DataType>(ListGet<0>(list_tag));
        constexpr Device::Type Dev =
            static_cast<Device::Type>(ListGet<1>(list_tag));
        using T = TypeMapType<DT>;
        // Remaining logic ...
      },
      "MixedDispatch", List<>{});

之后我会加一个方便一点的高层封装,就可以这样:

DispatchFunc<FloatTypes, List<Device::Type::kCpu, Device::Type::kNvidia>>(
      {static_cast<int64_t>(DataType::kFloat32),
       static_cast<int64_t>(Device::Type::kNvidia)},
      [](auto list_tag) {
        constexpr DataType DT = static_cast<DataType>(ListGet<0>(list_tag));
        constexpr Device::Type Dev =
            static_cast<Device::Type>(ListGet<1>(list_tag));
        using T = TypeMapType<DT>;
        // Remaining logic ...
      },
      "MixedDispatch");

本 PR 其他同类地方同理。

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原来如此,那这个文件可以还原,原本的 dispatcher 就是支持任意类型和多类型混合分发的。

Comment on lines +213 to +228
template <>
struct TypeMap<DataType::kFloat16, Device::Type::kNvidia> {
using type = half;
};
template <>
struct DataTypeMap<half, Device::Type::kNvidia> {
static constexpr DataType value = DataType::kFloat16;
};
template <>
struct TypeMap<DataType::kBFloat16, Device::Type::kNvidia> {
using type = __nv_bfloat16;
};
template <>
struct DataTypeMap<__nv_bfloat16, Device::Type::kNvidia> {
static constexpr DataType value = DataType::kBFloat16;
};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些应该也是为了配合 dispatcher 的改动?应该可以还原了。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants