Skip to content

Commit 363370b

Browse files
wip
1 parent c12280c commit 363370b

8 files changed

Lines changed: 106 additions & 46 deletions

File tree

tests/test_ref.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class MyStruct2(xo.Struct):
164164

165165

166166
@for_all_test_contexts
167-
def no_test_unionref(test_context):
167+
def test_unionref(test_context):
168168
arr = xo.Float64[:]([1, 2, 3], _context=test_context)
169169
buf = arr._buffer
170170
string = xo.String("Test", _buffer=buf)
@@ -185,14 +185,16 @@ class MyStructRef(xo.Struct):
185185

186186

187187
@for_all_test_contexts
188-
def no_test_array_of_unionrefs(test_context):
188+
def test_array_of_unionrefs(test_context):
189189
class MyStructA(xo.Struct):
190190
a = xo.Float64
191191

192192
class MyStructB(xo.Struct):
193193
a = xo.Int32
194194

195-
Element = xo.Ref[MyStructA, MyStructB]
195+
class Element(xo.UnionRef):
196+
_reftypes = [MyStructA, MyStructB]
197+
196198
ArrOfUnionRefs = Element[:]
197199

198200
aoref = ArrOfUnionRefs(10, _context=test_context)
@@ -338,3 +340,46 @@ class MyUnion(xo.UnionRef):
338340
_ref = [xo.Float64, xo.Int32]
339341

340342
assert MyUnion._has_refs
343+
344+
345+
def test_move_arrays_of_unionref():
346+
class Circle(xo.Struct):
347+
radius = xo.UInt64
348+
349+
class Rectangle(xo.Struct):
350+
width = xo.UInt64
351+
height = xo.UInt64
352+
353+
class Shape(xo.UnionRef):
354+
_reftypes = [Circle, Rectangle]
355+
356+
class StructHasRef(xo.Struct):
357+
shapes = Shape[:]
358+
359+
class NestedHasRef(xo.Struct):
360+
padding = xo.UInt64
361+
nested = StructHasRef
362+
363+
common_buffer = xo.context_default.new_buffer()
364+
365+
circle = Circle(radius=5, _buffer=common_buffer)
366+
rectangle = Rectangle(width=6, height=7, _buffer=common_buffer)
367+
368+
has_ref = StructHasRef(shapes=[circle, rectangle, circle], _buffer=common_buffer)
369+
circle.radius = 10
370+
371+
assert has_ref.shapes[0].radius == 10
372+
assert has_ref.shapes[1].width == 6
373+
assert has_ref.shapes[1].height == 7
374+
assert has_ref.shapes[2].radius == 10
375+
376+
nested_has_ref = NestedHasRef(padding=42, nested=has_ref, _buffer=common_buffer)
377+
assert nested_has_ref.nested.shapes[0].radius == 10
378+
assert nested_has_ref.nested.shapes[1].width == 6
379+
assert nested_has_ref.nested.shapes[1].height == 7
380+
assert nested_has_ref.nested.shapes[2].radius == 10
381+
382+
# Check that no copies were made
383+
assert nested_has_ref.nested.shapes[0]._offset == has_ref.shapes[0]._offset == circle._offset
384+
assert nested_has_ref.nested.shapes[1]._offset == has_ref.shapes[1]._offset == rectangle._offset
385+
assert nested_has_ref.nested.shapes[2]._offset == has_ref.shapes[2]._offset == circle._offset

xobjects/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .string import String
2020
from .struct import Struct, Field
2121
from .ref import Ref, UnionRef
22+
from .union import Union
2223

2324
from .context_cpu import ContextCpu
2425
from .context_pyopencl import ContextPyopencl

xobjects/_patch_pyopencl_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _infer_fccont(arr):
7272
def copy_non_cont(src, dest, custom_itemsize=None, skip_typecheck=False):
7373
assert src.shape == dest.shape
7474

75-
# The case float -> complex just works (by using the src itemsize)
75+
# The case float -> complex just works (by using the headers itemsize)
7676
if not (src.dtype == np.float64 and dest.dtype == np.complex128):
7777
if not skip_typecheck:
7878
assert src.dtype == dest.dtype

xobjects/array.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# ########################################### #
55

66
import logging
7+
from typing import TypeVar, Iterator, Protocol, Container, Iterable, Type, Self
78

89
import numpy as np
910

@@ -175,6 +176,14 @@ def __init__(self, cls):
175176
self.cls = cls
176177

177178

179+
ArrayItemType = TypeVar("ArrayItemType")
180+
181+
182+
class ContainerInstance(Protocol[ArrayItemType]):
183+
def __contains__(self, item: ArrayItemType) -> bool: ...
184+
def __iter__(self) -> Iterator[ArrayItemType]: ...
185+
186+
178187
class MetaArray(type):
179188
def __new__(cls, name, bases, data):
180189
if "_itemtype" in data: # specialized class
@@ -266,7 +275,7 @@ class Array(metaclass=MetaArray):
266275
_data_offset: int
267276

268277
@classmethod
269-
def mk_arrayclass(cls, itemtype, shape):
278+
def mk_arrayclass(cls, itemtype: TypeVar, shape) -> Type[ContainerInstance[TypeVar] & Type[Self]]:
270279
if type(shape) in (int, slice):
271280
shape = (shape,)
272281
order = list(range(len(shape)))
@@ -478,13 +487,13 @@ def _to_buffer(cls, buffer, offset, value, info=None):
478487
if not isinstance(value, buffer.context.nplike_array_type):
479488
value = buffer.context.nparray_to_context_array(value)
480489
buffer.update_from_nplike(coffset, cls._itemtype._dtype, value)
481-
elif isinstance(value, cls):
490+
elif isinstance(value, cls) and not cls._has_refs:
482491
if value._size == info.size:
483492
buffer.update_from_xbuffer(
484493
offset, value._buffer, value._offset, value._size
485494
)
486495
else:
487-
raise ValueError("Value {value} not compatible size")
496+
raise ValueError(f"Value {value} not compatible size")
488497
elif value is None: # no value to initialize
489498
if is_scalar(cls._itemtype):
490499
pass # leave uninitialized

xobjects/capi.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,14 @@ def gen_c_pointed(target: Arg, conf):
208208
size = gen_c_size_from_arg(target, conf)
209209
ret = gen_c_type_from_arg(target, conf)
210210

211-
if target.pointer or is_compound(target.atype) or is_string(target.atype):
211+
if target.pointer or is_compound(target.atype):
212212
chartype = gen_pointer(conf.get("chartype", "char") + "*", conf)
213213
return f"({ret})(({chartype}) obj+offset)"
214214

215+
if is_string(target.atype):
216+
chartype = gen_pointer(conf.get("chartype", "char") + "*", conf)
217+
return f"({chartype}) obj+offset+8" # first entry in the string is it's size
218+
215219
rettype = gen_pointer(ret + "*", conf)
216220
if size == 1:
217221
return f"*(({rettype}) obj+offset)"

xobjects/context_cpu.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,10 @@ def compile_kernel(
464464
module_name + ".o",
465465
]
466466

467-
for ff in files_to_remove:
468-
if os.path.exists(ff):
469-
os.remove(ff)
467+
if 'XOBJECTS_RETAIN_BUILD_FILES' not in os.environ:
468+
for ff in files_to_remove:
469+
if os.path.exists(ff):
470+
os.remove(ff)
470471

471472
def _build_sources(
472473
self,

xobjects/ref.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# ########################################### #
55

66
import logging
7+
from typing import Dict
78

89
import numpy as np
910

10-
from .typeutils import Info, dispatch_arg, allocate_on_buffer, default_conf
11+
from .context import Source, Kernel
12+
from .typeutils import Info, allocate_on_buffer, default_conf
1113
from .scalar import Int64
1214
from .array import Array
1315

@@ -269,21 +271,24 @@ def _gen_data_paths(cls, base=None):
269271
return paths
270272

271273
@classmethod
272-
def _gen_c_decl(cls, conf=default_conf):
274+
def _gen_c_decl(cls, conf=default_conf) -> str:
275+
"""Generate method declarations for ``cls``."""
273276
from . import capi
274277

275278
paths = cls._gen_data_paths()
276279
return capi.gen_cdefs(cls, paths, conf)
277280

278281
@classmethod
279-
def _gen_c_api(cls, conf=default_conf):
282+
def _gen_c_api(cls, conf=default_conf) -> Source:
283+
"""Generate method definitions for ``cls``."""
280284
from . import capi
281285

282286
paths = cls._gen_data_paths()
283287
return capi.gen_code(cls, paths, conf)
284288

285289
@classmethod
286-
def _gen_kernels(cls, conf=default_conf):
290+
def _gen_kernels(cls, conf=default_conf) -> Dict[str, Kernel]:
291+
"""Generate kernel definitions for ``cls``."""
287292
from . import capi
288293

289294
paths = cls._gen_data_paths()

xobjects/struct.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,11 @@
4848
import logging
4949
from typing import Callable, Optional
5050

51-
from .typeutils import (
52-
allocate_on_buffer,
53-
dispatch_arg,
54-
Info,
55-
_to_slot_size,
56-
_is_dynamic,
57-
default_conf,
58-
)
59-
60-
from .general import Print
61-
from .scalar import Int64
6251
from .array import Array
63-
from .context import Source, Arg, Kernel
64-
from .context_cpu import ContextCpu
52+
from .context import Source
53+
from .scalar import Int64
54+
from .typeutils import (Info, _to_slot_size, allocate_on_buffer, default_conf,
55+
dispatch_arg)
6556

6657
log = logging.getLogger(__name__)
6758

@@ -307,24 +298,28 @@ def _to_buffer(cls, buffer, offset, value, info=None):
307298
buffer.update_from_xbuffer(
308299
offset, value._buffer, value._offset, value._size
309300
)
310-
else: # value must be a dict, again potential disctructive
311-
if info is None:
312-
info = cls._inspect_args(value)
313-
if cls._size is None:
314-
Int64._to_buffer(buffer, offset, info.size)
315-
if hasattr(
316-
info, "_offsets"
317-
): # if it has a least two dynamic fields
318-
cls._set_offsets(buffer, offset, info._offsets)
319-
extra = getattr(info, "extra", {})
320-
for field in cls._fields:
321-
fvalue = field.value_from_args(value)
322-
if field.is_reference:
323-
foffset = offset + info._offsets[field.index]
324-
else:
325-
foffset = offset + field.offset
326-
finfo = extra.get(field.index)
327-
field.ftype._to_buffer(buffer, foffset, fvalue, finfo)
301+
return
302+
303+
if info is None:
304+
info = cls._inspect_args(value)
305+
306+
if cls._size is None:
307+
Int64._to_buffer(buffer, offset, info.size)
308+
309+
if hasattr(
310+
info, "_offsets"
311+
): # if it has at least two dynamic fields
312+
cls._set_offsets(buffer, offset, info._offsets)
313+
314+
extra = getattr(info, "extra", {})
315+
for field in cls._fields:
316+
fvalue = field.value_from_args(value)
317+
if field.is_reference:
318+
foffset = offset + info._offsets[field.index]
319+
else:
320+
foffset = offset + field.offset
321+
finfo = extra.get(field.index)
322+
field.ftype._to_buffer(buffer, foffset, fvalue, finfo)
328323

329324
def _update(self, value):
330325
# check if direct copy is possible

0 commit comments

Comments
 (0)