From 40d07f95778e9406730c3426136fd48eb80a020d Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 29 Mar 2026 00:05:42 -0500 Subject: [PATCH 1/8] Add MLX dispatch for eig --- pytensor/link/mlx/dispatch/nlinalg.py | 18 +++++++++++++++++- tests/link/mlx/test_nlinalg.py | 12 ++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/pytensor/link/mlx/dispatch/nlinalg.py b/pytensor/link/mlx/dispatch/nlinalg.py index 7303305aa6..aa0482674d 100644 --- a/pytensor/link/mlx/dispatch/nlinalg.py +++ b/pytensor/link/mlx/dispatch/nlinalg.py @@ -1,7 +1,13 @@ import mlx.core as mx from pytensor.link.mlx.dispatch.basic import mlx_funcify -from pytensor.tensor.nlinalg import SVD, KroneckerProduct, MatrixInverse, MatrixPinv +from pytensor.tensor.nlinalg import ( + SVD, + Eig, + KroneckerProduct, + MatrixInverse, + MatrixPinv, +) @mlx_funcify.register(SVD) @@ -49,6 +55,16 @@ def kron(a, b): return kron +@mlx_funcify.register(Eig) +def mlx_funcify_Eig(op, node, **kwargs): + X_dtype = getattr(mx, node.inputs[0].dtype) + + def eig(x): + return mx.linalg.eig(x.astype(dtype=X_dtype, stream=mx.cpu), stream=mx.cpu) + + return eig + + @mlx_funcify.register(MatrixInverse) def mlx_funcify_MatrixInverse(op, node, **kwargs): X_dtype = getattr(mx, node.inputs[0].dtype) diff --git a/tests/link/mlx/test_nlinalg.py b/tests/link/mlx/test_nlinalg.py index 1df408c735..e33cb68b9a 100644 --- a/tests/link/mlx/test_nlinalg.py +++ b/tests/link/mlx/test_nlinalg.py @@ -8,6 +8,18 @@ from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode +def test_mlx_eig(): + rng = np.random.default_rng(15) + + M = rng.normal(size=(3, 3)) + A_val = (M @ M.T).astype(config.floatX) + + A = pt.matrix(name="A") + outs = pt.linalg.eig(A) + + compare_mlx_and_py([A], outs, [A_val]) + + @pytest.mark.parametrize("compute_uv", [True, False]) def test_mlx_svd(compute_uv): rng = np.random.default_rng(15) From bc50145915acc98a41db35611854da39cb6e3eef Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 29 Mar 2026 00:07:02 -0500 Subject: [PATCH 2/8] Add MLX dispatch for eigh --- pytensor/link/mlx/dispatch/nlinalg.py | 14 ++++++++++++++ tests/link/mlx/test_nlinalg.py | 13 +++++++++++++ 2 files changed, 27 insertions(+) diff --git a/pytensor/link/mlx/dispatch/nlinalg.py b/pytensor/link/mlx/dispatch/nlinalg.py index aa0482674d..c4bd8ce3ae 100644 --- a/pytensor/link/mlx/dispatch/nlinalg.py +++ b/pytensor/link/mlx/dispatch/nlinalg.py @@ -4,6 +4,7 @@ from pytensor.tensor.nlinalg import ( SVD, Eig, + Eigh, KroneckerProduct, MatrixInverse, MatrixPinv, @@ -65,6 +66,19 @@ def eig(x): return eig +@mlx_funcify.register(Eigh) +def mlx_funcify_Eigh(op, node, **kwargs): + uplo = op.UPLO + X_dtype = getattr(mx, node.inputs[0].dtype) + + def eigh(x): + return mx.linalg.eigh( + x.astype(dtype=X_dtype, stream=mx.cpu), UPLO=uplo, stream=mx.cpu + ) + + return eigh + + @mlx_funcify.register(MatrixInverse) def mlx_funcify_MatrixInverse(op, node, **kwargs): X_dtype = getattr(mx, node.inputs[0].dtype) diff --git a/tests/link/mlx/test_nlinalg.py b/tests/link/mlx/test_nlinalg.py index e33cb68b9a..9eb5a84f40 100644 --- a/tests/link/mlx/test_nlinalg.py +++ b/tests/link/mlx/test_nlinalg.py @@ -20,6 +20,19 @@ def test_mlx_eig(): compare_mlx_and_py([A], outs, [A_val]) +@pytest.mark.parametrize("UPLO", ["L", "U"]) +def test_mlx_eigh(UPLO): + rng = np.random.default_rng(15) + + M = rng.normal(size=(3, 3)) + A_val = (M @ M.T).astype(config.floatX) + + A = pt.matrix(name="A") + outs = pt.linalg.eigh(A, UPLO=UPLO) + + compare_mlx_and_py([A], outs, [A_val]) + + @pytest.mark.parametrize("compute_uv", [True, False]) def test_mlx_svd(compute_uv): rng = np.random.default_rng(15) From ed37c46334d7943721af2b861c3cb38bbb81b432 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 29 Mar 2026 00:18:05 -0500 Subject: [PATCH 3/8] Add MLX dispatch for eigvalsh --- pytensor/link/mlx/dispatch/slinalg.py | 25 +++++++++++++++++++++++++ tests/link/mlx/test_slinalg.py | 14 ++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/pytensor/link/mlx/dispatch/slinalg.py b/pytensor/link/mlx/dispatch/slinalg.py index 6c102b7f67..b663117c87 100644 --- a/pytensor/link/mlx/dispatch/slinalg.py +++ b/pytensor/link/mlx/dispatch/slinalg.py @@ -4,6 +4,31 @@ from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.tensor.slinalg import LU, Cholesky, Solve, SolveTriangular +from pytensor.tensor.slinalg import ( + Cholesky, + Eigvalsh, + LU, + Solve, + SolveTriangular, +) + + +@mlx_funcify.register(Eigvalsh) +def mlx_funcify_Eigvalsh(op, node, **kwargs): + UPLO = "L" if op.lower else "U" + X_dtype = getattr(mx, node.inputs[0].dtype) + + def eigvalsh(a, b=None): + if b is not None: + raise NotImplementedError( + "mlx.core.linalg.eigvalsh does not support generalized " + "eigenvector problems (b != None)" + ) + return mx.linalg.eigvalsh( + a.astype(dtype=X_dtype, stream=mx.cpu), UPLO=UPLO, stream=mx.cpu + ) + + return eigvalsh @mlx_funcify.register(Cholesky) diff --git a/tests/link/mlx/test_slinalg.py b/tests/link/mlx/test_slinalg.py index 88d92bf43e..8d16124ce5 100644 --- a/tests/link/mlx/test_slinalg.py +++ b/tests/link/mlx/test_slinalg.py @@ -107,3 +107,17 @@ def test_mlx_LU(): mlx_mode=mlx_mode, assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True), ) + + +@pytest.mark.parametrize("lower", [True, False]) +def test_mlx_eigvalsh(lower): + rng = np.random.default_rng(15) + + M = rng.normal(size=(3, 3)) + A_val = (M @ M.T).astype(config.floatX) + + A = pt.matrix(name="A") + B = pt.matrix(name="B") + out = pt.linalg.eigvalsh(A, B, lower=lower) + + compare_mlx_and_py([A, B], [out], [A_val, None]) From eff057317313b008213f5be13849a748c10099d7 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 29 Mar 2026 00:18:15 -0500 Subject: [PATCH 4/8] Add MLX dispatch for LUFactor --- pytensor/link/mlx/dispatch/slinalg.py | 35 +++++++++++++++++++++++++-- tests/link/mlx/test_slinalg.py | 25 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/pytensor/link/mlx/dispatch/slinalg.py b/pytensor/link/mlx/dispatch/slinalg.py index b663117c87..937e07ed9f 100644 --- a/pytensor/link/mlx/dispatch/slinalg.py +++ b/pytensor/link/mlx/dispatch/slinalg.py @@ -3,11 +3,12 @@ import mlx.core as mx from pytensor.link.mlx.dispatch.basic import mlx_funcify -from pytensor.tensor.slinalg import LU, Cholesky, Solve, SolveTriangular from pytensor.tensor.slinalg import ( + LU, Cholesky, Eigvalsh, - LU, + LUFactor, + PivotToPermutations, Solve, SolveTriangular, ) @@ -107,3 +108,33 @@ def lu(a): ) return lu + + +@mlx_funcify.register(LUFactor) +def mlx_funcify_LUFactor(op, node, **kwargs): + A_dtype = getattr(mx, node.inputs[0].dtype) + + def lu_factor(a): + lu, pivots = mx.linalg.lu_factor( + a.astype(dtype=A_dtype, stream=mx.cpu), stream=mx.cpu + ) + return lu, pivots.astype(mx.int32, stream=mx.cpu) + + return lu_factor + + +@mlx_funcify.register(PivotToPermutations) +def mlx_funcify_PivotToPermutations(op, **kwargs): + inverse = op.inverse + + def pivot_to_permutations(pivots): + pivots = mx.array(pivots) + n = pivots.shape[0] + p_inv = mx.arange(n, dtype=mx.int32) + for i in range(n): + p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i] + if inverse: + return p_inv + return mx.argsort(p_inv) + + return pivot_to_permutations diff --git a/tests/link/mlx/test_slinalg.py b/tests/link/mlx/test_slinalg.py index 8d16124ce5..7c973b872c 100644 --- a/tests/link/mlx/test_slinalg.py +++ b/tests/link/mlx/test_slinalg.py @@ -121,3 +121,28 @@ def test_mlx_eigvalsh(lower): out = pt.linalg.eigvalsh(A, B, lower=lower) compare_mlx_and_py([A, B], [out], [A_val, None]) + + +def test_mlx_lu_factor(): + rng = np.random.default_rng(15) + + A = pt.matrix(name="A") + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + + out = pt.linalg.lu_factor(A) + + compare_mlx_and_py([A], out, [A_val]) + + +def test_mlx_pivot_to_permutations(): + rng = np.random.default_rng(15) + + A = pt.matrix(name="A") + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + + from pytensor.tensor.slinalg import pivot_to_permutation + + lu_and_pivots = pt.linalg.lu_factor(A) + out = pivot_to_permutation(lu_and_pivots[1]) + + compare_mlx_and_py([A], [out], [A_val]) From d1691bb20ed5054709b59369766f6aeda2274b82 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 29 Mar 2026 00:26:41 -0500 Subject: [PATCH 5/8] Add MLX dispatch for QR --- pytensor/link/mlx/dispatch/slinalg.py | 28 +++++++++++++++++++++++++++ tests/link/mlx/test_slinalg.py | 12 ++++++++++++ 2 files changed, 40 insertions(+) diff --git a/pytensor/link/mlx/dispatch/slinalg.py b/pytensor/link/mlx/dispatch/slinalg.py index 937e07ed9f..88ff3e5553 100644 --- a/pytensor/link/mlx/dispatch/slinalg.py +++ b/pytensor/link/mlx/dispatch/slinalg.py @@ -5,6 +5,7 @@ from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.tensor.slinalg import ( LU, + QR, Cholesky, Eigvalsh, LUFactor, @@ -110,6 +111,33 @@ def lu(a): return lu +@mlx_funcify.register(QR) +def mlx_funcify_QR(op, node, **kwargs): + mode = op.mode + A_dtype = getattr(mx, node.inputs[0].dtype) + + if mode not in ("economic", "r"): + raise NotImplementedError( + f"mode='{mode}' is not supported in the MLX backend. " + "Only 'economic' and 'r' modes are available." + ) + + def qr(a): + Q, R = mx.linalg.qr(a.astype(dtype=A_dtype, stream=mx.cpu), stream=mx.cpu) + if mode == "r": + M = a.shape[-2] + K = R.shape[-2] + if M > K: + # Pytensor follows scipy convention for mode = 'r', which returns R with the same + # leading shape as the input. + pad_width = [(0, 0)] * (R.ndim - 2) + [(0, M - K), (0, 0)] + return mx.pad(R, pad_width, stream=mx.cpu) + return R + return Q, R + + return qr + + @mlx_funcify.register(LUFactor) def mlx_funcify_LUFactor(op, node, **kwargs): A_dtype = getattr(mx, node.inputs[0].dtype) diff --git a/tests/link/mlx/test_slinalg.py b/tests/link/mlx/test_slinalg.py index 7c973b872c..3c942aec02 100644 --- a/tests/link/mlx/test_slinalg.py +++ b/tests/link/mlx/test_slinalg.py @@ -146,3 +146,15 @@ def test_mlx_pivot_to_permutations(): out = pivot_to_permutation(lu_and_pivots[1]) compare_mlx_and_py([A], [out], [A_val]) + + +@pytest.mark.parametrize("mode", ["economic", "r"]) +def test_mlx_qr(mode): + rng = np.random.default_rng(15) + + A = pt.matrix(name="A") + A_val = rng.normal(size=(5, 3)).astype(config.floatX) + + out = pt.linalg.qr(A, mode=mode) + + compare_mlx_and_py([A], out, [A_val]) From 5d65608864b20b67b770b5ac8fa11ca86ce596d2 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 29 Mar 2026 00:37:40 -0500 Subject: [PATCH 6/8] Add MLX dispatch for Det and SlogDet --- pytensor/link/mlx/dispatch/nlinalg.py | 44 +++++++++++++++++++++++++++ tests/link/mlx/test_nlinalg.py | 23 ++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/pytensor/link/mlx/dispatch/nlinalg.py b/pytensor/link/mlx/dispatch/nlinalg.py index c4bd8ce3ae..dc4b270a72 100644 --- a/pytensor/link/mlx/dispatch/nlinalg.py +++ b/pytensor/link/mlx/dispatch/nlinalg.py @@ -3,11 +3,13 @@ from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.tensor.nlinalg import ( SVD, + Det, Eig, Eigh, KroneckerProduct, MatrixInverse, MatrixPinv, + SLogDet, ) @@ -56,6 +58,48 @@ def kron(a, b): return kron +def _lu_det_parts(x): + """Shared helper: compute sign and log|det| via LU factorization.""" + lu, pivots = mx.linalg.lu_factor(x, stream=mx.cpu) + diag_u = mx.diagonal(lu, stream=mx.cpu) + n_swaps = mx.sum( + pivots != mx.arange(pivots.shape[0], dtype=pivots.dtype, stream=mx.cpu), + stream=mx.cpu, + ) + pivot_sign = 1 - 2 * (n_swaps % 2) + sign = mx.multiply( + pivot_sign, + mx.prod(mx.sign(diag_u, stream=mx.cpu), stream=mx.cpu), + stream=mx.cpu, + ) + logabsdet = mx.sum( + mx.log(mx.abs(diag_u, stream=mx.cpu), stream=mx.cpu), + stream=mx.cpu, + ) + return sign, logabsdet + + +@mlx_funcify.register(Det) +def mlx_funcify_Det(op, node, **kwargs): + X_dtype = getattr(mx, node.inputs[0].dtype) + + def det(x): + sign, logabsdet = _lu_det_parts(x.astype(dtype=X_dtype, stream=mx.cpu)) + return mx.multiply(sign, mx.exp(logabsdet, stream=mx.cpu), stream=mx.cpu) + + return det + + +@mlx_funcify.register(SLogDet) +def mlx_funcify_SLogDet(op, node, **kwargs): + X_dtype = getattr(mx, node.inputs[0].dtype) + + def slogdet(x): + return _lu_det_parts(x.astype(dtype=X_dtype, stream=mx.cpu)) + + return slogdet + + @mlx_funcify.register(Eig) def mlx_funcify_Eig(op, node, **kwargs): X_dtype = getattr(mx, node.inputs[0].dtype) diff --git a/tests/link/mlx/test_nlinalg.py b/tests/link/mlx/test_nlinalg.py index 9eb5a84f40..728d088f2c 100644 --- a/tests/link/mlx/test_nlinalg.py +++ b/tests/link/mlx/test_nlinalg.py @@ -5,9 +5,32 @@ import pytensor.tensor as pt from pytensor import config +from pytensor.compile.mode import get_mode from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode +def test_mlx_det(): + rng = np.random.default_rng(15) + + A = pt.matrix(name="A") + A_val = rng.normal(size=(3, 3)).astype(config.floatX) + + out = pt.linalg.det(A) + + compare_mlx_and_py([A], [out], [A_val]) + + +def test_mlx_slogdet(): + rng = np.random.default_rng(15) + + A = pt.matrix(name="A") + A_val = rng.normal(size=(3, 3)).astype(config.floatX) + + sign, logabsdet = pt.linalg.slogdet(A) + + compare_mlx_and_py([A], [sign, logabsdet], [A_val], mlx_mode=get_mode("MLX")) + + def test_mlx_eig(): rng = np.random.default_rng(15) From 6584f1f4a38d719974fc413f4d2177db8fc98ab1 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 29 Mar 2026 19:50:48 -0500 Subject: [PATCH 7/8] Skip eig test --- tests/link/mlx/test_nlinalg.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/link/mlx/test_nlinalg.py b/tests/link/mlx/test_nlinalg.py index 728d088f2c..07fe0fa4cf 100644 --- a/tests/link/mlx/test_nlinalg.py +++ b/tests/link/mlx/test_nlinalg.py @@ -1,7 +1,9 @@ from functools import partial +import mlx.core as mx import numpy as np import pytest +from packaging.version import parse as V import pytensor.tensor as pt from pytensor import config @@ -31,6 +33,11 @@ def test_mlx_slogdet(): compare_mlx_and_py([A], [sign, logabsdet], [A_val], mlx_mode=get_mode("MLX")) +@pytest.mark.skipif( + V(mx.__version__) < V("0.30.1"), + reason="mx.linalg.eig causes a Fatal Python error (Abort trap) on MLX <0.30.1 " + "(maybe -- the exact version cutoff is unknown)", +) def test_mlx_eig(): rng = np.random.default_rng(15) From 7748e416e1c0180b51a21a90e7588ad1c2a918de Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 29 Mar 2026 20:50:25 -0500 Subject: [PATCH 8/8] importorskip test_nlinalg.py --- tests/link/mlx/test_nlinalg.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/link/mlx/test_nlinalg.py b/tests/link/mlx/test_nlinalg.py index 07fe0fa4cf..fcfc2cc26a 100644 --- a/tests/link/mlx/test_nlinalg.py +++ b/tests/link/mlx/test_nlinalg.py @@ -1,6 +1,5 @@ from functools import partial -import mlx.core as mx import numpy as np import pytest from packaging.version import parse as V @@ -11,6 +10,9 @@ from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode +mx = pytest.importorskip("mlx.core") + + def test_mlx_det(): rng = np.random.default_rng(15)