Skip to content
326 changes: 195 additions & 131 deletions pytensor/compile/builders.py

Large diffs are not rendered by default.

97 changes: 93 additions & 4 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import warnings
import weakref
from collections.abc import (
Hashable,
Iterable,
Expand All @@ -14,6 +15,7 @@
Any,
Generic,
Optional,
Self,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -693,7 +695,7 @@ def _reduce(self):
return cls, (self.id, self.type)

def _str(self):
return f"*{self.id}-{var_type.__str__(self)}"
return f"i{self.id}"

new_type = type(
type_name, (cls, var_type), {"__reduce__": _reduce, "__str__": _str}
Expand Down Expand Up @@ -788,6 +790,93 @@ def value(self):
return self.data


def _get_frozen_output(apply_node: "FrozenApply", index: int) -> Variable:
"""Resolve a FrozenApply output by index. Used by pickle."""
return apply_node.outputs[index]


def _make_frozen_output_reduce(out: Variable):
"""Create a __reduce_ex__ override for a FrozenApply output Variable."""
owner = out.owner
index = out.index

def __reduce_ex__(protocol):
return (_get_frozen_output, (owner, index))

return __reduce_ex__


class FrozenApply(Apply):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still seems like we're getting a lot of complexity for this FrozenApply thing that is not present in regular interned Variables? Is it due to pickling? Why is it trickier?

Copy link
Copy Markdown
Member Author

@jessegrabowski jessegrabowski Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of FrozenApply is to allow equality-based comparison of FrozenFunctionGraph outputs. I went through an intermediate plan that used a tuple spec of (op, input_refs, n_ouputs) per node that would also have worked, but then we're maintaining that machinery. If we want to be able to do fg1 == fg2, we have to have an abstraction somewhere that allows robust serialization/hashing of nodes (including constants, which have been a repeated challenge during this PR)

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that the variables being hash-consed already achieves that, you could even have stuck with regular Apply objects holding the hash consed variables as inputs/outputs.

I suggested a frozen apply just so it would use tuples and reduce the risk of accidentally mutating them, but they never seemed necessary (to me) for the goal of hash/equality

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's not necessary per se, but there has to be some kind of hashable topological representation somewhere so we can rebuild the graph. The tuple-spec or "fingerprint" was one approach that interns the variables directly, and just the whole graph topology in a list of nested tuples. I settled on FrozenApply because it also encodes the same topology but in a representation that feels more "pytensor native". The downside is that it adds this intermediate object that we have to go through to get the variables themselves.

"""An immutable, globally-interned Apply node for frozen graphs.

Uses tuples for ``inputs`` and ``outputs`` so mutation raises ``TypeError``
at the language level. Interned by ``(op, cache_key(inputs))`` —
constructing a ``FrozenApply`` with the same op and input variables returns
the cached instance.

Constants are keyed by ``(type, data_bytes)`` so that two independently
created Constants with the same value resolve to the same cached node.
"""

_cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary()

@staticmethod
def _input_to_key(inp: Variable):
"""Convert an input Variable to a hashable, value-based cache key element.

Non-Constants (NominalVariables, FrozenApply outputs) are already
globally interned, so identity works. Constants use their byte
representation so that independently-created equal constants
(including NaN) produce the same key. Object-dtype constants
(e.g. slices) fall back to ``signature()`` since their byte
representation stores pointers, not values.
"""
if isinstance(inp, Constant):
a = np.asarray(inp.data)
if a.dtype.kind != "O":
return (inp.type, a.tobytes(), a.dtype.str, a.shape)
return inp.signature()
return inp

def __new__(
cls,
op: "Op",
inputs: tuple[Variable, ...],
output_types: tuple["Type", ...],
):
cache_key = (op, tuple(cls._input_to_key(i) for i in inputs))
cached = cls._cache.get(cache_key)
if cached is not None:
return cached

instance = object.__new__(cls)
instance.op = op
instance.inputs = inputs # type: ignore[assignment]
instance.outputs = tuple( # type: ignore[assignment]
t.variable_type(type=t, owner=instance, index=i)
for i, t in enumerate(output_types)
)
# Give each output Variable a __reduce__ that resolves to the
# canonical output on unpickle, avoiding fresh Variable objects.
for out in instance.outputs:
out.__reduce_ex__ = _make_frozen_output_reduce(out) # type: ignore[method-assign]
Comment on lines +859 to +862
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ELI5?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FrozenApply is interned, but it's output variables are not. When we unpickle stuff, python will make a new object so that a is not unpickle(pickle(a)). Our variable equality is identity based, so now a == unpickle(pickle(a)) fails. FrozenFunctionGraph also depends on comparing outputs by equality, so without this we end up with fg1 != fg2

The patch makes it so that when unpickling, we go look for the (interned) "canonical" version of the output (e.g. cached_apply.outputs[0]) and use that. Note that the variable itself isn't interned, just the apply. So this patch is basically letting pickle know about and use this relationship, rather than creating orphan duplicates.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the apply is interned the variables are also (and vice versa). I assumed we were going to intern the variables because you may have variables without apply but not the other way around

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We intern the variables, but only transitively. We have to go reach into the FrozenApply to get them.

instance.tag = Scratchpad()
cls._cache[cache_key] = instance
return instance
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the frozenapply to be hash-consed? Isn't it enough if the input/output variables are? Wondering if we can remove some extra code that way. The Apply doesn't do much anyway

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tied up in the current identity-based equality scheme. If we remove the FrozenApply interning, FrozenFunctionGraph.init will create new Apply nodes with new output Variables, so we lose output1 is output2 and thus fg1 == fg2.

That's not to say we couldn't move the equality check responsibility inside FFG, but it's a slightly different design.


def __init__(self, op, inputs, output_types):
# All initialization is done in __new__
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it in __new__ and not __init__?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we want to return cached variables for the outputs. new is run before an object exists, so we have time to do the output injection. If we put the setup in init, on cache hit:

  • call new
  • put the cached outputs into self.outputs
  • call init
  • overwrite the cached output with new outputs, breaking the identity equality


def clone(self, clone_inner_graph: bool = False) -> Self:
"""Frozen nodes are immutable — cloning returns self."""
return self

def __reduce__(self):
output_types = tuple(o.type for o in self.outputs)
return (type(self), (self.op, self.inputs, output_types))


def clone(
inputs: Sequence[Variable],
outputs: Sequence[Variable],
Expand Down Expand Up @@ -1104,14 +1193,14 @@ def equal_computations(

for x, y in zip(xs, ys, strict=True):
if not isinstance(x, Variable) and not isinstance(y, Variable):
return np.array_equal(x, y)
return np.array_equal(x, y, equal_nan=True)
if not isinstance(x, Variable):
if isinstance(y, Constant):
return np.array_equal(y.data, x)
return np.array_equal(y.data, x, equal_nan=True)
return False
if not isinstance(y, Variable):
if isinstance(x, Constant):
return np.array_equal(x.data, y)
return np.array_equal(x.data, y, equal_nan=True)
return False
x_is_owned, y_is_owned = (x.owner is not None, y.owner is not None)
if x_is_owned != y_is_owned:
Expand Down
186 changes: 184 additions & 2 deletions pytensor/graph/fg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A container for specifying and manipulating a graph with distinct inputs and outputs."""

import time
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Iterable, Sequence
from typing import Any, Union, cast
Expand All @@ -9,6 +10,8 @@
from pytensor.graph.basic import (
Apply,
AtomicVariable,
Constant,
NominalVariable,
Variable,
clone_get_equiv,
)
Expand All @@ -22,12 +25,25 @@
toposort_with_orderings,
vars_between,
)
from pytensor.graph.utils import MetaObject, MissingInputError
from pytensor.graph.utils import MissingInputError


ClientType = tuple[Apply, int]


class AbstractFunctionGraph(ABC):
"""Read-only interface shared by FunctionGraph and FrozenFunctionGraph."""

inputs: Sequence[Variable]
outputs: Sequence[Variable]
apply_nodes: set[Apply]
variables: set[Variable]
clients: dict[Variable, list[ClientType]]

@abstractmethod
def toposort(self) -> list[Apply]: ...


class Output(Op):
"""A dummy `Op` that represents an output variable in a `FunctionGraph`."""

Expand All @@ -46,7 +62,7 @@ def __str__(self):
return f"output[{self.idx}]"


class FunctionGraph(MetaObject):
class FunctionGraph(AbstractFunctionGraph):
r"""
A `FunctionGraph` represents a subgraph bound by a set of input variables and
a set of output variables, ie a subgraph that specifies an PyTensor function.
Expand Down Expand Up @@ -911,3 +927,169 @@ def dprint(self, **kwargs):
from pytensor.printing import debugprint

return debugprint(self, **kwargs)

def freeze(self) -> "FrozenFunctionGraph":
"""Return a frozen, hashable version of this FunctionGraph."""
return FrozenFunctionGraph(self.inputs, self.outputs)


class FrozenFunctionGraph(AbstractFunctionGraph):
"""Immutable, hashable function graph for inner graphs of Ops.

All internal nodes are globally interned via ``FrozenApply``. Two
``FrozenFunctionGraph`` instances built from structurally identical source
graphs share the same interned output objects, so equality reduces to
identity comparison on the outputs tuple.

Use ``FunctionGraph.freeze()`` or ``FrozenFunctionGraph(inputs, outputs)``
to create instances.

.. code-block:: python

from pytensor.scalar.basic import float64, add
from pytensor.graph.fg import FunctionGraph

x, y = float64("x"), float64("y")
frozen = FunctionGraph([x, y], [add(x, y)]).freeze()
frozen2 = FunctionGraph([x, y], [add(x, y)]).freeze()

assert frozen == frozen2
assert {frozen: "value"}[frozen2] == "value"
"""

def __init__(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
):
from pytensor.graph.basic import FrozenApply

nominal_inputs = tuple(
NominalVariable(i, inp.type) for i, inp in enumerate(inputs)
)

memo: dict[Variable, Variable] = dict(zip(inputs, nominal_inputs, strict=True))

for node in toposort(outputs, blockers=inputs):
for inp in node.inputs:
if inp not in memo:
if isinstance(inp, Constant):
memo[inp] = inp
elif isinstance(inp, AtomicVariable):
memo[inp] = inp
else:
raise ValueError(
f"Non-Constant, non-AtomicVariable orphan {inp} found "
"in the graph. All variables must be graph inputs, "
"Constants, AtomicVariables, or produced by Apply "
"nodes reachable from the inputs."
)

new_inputs = tuple(memo[i] for i in node.inputs)
output_types = tuple(out.type for out in node.outputs)
new_node = FrozenApply(node.op, new_inputs, output_types)

memo.update(zip(node.outputs, new_node.outputs, strict=True))

# Handle outputs that are Constants or AtomicVariables not
# encountered during toposort (e.g. a graph with no Apply nodes)
for o in outputs:
if o not in memo:
if isinstance(o, Constant):
memo[o] = o
elif isinstance(o, AtomicVariable):
memo[o] = o

try:
frozen_outputs = tuple(memo[o] for o in outputs)
except KeyError:
unmapped = [o for o in outputs if o not in memo]
raise ValueError(
f"Output variable {unmapped[0]} could not be mapped to a frozen "
"graph variable. All outputs must be graph inputs, "
"constants, or produced by Apply nodes reachable from "
"the inputs."
)

self.inputs: tuple[Variable, ...] = nominal_inputs
self.outputs: tuple[Variable, ...] = frozen_outputs
for i, out in enumerate(frozen_outputs):
out.name = f"o{i}"
self._variables: set[Variable] | None = None
self._apply_nodes: set[Apply] | None = None
self._clients: dict[Variable, list[ClientType]] | None = None
self._toposort: list[Apply] | None = None

def __reduce__(self):
return FrozenFunctionGraph, (self.inputs, self.outputs)

def __hash__(self):
return hash(self.outputs)

def __eq__(self, other):
if self is other:
return True
if not isinstance(other, FrozenFunctionGraph):
return False
return self.outputs == other.outputs and self.inputs == other.inputs

def __repr__(self):
return f"FrozenFunctionGraph(inputs={list(self.inputs)}, outputs={list(self.outputs)})"

def __copy__(self):
return self

def __deepcopy__(self, memo):
return self

@property
def apply_nodes(self) -> set[Apply]: # type: ignore[override]
if self._apply_nodes is None:
self._apply_nodes = set(applys_between(self.inputs, self.outputs))
return self._apply_nodes

def toposort(self) -> list[Apply]:
if self._toposort is None:
self._toposort = list(toposort(self.outputs, blockers=self.inputs))
return self._toposort

@property
def variables(self) -> set[Variable]: # type: ignore[override]
if self._variables is None:
self._variables = set(vars_between(self.inputs, self.outputs))
return self._variables

@property
def clients(self) -> dict[Variable, list[ClientType]]: # type: ignore[override]
if self._clients is None:
clients: dict[Variable, list[ClientType]] = {v: [] for v in self.inputs}
for node in self.toposort():
for i, inp in enumerate(node.inputs):
clients.setdefault(inp, []).append((node, i))
for out in node.outputs:
clients.setdefault(out, [])
self._clients = clients
return self._clients

def unfreeze(self) -> "FunctionGraph":
"""Return a mutable FunctionGraph with fresh mutable Apply nodes."""
memo: dict[Variable, Variable] = {inp: inp.type() for inp in self.inputs}

for node in self.toposort():
for i in node.inputs:
if i not in memo:
if isinstance(i, AtomicVariable):
memo[i] = i
else:
memo[i] = i.clone()
new_inputs = [memo[i] for i in node.inputs]
new_node = Apply(
node.op,
new_inputs,
[o.type() for o in node.outputs],
)
memo.update(zip(node.outputs, new_node.outputs))

new_inputs = [memo[i] for i in self.inputs]
new_outputs = [memo[o] for o in self.outputs]
return FunctionGraph(new_inputs, new_outputs, clone=False)
Loading
Loading