Skip to content

Commit 0410c83

Browse files
Metal backend: decompose linear with bias
1 parent 9011930 commit 0410c83

4 files changed

Lines changed: 144 additions & 2 deletions

File tree

backends/apple/metal/metal_backend.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,12 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:
4444

4545
@classmethod
4646
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
47-
"""Return Metal-specific passes (currently none)"""
48-
return []
47+
"""Return Metal-specific passes"""
48+
from executorch.backends.apple.metal.passes.decompose_linear_pass import (
49+
DecomposeLinearPass,
50+
)
51+
52+
return [DecomposeLinearPass()]
4953

5054
@classmethod
5155
def get_aoti_compile_options(
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from executorch.backends.apple.metal.passes.decompose_linear_pass import ( # noqa: F401
8+
DecomposeLinearPass,
9+
)
10+
11+
__all__ = ["DecomposeLinearPass"]
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class DecomposeLinearPass(ExportPass):
13+
"""
14+
Decompose aten.linear into matmul + add to avoid addmm.
15+
16+
For 2D inputs, we unsqueeze to 3D before decomposition to force the matmul
17+
code path instead of addmm. The C++ implementation of aten.linear directly
18+
calls addmm for 2D inputs with bias, which would require implementing
19+
aoti_torch_mps_addmm_out. By unsqueezing to 3D, we force the matmul path,
20+
then squeeze back to 2D.
21+
"""
22+
23+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
24+
modified = False
25+
graph = graph_module.graph
26+
27+
for node in graph.nodes:
28+
# Check if this is a linear operation
29+
is_linear = False
30+
31+
if node.op == "call_function":
32+
# Match both edge dialect and core aten linear operators
33+
if node.target == exir_ops.edge.aten.linear.default:
34+
is_linear = True
35+
elif node.target == torch.ops.aten.linear.default:
36+
is_linear = True
37+
38+
if is_linear:
39+
# Get input, weight, and bias arguments
40+
input_node = node.args[0]
41+
weight_node = node.args[1]
42+
bias_node = node.args[2] if len(node.args) > 2 else None
43+
44+
with graph.inserting_before(node):
45+
# Determine which ops to use based on the input operator
46+
target_str = str(node.target)
47+
48+
if "executorch_exir_dialects_edge" in target_str:
49+
# Use edge dialect operators
50+
t_op = exir_ops.edge.aten.t.default
51+
matmul_op = exir_ops.edge.aten.matmul.default
52+
add_op = exir_ops.edge.aten.add.Tensor
53+
unsqueeze_op = exir_ops.edge.aten.unsqueeze.default
54+
squeeze_op = exir_ops.edge.aten.squeeze.dims
55+
else:
56+
# Use core aten operators
57+
t_op = torch.ops.aten.t.default
58+
matmul_op = torch.ops.aten.matmul.default
59+
add_op = torch.ops.aten.add.Tensor
60+
unsqueeze_op = torch.ops.aten.unsqueeze.default
61+
squeeze_op = torch.ops.aten.squeeze.dims
62+
63+
# Check if input is 2D
64+
needs_unsqueeze = False
65+
if hasattr(input_node, "meta") and "val" in input_node.meta:
66+
if len(input_node.meta["val"].shape) == 2:
67+
needs_unsqueeze = True
68+
69+
# Unsqueeze 2D input to 3D: (M, K) -> (1, M, K)
70+
current_input = input_node
71+
if needs_unsqueeze:
72+
current_input = graph.call_function(
73+
unsqueeze_op,
74+
args=(input_node, 0),
75+
)
76+
77+
# Decompose linear: matmul(input, weight.T) + bias
78+
weight_t = graph.call_function(
79+
t_op,
80+
args=(weight_node,),
81+
)
82+
83+
matmul_result = graph.call_function(
84+
matmul_op,
85+
args=(current_input, weight_t),
86+
)
87+
88+
if bias_node is not None:
89+
result = graph.call_function(
90+
add_op,
91+
args=(matmul_result, bias_node),
92+
)
93+
else:
94+
result = matmul_result
95+
96+
# Squeeze 3D output back to 2D: (1, M, N) -> (M, N)
97+
if needs_unsqueeze:
98+
result = graph.call_function(
99+
squeeze_op,
100+
args=(result, [0]),
101+
)
102+
103+
# Replace all uses of the linear node with the decomposed result
104+
node.replace_all_uses_with(result)
105+
graph.erase_node(node)
106+
modified = True
107+
108+
if modified:
109+
graph_module.recompile()
110+
111+
return PassResult(graph_module, modified)

backends/apple/metal/tests/test_modules.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,22 @@ def forward(self, x: torch.Tensor):
208208
}
209209

210210

211+
# -------------------------------------------------------------------------
212+
class LinearWithBias(nn.Module):
213+
def __init__(self):
214+
super().__init__()
215+
self.linear = nn.Linear(7, 101, bias=True)
216+
217+
def forward(self, x: torch.Tensor):
218+
return self.linear(x)
219+
220+
221+
MODULE_REGISTRY["linear_bias"] = {
222+
"model_class": LinearWithBias,
223+
"input_shapes": [(127, 7)],
224+
"description": "Simple linear layer model bias",
225+
226+
211227
# -------------------------------------------------------------------------
212228
class LinearNoBiasInt4(nn.Module):
213229
def __init__(self):

0 commit comments

Comments
 (0)