Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions infini_train/include/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@

#include "infini_train/include/datatype.h"

/**
* General Utility Macros
*/
#define EXPAND(X) X
// This macro lets you pass an arbitrary expression that may contain internal
// commas to another macro without having the commas causing the expression
// to be interpreted as being multiple arguments
// Basically an alternative for __VA_OPTS__ before C++20
// ref: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Dispatch_v2.h
#define WRAP(...) __VA_ARGS__
#define CAT(a, b) CAT_(a, b)
#define CAT_(a, b) a##b

#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
#define LOG_LOC(LEVEL, MSG) LOG(LEVEL) << MSG << " at " << __FILE__ << ":" << __LINE__
#define LOG_UNSUPPORTED_DTYPE(DTYPE, CONTEXT_IDENTIFIER) \
LOG_LOC(FATAL, WRAP(CONTEXT_IDENTIFIER << ": Unsupported data type: " \
+ kDataTypeToDesc.at(static_cast<infini_train::DataType>(dtype))))

inline std::vector<int64_t> ComputeStrides(const std::vector<int64_t> &dims) {
std::vector<int64_t> strides(dims.size(), 1);
Expand Down
35 changes: 28 additions & 7 deletions infini_train/include/common/cpu/common_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,41 @@
#include <type_traits>
#include <utility>

#include "infini_train/include/datatype.h"

namespace infini_train::common::cpu {

namespace detail {

// FP16/BF16 don't support implicit conversion, so we route through float.
template <typename DST, typename SRC> DST CastImpl(SRC &&x) {
using SrcBase = std::remove_cvref_t<SRC>;
if constexpr (std::is_same_v<DST, SrcBase>) {
return x;
} else if constexpr (std::is_same_v<DST, FP16> || std::is_same_v<DST, BF16>) {
// Destination is a framework 16-bit type: convert via float
return DST(static_cast<float>(std::forward<SRC>(x)));
} else if constexpr (std::is_same_v<SrcBase, FP16> || std::is_same_v<SrcBase, BF16>) {
// Source is a framework 16-bit type: widen to float first
return static_cast<DST>(static_cast<float>(x));
} else {
return static_cast<DST>(std::forward<SRC>(x));
}
}

} // namespace detail

/**
* Converts a value between arbitrary types. This offers perfect
* forwarding which preserves value categories (lvalues/rvalues)
* Converts a value between arbitrary types, including framework FP16/BF16.
*
* @tparam DST Destination type (deduced)
* @tparam DST Destination type
* @tparam SRC Source type (deduced)
* @param x Input value (preserves const/volatile and value category)
* @param x Input value
* @return Value converted to DST type
*/
template <typename DST, typename SRC> DST Cast(SRC &&x) {
static_assert(!std::is_reference_v<DST>, "Cast cannot return reference types");

// TODO(lzm): add cpu-version fp16 and bf16
return (DST)(std::forward<SRC>(x));
return detail::CastImpl<DST>(std::forward<SRC>(x));
}

} // namespace infini_train::common::cpu
1 change: 1 addition & 0 deletions infini_train/include/core/device_guard.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class DeviceGuardImpl {
// Device management
// ----------------------------------------------------------------------

// FIXME(dcj): impl should only bind with device type
virtual Device GetDevice() const = 0;

virtual void SetDevice(Device device) const;
Expand Down
122 changes: 122 additions & 0 deletions infini_train/include/core/dtype_bridge.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#pragma once

#include <cstdint>
#include <cstring>
#include <type_traits>

#include "infini_train/include/datatype.h"
#include "infini_train/include/device.h"

namespace infini_train::core {

/**
* Dtype bridge
*
* Purpose:
* - Define the backend-agnostic mapping protocol from framework scalar types
* (e.g. infini_train::FP16/BF16) to backend-native scalar types
* (e.g. __half / __nv_bfloat16 / vendor fp16/bf16 types).
*
* Design notes:
* - This header MUST remain backend-agnostic.
* - Framework public code should only depend on infini_train::FP16/BF16.
* - Backend code provides specializations of NativeScalar<Dev, Scalar>.
* - ScalarConvert provides optional value-level conversion helpers.
*/

// -----------------------------------------------------------------------------
// NativeScalar: framework scalar -> backend native scalar mapping
// -----------------------------------------------------------------------------
// Primary template intentionally undefined.
// Each backend specializes the scalar types it supports.
template <Device::DeviceType Dev, typename Scalar> struct NativeScalar;

template <Device::DeviceType Dev, typename Scalar> using NativeScalar_t = typename NativeScalar<Dev, Scalar>::type;

// Optional convenience alias for CUDA call sites.
// Keep only one copy here; backend files should NOT redefine it.
template <typename Scalar> using NativeScalarCUDA_t = NativeScalar_t<Device::DeviceType::kCUDA, Scalar>;

// -----------------------------------------------------------------------------
// Bitcast utilities
// -----------------------------------------------------------------------------
template <typename To, typename From> inline To Bitcast(const From &from) noexcept {
static_assert(sizeof(To) == sizeof(From), "Bitcast requires same size");
static_assert(std::is_trivially_copyable_v<To>, "Bitcast To must be trivially copyable");
static_assert(std::is_trivially_copyable_v<From>, "Bitcast From must be trivially copyable");

To to{};
std::memcpy(&to, &from, sizeof(To));
return to;
}

// -----------------------------------------------------------------------------
// HasNativeScalar: detect whether a NativeScalar specialization exists
// -----------------------------------------------------------------------------
template <Device::DeviceType Dev, typename Scalar, typename = void> struct HasNativeScalar : std::false_type {};

template <Device::DeviceType Dev, typename Scalar>
struct HasNativeScalar<Dev, Scalar, std::void_t<typename NativeScalar<Dev, Scalar>::type>> : std::true_type {};

template <Device::DeviceType Dev, typename Scalar>
inline constexpr bool HasNativeScalar_v = HasNativeScalar<Dev, Scalar>::value;

// -----------------------------------------------------------------------------
// ScalarConvert: framework scalar <-> backend native scalar conversion glue
// -----------------------------------------------------------------------------
// Primary template intentionally undefined by default.
// Backends may specialize this if simple bitcast is insufficient.
template <Device::DeviceType Dev, typename Scalar, typename Enable = void> struct ScalarConvert;

// Default FP16 conversion: preserve raw 16-bit bit pattern.
template <Device::DeviceType Dev> struct ScalarConvert<Dev, infini_train::FP16, void> {
static_assert(HasNativeScalar_v<Dev, infini_train::FP16>,
"Missing NativeScalar specialization for FP16 on this backend");

using Native = NativeScalar_t<Dev, infini_train::FP16>;

static inline Native ToNative(infini_train::FP16 v) noexcept {
static_assert(sizeof(Native) == sizeof(uint16_t), "Native FP16 must be 16-bit");
return Bitcast<Native>(v.x);
}

static inline infini_train::FP16 FromNative(Native v) noexcept {
infini_train::FP16 out{};
static_assert(sizeof(Native) == sizeof(uint16_t), "Native FP16 must be 16-bit");
out.x = Bitcast<uint16_t>(v);
return out;
}
};

// Default BF16 conversion: preserve raw 16-bit bit pattern.
template <Device::DeviceType Dev> struct ScalarConvert<Dev, infini_train::BF16, void> {
static_assert(HasNativeScalar_v<Dev, infini_train::BF16>,
"Missing NativeScalar specialization for BF16 on this backend");

using Native = NativeScalar_t<Dev, infini_train::BF16>;

static inline Native ToNative(infini_train::BF16 v) noexcept {
static_assert(sizeof(Native) == sizeof(uint16_t), "Native BF16 must be 16-bit");
return Bitcast<Native>(v.x);
}

static inline infini_train::BF16 FromNative(Native v) noexcept {
infini_train::BF16 out{};
static_assert(sizeof(Native) == sizeof(uint16_t), "Native BF16 must be 16-bit");
out.x = Bitcast<uint16_t>(v);
return out;
}
};

// -----------------------------------------------------------------------------
// Convenience wrappers
// -----------------------------------------------------------------------------
template <Device::DeviceType Dev, typename Scalar> inline NativeScalar_t<Dev, Scalar> ToNative(Scalar v) noexcept {
return ScalarConvert<Dev, Scalar>::ToNative(v);
}

template <Device::DeviceType Dev, typename Scalar> inline Scalar FromNative(NativeScalar_t<Dev, Scalar> v) noexcept {
return ScalarConvert<Dev, Scalar>::FromNative(v);
}

} // namespace infini_train::core
Loading
Loading