Skip to content

Feature Request: Device context manager for temporary device switching #1586

@bdice

Description

@bdice

Problem

There is no context manager for temporarily switching the active CUDA device in cuda.core. Users who need to perform work on a specific device and then restore the previous device must manually manage this state:

from cuda.core import Device

# Save current state, switch, do work, restore
dev0 = Device(0)
dev0.set_current()
# ... do work on device 0 ...

dev1 = Device(1)
dev1.set_current()
# ... do work on device 1 ...

dev0.set_current()  # manually restore -- easy to forget, not exception-safe

This pattern is error-prone (the restore call is not exception-safe) and verbose compared to the idiomatic Python with statement. The pattern appears in real-world code such as dask-cuda, which implemented a workaround for this missing feature.

Proposed Design

Add __enter__ and __exit__ methods to Device, making it usable as a context manager that temporarily activates the device's primary context and restores the previous state on exit.

API

from cuda.core import Device

dev0 = Device(0)
dev0.set_current()
# ... do work on device 0 ...

with Device(1) as device:
    # device 1 is now current
    stream = device.create_stream()
    # ...

# device 0 is automatically restored here

Semantics

On __enter__:

  1. Query the current CUDA context via cuCtxGetCurrent and save it on the context manager instance.
  2. Call self.set_current() (which uses the primary context for this device via cuCtxSetCurrent).
  3. Return self.

On __exit__:

  1. Restore the saved context via cuCtxSetCurrent. If the saved context was NULL, set NULL (no active context).
  2. Do NOT suppress exceptions (return False).

Key design properties

Stateless restoration (no Python-side stack). Each __enter__ call queries the actual CUDA driver state for the current context rather than maintaining a Python-side stack. On __exit__, it restores exactly what was saved. This is the critical lesson from CuPy's experience (cupy/cupy#6965, cupy/cupy#7427): libraries that maintain their own stack of previous devices break interoperability with libraries that use the CUDA API directly to check device state. By always querying and restoring the driver-level state, we interoperate correctly with PyTorch, CuPy, and any other library that uses cudaGetDevice/cudaSetDevice or cuCtxGetCurrent/cuCtxSetCurrent.

Reentrant and reusable. Because Device is a thread-local singleton and the saved-context state is stored per-__enter__ invocation (not on the Device object itself), the context manager is both reusable and reentrant:

dev0 = Device(0)
dev1 = Device(1)

with dev0:
    with dev1:
        with dev0:  # reentrant -- works correctly
            ...
        # dev1 restored
    # dev0 restored
# original state restored

To achieve reentrancy, the saved context must NOT be stored on self (the Device singleton). Instead, use a thread-local stack or return a helper object from __enter__ that holds the saved state. The simplest correct approach: store a per-thread stack of saved contexts on the Device class (or module-level), pushing on __enter__ and popping on __exit__.

Implementation sketch (in _device.pyx):

# Module-level: add a per-thread stack for saved contexts
# (reuse existing _tls threading.local())

def __enter__(self):
    # Query actual CUDA state -- do NOT use a Python-side device cache
    prev_ctx = handle_return(cuCtxGetCurrent())
    # Store on a per-thread stack so nested `with` works
    if not hasattr(_tls, '_ctx_stack'):
        _tls._ctx_stack = []
    _tls._ctx_stack.append(prev_ctx)
    self.set_current()
    return self

def __exit__(self, exc_type, exc_val, exc_tb):
    prev_ctx = _tls._ctx_stack.pop()
    handle_return(cuCtxSetCurrent(prev_ctx))
    return False

Note: The stack here is NOT a device stack -- it is a stack of saved CUcontext values that the __exit__ restores. Each entry corresponds to exactly one __enter__ call. This is fundamentally different from CuPy's old broken approach which tracked a stack of device IDs and queried that stack instead of the CUDA API.

Interoperability with other libraries. Because we use cuCtxSetCurrent (driver API), and both PyTorch and CuPy use cudaGetDevice/cudaSetDevice (runtime API) which queries the same underlying driver state, cross-library nesting works:

with torch.cuda.device(1):
    with Device(2):
        # Both torch and cuda.core see device 2
        ...
    # torch sees device 1 again (cuda.core restored the context)

Note that correct cross-library nesting depends on each library querying the CUDA API for the current device on context exit rather than relying on a cached value. Libraries that follow this pattern (including CuPy v12+ and the CUDA runtime API) will interoperate correctly.

Alternatives Considered

1. Separate Device.activate() method returning a context manager

with Device(1).activate():
    ...

This avoids adding __enter__/__exit__ to the singleton Device object. However, it adds API surface for no practical benefit -- the saved context state can be stored on a thread-local stack rather than on the Device instance, making Device itself safe to use directly as a reentrant context manager. The with Device(1): syntax is also more natural and matches PyTorch's with torch.cuda.device(1): pattern.

Rejected because it adds unnecessary indirection.

2. Do nothing -- recommend set_current() only

Per CuPy's internal policy, context managers for device switching are banned in CuPy's own codebase because they are footguns for library developers. The argument is that set_current() is explicit and unambiguous.

However, cuda.core targets end users (not just library internals), and the context manager pattern is:

  • Exception-safe by default
  • Idiomatic Python
  • Already provided by PyTorch and CuPy (for end users)
  • Requested by downstream users (dask-cuda)

Rejected as the sole approach, but set_current() remains the recommended approach for library code that needs precise control.

3. Use cuCtxPushCurrent/cuCtxPopCurrent instead of cuCtxSetCurrent

The CUDA driver provides an explicit context stack via push/pop. Using this would make nesting trivially correct. However, Device.set_current() currently uses cuCtxSetCurrent for primary contexts (not push/pop), and mixing the two models is fragile. The push/pop model also does not interoperate with libraries using cudaSetDevice (runtime API). The current approach of save-via-query/restore-via-set is correct and interoperable.

Rejected because it would diverge from the runtime API model that other libraries use.

Open Questions

  1. Should __enter__ call set_current() even if this device is already current? Calling cuCtxSetCurrent with the already-current context is cheap (no-op at the driver level) and keeps the implementation simple. The alternative (check-and-skip) adds complexity for negligible performance gain. Recommendation: always call set_current().

  2. What should __enter__ do if set_current() has never been called on this device? Currently, many Device properties require set_current() to have been called first (_check_context_initialized). The context manager should unconditionally call set_current(), initializing the device if needed. This is the natural expectation: with Device(1): should make device 1 ready for use.

  3. Should we document cross-library interop expectations? We should document that with Device(N): works correctly for cuda.core code, and that cross-library nesting works as long as the other library's context manager correctly queries CUDA state on exit rather than relying on a cached value.

Test Plan

  • Basic usage: with Device(0): sets device 0 as current, restores on exit.
  • Exception safety: device is restored even when an exception is raised inside the with block.
  • Nesting (same device): with dev0: with dev0: works without error.
  • Nesting (different devices): with dev0: with dev1: correctly restores dev0 on exit of inner block.
  • Deep nesting / reentrancy: with dev0: with dev1: with dev0: with dev1: restores correctly at each level.
  • Device remains usable after context manager exit (singleton not corrupted).
  • Multi-GPU: requires 2+ GPUs. Verify cudaGetDevice() (runtime API) reflects the device set by the context manager.
  • Thread safety: context manager state is per-thread (uses thread-local storage), so concurrent threads using different devices should not interfere.

Metadata

Metadata

Assignees

Labels

cuda.coreEverything related to the cuda.core moduletriageNeeds the team's attention

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions