diff --git a/pytensor/link/mlx/dispatch/nlinalg.py b/pytensor/link/mlx/dispatch/nlinalg.py index 7303305aa6..dc4b270a72 100644 --- a/pytensor/link/mlx/dispatch/nlinalg.py +++ b/pytensor/link/mlx/dispatch/nlinalg.py @@ -1,7 +1,16 @@ import mlx.core as mx from pytensor.link.mlx.dispatch.basic import mlx_funcify -from pytensor.tensor.nlinalg import SVD, KroneckerProduct, MatrixInverse, MatrixPinv +from pytensor.tensor.nlinalg import ( + SVD, + Det, + Eig, + Eigh, + KroneckerProduct, + MatrixInverse, + MatrixPinv, + SLogDet, +) @mlx_funcify.register(SVD) @@ -49,6 +58,71 @@ def kron(a, b): return kron +def _lu_det_parts(x): + """Shared helper: compute sign and log|det| via LU factorization.""" + lu, pivots = mx.linalg.lu_factor(x, stream=mx.cpu) + diag_u = mx.diagonal(lu, stream=mx.cpu) + n_swaps = mx.sum( + pivots != mx.arange(pivots.shape[0], dtype=pivots.dtype, stream=mx.cpu), + stream=mx.cpu, + ) + pivot_sign = 1 - 2 * (n_swaps % 2) + sign = mx.multiply( + pivot_sign, + mx.prod(mx.sign(diag_u, stream=mx.cpu), stream=mx.cpu), + stream=mx.cpu, + ) + logabsdet = mx.sum( + mx.log(mx.abs(diag_u, stream=mx.cpu), stream=mx.cpu), + stream=mx.cpu, + ) + return sign, logabsdet + + +@mlx_funcify.register(Det) +def mlx_funcify_Det(op, node, **kwargs): + X_dtype = getattr(mx, node.inputs[0].dtype) + + def det(x): + sign, logabsdet = _lu_det_parts(x.astype(dtype=X_dtype, stream=mx.cpu)) + return mx.multiply(sign, mx.exp(logabsdet, stream=mx.cpu), stream=mx.cpu) + + return det + + +@mlx_funcify.register(SLogDet) +def mlx_funcify_SLogDet(op, node, **kwargs): + X_dtype = getattr(mx, node.inputs[0].dtype) + + def slogdet(x): + return _lu_det_parts(x.astype(dtype=X_dtype, stream=mx.cpu)) + + return slogdet + + +@mlx_funcify.register(Eig) +def mlx_funcify_Eig(op, node, **kwargs): + X_dtype = getattr(mx, node.inputs[0].dtype) + + def eig(x): + return mx.linalg.eig(x.astype(dtype=X_dtype, stream=mx.cpu), stream=mx.cpu) + + return eig + + +@mlx_funcify.register(Eigh) +def mlx_funcify_Eigh(op, node, **kwargs): + uplo = op.UPLO + X_dtype = getattr(mx, node.inputs[0].dtype) + + def eigh(x): + return mx.linalg.eigh( + x.astype(dtype=X_dtype, stream=mx.cpu), UPLO=uplo, stream=mx.cpu + ) + + return eigh + + @mlx_funcify.register(MatrixInverse) def mlx_funcify_MatrixInverse(op, node, **kwargs): X_dtype = getattr(mx, node.inputs[0].dtype) diff --git a/pytensor/link/mlx/dispatch/slinalg.py b/pytensor/link/mlx/dispatch/slinalg.py index 6c102b7f67..88ff3e5553 100644 --- a/pytensor/link/mlx/dispatch/slinalg.py +++ b/pytensor/link/mlx/dispatch/slinalg.py @@ -3,7 +3,34 @@ import mlx.core as mx from pytensor.link.mlx.dispatch.basic import mlx_funcify -from pytensor.tensor.slinalg import LU, Cholesky, Solve, SolveTriangular +from pytensor.tensor.slinalg import ( + LU, + QR, + Cholesky, + Eigvalsh, + LUFactor, + PivotToPermutations, + Solve, + SolveTriangular, +) + + +@mlx_funcify.register(Eigvalsh) +def mlx_funcify_Eigvalsh(op, node, **kwargs): + UPLO = "L" if op.lower else "U" + X_dtype = getattr(mx, node.inputs[0].dtype) + + def eigvalsh(a, b=None): + if b is not None: + raise NotImplementedError( + "mlx.core.linalg.eigvalsh does not support generalized " + "eigenvector problems (b != None)" + ) + return mx.linalg.eigvalsh( + a.astype(dtype=X_dtype, stream=mx.cpu), UPLO=UPLO, stream=mx.cpu + ) + + return eigvalsh @mlx_funcify.register(Cholesky) @@ -82,3 +109,60 @@ def lu(a): ) return lu + + +@mlx_funcify.register(QR) +def mlx_funcify_QR(op, node, **kwargs): + mode = op.mode + A_dtype = getattr(mx, node.inputs[0].dtype) + + if mode not in ("economic", "r"): + raise NotImplementedError( + f"mode='{mode}' is not supported in the MLX backend. " + "Only 'economic' and 'r' modes are available." + ) + + def qr(a): + Q, R = mx.linalg.qr(a.astype(dtype=A_dtype, stream=mx.cpu), stream=mx.cpu) + if mode == "r": + M = a.shape[-2] + K = R.shape[-2] + if M > K: + # Pytensor follows scipy convention for mode = 'r', which returns R with the same + # leading shape as the input. + pad_width = [(0, 0)] * (R.ndim - 2) + [(0, M - K), (0, 0)] + return mx.pad(R, pad_width, stream=mx.cpu) + return R + return Q, R + + return qr + + +@mlx_funcify.register(LUFactor) +def mlx_funcify_LUFactor(op, node, **kwargs): + A_dtype = getattr(mx, node.inputs[0].dtype) + + def lu_factor(a): + lu, pivots = mx.linalg.lu_factor( + a.astype(dtype=A_dtype, stream=mx.cpu), stream=mx.cpu + ) + return lu, pivots.astype(mx.int32, stream=mx.cpu) + + return lu_factor + + +@mlx_funcify.register(PivotToPermutations) +def mlx_funcify_PivotToPermutations(op, **kwargs): + inverse = op.inverse + + def pivot_to_permutations(pivots): + pivots = mx.array(pivots) + n = pivots.shape[0] + p_inv = mx.arange(n, dtype=mx.int32) + for i in range(n): + p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i] + if inverse: + return p_inv + return mx.argsort(p_inv) + + return pivot_to_permutations diff --git a/tests/link/mlx/test_nlinalg.py b/tests/link/mlx/test_nlinalg.py index 1df408c735..fcfc2cc26a 100644 --- a/tests/link/mlx/test_nlinalg.py +++ b/tests/link/mlx/test_nlinalg.py @@ -2,12 +2,69 @@ import numpy as np import pytest +from packaging.version import parse as V import pytensor.tensor as pt from pytensor import config +from pytensor.compile.mode import get_mode from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode +mx = pytest.importorskip("mlx.core") + + +def test_mlx_det(): + rng = np.random.default_rng(15) + + A = pt.matrix(name="A") + A_val = rng.normal(size=(3, 3)).astype(config.floatX) + + out = pt.linalg.det(A) + + compare_mlx_and_py([A], [out], [A_val]) + + +def test_mlx_slogdet(): + rng = np.random.default_rng(15) + + A = pt.matrix(name="A") + A_val = rng.normal(size=(3, 3)).astype(config.floatX) + + sign, logabsdet = pt.linalg.slogdet(A) + + compare_mlx_and_py([A], [sign, logabsdet], [A_val], mlx_mode=get_mode("MLX")) + + +@pytest.mark.skipif( + V(mx.__version__) < V("0.30.1"), + reason="mx.linalg.eig causes a Fatal Python error (Abort trap) on MLX <0.30.1 " + "(maybe -- the exact version cutoff is unknown)", +) +def test_mlx_eig(): + rng = np.random.default_rng(15) + + M = rng.normal(size=(3, 3)) + A_val = (M @ M.T).astype(config.floatX) + + A = pt.matrix(name="A") + outs = pt.linalg.eig(A) + + compare_mlx_and_py([A], outs, [A_val]) + + +@pytest.mark.parametrize("UPLO", ["L", "U"]) +def test_mlx_eigh(UPLO): + rng = np.random.default_rng(15) + + M = rng.normal(size=(3, 3)) + A_val = (M @ M.T).astype(config.floatX) + + A = pt.matrix(name="A") + outs = pt.linalg.eigh(A, UPLO=UPLO) + + compare_mlx_and_py([A], outs, [A_val]) + + @pytest.mark.parametrize("compute_uv", [True, False]) def test_mlx_svd(compute_uv): rng = np.random.default_rng(15) diff --git a/tests/link/mlx/test_slinalg.py b/tests/link/mlx/test_slinalg.py index 88d92bf43e..3c942aec02 100644 --- a/tests/link/mlx/test_slinalg.py +++ b/tests/link/mlx/test_slinalg.py @@ -107,3 +107,54 @@ def test_mlx_LU(): mlx_mode=mlx_mode, assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True), ) + + +@pytest.mark.parametrize("lower", [True, False]) +def test_mlx_eigvalsh(lower): + rng = np.random.default_rng(15) + + M = rng.normal(size=(3, 3)) + A_val = (M @ M.T).astype(config.floatX) + + A = pt.matrix(name="A") + B = pt.matrix(name="B") + out = pt.linalg.eigvalsh(A, B, lower=lower) + + compare_mlx_and_py([A, B], [out], [A_val, None]) + + +def test_mlx_lu_factor(): + rng = np.random.default_rng(15) + + A = pt.matrix(name="A") + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + + out = pt.linalg.lu_factor(A) + + compare_mlx_and_py([A], out, [A_val]) + + +def test_mlx_pivot_to_permutations(): + rng = np.random.default_rng(15) + + A = pt.matrix(name="A") + A_val = rng.normal(size=(5, 5)).astype(config.floatX) + + from pytensor.tensor.slinalg import pivot_to_permutation + + lu_and_pivots = pt.linalg.lu_factor(A) + out = pivot_to_permutation(lu_and_pivots[1]) + + compare_mlx_and_py([A], [out], [A_val]) + + +@pytest.mark.parametrize("mode", ["economic", "r"]) +def test_mlx_qr(mode): + rng = np.random.default_rng(15) + + A = pt.matrix(name="A") + A_val = rng.normal(size=(5, 3)).astype(config.floatX) + + out = pt.linalg.qr(A, mode=mode) + + compare_mlx_and_py([A], out, [A_val])