Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions plugin/sycl/common/optional_weight.cc
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
/*!
* Copyright by Contributors 2017-2025
*/
#include <sycl/sycl.hpp>

#include "../../../src/common/optional_weight.h"

#include <sycl/sycl.hpp>

#include "../device_manager.h"

namespace xgboost::common::sycl_impl {
double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) {
sycl::DeviceManager device_manager;
auto* qu = device_manager.GetQueue(ctx->Device());

template <typename T>
T ElementWiseSum(::sycl::queue* qu, OptionalWeights const& weights) {
const auto* data = weights.Data();
double result = 0;
T result = 0;
{
::sycl::buffer<double> buff(&result, 1);
::sycl::buffer<T> buff(&result, 1);
qu->submit([&](::sycl::handler& cgh) {
auto reduction = ::sycl::reduction(buff, cgh, ::sycl::plus<>());
cgh.parallel_for<>(::sycl::range<1>(weights.Size()), reduction,
[=](::sycl::id<1> pid, auto& sum) {
size_t i = pid[0];
sum += data[i];
});
}).wait_and_throw();
auto reduction = ::sycl::reduction(buff, cgh, ::sycl::plus<>());
cgh.parallel_for<>(::sycl::range<1>(weights.Size()), reduction,
[=](::sycl::id<1> pid, auto& sum) {
size_t i = pid[0];
sum += data[i];
});
}).wait_and_throw();
}

return result;
}

double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) {
sycl::DeviceManager device_manager;
auto* qu = device_manager.GetQueue(ctx->Device());

bool has_fp64_support = qu->get_device().has(::sycl::aspect::fp64);
if (has_fp64_support) {
return ElementWiseSum<double>(qu, weights);
} else {
return ElementWiseSum<float>(qu, weights);
}
}
} // namespace xgboost::common::sycl_impl
24 changes: 24 additions & 0 deletions plugin/sycl/common/stats.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*!
* Copyright by Contributors 2017-2025
*/
#include "../../../src/common/stats.h"

#include <sycl/sycl.hpp>

#include "../device_manager.h"

namespace xgboost::common::sycl_impl {
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out) {
sycl::DeviceManager device_manager;
auto* qu = device_manager.GetQueue(ctx->Device());

qu->submit([&](::sycl::handler& cgh) {
auto reduction = ::sycl::reduction(&(out(0)), 0.0f, ::sycl::plus<float>(),
::sycl::property::reduction::initialize_to_identity());
cgh.parallel_for<>(::sycl::range<1>(v.Size()), reduction, [=](::sycl::id<1> pid, auto& sum) {
size_t i = pid[0];
sum += v(i);
});
}).wait_and_throw();
}
} // namespace xgboost::common::sycl_impl
9 changes: 6 additions & 3 deletions plugin/sycl/context_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
* \file context_helper.cc
*/

#include <sycl/sycl.hpp>
#include "context_helper.h"

#include <sycl/sycl.hpp>

#include "device_manager.h"
#include "context_helper.h"

namespace xgboost {
namespace sycl {

DeviceOrd DeviceFP64(const DeviceOrd& device) {
DeviceManager device_manager;
bool support_fp64 = device_manager.GetQueue(device)->get_device().has(::sycl::aspect::fp64);
bool support_fp64 = true;
if (device.IsSycl()) {
support_fp64 = device_manager.GetQueue(device)->get_device().has(::sycl::aspect::fp64);
}
if (support_fp64) {
return device;
} else {
Expand Down
121 changes: 63 additions & 58 deletions plugin/sycl/device_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,73 +10,78 @@ namespace xgboost {
namespace sycl {

::sycl::queue* DeviceManager::GetQueue(const DeviceOrd& device_spec) const {
if (!device_spec.IsSycl()) {
LOG(WARNING) << "Sycl kernel is executed with non-sycl context: "
<< device_spec.Name() << ". "
<< "Default sycl device_selector will be used.";
}
if (!device_spec.IsSycl()) {
LOG(WARNING) << "Sycl kernel is executed with non-sycl context: " << device_spec.Name() << ". "
<< "Default sycl device_selector will be used.";
}

size_t queue_idx;
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
(collective::IsDistributed());
DeviceRegister& device_register = GetDevicesRegister();
if (not_use_default_selector) {
const int device_idx =
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
if (device_spec.IsSyclDefault()) {
auto& devices = device_register.devices;
CHECK_LT(device_idx, devices.size());
queue_idx = device_idx;
} else if (device_spec.IsSyclCPU()) {
auto& cpu_devices_idxes = device_register.cpu_devices_idxes;
CHECK_LT(device_idx, cpu_devices_idxes.size());
queue_idx = cpu_devices_idxes[device_idx];
} else if (device_spec.IsSyclGPU()) {
auto& gpu_devices_idxes = device_register.gpu_devices_idxes;
CHECK_LT(device_idx, gpu_devices_idxes.size());
queue_idx = gpu_devices_idxes[device_idx];
} else {
LOG(WARNING) << device_spec << " is not sycl, sycl:cpu or sycl:gpu";
auto device = ::sycl::queue(::sycl::default_selector_v).get_device();
queue_idx = device_register.devices.at(device);
}
size_t queue_idx;
bool not_use_default_selector =
(device_spec.ordinal != kDefaultOrdinal) || (collective::IsDistributed());
DeviceRegister& device_register = GetDevicesRegister();
if (not_use_default_selector) {
if (device_spec.IsSyclDefault()) {
auto& devices = device_register.devices;
const int device_idx = collective::IsDistributed() ? collective::GetRank() % devices.size()
: device_spec.ordinal;
CHECK_LT(device_idx, devices.size());
queue_idx = device_idx;
} else if (device_spec.IsSyclCPU()) {
auto& cpu_devices_idxes = device_register.cpu_devices_idxes;
const int device_idx = collective::IsDistributed()
? collective::GetRank() % cpu_devices_idxes.size()
: device_spec.ordinal;
CHECK_LT(device_idx, cpu_devices_idxes.size());
queue_idx = cpu_devices_idxes[device_idx];
} else if (device_spec.IsSyclGPU()) {
auto& gpu_devices_idxes = device_register.gpu_devices_idxes;
const int device_idx = collective::IsDistributed()
? collective::GetRank() % gpu_devices_idxes.size()
: device_spec.ordinal;
CHECK_LT(device_idx, gpu_devices_idxes.size());
queue_idx = gpu_devices_idxes[device_idx];
} else {
LOG(WARNING) << device_spec << " is not sycl, sycl:cpu or sycl:gpu";
auto device = ::sycl::queue(::sycl::default_selector_v).get_device();
queue_idx = device_register.devices.at(device);
}
} else {
if (device_spec.IsSyclCPU()) {
auto device = ::sycl::queue(::sycl::cpu_selector_v).get_device();
queue_idx = device_register.devices.at(device);
} else if (device_spec.IsSyclGPU()) {
auto device = ::sycl::queue(::sycl::gpu_selector_v).get_device();
queue_idx = device_register.devices.at(device);
} else {
if (device_spec.IsSyclCPU()) {
auto device = ::sycl::queue(::sycl::cpu_selector_v).get_device();
queue_idx = device_register.devices.at(device);
} else if (device_spec.IsSyclGPU()) {
auto device = ::sycl::queue(::sycl::gpu_selector_v).get_device();
queue_idx = device_register.devices.at(device);
} else {
auto device = ::sycl::queue(::sycl::default_selector_v).get_device();
queue_idx = device_register.devices.at(device);
}
auto device = ::sycl::queue(::sycl::default_selector_v).get_device();
queue_idx = device_register.devices.at(device);
}
return &(device_register.queues[queue_idx]);
}
return &(device_register.queues[queue_idx]);
}

DeviceManager::DeviceRegister& DeviceManager::GetDevicesRegister() const {
static DeviceRegister device_register;
static DeviceRegister device_register;

if (device_register.devices.size() == 0) {
std::lock_guard<std::mutex> guard(device_registering_mutex);
std::vector<::sycl::device> devices = ::sycl::device::get_devices();
for (size_t i = 0; i < devices.size(); i++) {
LOG(INFO) << "device_index = " << i << ", name = "
<< devices[i].get_info<::sycl::info::device::name>();
}
if (device_register.devices.size() == 0) {
std::lock_guard<std::mutex> guard(device_registering_mutex);
std::vector<::sycl::device> devices = ::sycl::device::get_devices();
for (size_t i = 0; i < devices.size(); i++) {
LOG(INFO) << "device_index = " << i
<< ", name = " << devices[i].get_info<::sycl::info::device::name>();
}

for (size_t i = 0; i < devices.size(); i++) {
device_register.devices[devices[i]] = i;
device_register.queues.push_back(::sycl::queue(devices[i]));
if (devices[i].is_cpu()) {
device_register.cpu_devices_idxes.push_back(i);
} else if (devices[i].is_gpu()) {
device_register.gpu_devices_idxes.push_back(i);
}
}
for (size_t i = 0; i < devices.size(); i++) {
device_register.devices[devices[i]] = i;
device_register.queues.push_back(::sycl::queue(devices[i]));
if (devices[i].is_cpu()) {
device_register.cpu_devices_idxes.push_back(i);
} else if (devices[i].is_gpu()) {
device_register.gpu_devices_idxes.push_back(i);
}
}
return device_register;
}
return device_register;
}

} // namespace sycl
Expand Down
11 changes: 8 additions & 3 deletions src/common/linalg_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,14 @@ void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
#elif defined(SYCL_LANGUAGE_VERSION)
template <typename T, std::int32_t D, typename Fn, auto _tag = detail::SysTag()>
void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
ctx->DispatchDevice([&] { cpu_impl::ElementWiseKernel(t, ctx->Threads(), std::forward<Fn>(fn)); },
[&] { LOG(FATAL) << "Invalid TU"; },
[&] { ::xgboost::sycl::linalg::ElementWiseKernel(t, std::forward<Fn>(fn)); });
if (t.Device().IsCPU()) {
cpu_impl::ElementWiseKernel(t, ctx->Threads(), std::forward<Fn>(fn));
} else {
ctx->DispatchDevice(
[&] { cpu_impl::ElementWiseKernel(t, ctx->Threads(), std::forward<Fn>(fn)); },
[&] { LOG(FATAL) << "Invalid TU"; },
[&] { ::xgboost::sycl::linalg::ElementWiseKernel(t, std::forward<Fn>(fn)); });
}
}
#else
template <typename T, std::int32_t D, typename Fn, auto _tag = detail::SysTag()>
Expand Down
2 changes: 2 additions & 0 deletions src/common/stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::Vector<

if (ctx->IsCUDA()) {
cuda_impl::Mean(ctx, v, out->View(ctx->Device()));
} else if (ctx->IsSycl()) {
sycl_impl::Mean(ctx, v, out->View(ctx->Device()));
} else {
auto h_v = v;
float n = v.Size();
Expand Down
13 changes: 12 additions & 1 deletion src/common/stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "xgboost/linalg.h" // TensorView,VectorView
#include "xgboost/logging.h" // CHECK_GE

#if !defined(XGBOOST_USE_CUDA)
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_SYCL)
#include "common.h" // AssertGPUSupport
#endif

Expand Down Expand Up @@ -140,6 +140,17 @@ inline void WeightedSampleMean(Context const*, bool, linalg::MatrixView<float co
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda_impl

namespace sycl_impl {
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out);

#if !defined(XGBOOST_USE_SYCL)
inline void Mean(Context const*, linalg::VectorView<float const>, linalg::VectorView<float>) {
common::AssertGPUSupport();
}

#endif // !defined(XGBOOST_USE_SYCL)
} // namespace sycl_impl

/**
* @brief Calculate medians for each column of the input matrix.
*/
Expand Down
6 changes: 4 additions & 2 deletions src/objective/multiclass_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,12 @@ class SoftmaxMultiClassObj : public ObjFunction {
<< "Number of weights should be equal to number of data points.";
}
info.weights_.SetDevice(device);
auto weights = common::MakeOptionalWeights(this->ctx_->Device(), info.weights_);
auto weights = common::MakeOptionalWeights(device, info.weights_);

preds.SetDevice(device);
auto predt = linalg::MakeTensorView(this->ctx_, &preds, n_samples, n_classes);
Context cpu_context = Context();
auto predt = linalg::MakeTensorView(device == ctx_->Device() ? this->ctx_ : &cpu_context,
&preds, n_samples, n_classes);
CHECK_EQ(labels.Shape(1), 1);
auto y1d = labels.Slice(linalg::All(), 0);
CHECK_EQ(y1d.Shape(0), info.num_row_);
Expand Down
Loading