From 2b0c9096320f011467d288402181d73efa36d690 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Fri, 6 Mar 2026 09:56:28 +0000 Subject: [PATCH] feat: add common BF16/FP16 abstraction infrastructure --- infini_train/include/common/common.h | 16 +- infini_train/include/common/cpu/common_cpu.h | 35 +- infini_train/include/core/device_guard.h | 1 + infini_train/include/core/dtype_bridge.h | 122 ++++++ infini_train/include/datatype.h | 354 ++++++++++++----- infini_train/include/dispatcher.h | 362 +----------------- infini_train/include/dtype_dispatch.h | 325 ++++++++++++++++ infini_train/include/tensor.h | 4 +- infini_train/include/tensor_fill_impl.h | 28 ++ infini_train/src/core/cpu/cpu_dispatch.h | 24 ++ infini_train/src/core/cuda/cuda_dispatch.h | 42 ++ .../src/core/cuda/cuda_dtype_bridge.h | 20 + infini_train/src/core/cuda/cuda_stream.cc | 4 +- infini_train/src/core/cuda/cuda_stream.h | 2 +- infini_train/src/kernels/cpu/cast.cc | 14 +- infini_train/src/kernels/cpu/elementwise.cc | 4 +- infini_train/src/kernels/cpu/embedding.cc | 2 +- infini_train/src/kernels/cpu/gather.cc | 2 +- infini_train/src/kernels/cpu/layernorm.cc | 10 +- infini_train/src/kernels/cpu/linear.cc | 4 +- infini_train/src/kernels/cpu/outer.cc | 4 +- infini_train/src/kernels/cpu/slice.cc | 2 +- infini_train/src/kernels/cpu/split.cc | 2 +- infini_train/src/kernels/cpu/transform.cc | 2 +- .../src/kernels/cuda/accumulate_grad.cu | 44 ++- infini_train/src/kernels/cuda/cast.cu | 14 +- infini_train/src/kernels/cuda/comm.cu | 2 +- infini_train/src/kernels/cuda/concat.cu | 25 +- .../src/kernels/cuda/cross_entropy.cu | 47 ++- infini_train/src/kernels/cuda/elementwise.cu | 18 +- infini_train/src/kernels/cuda/embedding.cu | 7 +- infini_train/src/kernels/cuda/fill.cu | 3 +- infini_train/src/kernels/cuda/gather.cu | 8 +- infini_train/src/kernels/cuda/layernorm.cu | 15 +- infini_train/src/kernels/cuda/linear.cu | 33 +- infini_train/src/kernels/cuda/outer.cu | 7 +- infini_train/src/kernels/cuda/reduction.cu | 7 +- infini_train/src/kernels/cuda/slice.cu | 11 +- infini_train/src/kernels/cuda/softmax.cu | 4 +- infini_train/src/kernels/cuda/split.cu | 7 +- infini_train/src/kernels/cuda/stack.cu | 8 +- infini_train/src/kernels/cuda/transform.cu | 35 +- .../cuda/vocab_parallel_cross_entropy.cu | 4 +- infini_train/src/nn/init.cc | 11 +- infini_train/src/nn/parallel/process_group.cc | 2 +- infini_train/src/optimizer.cc | 9 +- infini_train/src/tensor.cc | 37 +- test/hook/test_precision_check.cc | 10 +- 48 files changed, 1071 insertions(+), 682 deletions(-) create mode 100644 infini_train/include/core/dtype_bridge.h create mode 100644 infini_train/include/dtype_dispatch.h create mode 100644 infini_train/include/tensor_fill_impl.h create mode 100644 infini_train/src/core/cpu/cpu_dispatch.h create mode 100644 infini_train/src/core/cuda/cuda_dispatch.h create mode 100644 infini_train/src/core/cuda/cuda_dtype_bridge.h diff --git a/infini_train/include/common/common.h b/infini_train/include/common/common.h index b6a02543..80cba728 100644 --- a/infini_train/include/common/common.h +++ b/infini_train/include/common/common.h @@ -7,11 +7,21 @@ #include "infini_train/include/datatype.h" +/** + * General Utility Macros + */ +#define EXPAND(X) X +// This macro lets you pass an arbitrary expression that may contain internal +// commas to another macro without having the commas causing the expression +// to be interpreted as being multiple arguments +// Basically an alternative for __VA_OPTS__ before C++20 +// ref: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Dispatch_v2.h +#define WRAP(...) __VA_ARGS__ +#define CAT(a, b) CAT_(a, b) +#define CAT_(a, b) a##b + #define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) #define LOG_LOC(LEVEL, MSG) LOG(LEVEL) << MSG << " at " << __FILE__ << ":" << __LINE__ -#define LOG_UNSUPPORTED_DTYPE(DTYPE, CONTEXT_IDENTIFIER) \ - LOG_LOC(FATAL, WRAP(CONTEXT_IDENTIFIER << ": Unsupported data type: " \ - + kDataTypeToDesc.at(static_cast(dtype)))) inline std::vector ComputeStrides(const std::vector &dims) { std::vector strides(dims.size(), 1); diff --git a/infini_train/include/common/cpu/common_cpu.h b/infini_train/include/common/cpu/common_cpu.h index d4c73e84..b8a01538 100644 --- a/infini_train/include/common/cpu/common_cpu.h +++ b/infini_train/include/common/cpu/common_cpu.h @@ -3,20 +3,41 @@ #include #include +#include "infini_train/include/datatype.h" + namespace infini_train::common::cpu { + +namespace detail { + +// FP16/BF16 don't support implicit conversion, so we route through float. +template DST CastImpl(SRC &&x) { + using SrcBase = std::remove_cvref_t; + if constexpr (std::is_same_v) { + return x; + } else if constexpr (std::is_same_v || std::is_same_v) { + // Destination is a framework 16-bit type: convert via float + return DST(static_cast(std::forward(x))); + } else if constexpr (std::is_same_v || std::is_same_v) { + // Source is a framework 16-bit type: widen to float first + return static_cast(static_cast(x)); + } else { + return static_cast(std::forward(x)); + } +} + +} // namespace detail + /** - * Converts a value between arbitrary types. This offers perfect - * forwarding which preserves value categories (lvalues/rvalues) + * Converts a value between arbitrary types, including framework FP16/BF16. * - * @tparam DST Destination type (deduced) + * @tparam DST Destination type * @tparam SRC Source type (deduced) - * @param x Input value (preserves const/volatile and value category) + * @param x Input value * @return Value converted to DST type */ template DST Cast(SRC &&x) { static_assert(!std::is_reference_v, "Cast cannot return reference types"); - - // TODO(lzm): add cpu-version fp16 and bf16 - return (DST)(std::forward(x)); + return detail::CastImpl(std::forward(x)); } + } // namespace infini_train::common::cpu diff --git a/infini_train/include/core/device_guard.h b/infini_train/include/core/device_guard.h index 36945ea1..098a2b0b 100644 --- a/infini_train/include/core/device_guard.h +++ b/infini_train/include/core/device_guard.h @@ -67,6 +67,7 @@ class DeviceGuardImpl { // Device management // ---------------------------------------------------------------------- + // FIXME(dcj): impl should only bind with device type virtual Device GetDevice() const = 0; virtual void SetDevice(Device device) const; diff --git a/infini_train/include/core/dtype_bridge.h b/infini_train/include/core/dtype_bridge.h new file mode 100644 index 00000000..d9575496 --- /dev/null +++ b/infini_train/include/core/dtype_bridge.h @@ -0,0 +1,122 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" + +namespace infini_train::core { + +/** + * Dtype bridge + * + * Purpose: + * - Define the backend-agnostic mapping protocol from framework scalar types + * (e.g. infini_train::FP16/BF16) to backend-native scalar types + * (e.g. __half / __nv_bfloat16 / vendor fp16/bf16 types). + * + * Design notes: + * - This header MUST remain backend-agnostic. + * - Framework public code should only depend on infini_train::FP16/BF16. + * - Backend code provides specializations of NativeScalar. + * - ScalarConvert provides optional value-level conversion helpers. + */ + +// ----------------------------------------------------------------------------- +// NativeScalar: framework scalar -> backend native scalar mapping +// ----------------------------------------------------------------------------- +// Primary template intentionally undefined. +// Each backend specializes the scalar types it supports. +template struct NativeScalar; + +template using NativeScalar_t = typename NativeScalar::type; + +// Optional convenience alias for CUDA call sites. +// Keep only one copy here; backend files should NOT redefine it. +template using NativeScalarCUDA_t = NativeScalar_t; + +// ----------------------------------------------------------------------------- +// Bitcast utilities +// ----------------------------------------------------------------------------- +template inline To Bitcast(const From &from) noexcept { + static_assert(sizeof(To) == sizeof(From), "Bitcast requires same size"); + static_assert(std::is_trivially_copyable_v, "Bitcast To must be trivially copyable"); + static_assert(std::is_trivially_copyable_v, "Bitcast From must be trivially copyable"); + + To to{}; + std::memcpy(&to, &from, sizeof(To)); + return to; +} + +// ----------------------------------------------------------------------------- +// HasNativeScalar: detect whether a NativeScalar specialization exists +// ----------------------------------------------------------------------------- +template struct HasNativeScalar : std::false_type {}; + +template +struct HasNativeScalar::type>> : std::true_type {}; + +template +inline constexpr bool HasNativeScalar_v = HasNativeScalar::value; + +// ----------------------------------------------------------------------------- +// ScalarConvert: framework scalar <-> backend native scalar conversion glue +// ----------------------------------------------------------------------------- +// Primary template intentionally undefined by default. +// Backends may specialize this if simple bitcast is insufficient. +template struct ScalarConvert; + +// Default FP16 conversion: preserve raw 16-bit bit pattern. +template struct ScalarConvert { + static_assert(HasNativeScalar_v, + "Missing NativeScalar specialization for FP16 on this backend"); + + using Native = NativeScalar_t; + + static inline Native ToNative(infini_train::FP16 v) noexcept { + static_assert(sizeof(Native) == sizeof(uint16_t), "Native FP16 must be 16-bit"); + return Bitcast(v.x); + } + + static inline infini_train::FP16 FromNative(Native v) noexcept { + infini_train::FP16 out{}; + static_assert(sizeof(Native) == sizeof(uint16_t), "Native FP16 must be 16-bit"); + out.x = Bitcast(v); + return out; + } +}; + +// Default BF16 conversion: preserve raw 16-bit bit pattern. +template struct ScalarConvert { + static_assert(HasNativeScalar_v, + "Missing NativeScalar specialization for BF16 on this backend"); + + using Native = NativeScalar_t; + + static inline Native ToNative(infini_train::BF16 v) noexcept { + static_assert(sizeof(Native) == sizeof(uint16_t), "Native BF16 must be 16-bit"); + return Bitcast(v.x); + } + + static inline infini_train::BF16 FromNative(Native v) noexcept { + infini_train::BF16 out{}; + static_assert(sizeof(Native) == sizeof(uint16_t), "Native BF16 must be 16-bit"); + out.x = Bitcast(v); + return out; + } +}; + +// ----------------------------------------------------------------------------- +// Convenience wrappers +// ----------------------------------------------------------------------------- +template inline NativeScalar_t ToNative(Scalar v) noexcept { + return ScalarConvert::ToNative(v); +} + +template inline Scalar FromNative(NativeScalar_t v) noexcept { + return ScalarConvert::FromNative(v); +} + +} // namespace infini_train::core diff --git a/infini_train/include/datatype.h b/infini_train/include/datatype.h index 79f325db..bfef622e 100644 --- a/infini_train/include/datatype.h +++ b/infini_train/include/datatype.h @@ -1,14 +1,203 @@ #pragma once +#include +#include #include +#include #include +#include #include -#ifdef USE_CUDA -#include -#include -#endif namespace infini_train { + +// ----------------------------------------------------------------------------- +// Framework scalar types (16-bit storage + fallback scalar semantics) +// ----------------------------------------------------------------------------- +// FP16/BF16 are framework-level 16-bit scalar/storage types. +// They are used for: +// - framework type identity +// - dtype mapping +// - metadata / storage layout +// - CPU/reference/fallback conversion paths +// +// They are NOT intended to define backend-native arithmetic semantics. +// Backend kernels should use backend-specific type maps, e.g.: +// - CUDA: __half / __nv_bfloat16 +// - CPU : FP16 / BF16 / widened compute types (as needed) +// ----------------------------------------------------------------------------- + +namespace detail { + +// --------------------------- +// BF16 helpers +// --------------------------- +inline constexpr uint16_t FloatToBf16Bits(float value) { + const uint32_t bits = std::bit_cast(value); + const uint32_t lsb = (bits >> 16) & 1u; + const uint32_t rounding_bias = 0x7fffu + lsb; + return static_cast((bits + rounding_bias) >> 16); +} + +inline constexpr float Bf16BitsToFloat(uint16_t bits) { + const uint32_t u32 = static_cast(bits) << 16; + return std::bit_cast(u32); +} + +// --------------------------- +// FP16 helpers +// Pure software IEEE-754 half <-> float conversion for framework fallback use. +// --------------------------- +inline constexpr uint16_t FloatToFp16Bits(float value) { + const uint32_t bits = std::bit_cast(value); + + const uint32_t sign = (bits >> 16) & 0x8000u; + uint32_t mantissa = bits & 0x007fffffu; + int32_t exp = static_cast((bits >> 23) & 0xffu); + + // NaN / Inf + if (exp == 0xff) { + if (mantissa == 0) { + return static_cast(sign | 0x7c00u); // inf + } + return static_cast(sign | 0x7e00u); // quiet NaN + } + + // Zero / subnormal in float32 + if (exp == 0) { + return static_cast(sign); + } + + // Convert exponent bias: fp32 bias 127 -> fp16 bias 15 + exp = exp - 127 + 15; + + // Overflow -> inf + if (exp >= 0x1f) { + return static_cast(sign | 0x7c00u); + } + + // Underflow -> subnormal / zero + if (exp <= 0) { + if (exp < -10) { + return static_cast(sign); + } + + mantissa |= 0x00800000u; + + const int shift = 14 - exp; + uint32_t half_mant = mantissa >> shift; + + const uint32_t remainder = mantissa & ((1u << shift) - 1u); + const uint32_t halfway = 1u << (shift - 1); + if (remainder > halfway || (remainder == halfway && (half_mant & 1u))) { + ++half_mant; + } + + return static_cast(sign | half_mant); + } + + // Normal fp16 + uint32_t half_exp = static_cast(exp) << 10; + uint32_t half_mant = mantissa >> 13; + + const uint32_t round_bits = mantissa & 0x1fffu; + if (round_bits > 0x1000u || (round_bits == 0x1000u && (half_mant & 1u))) { + ++half_mant; + if (half_mant == 0x400u) { + half_mant = 0; + half_exp += 0x0400u; + if (half_exp >= 0x7c00u) { + return static_cast(sign | 0x7c00u); + } + } + } + + return static_cast(sign | half_exp | half_mant); +} + +inline constexpr float Fp16BitsToFloat(uint16_t bits) { + const uint32_t sign = (static_cast(bits & 0x8000u)) << 16; + const uint32_t exp = (bits >> 10) & 0x1fu; + const uint32_t mant = bits & 0x03ffu; + + uint32_t out = 0; + + if (exp == 0) { + if (mant == 0) { + out = sign; + } else { + uint32_t mantissa = mant; + int32_t e = -14; + while ((mantissa & 0x0400u) == 0) { + mantissa <<= 1; + --e; + } + mantissa &= 0x03ffu; + const uint32_t exp32 = static_cast(e + 127) << 23; + const uint32_t mant32 = mantissa << 13; + out = sign | exp32 | mant32; + } + } else if (exp == 0x1f) { + out = sign | 0x7f800000u | (mant << 13); + } else { + const uint32_t exp32 = static_cast(static_cast(exp) - 15 + 127) << 23; + const uint32_t mant32 = mant << 13; + out = sign | exp32 | mant32; + } + + return std::bit_cast(out); +} + +} // namespace detail + +struct alignas(2) FP16 { + uint16_t x{0}; + + struct from_bits_t {}; + static constexpr from_bits_t from_bits() { return {}; } + + constexpr FP16() = default; + constexpr FP16(uint16_t bits, from_bits_t) : x(bits) {} + + explicit constexpr FP16(float value) : x(detail::FloatToFp16Bits(value)) {} + explicit constexpr FP16(double value) : FP16(static_cast(value)) {} + explicit constexpr FP16(int value) : FP16(static_cast(value)) {} + explicit constexpr FP16(int64_t value) : FP16(static_cast(value)) {} + + explicit constexpr operator float() const { return detail::Fp16BitsToFloat(x); } + explicit constexpr operator double() const { return static_cast(static_cast(*this)); } + + FP16 &operator++() { + *this = FP16(static_cast(*this) + 1.0f); + return *this; + } +}; + +struct alignas(2) BF16 { + uint16_t x{0}; + + struct from_bits_t {}; + static constexpr from_bits_t from_bits() { return {}; } + + constexpr BF16() = default; + constexpr BF16(uint16_t bits, from_bits_t) : x(bits) {} + + explicit constexpr BF16(float value) : x(detail::FloatToBf16Bits(value)) {} + explicit constexpr BF16(double value) : BF16(static_cast(value)) {} + explicit constexpr BF16(int value) : BF16(static_cast(value)) {} + explicit constexpr BF16(int64_t value) : BF16(static_cast(value)) {} + + explicit constexpr operator float() const { return detail::Bf16BitsToFloat(x); } + explicit constexpr operator double() const { return static_cast(static_cast(*this)); } + + BF16 &operator++() { + *this = BF16(static_cast(*this) + 1.0f); + return *this; + } +}; + +// ----------------------------------------------------------------------------- +// DataType enum and metadata tables +// ----------------------------------------------------------------------------- enum class DataType : int8_t { kUINT8, kINT8, @@ -37,103 +226,91 @@ inline const std::unordered_map kDataTypeToDesc = { {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"}, }; -/** - * Compile-time type mapping from DataType enum to concrete C++ types. - * - * - Primary template: Declared but undefined to enforce specialization - * - Specializations: Explicit mappings (DataType::kFLOAT32 → float, etc) - * - TypeMap_t alias: Direct access to mapped type (TypeMap_t → int32_t) - * - * Enables type-safe generic code where operations dispatch based on DataType tokens, - * with zero runtime overhead. Extend by adding new specializations. - */ -template struct TypeMap; +// ----------------------------------------------------------------------------- +// Compile-time type mapping infrastructure +// ----------------------------------------------------------------------------- + +// Default framework scalar/storage mapping. +// This is the shared baseline mapping used by: +// - framework TypeMap +// - CPU backend type map +// - other backends that only need to override a few dtypes +template struct DefaultScalarTypeMap; + +// Forward framework mapping alias +template using DefaultScalarTypeMap_t = typename DefaultScalarTypeMap::type; + +// Framework compile-time type mapping: DataType -> framework C++ type +template struct TypeMap : DefaultScalarTypeMap {}; + template using TypeMap_t = typename TypeMap::type; -/** - * Compile-time type mapping from C++ types to DataType enum. - * - * Example usage: DataTypeMap::value // Returns DataType::kINT32 - * DataTypeMap_v for convenient access to the mapped value (e.g., DataTypeMap_v). - */ +// ----------------------------------------------------------------------------- +// Compile-time reverse mapping: framework C++ type -> DataType +// ----------------------------------------------------------------------------- template struct DataTypeMap; + template inline constexpr DataType DataTypeMap_v = DataTypeMap::value; -// Macro to define TypeMap specializations and reverse mappings -#define DEFINE_DATA_TYPE_MAPPING(ENUM_VALUE, CPP_TYPE) \ - template <> struct TypeMap { \ +// Macro to define DefaultScalarTypeMap specialization + framework reverse mapping +#define DEFINE_DEFAULT_DATA_TYPE_MAPPING(ENUM_VALUE, CPP_TYPE) \ + template <> struct DefaultScalarTypeMap { \ using type = CPP_TYPE; \ }; \ template <> struct DataTypeMap { \ static constexpr DataType value = DataType::ENUM_VALUE; \ }; -DEFINE_DATA_TYPE_MAPPING(kUINT8, uint8_t) -DEFINE_DATA_TYPE_MAPPING(kINT8, int8_t) -DEFINE_DATA_TYPE_MAPPING(kUINT16, uint16_t) -DEFINE_DATA_TYPE_MAPPING(kINT16, int16_t) -DEFINE_DATA_TYPE_MAPPING(kUINT32, uint32_t) -DEFINE_DATA_TYPE_MAPPING(kINT32, int32_t) -DEFINE_DATA_TYPE_MAPPING(kUINT64, uint64_t) -DEFINE_DATA_TYPE_MAPPING(kINT64, int64_t) -DEFINE_DATA_TYPE_MAPPING(kFLOAT32, float) -DEFINE_DATA_TYPE_MAPPING(kFLOAT64, double) - -#ifdef USE_CUDA -DEFINE_DATA_TYPE_MAPPING(kBFLOAT16, nv_bfloat16) -DEFINE_DATA_TYPE_MAPPING(kFLOAT16, half) -#else -// Non-CUDA fallbacks -template <> struct TypeMap { - using type = uint16_t; -}; -template <> struct TypeMap { - using type = uint16_t; -}; +DEFINE_DEFAULT_DATA_TYPE_MAPPING(kUINT8, uint8_t) +DEFINE_DEFAULT_DATA_TYPE_MAPPING(kINT8, int8_t) +DEFINE_DEFAULT_DATA_TYPE_MAPPING(kUINT16, uint16_t) +DEFINE_DEFAULT_DATA_TYPE_MAPPING(kINT16, int16_t) +DEFINE_DEFAULT_DATA_TYPE_MAPPING(kUINT32, uint32_t) +DEFINE_DEFAULT_DATA_TYPE_MAPPING(kINT32, int32_t) +DEFINE_DEFAULT_DATA_TYPE_MAPPING(kUINT64, uint64_t) +DEFINE_DEFAULT_DATA_TYPE_MAPPING(kINT64, int64_t) +DEFINE_DEFAULT_DATA_TYPE_MAPPING(kBFLOAT16, BF16) +DEFINE_DEFAULT_DATA_TYPE_MAPPING(kFLOAT16, FP16) +DEFINE_DEFAULT_DATA_TYPE_MAPPING(kFLOAT32, float) +DEFINE_DEFAULT_DATA_TYPE_MAPPING(kFLOAT64, double) -// TODO(lzm): currently for non-CUDA/CPU, there's an ambiguity of uint16_t mapping to both kUINT16 and -// kFLOAT16/kBFLOAT16. When CPU custom bfloat16/float16 types are defined, we should replace uint16_t with those types. -#endif -#undef DEFINE_DATA_TYPE_MAPPING +#undef DEFINE_DEFAULT_DATA_TYPE_MAPPING -// Extends std::is_floating_point to support CUDA floating-point types. +// ----------------------------------------------------------------------------- +// Type traits extensions (framework fallback scalar semantics) +// ----------------------------------------------------------------------------- template struct is_floating_point_ext : std::is_floating_point {}; -// Extends std::is_arithmetic to support CUDA floating-point types. template struct is_arithmetic_ext : std::is_arithmetic {}; -// Specializations for CUDA types -#ifdef USE_CUDA -template <> struct is_floating_point_ext<__nv_bfloat16> : std::true_type {}; -template <> struct is_arithmetic_ext<__nv_bfloat16> : std::true_type {}; -template <> struct is_floating_point_ext<__half> : std::true_type {}; -template <> struct is_arithmetic_ext<__half> : std::true_type {}; -#endif +template <> struct is_floating_point_ext : std::true_type {}; +template <> struct is_arithmetic_ext : std::true_type {}; + +template <> struct is_floating_point_ext : std::true_type {}; +template <> struct is_arithmetic_ext : std::true_type {}; + +// ----------------------------------------------------------------------------- +// Promotion helpers (framework-level WidestType) +// ----------------------------------------------------------------------------- +namespace detail { -namespace { template struct LargerType { static constexpr size_t size1 = sizeof(T1); static constexpr size_t size2 = sizeof(T2); using type = std::conditional_t<(size1 >= size2), T1, T2>; }; -// Specializations of LargerType for the specific 16-bit FP combinations -#ifdef USE_CUDA -template <> struct LargerType<__nv_bfloat16, __half> { +template <> struct LargerType { using type = float; }; -template <> struct LargerType<__half, __nv_bfloat16> { +template <> struct LargerType { using type = float; }; -#endif /** - * @brief Finds the first type in a parameter pack that satisfies the given predicate. If no type matches, - * returns the last type in the pack (base case). - * - * @tparam Predicate Template template parameter that takes one type and provides a static `value` member - * @tparam Ts Parameter pack of types to check + * @brief Finds the first type in a parameter pack that satisfies the given predicate. + * If no type matches, returns the last type in the pack (base case). */ template