Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- bump: patch
changes:
fixed:
- Optimisation improvements for loading tax-benefit systems (caching).
- Replaced O(N×K) numpy.select in vectorial parameter lookups with O(N) index-based selection, and cached build_from_node results. US simulation compute -30%, from 12.8s to 9.0s.
16 changes: 13 additions & 3 deletions policyengine_core/parameters/parameter_node_at_instant.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,19 @@ def __getitem__(
) -> Union["ParameterNodeAtInstant", VectorialParameterNodeAtInstant]:
# If fancy indexing is used, cast to a vectorial node
if isinstance(key, numpy.ndarray):
return parameters.VectorialParameterNodeAtInstant.build_from_node(
self
)[key]
# Cache the vectorial node to avoid rebuilding the recarray on
# every call — build_from_node is expensive (walks the full
# parameter subtree each time).
try:
vectorial = self._vectorial_node
except AttributeError:
vectorial = (
parameters.VectorialParameterNodeAtInstant.build_from_node(
self
)
)
self._vectorial_node = vectorial
return vectorial[key]
return self._children[key]

def __iter__(self) -> Iterable:
Expand Down
289 changes: 225 additions & 64 deletions policyengine_core/parameters/vectorial_parameter_node_at_instant.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,87 +197,248 @@ def __getitem__(self, key: str) -> Any:
return self.__getattr__(key)
# If the key is a vector, e.g. ['zone_1', 'zone_2', 'zone_1']
elif isinstance(key, numpy.ndarray):
if not numpy.issubdtype(key.dtype, numpy.str_):
# In case the key is not a string vector, stringify it
if key.dtype == object and issubclass(type(key[0]), Enum):
enum = type(key[0])
key = numpy.select(
[key == item for item in enum],
[str(item.name) for item in enum],
default="unknown",
)
elif isinstance(key, EnumArray):
enum = key.possible_values
key = numpy.select(
[key == item.index for item in enum],
[item.name for item in enum],
default="unknown",
)
else:
names = self.dtype.names
# Build name→child-index mapping (cached on instance)
if not hasattr(self, "_name_to_child_idx"):
self._name_to_child_idx = {
name: i for i, name in enumerate(names)
}

name_to_child_idx = self._name_to_child_idx
n = len(key)
SENTINEL = len(names)

# Convert key to integer indices directly, avoiding
# expensive intermediate string arrays where possible.
if isinstance(key, EnumArray):
# EnumArray: map enum int codes → child indices via
# a pre-built lookup table (O(N), no string comparison).
enum = key.possible_values
cache_key = id(enum)
if not hasattr(self, "_enum_lut_cache"):
self._enum_lut_cache = {}
lut = self._enum_lut_cache.get(cache_key)
if lut is None:
enum_items = list(enum)
max_code = max(item.index for item in enum_items) + 1
lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp)
for item in enum_items:
child_idx = name_to_child_idx.get(item.name)
if child_idx is not None:
lut[item.index] = child_idx
self._enum_lut_cache[cache_key] = lut
idx = lut[numpy.asarray(key)]
elif key.dtype == object and len(key) > 0 and issubclass(type(key[0]), Enum):
# Object array of Enum instances
enum = type(key[0])
cache_key = id(enum)
if not hasattr(self, "_enum_lut_cache"):
self._enum_lut_cache = {}
lut = self._enum_lut_cache.get(cache_key)
if lut is None:
enum_items = list(enum)
max_code = max(item.index for item in enum_items) + 1
lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp)
for item in enum_items:
child_idx = name_to_child_idx.get(str(item.name))
if child_idx is not None:
lut[item.index] = child_idx
self._enum_lut_cache[cache_key] = lut
codes = numpy.array(
[v.index for v in key], dtype=numpy.intp
)
idx = lut[codes]
else:
# String keys: map via dict lookup
if not numpy.issubdtype(key.dtype, numpy.str_):
key = key.astype("str")
names = list(
self.dtype.names
) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2']
conditions = [key == name for name in names]
# Vectorised dict lookup using numpy unique + scatter
uniq, inverse = numpy.unique(key, return_inverse=True)
uniq_idx = numpy.array(
[name_to_child_idx.get(u, SENTINEL) for u in uniq],
dtype=numpy.intp,
)
idx = uniq_idx[inverse]

# Gather values by child index using take on a stacked array.
values = [self.vector[name] for name in names]

# NumPy 2.x requires all arrays in numpy.select to have identical dtypes
# For structured arrays with different field sets, we need to normalize them
if (
is_structured = (
len(values) > 0
and hasattr(values[0].dtype, "names")
and values[0].dtype.names
):
# Check if all values have the same dtype
)

if is_structured:
dtypes_match = all(
val.dtype == values[0].dtype for val in values
)
v0_len = len(values[0])

if not dtypes_match:
# Find the union of all field names across all values, preserving first seen order
all_fields = []
seen = set()
for val in values:
for field in val.dtype.names:
if field not in seen:
all_fields.append(field)
seen.add(field)

# Create unified dtype with all fields
unified_dtype = numpy.dtype(
[(f, "<f8") for f in all_fields]
)
if v0_len <= 1:
# 1-element structured arrays: simple concat + index
if not dtypes_match:
all_fields = []
seen = set()
for val in values:
for field in val.dtype.names:
if field not in seen:
all_fields.append(field)
seen.add(field)

# Cast all values to unified dtype
values_cast = []
for val in values:
casted = numpy.zeros(len(val), dtype=unified_dtype)
for field in val.dtype.names:
casted[field] = val[field]
values_cast.append(casted)
unified_dtype = numpy.dtype(
[(f, "<f8") for f in all_fields]
)

default = numpy.zeros(
len(values_cast[0]), dtype=unified_dtype
)
# Fill with NaN
for field in unified_dtype.names:
default[field] = numpy.nan
values_cast = []
for val in values:
casted = numpy.zeros(
len(val), dtype=unified_dtype
)
for field in val.dtype.names:
casted[field] = val[field]
values_cast.append(casted)

result = numpy.select(conditions, values_cast, default)
default = numpy.zeros(1, dtype=unified_dtype)
for field in unified_dtype.names:
default[field] = numpy.nan
stacked = numpy.concatenate(
values_cast + [default]
)
result = stacked[idx]
else:
default = numpy.full(
1, numpy.nan, dtype=values[0].dtype
)
stacked = numpy.concatenate(values + [default])
result = stacked[idx]
else:
# All dtypes match, use original logic
default = numpy.full_like(values[0], numpy.nan)
result = numpy.select(conditions, values, default)
# N-element structured arrays: check if fields are
# simple scalars (fast path) or nested records
# (fall back to numpy.select).
first_field = values[0].dtype.names[0]
field_dtype = values[0][first_field].dtype
is_nested = (
hasattr(field_dtype, "names")
and field_dtype.names is not None
)

if is_nested:
# Nested structured: fall back to numpy.select
conditions = [
idx == i for i in range(len(values))
]
if not dtypes_match:
all_fields = []
seen = set()
for val in values:
for field in val.dtype.names:
if field not in seen:
all_fields.append(field)
seen.add(field)
unified_dtype = numpy.dtype(
[(f, "<f8") for f in all_fields]
)
values_cast = []
for val in values:
casted = numpy.zeros(
len(val), dtype=unified_dtype
)
for field in val.dtype.names:
casted[field] = val[field]
values_cast.append(casted)
default = numpy.zeros(
v0_len, dtype=unified_dtype
)
for field in unified_dtype.names:
default[field] = numpy.nan
result = numpy.select(
conditions, values_cast, default
)
else:
default = numpy.full_like(
values[0], numpy.nan
)
result = numpy.select(
conditions, values, default
)
else:
# Flat structured: fast per-field indexing
if not dtypes_match:
all_fields = []
seen = set()
for val in values:
for field in val.dtype.names:
if field not in seen:
all_fields.append(field)
seen.add(field)
unified_dtype = numpy.dtype(
[(f, "<f8") for f in all_fields]
)
values_unified = []
for val in values:
casted = numpy.zeros(
len(val), dtype=unified_dtype
)
for field in val.dtype.names:
casted[field] = val[field]
values_unified.append(casted)
field_names = all_fields
result_dtype = unified_dtype
else:
values_unified = values
field_names = values[0].dtype.names
result_dtype = values[0].dtype

result = numpy.empty(n, dtype=result_dtype)
arange_n = numpy.arange(v0_len)
for field in field_names:
field_stack = numpy.empty(
(len(values_unified) + 1, v0_len),
dtype=numpy.float64,
)
for i, v in enumerate(values_unified):
field_stack[i] = v[field]
field_stack[-1] = numpy.nan
result[field] = field_stack[
idx, arange_n
]
else:
# Non-structured array case
default = numpy.full_like(
values[0] if values else self.vector[key[0]], numpy.nan
)
result = numpy.select(conditions, values, default)
# Non-structured: values are either scalars (1-elem arrays)
# or N-element vectors (after prior vectorial indexing).
if values:
v0 = numpy.asarray(values[0])
if v0.ndim == 0 or v0.shape[0] <= 1:
# Scalar per child: 1D lookup
scalar_vals = numpy.empty(
len(values) + 1, dtype=numpy.float64
)
for i, v in enumerate(values):
scalar_vals[i] = float(v)
scalar_vals[-1] = numpy.nan
result = scalar_vals[idx]
else:
# N-element vectors: stack into (K+1, N) matrix
m = v0.shape[0]
stacked = numpy.empty(
(len(values) + 1, m), dtype=numpy.float64
)
for i, v in enumerate(values):
stacked[i] = v
stacked[-1] = numpy.nan
result = stacked[idx, numpy.arange(m)]
else:
result = numpy.full(n, numpy.nan)

# Check for unexpected keys (NaN results from missing keys)
# Check for unexpected keys
if helpers.contains_nan(result):
unexpected_keys = set(key).difference(self.vector.dtype.names)
unexpected_keys = set(
numpy.asarray(key, dtype=str)
if not numpy.issubdtype(
numpy.asarray(key).dtype, numpy.str_
)
else key
).difference(self.vector.dtype.names)
if unexpected_keys:
unexpected_key = unexpected_keys.pop()
raise ParameterNotFoundError(
Expand Down
Loading