Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion pytensor/link/mlx/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import mlx.core as mx

from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.nlinalg import SVD, KroneckerProduct, MatrixInverse, MatrixPinv
from pytensor.tensor.nlinalg import (
SVD,
Det,
Eig,
Eigh,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
SLogDet,
)


@mlx_funcify.register(SVD)
Expand Down Expand Up @@ -49,6 +58,71 @@ def kron(a, b):
return kron


def _lu_det_parts(x):
"""Shared helper: compute sign and log|det| via LU factorization."""
lu, pivots = mx.linalg.lu_factor(x, stream=mx.cpu)
diag_u = mx.diagonal(lu, stream=mx.cpu)
n_swaps = mx.sum(
pivots != mx.arange(pivots.shape[0], dtype=pivots.dtype, stream=mx.cpu),
stream=mx.cpu,
)
pivot_sign = 1 - 2 * (n_swaps % 2)
sign = mx.multiply(
pivot_sign,
mx.prod(mx.sign(diag_u, stream=mx.cpu), stream=mx.cpu),
stream=mx.cpu,
)
logabsdet = mx.sum(
mx.log(mx.abs(diag_u, stream=mx.cpu), stream=mx.cpu),
stream=mx.cpu,
)
return sign, logabsdet


@mlx_funcify.register(Det)
def mlx_funcify_Det(op, node, **kwargs):
X_dtype = getattr(mx, node.inputs[0].dtype)

def det(x):
sign, logabsdet = _lu_det_parts(x.astype(dtype=X_dtype, stream=mx.cpu))
return mx.multiply(sign, mx.exp(logabsdet, stream=mx.cpu), stream=mx.cpu)

return det


@mlx_funcify.register(SLogDet)
def mlx_funcify_SLogDet(op, node, **kwargs):
X_dtype = getattr(mx, node.inputs[0].dtype)

def slogdet(x):
return _lu_det_parts(x.astype(dtype=X_dtype, stream=mx.cpu))

return slogdet


@mlx_funcify.register(Eig)
def mlx_funcify_Eig(op, node, **kwargs):
X_dtype = getattr(mx, node.inputs[0].dtype)

def eig(x):
return mx.linalg.eig(x.astype(dtype=X_dtype, stream=mx.cpu), stream=mx.cpu)

return eig


@mlx_funcify.register(Eigh)
def mlx_funcify_Eigh(op, node, **kwargs):
uplo = op.UPLO
X_dtype = getattr(mx, node.inputs[0].dtype)

def eigh(x):
return mx.linalg.eigh(
x.astype(dtype=X_dtype, stream=mx.cpu), UPLO=uplo, stream=mx.cpu
)

return eigh


@mlx_funcify.register(MatrixInverse)
def mlx_funcify_MatrixInverse(op, node, **kwargs):
X_dtype = getattr(mx, node.inputs[0].dtype)
Expand Down
86 changes: 85 additions & 1 deletion pytensor/link/mlx/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,34 @@
import mlx.core as mx

from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.slinalg import LU, Cholesky, Solve, SolveTriangular
from pytensor.tensor.slinalg import (
LU,
QR,
Cholesky,
Eigvalsh,
LUFactor,
PivotToPermutations,
Solve,
SolveTriangular,
)


@mlx_funcify.register(Eigvalsh)
def mlx_funcify_Eigvalsh(op, node, **kwargs):
UPLO = "L" if op.lower else "U"
X_dtype = getattr(mx, node.inputs[0].dtype)

def eigvalsh(a, b=None):
if b is not None:
raise NotImplementedError(
"mlx.core.linalg.eigvalsh does not support generalized "
"eigenvector problems (b != None)"
)
return mx.linalg.eigvalsh(
a.astype(dtype=X_dtype, stream=mx.cpu), UPLO=UPLO, stream=mx.cpu
)

return eigvalsh


@mlx_funcify.register(Cholesky)
Expand Down Expand Up @@ -82,3 +109,60 @@ def lu(a):
)

return lu


@mlx_funcify.register(QR)
def mlx_funcify_QR(op, node, **kwargs):
mode = op.mode
A_dtype = getattr(mx, node.inputs[0].dtype)

if mode not in ("economic", "r"):
raise NotImplementedError(
f"mode='{mode}' is not supported in the MLX backend. "
"Only 'economic' and 'r' modes are available."
)

def qr(a):
Q, R = mx.linalg.qr(a.astype(dtype=A_dtype, stream=mx.cpu), stream=mx.cpu)
if mode == "r":
M = a.shape[-2]
K = R.shape[-2]
if M > K:
# Pytensor follows scipy convention for mode = 'r', which returns R with the same
# leading shape as the input.
pad_width = [(0, 0)] * (R.ndim - 2) + [(0, M - K), (0, 0)]
return mx.pad(R, pad_width, stream=mx.cpu)
return R
return Q, R

return qr


@mlx_funcify.register(LUFactor)
def mlx_funcify_LUFactor(op, node, **kwargs):
A_dtype = getattr(mx, node.inputs[0].dtype)

def lu_factor(a):
lu, pivots = mx.linalg.lu_factor(
a.astype(dtype=A_dtype, stream=mx.cpu), stream=mx.cpu
)
return lu, pivots.astype(mx.int32, stream=mx.cpu)

return lu_factor


@mlx_funcify.register(PivotToPermutations)
def mlx_funcify_PivotToPermutations(op, **kwargs):
inverse = op.inverse

def pivot_to_permutations(pivots):
pivots = mx.array(pivots)
n = pivots.shape[0]
p_inv = mx.arange(n, dtype=mx.int32)
for i in range(n):
p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i]
if inverse:
return p_inv
return mx.argsort(p_inv)

return pivot_to_permutations
57 changes: 57 additions & 0 deletions tests/link/mlx/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,69 @@

import numpy as np
import pytest
from packaging.version import parse as V

import pytensor.tensor as pt
from pytensor import config
from pytensor.compile.mode import get_mode
from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode


mx = pytest.importorskip("mlx.core")


def test_mlx_det():
rng = np.random.default_rng(15)

A = pt.matrix(name="A")
A_val = rng.normal(size=(3, 3)).astype(config.floatX)

out = pt.linalg.det(A)

compare_mlx_and_py([A], [out], [A_val])


def test_mlx_slogdet():
rng = np.random.default_rng(15)

A = pt.matrix(name="A")
A_val = rng.normal(size=(3, 3)).astype(config.floatX)

sign, logabsdet = pt.linalg.slogdet(A)

compare_mlx_and_py([A], [sign, logabsdet], [A_val], mlx_mode=get_mode("MLX"))


@pytest.mark.skipif(
V(mx.__version__) < V("0.30.1"),
reason="mx.linalg.eig causes a Fatal Python error (Abort trap) on MLX <0.30.1 "
"(maybe -- the exact version cutoff is unknown)",
)
def test_mlx_eig():
rng = np.random.default_rng(15)

M = rng.normal(size=(3, 3))
A_val = (M @ M.T).astype(config.floatX)

A = pt.matrix(name="A")
outs = pt.linalg.eig(A)

compare_mlx_and_py([A], outs, [A_val])


@pytest.mark.parametrize("UPLO", ["L", "U"])
def test_mlx_eigh(UPLO):
rng = np.random.default_rng(15)

M = rng.normal(size=(3, 3))
A_val = (M @ M.T).astype(config.floatX)

A = pt.matrix(name="A")
outs = pt.linalg.eigh(A, UPLO=UPLO)

compare_mlx_and_py([A], outs, [A_val])


@pytest.mark.parametrize("compute_uv", [True, False])
def test_mlx_svd(compute_uv):
rng = np.random.default_rng(15)
Expand Down
51 changes: 51 additions & 0 deletions tests/link/mlx/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,54 @@ def test_mlx_LU():
mlx_mode=mlx_mode,
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
)


@pytest.mark.parametrize("lower", [True, False])
def test_mlx_eigvalsh(lower):
rng = np.random.default_rng(15)

M = rng.normal(size=(3, 3))
A_val = (M @ M.T).astype(config.floatX)

A = pt.matrix(name="A")
B = pt.matrix(name="B")
out = pt.linalg.eigvalsh(A, B, lower=lower)

compare_mlx_and_py([A, B], [out], [A_val, None])


def test_mlx_lu_factor():
rng = np.random.default_rng(15)

A = pt.matrix(name="A")
A_val = rng.normal(size=(5, 5)).astype(config.floatX)

out = pt.linalg.lu_factor(A)

compare_mlx_and_py([A], out, [A_val])


def test_mlx_pivot_to_permutations():
rng = np.random.default_rng(15)

A = pt.matrix(name="A")
A_val = rng.normal(size=(5, 5)).astype(config.floatX)

from pytensor.tensor.slinalg import pivot_to_permutation

lu_and_pivots = pt.linalg.lu_factor(A)
out = pivot_to_permutation(lu_and_pivots[1])

compare_mlx_and_py([A], [out], [A_val])


@pytest.mark.parametrize("mode", ["economic", "r"])
def test_mlx_qr(mode):
rng = np.random.default_rng(15)

A = pt.matrix(name="A")
A_val = rng.normal(size=(5, 3)).astype(config.floatX)

out = pt.linalg.qr(A, mode=mode)

compare_mlx_and_py([A], out, [A_val])
Loading