Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions runtime/executor/test/executor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ TEST_F(ExecutorTest, OpRegistration) {
auto s2 = register_kernel(Kernel("test_2", test_op));
ASSERT_EQ(Error::Ok, s1);
ASSERT_EQ(Error::Ok, s2);
ET_EXPECT_DEATH(
[]() { (void)register_kernel(Kernel("test", test_op)); }(), "");
// Duplicate registration should succeed and skip gracefully
auto s3 = register_kernel(Kernel("test", test_op));
ASSERT_EQ(Error::Ok, s3);

ASSERT_TRUE(registry_has_op_function("test"));
ASSERT_TRUE(registry_has_op_function("test_2"));
Expand Down
10 changes: 7 additions & 3 deletions runtime/kernel/operator_registry.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -82,16 +82,20 @@

for (const auto& kernel : kernels) {
// Linear search. This is fine if the number of kernels is small.
bool is_duplicate = false;
for (size_t i = 0; i < num_registered_kernels; i++) {
Kernel k = registered_kernels[i];
if (strcmp(kernel.name_, k.name_) == 0 &&
kernel.kernel_key_ == k.kernel_key_) {
ET_LOG(Error, "Re-registering %s, from %s", k.name_, lib_name);
ET_LOG(Info, "Skipping duplicate registration of %s, from %s", k.name_, lib_name);
ET_LOG_KERNEL_KEY(k.kernel_key_);
return Error::RegistrationAlreadyRegistered;
is_duplicate = true;
break;
}
}
registered_kernels[num_registered_kernels++] = kernel;
if (!is_duplicate) {
registered_kernels[num_registered_kernels++] = kernel;
}
}
ET_LOG(
Debug,
Expand Down
11 changes: 7 additions & 4 deletions runtime/kernel/test/kernel_double_registration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ TEST_F(KernelDoubleRegistrationTest, Basic) {
"aten::add.out",
"v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3",
[](KernelRuntimeContext&, Span<EValue*>) {})};
Error err = Error::RegistrationAlreadyRegistered;

ET_EXPECT_DEATH(
{ (void)register_kernels({kernels}); },
std::to_string(static_cast<uint32_t>(err)));
// First registration should succeed
Error err = register_kernels({kernels});
EXPECT_EQ(err, Error::Ok);

// Second registration should succeed but skip the duplicate
err = register_kernels({kernels});
EXPECT_EQ(err, Error::Ok);
}
41 changes: 34 additions & 7 deletions runtime/kernel/test/operator_registry_test.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -192,12 +192,16 @@
EXPECT_TRUE(registry_has_op_function("foo"));
}

TEST_F(OperatorRegistryTest, RegisterOpsMoreThanOnceDie) {
TEST_F(OperatorRegistryTest, RegisterOpsMoreThanOnceSkipsDuplicate) {
Kernel kernels[] = {
Kernel("foo", [](KernelRuntimeContext&, Span<EValue*>) {}),
Kernel("foo", [](KernelRuntimeContext&, Span<EValue*>) {})};
Span<const Kernel> kernels_span = Span<const Kernel>(kernels);
ET_EXPECT_DEATH((void)register_kernels(kernels_span), "registration failed");
// Should succeed and skip the duplicate
Error err = register_kernels(kernels_span);
EXPECT_EQ(err, Error::Ok);
// Verify the operator was registered
EXPECT_TRUE(registry_has_op_function("foo"));
}

TEST_F(OperatorRegistryTest, KernelKeyEquals) {
Expand Down Expand Up @@ -387,7 +391,7 @@
ASSERT_EQ(val_2, 50);
}

TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) {
TEST_F(OperatorRegistryTest, DoubleRegisterKernelsSkipsDuplicate) {
std::array<char, kKernelKeyBufSize> buf_long_contiguous;
Error err = make_kernel_key(
{{ScalarType::Long, {0, 1, 2, 3}}},
Expand All @@ -406,10 +410,33 @@
(void)context;
*(stack[0]) = Scalar(50);
});
Kernel kernels[] = {kernel_1, kernel_2};
// clang-tidy off
ET_EXPECT_DEATH((void)register_kernels(kernels), "registration failed");
// clang-tidy on

// Register first kernel
err = register_kernels({kernel_1});
ASSERT_EQ(err, Error::Ok);

// Attempt to register duplicate - should succeed but skip
err = register_kernels({kernel_2});
ASSERT_EQ(err, Error::Ok);

// Verify first registration was kept (returns 100, not 50)
Tensor::DimOrderType dims[] = {0, 1, 2, 3};
auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
Span<const TensorMeta> user_kernel_key(meta);

EXPECT_TRUE(registry_has_op_function("test::baz", user_kernel_key));
Result<OpFunction> op = get_op_function_from_registry("test::baz", user_kernel_key);
ASSERT_EQ(op.error(), Error::Ok);

EValue values[1];
values[0] = Scalar(0);
EValue* evalues[1];
evalues[0] = &values[0];
KernelRuntimeContext context{};

(*op)(context, Span<EValue*>(evalues));
ASSERT_EQ(values[0].toScalar().to<int64_t>(), 100);
}

TEST_F(OperatorRegistryTest, ExecutorChecksKernel) {
Expand Down
Loading