diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..bfa35419 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: patch + changes: + fixed: + - Optimisation improvements for loading tax-benefit systems (caching). + - Replaced O(N×K) numpy.select in vectorial parameter lookups with O(N) index-based selection, and cached build_from_node results. US simulation compute -30%, from 12.8s to 9.0s. diff --git a/policyengine_core/parameters/parameter_node_at_instant.py b/policyengine_core/parameters/parameter_node_at_instant.py index 67f4695e..05881959 100644 --- a/policyengine_core/parameters/parameter_node_at_instant.py +++ b/policyengine_core/parameters/parameter_node_at_instant.py @@ -55,9 +55,19 @@ def __getitem__( ) -> Union["ParameterNodeAtInstant", VectorialParameterNodeAtInstant]: # If fancy indexing is used, cast to a vectorial node if isinstance(key, numpy.ndarray): - return parameters.VectorialParameterNodeAtInstant.build_from_node( - self - )[key] + # Cache the vectorial node to avoid rebuilding the recarray on + # every call — build_from_node is expensive (walks the full + # parameter subtree each time). + try: + vectorial = self._vectorial_node + except AttributeError: + vectorial = ( + parameters.VectorialParameterNodeAtInstant.build_from_node( + self + ) + ) + self._vectorial_node = vectorial + return vectorial[key] return self._children[key] def __iter__(self) -> Iterable: diff --git a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py index 9a7ce385..44cd0d44 100644 --- a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py +++ b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py @@ -197,87 +197,248 @@ def __getitem__(self, key: str) -> Any: return self.__getattr__(key) # If the key is a vector, e.g. ['zone_1', 'zone_2', 'zone_1'] elif isinstance(key, numpy.ndarray): - if not numpy.issubdtype(key.dtype, numpy.str_): - # In case the key is not a string vector, stringify it - if key.dtype == object and issubclass(type(key[0]), Enum): - enum = type(key[0]) - key = numpy.select( - [key == item for item in enum], - [str(item.name) for item in enum], - default="unknown", - ) - elif isinstance(key, EnumArray): - enum = key.possible_values - key = numpy.select( - [key == item.index for item in enum], - [item.name for item in enum], - default="unknown", - ) - else: + names = self.dtype.names + # Build name→child-index mapping (cached on instance) + if not hasattr(self, "_name_to_child_idx"): + self._name_to_child_idx = { + name: i for i, name in enumerate(names) + } + + name_to_child_idx = self._name_to_child_idx + n = len(key) + SENTINEL = len(names) + + # Convert key to integer indices directly, avoiding + # expensive intermediate string arrays where possible. + if isinstance(key, EnumArray): + # EnumArray: map enum int codes → child indices via + # a pre-built lookup table (O(N), no string comparison). + enum = key.possible_values + cache_key = id(enum) + if not hasattr(self, "_enum_lut_cache"): + self._enum_lut_cache = {} + lut = self._enum_lut_cache.get(cache_key) + if lut is None: + enum_items = list(enum) + max_code = max(item.index for item in enum_items) + 1 + lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp) + for item in enum_items: + child_idx = name_to_child_idx.get(item.name) + if child_idx is not None: + lut[item.index] = child_idx + self._enum_lut_cache[cache_key] = lut + idx = lut[numpy.asarray(key)] + elif key.dtype == object and len(key) > 0 and issubclass(type(key[0]), Enum): + # Object array of Enum instances + enum = type(key[0]) + cache_key = id(enum) + if not hasattr(self, "_enum_lut_cache"): + self._enum_lut_cache = {} + lut = self._enum_lut_cache.get(cache_key) + if lut is None: + enum_items = list(enum) + max_code = max(item.index for item in enum_items) + 1 + lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp) + for item in enum_items: + child_idx = name_to_child_idx.get(str(item.name)) + if child_idx is not None: + lut[item.index] = child_idx + self._enum_lut_cache[cache_key] = lut + codes = numpy.array( + [v.index for v in key], dtype=numpy.intp + ) + idx = lut[codes] + else: + # String keys: map via dict lookup + if not numpy.issubdtype(key.dtype, numpy.str_): key = key.astype("str") - names = list( - self.dtype.names - ) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2'] - conditions = [key == name for name in names] + # Vectorised dict lookup using numpy unique + scatter + uniq, inverse = numpy.unique(key, return_inverse=True) + uniq_idx = numpy.array( + [name_to_child_idx.get(u, SENTINEL) for u in uniq], + dtype=numpy.intp, + ) + idx = uniq_idx[inverse] + + # Gather values by child index using take on a stacked array. values = [self.vector[name] for name in names] - # NumPy 2.x requires all arrays in numpy.select to have identical dtypes - # For structured arrays with different field sets, we need to normalize them - if ( + is_structured = ( len(values) > 0 and hasattr(values[0].dtype, "names") and values[0].dtype.names - ): - # Check if all values have the same dtype + ) + + if is_structured: dtypes_match = all( val.dtype == values[0].dtype for val in values ) + v0_len = len(values[0]) - if not dtypes_match: - # Find the union of all field names across all values, preserving first seen order - all_fields = [] - seen = set() - for val in values: - for field in val.dtype.names: - if field not in seen: - all_fields.append(field) - seen.add(field) - - # Create unified dtype with all fields - unified_dtype = numpy.dtype( - [(f, " None: @@ -813,6 +828,7 @@ def purge_cache_of_invalid_values(self) -> None: for _name, _period in self.invalidated_caches: holder = self.get_holder(_name) holder.delete_arrays(_period) + self._fast_cache.pop((_name, str(_period)), None) self.invalidated_caches = set() def calculate_add( @@ -1193,6 +1209,12 @@ def delete_arrays(self, variable: str, period: Period = None) -> None: True """ self.get_holder(variable).delete_arrays(period) + if period is None: + self._fast_cache = { + k: v for k, v in self._fast_cache.items() if k[0] != variable + } + else: + self._fast_cache.pop((variable, str(period)), None) def get_known_periods(self, variable: str) -> List[Period]: """ @@ -1281,8 +1303,15 @@ def clone( new_dict = new.__dict__ for key, value in self.__dict__.items(): - if key not in ("debug", "trace", "tracer", "branches"): + if key not in ( + "debug", + "trace", + "tracer", + "branches", + "_fast_cache", + ): new_dict[key] = value + new._fast_cache = {} new.persons = self.persons.clone(new) setattr(new, new.persons.entity.key, new.persons)