Skip to content

Commit 9bf3f28

Browse files
fix linear bias decomposition invokation (#18168)
Fix parakeet/voxtral realtime export using PyTorch release. In both `export_parakeet_tdt.py` and `export_voxtral_rt.py`, replaced the manual decomposition dictionary with `torch.export.default_decompositions()` and added the custom `_linear_bias_decomposition` for `torch.ops.aten.linear.default`.
1 parent 5c1b080 commit 9bf3f28

2 files changed

Lines changed: 6 additions & 8 deletions

File tree

examples/models/parakeet/export_parakeet_tdt.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -509,13 +509,11 @@ def _create_metal_partitioners(programs):
509509

510510
# Run decompositions for non-preprocessor programs
511511
updated_programs = {}
512+
decomp_table = torch.export.default_decompositions()
513+
decomp_table[torch.ops.aten.linear.default] = _linear_bias_decomposition
512514
for key, ep in programs.items():
513-
# print(f"Running decompositions for {key}")
514-
# print(ep.graph_module)
515515
if key != "preprocessor":
516-
updated_programs[key] = ep.run_decompositions(
517-
{torch.ops.aten.linear.default: _linear_bias_decomposition}
518-
)
516+
updated_programs[key] = ep.run_decompositions(decomp_table)
519517
else:
520518
updated_programs[key] = ep
521519

examples/models/voxtral_realtime/export_voxtral_rt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,10 +394,10 @@ def lower_to_executorch(programs, metadata, backend="xnnpack"):
394394

395395
# Run decompositions for Metal backend
396396
updated_programs = {}
397+
decomp_table = torch.export.default_decompositions()
398+
decomp_table[torch.ops.aten.linear.default] = _linear_bias_decomposition
397399
for key, ep in programs.items():
398-
updated_programs[key] = ep.run_decompositions(
399-
{torch.ops.aten.linear.default: _linear_bias_decomposition}
400-
)
400+
updated_programs[key] = ep.run_decompositions(decomp_table)
401401
programs = updated_programs
402402

403403
partitioner = {}

0 commit comments

Comments
 (0)