diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index aa0e1b3e28..26e47be278 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -467,7 +467,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): MLX = Mode( MLXLinker(), - RewriteDatabaseQuery(include=["fast_run"]), + RewriteDatabaseQuery(include=["fast_run", "mlx"]), ) FAST_COMPILE = Mode( diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index ac59f1809c..8b6447a1fd 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -9,6 +9,7 @@ import pytensor.link.mlx.dispatch.shape import pytensor.link.mlx.dispatch.subtensor import pytensor.link.mlx.dispatch.tensor_basic +import pytensor.link.mlx.dispatch.einsum import pytensor.link.mlx.dispatch.signal import pytensor.link.mlx.dispatch.signal.conv import pytensor.link.mlx.dispatch.blockwise diff --git a/pytensor/link/mlx/dispatch/einsum.py b/pytensor/link/mlx/dispatch/einsum.py new file mode 100644 index 0000000000..1961a10ab6 --- /dev/null +++ b/pytensor/link/mlx/dispatch/einsum.py @@ -0,0 +1,14 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.einsum import Einsum + + +@mlx_funcify.register(Einsum) +def mlx_funcify_Einsum(op, **kwargs): + subscripts = op.subscripts + + def einsum(*operands): + return mx.einsum(subscripts, *operands) + + return einsum diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index fea4c73d5c..6ec9619ce6 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -10,6 +10,7 @@ class MLXLinker(JITLinker): "local_careduce_fusion", "inplace", "scan_save_mem_prealloc", + "inline_einsum", ) def __init__(self, use_compile=True, *args, **kwargs): diff --git a/pytensor/tensor/rewriting/einsum.py b/pytensor/tensor/rewriting/einsum.py index 5e9fe2d026..5ee8acc92a 100644 --- a/pytensor/tensor/rewriting/einsum.py +++ b/pytensor/tensor/rewriting/einsum.py @@ -36,14 +36,14 @@ def optimize_einsum_inner_graph( return [new_out] -@register_specialize +@register_specialize("inline_einsum") @node_rewriter([Einsum]) def inline_optimized_einsum( fgraph: FunctionGraph, node: Apply ) -> list[TensorVariable] | None: """Inline einsums that are already optimized. - This allows the inner garph to be optimized with the rest of the graph, now that we got ordering right. + This allows the inner graph to be optimized with the rest of the graph, now that we got ordering right. """ op: Einsum = node.op diff --git a/tests/link/mlx/test_einsum.py b/tests/link/mlx/test_einsum.py new file mode 100644 index 0000000000..712045d492 --- /dev/null +++ b/tests/link/mlx/test_einsum.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from tests.link.mlx.test_basic import compare_mlx_and_py + + +mx = pytest.importorskip("mlx.core") + + +def test_mlx_einsum(): + subscripts = "ij, jk, kl -> il" + x = np.random.rand(3, 5) + y = np.random.rand(5, 2) + z = np.random.rand(2, 4) + + shapes = { + "x": (3, 5), + "y": (5, 2), + "z": (2, 4), + } + x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items()) + out = pt.einsum(subscripts, x_pt, y_pt, z_pt) + compare_mlx_and_py([x_pt, y_pt, z_pt], [out], [x, y, z]) + + +def test_ellipsis_einsum(): + subscripts = "...i,...i->..." + x = np.random.rand(2, 5) + y = np.random.rand(2, 5) + + x_pt = pt.tensor("x", shape=x.shape) + y_pt = pt.tensor("y", shape=y.shape) + out = pt.einsum(subscripts, x_pt, y_pt) + compare_mlx_and_py([x_pt, y_pt], [out], [x, y]) + + +def test_einsum_trace(): + subscripts = "ii->" + x_pt = pt.matrix("x") + x_val = np.random.rand(5, 5) + out = pt.einsum(subscripts, x_pt) + compare_mlx_and_py([x_pt], [out], [x_val]) + + +def test_einsum_batched_outer_product(): + a = pt.matrix("a", dtype="float32") + b = pt.matrix("b", dtype="float32") + out = pt.einsum("bi,bj->bij", a, b) + + a_val = np.random.normal(size=(5, 3)).astype("float32") + b_val = np.random.normal(size=(5, 2)).astype("float32") + + compare_mlx_and_py([a, b], [out], [a_val, b_val])