diff --git a/backends/arm/_passes/ANALYSIS_expensive_transposes.md b/backends/arm/_passes/ANALYSIS_expensive_transposes.md new file mode 100644 index 00000000000..54c1450320e --- /dev/null +++ b/backends/arm/_passes/ANALYSIS_expensive_transposes.md @@ -0,0 +1,320 @@ +# Analysis: Expensive Transposes in Control Ceres Model + +## Executive Summary + +The Control Ceres model has expensive transpose operations that Vela implements as long sequences of `NPU_OP_POOL` (1x1 AvgPool) operations on Ethos-U55. These transposes are **NOT** fused by the current passes because they surround `Reshape` (view_copy) operations. + +## Key Finding: Reshape Requires Transposes for NCHW ↔ NHWC Conversion + +The most expensive transposes in the model are generated by patterns involving `Reshape` operations: + +``` +Pattern: Transpose → Reshape → Transpose + [1,2,14,72] → [1,28,1,72] → [1,1,72,28] + (NHWC→NCHW) (reshape) (NCHW→NHWC) +``` + +The transposes are **required** because: +1. The TOSA/Vela backend requires NHWC layout for Conv2D operations +2. The PyTorch model uses NCHW layout internally +3. `view_copy` (Reshape) changes tensor dimensions and requires consistent layout + +## Expensive Transpose Inventory + +### Highest Priority (Around Reshape Operations) + +| Transpose ID | Tensor Size | % Time | Cycles | Pattern | +|--------------|-------------|--------|--------|---------| +| `tosa_transpose_default_8` | 54,048 bytes | 2.51% | 252 | T→Reshape→**T** | +| `tosa_transpose_default_7` | 54,048 bytes | 2.51% | 252 | **T**→Reshape→T | +| `tosa_transpose_default_9` | 53,984 bytes | 2.51% | 756 | Rescale→**T**→Reshape | +| `tosa_transpose_default_10` | 53,984 bytes | 2.51% | 252 | Reshape→**T** | +| `tosa_transpose_default_11` | 51,680 bytes | 2.40% | 126 | Reshape→**T** | +| `tosa_transpose_default_12` | 51,616 bytes | 2.40% | 378 | **T**→Reshape | + +**Total time for these 6 transposes: ~15% of model execution** + +### Medium Priority (Other Transposes) + +| Transpose ID | Tensor Size | % Time | Cycles | +|--------------|-------------|--------|--------| +| `tosa_transpose_default_4` | 49,664 bytes | 2.31% | 189 | +| `tosa_transpose_default_3` | 47,008 bytes | 2.18% | 95 | +| `tosa_transpose_default_2` | 44,992 bytes | 2.09% | 95 | +| `tosa_transpose_default_1` | 40,960 bytes | 1.90% | 95 | +| `tosa_transpose_default_6` | 33,312 bytes | 1.55% | 178 | +| `tosa_transpose_default_5` | 32,320 bytes | 1.50% | 224 | +| `tosa_transpose_default` | 32,416 bytes | 1.50% | 133 | + +### Lower Priority (Final Output Transposes) + +| Transpose ID | Tensor Size | % Time | Cycles | Location | +|--------------|-------------|--------|--------|----------| +| `tosa_transpose_default_13` | 28,032 bytes | 1.30% | 62 | Near end | +| `tosa_transpose_default_18` | 24,160 bytes | 1.12% | 252 | Final | +| `tosa_transpose_default_17` | 21,648 bytes | 1.00% | 126 | Final | +| `tosa_transpose_default_16` | 17,616 bytes | 0.82% | 126 | Final | +| `tosa_transpose_default_15` | 9,552 bytes | 0.44% | 126 | Final | +| `tosa_transpose_default_14` | 8,864 bytes | 0.41% | 66 | Final | + +## Graph Pattern Analysis + +### Pattern 1: Around Reshape Operations (indices ~797-802) +``` +Transpose (permute_copy_6) + ↓ +Clamp (aten_clamp_default_2) + ↓ +Rescale (tosa_rescale_default_24) + ↓ +Transpose (tosa_transpose_default_7) ← EXPENSIVE + ↓ +Reshape (aten_view_copy_default_2) + ↓ +Transpose (tosa_transpose_default_8) ← EXPENSIVE +``` + +### Pattern 2: Rescale → Transpose → Reshape (indices ~838-841) +``` +Rescale (tosa_rescale_default_37) + ↓ +Transpose (tosa_transpose_default_9) ← EXPENSIVE + ↓ +Reshape (aten_view_copy_default_5) + ↓ +Transpose (tosa_transpose_default_10) ← EXPENSIVE +``` + +### Pattern 3: Clamp → Rescale → Reshape → Transpose (indices ~856-859) +``` +Clamp (aten_clamp_default_3) + ↓ +Rescale (tosa_rescale_default_45) + ↓ +Reshape (aten_view_copy_default_8) + ↓ +Transpose (tosa_transpose_default_11) ← EXPENSIVE +``` + +## Investigation Questions + +### 1. Are Order Annotations Correct? +- **Question**: Are the TOSA transposes using the correct permutation orders? +- **Investigation**: Check `ToTosaMemoryFormatPass` to verify permutation logic + +### 2. Where Are Transposes Inserted? +- **Source**: `ToTosaMemoryFormatPass` in `executorch/backends/arm/_passes/to_tosa_memory_format_pass.py` +- **Purpose**: Convert NCHW (PyTorch default) to NHWC (TOSA/Ethos requirement) + +### 3. Can We Eliminate Earlier? +- **Option A**: Modify model to use NHWC throughout (training change) +- **Option B**: Fuse Transpose through Reshape mathematically +- **Option C**: Handle reshape in NHWC space directly + +## Root Cause Analysis: Why FuseTransposeReshapeTransposePass Doesn't Help + +### Finding: The Reshapes Are Not Simple Dimension Combinations + +The `FuseTransposeReshapeTransposePass` can only fuse patterns where the reshape is a simple dimension combination/split (e.g., `[1, 2, 14, 72]` → `[1, 28, 72]` which combines dims 1,2 into dim 1). + +However, the expensive transposes in Control Ceres have **complex reshapes that reorder dimensions**: + +``` +tosa_transpose_default_7: OFM: [1, 2, 14, 72] +reshape (view_copy_2): [1, 2, 14, 72] → [1, 1, 72, 28] ← This is NOT a simple combine/split! +tosa_transpose_default_8: IFM: [1, 1, 72, 28] +``` + +The reshape `[1, 2, 14, 72]` → `[1, 1, 72, 28]` involves: +- Combining `2 * 14 = 28` +- But also **moving** the 72 channel dimension to a different position + +This is equivalent to: +1. Flatten: `[1, 2, 14, 72]` → `[1, 2016]` (total elements: 2016) +2. Reshape: `[1, 2016]` → `[1, 1, 72, 28]` + +The `_get_shape_indices` function returns `None` for such reshapes, so the fusion is skipped. + +### Implication + +**The transposes cannot be fused with the current approach** because the reshape involves both dimension combining AND reordering. These transposes are mathematically necessary for the reshape to work correctly. + +### Possible Alternative Strategies + +1. **Modify the model architecture** to avoid such reshape patterns +2. **Use NHWC-native operations** in the model to eliminate the need for transposes +3. **Investigate Vela optimizations** to make transposes more efficient +4. **Create a more sophisticated fusion pass** that can handle arbitrary reshapes (complex mathematical analysis required) + +## Next Steps + +1. [x] Read `ToTosaMemoryFormatPass` to understand transpose insertion logic +2. [ ] Identify where reshapes are created in the model +3. [ ] Investigate if `Transpose → Reshape → Transpose` can be mathematically fused +4. [ ] Check if the model can use NHWC-compatible reshape shapes +5. [ ] Consider creating `FuseTransposeThroughReshapePass` if mathematically feasible + +--- + +## 🚀 New Strategy: Compile-Time Transpose Folding for Static Tensors + +### Critical Finding: FuseConstantArgsPass SKIPS Transposes + +The `FuseConstantArgsPass` (lines 142-148 in `fuse_constant_ops_pass.py`) **explicitly SKIPS** `TRANSPOSE.default` operations: + +```python +if node.target in [ + exir_ops.backend.tosa.MATMUL.default, + exir_ops.backend.tosa.RESCALE.default, + exir_ops.backend.tosa.RESIZE.default, + exir_ops.backend.tosa.TABLE.default, + exir_ops.backend.tosa.TRANSPOSE.default, # <-- SKIPPED! +]: + continue +``` + +This means that even when a transpose operates on a static tensor (weight/constant), it is NOT folded at compile time. + +### Why Transposes Are Currently Not Folded + +The comment history doesn't explain why transposes were excluded. Possible reasons: +1. **Concern about tensor size increase** - But transposes preserve tensor size +2. **Special handling needed for shape metadata** - Transposing changes `tosa_dim_order` +3. **Simply not implemented yet** + +### Proposed Solution: FoldConstantTransposePass + +Create a new pass that specifically folds transposes on constant/static tensors at compile time. + +#### How It Would Work + +1. **Identify transpose nodes** on static tensors (parameters, buffers, lifted tensor constants) +2. **Actually permute the tensor data** at compile time using `tensor.permute(order).contiguous()` +3. **Create a new constant placeholder** with the permuted data +4. **Remove the transpose node** and rewire users to the new constant + +#### Pattern Before: +``` +static_weight (placeholder) -> TRANSPOSE [0,2,3,1] -> Conv2D +``` + +#### Pattern After: +``` +static_weight_nhwc (placeholder, data already permuted) -> Conv2D +``` + +### Example Implementation (Conceptual) + +```python +class FoldConstantTransposePass(ArmPass): + """Folds transposes on static tensors at compile time.""" + + def __init__(self, exported_program: ExportedProgram, *args, **kwargs): + super().__init__(*args, **kwargs) + self.exported_program = exported_program + + def call(self, graph_module): + modified = False + for node in list(graph_module.graph.nodes): + if node.target != exir_ops.backend.tosa.TRANSPOSE.default: + continue + + input_node = node.all_input_nodes[0] + if not is_param_node(self.exported_program, input_node): + continue # Not a static tensor + + # Get the static tensor data + tensor = get_param_tensor(self.exported_program, input_node) + perm = node.args[1] # Permutation order + + # Actually permute the data at compile time + permuted_tensor = tensor.permute(perm).contiguous() + + # Create new constant placeholder with permuted data + with graph_module.graph.inserting_before(input_node): + const_node = create_constant_placeholder( + self.exported_program, + graph=graph_module.graph, + kind=get_constant_placeholder_kind(self.exported_program, input_node), + name=f"{input_node.name}_permuted", + data=permuted_tensor, + persistent_buffer=is_persistent_buffer(self.exported_program, input_node), + ) + + # Update metadata + const_node.meta["tosa_dim_order"] = node.meta.get("tosa_dim_order", tuple(range(len(perm)))) + + # Rewire users and remove transpose + node.replace_all_uses_with(const_node) + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, modified) +``` + +### Impact Analysis + +#### Transposes That Could Be Folded: +- Weight transposes for Conv2D (NCHW→NHWC) +- Constant tensor transposes for MatMul +- Lifted tensor constants used in reshape patterns + +#### Transposes That CANNOT Be Folded: +- Transposes on dynamic activations (runtime data) +- Transposes on model inputs + +### Precedent: RewriteConvPass._reshape_weights() + +The ARM backend already has precedent for compile-time weight reordering in `RewriteConvPass._reshape_weights()` (lines 115-161): + +```python +def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None: + weight_tensor = get_param_tensor(self.exported_program, weight_node) + + reshaped_weight_tensor = ( + weight_tensor.permute(HWCM_ORDER) + .reshape(...) + .permute(NHWC_INVERSE_ORDER) + ) + + # Update state_dict with permuted tensor + self.exported_program.state_dict[param_name] = reshaped_weight_tensor +``` + +### Questions to Investigate + +1. **Why was TRANSPOSE.default excluded from FuseConstantArgsPass?** + - Search for git history or comments + +2. **Does folding transposes break any downstream passes?** + - `ToTosaMemoryFormatPass` annotations + - TOSA serialization + +3. **Are there edge cases?** + - Transpose Conv2D weights (special handling at line 430) + - Multi-user constant nodes + +### Next Steps for Implementation + +1. [ ] Search for why transposes were excluded from FuseConstantArgsPass +2. [ ] Create FoldConstantTransposePass +3. [ ] Add to pass pipeline AFTER FuseConstantArgsPass +4. [ ] Test on Control Ceres model to measure impact +5. [ ] Measure Vela cycle reduction + +## Related Files + +- `/data/users/eliamesefe/fbsource/fbcode/executorch/backends/arm/_passes/to_tosa_memory_format_pass.py` +- `/data/users/eliamesefe/fbsource/fbcode/executorch/backends/arm/_passes/arm_pass_manager.py` +- `/data/users/eliamesefe/fbsource/fbcode/executorch/backends/arm/_passes/fuse_transpose_sandwich_pass.py` +- `/data/users/eliamesefe/fbsource/fbcode/executorch/backends/arm/_passes/propagate_transposes_through_rescale_pass.py` +- `/data/users/eliamesefe/fbsource/fbcode/executorch/backends/arm/_passes/fuse_constant_ops_pass.py` (FuseConstantArgsPass) +- `/data/users/eliamesefe/fbsource/fbcode/executorch/backends/arm/_passes/rewrite_conv_pass.py` (_reshape_weights precedent) + +--- + +**Date**: 2026-03-06 +**Author**: Eli Amesefe +**Related Work**: Transpose fusion passes for Ethos-U55 optimization diff --git a/backends/arm/_passes/TODO_relaxed_concat_pattern.md b/backends/arm/_passes/TODO_relaxed_concat_pattern.md new file mode 100644 index 00000000000..a78d8c9e106 --- /dev/null +++ b/backends/arm/_passes/TODO_relaxed_concat_pattern.md @@ -0,0 +1,81 @@ +# TODO: Investigate relaxed transpose-concat pattern matching (multiple users) + +## Background + +`PropagateTransposesThroughConcatPass` was added (commit `e7b5bd6d`) to target the pattern: +``` +[T(perm), T(perm), ...] → Concat(dim=d) → T(inv_perm) +``` + +Current pass requirements: +1. All Concat inputs must be transposes with the **same** permutation +2. All input transposes must have **only one user** (the Concat) +3. Output transpose must have the **inverse** permutation + +## Problem + +In the Control Ceres model, the actual patterns have: +- Input transposes with **multiple users** (not just Concat) +- Input transposes with **different permutations** + +This prevents the pass from intercepting patterns in Control Ceres. + +Example from Vela command stream: +``` +766 Transpose tosa_transpose_default_5 +767 Transpose tosa_transpose_default +768 Concat aten_cat_default +769 Transpose tosa_transpose_default_6 +``` + +## Investigation Scope + +1. **Analyze the actual graph patterns** in Control Ceres to understand: + - What users do the input transposes have besides Concat? + - What are the permutations of the input transposes? + +2. **Evaluate relaxation options**: + - **Option A**: Allow input transposes with multiple users + - Requires duplicating the transpose for other users + - May increase graph size but reduce total transpose ops + - **Option B**: Handle mixed permutations by only propagating matching subsets + - Only some inputs participate in the optimization + - **Option C**: Target a different pattern entirely + - E.g., single-input concat optimization + +3. **Assess impact**: + - Correctness: Ensure no semantic changes + - Performance: Measure actual cycle reduction on Ethos-U55 + +## Related Work + +- `FuseTransposeSandwichPass`: Targets T1 → op → T2 patterns +- `PropagateTransposesThroughRescalePass`: Targets T1 → Rescale → T2 patterns +- `FuseConsecutiveTransposesPass`: Targets T1 → T2 sequences + +## Files + +- `/fbcode/executorch/backends/arm/_passes/propagate_transposes_through_concat_pass.py` +- `/fbcode/executorch/backends/arm/_passes/arm_pass_manager.py` + +## Test Command + +```bash +buck2 test @fbcode//mode/dev fbcode//frl/ctrl/torchstream/torchstream/pt2/tests:test_pt2_emg_lowering -- 'test_combined_control_ceres_u55' --print-passing-details +``` + +## BEFORE Metrics (baseline) + +- NPU cycles: 1,356,652 cycles/batch +- Total cycles: 1,369,354 cycles/batch +- Total SRAM used: 2,103.83 KiB +- NPU operators: 367 (100.0%) +- Total Transpose ops: 536 (tosa_transpose: 273, aten_permute_copy: 278) + +## Author + +eliamesefe@meta.com + +## Date Created + +2026-03-05 diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 36e3fe004d9..1b754e596c3 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -102,10 +102,25 @@ QuantizeClampArgumentsPass, ) from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa +from .fuse_consecutive_transposes_pass import FuseConsecutiveTransposesPass # noqa +from .fuse_transpose_reshape_transpose_pass import ( # noqa + FuseTransposeReshapeTransposePass, +) +from .fuse_transpose_reshape_linear_pass import ( # noqa + FuseTransposeReshapeLinearPass, +) +from .fuse_transpose_sandwich_pass import FuseTransposeSandwichPass # noqa +from .propagate_transposes_through_rescale_pass import ( # noqa + PropagateTransposesThroughRescalePass, +) +from .propagate_transposes_through_concat_pass import ( # noqa + PropagateTransposesThroughConcatPass, +) from .fuse_constant_ops_pass import ( # noqa ComputeConstantOpsAOTPass, FuseConstantArgsPass, ) +from .fold_constant_transpose_pass import FoldConstantTransposePass # noqa from .fuse_duplicate_users_pass import FuseDuplicateUsersPass # noqa from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b3f9fd2ef8a..030921e9a19 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -97,6 +97,13 @@ DecorateFp32toInt32CastingPass, FoldAndAnnotateQParamsPass, FuseBatchNorm2dPass, + FuseConsecutiveTransposesPass, + FuseTransposeReshapeTransposePass, + FuseTransposeReshapeLinearPass, + FuseTransposeSandwichPass, + PropagateTransposesThroughConcatPass, + PropagateTransposesThroughRescalePass, + FoldConstantTransposePass, FuseConstantArgsPass, FuseDuplicateUsersPass, FuseEqualPlaceholdersPass, @@ -386,9 +393,16 @@ def _tosa_pipeline( # Postprocessing/cleanup passes self.add_passes( [ - CastInt64BuffersToInt32Pass(exported_program), + CastInt64BuffersToInt32Pass(exported_program), FuseEqualPlaceholdersPass(exported_program), - ToTosaMemoryFormatPass(exported_program), + ToTosaMemoryFormatPass(exported_program), + FoldConstantTransposePass(exported_program), + PropagateTransposesThroughConcatPass(), + PropagateTransposesThroughRescalePass(), + FuseConsecutiveTransposesPass(), + FuseTransposeReshapeLinearPass(), + FuseTransposeReshapeTransposePass(), + FuseTransposeSandwichPass(), RemoveNoopPass(), InsertRescalePass(), ] diff --git a/backends/arm/_passes/fold_constant_transpose_pass.py b/backends/arm/_passes/fold_constant_transpose_pass.py new file mode 100644 index 00000000000..2adf28ead81 --- /dev/null +++ b/backends/arm/_passes/fold_constant_transpose_pass.py @@ -0,0 +1,234 @@ +# Copyright 2025-2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Fold TOSA TRANSPOSE operations on constant tensors at compile time. + +This pass identifies TRANSPOSE operations where the input is a static tensor +(parameter, buffer, or lifted tensor constant) and folds the transpose at +compile time by: +1. Actually permuting the tensor data +2. Creating a new constant placeholder with the permuted data +3. Removing the transpose node and rewiring users + +This eliminates runtime transpose operations on static tensors like weights, +which is especially important for Ethos-U55 where Vela implements transposes +as expensive NPU_OP_POOL (1x1 AvgPool) sequences. +""" + +import logging +from typing import Sequence, Set, Type + +import torch +import torch.fx +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + get_constant_placeholder_kind, + get_param_tensor, + is_param_node, + is_persistent_buffer, +) +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +logger = logging.getLogger(__name__) + + +def _get_transpose_perm(node: torch.fx.Node) -> list[int] | None: + """Extract the permutation order from a TRANSPOSE node. + + Args: + node: A node with target exir_ops.backend.tosa.TRANSPOSE.default + + Returns: + The permutation order as a list of ints, or None if not extractable. + """ + if node.target != exir_ops.backend.tosa.TRANSPOSE.default: + return None + + if len(node.args) < 2: + return None + + perm = node.args[1] + if isinstance(perm, (list, tuple)): + return list(perm) + + return None + + +class FoldConstantTransposePass(ArmPass): + """Folds TOSA TRANSPOSE operations on constant tensors at compile time. + + This pass transforms patterns like: + static_weight (placeholder) -> TRANSPOSE [0,2,3,1] -> Conv2D + + Into: + static_weight_transposed (placeholder with pre-permuted data) -> Conv2D + + This eliminates runtime transposes on static tensors, which is especially + beneficial for Ethos-U55 where Vela implements transposes as expensive + NPU_OP_POOL (1x1 AvgPool) sequences. + + Example: + Before: weight (NCHW) -> TRANSPOSE -> Conv2D (expects NHWC) + After: weight_nhwc (already NHWC) -> Conv2D + + Note: + This pass only folds transposes on constant/static tensors. Transposes + on dynamic activations (runtime data) cannot be folded and must be + executed at runtime. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.exported_program = exported_program + + def _fold_transpose( + self, + graph_module: torch.fx.GraphModule, + transpose_node: torch.fx.Node, + input_node: torch.fx.Node, + perm: Sequence[int], + ) -> bool: + """Fold a single transpose operation on a constant tensor. + + Args: + graph_module: The graph module being transformed. + transpose_node: The TRANSPOSE node to fold. + input_node: The constant input node (parameter/buffer/lifted constant). + perm: The permutation order. + + Returns: + True if the transpose was successfully folded, False otherwise. + """ + # Get the original tensor data + tensor = get_param_tensor(self.exported_program, input_node) + if tensor is None: + logger.debug( + f"FoldConstantTransposePass: Could not get tensor for {input_node.name}" + ) + return False + + # Validate permutation + if len(perm) != tensor.dim(): + logger.warning( + f"FoldConstantTransposePass: Permutation length {len(perm)} does not " + f"match tensor rank {tensor.dim()} for {transpose_node.name}" + ) + return False + + # Actually permute the data at compile time + try: + permuted_tensor = tensor.permute(perm).contiguous() + except Exception as e: + logger.warning( + f"FoldConstantTransposePass: Failed to permute tensor for " + f"{transpose_node.name}: {e}" + ) + return False + + # Determine the kind and persistence of the original constant + try: + input_kind = get_constant_placeholder_kind(self.exported_program, input_node) + except RuntimeError: + logger.debug( + f"FoldConstantTransposePass: {input_node.name} is not a constant placeholder" + ) + return False + + persistent = is_persistent_buffer(self.exported_program, input_node) + + # Create new constant placeholder with permuted data + with graph_module.graph.inserting_before(input_node): + const_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=input_kind, + name=f"{input_node.name}_transposed", + data=permuted_tensor, + persistent_buffer=persistent if persistent is not None else True, + ) + + # Copy relevant metadata from the transpose output + if "tosa_dim_order" in transpose_node.meta: + const_node.meta["tosa_dim_order"] = transpose_node.meta["tosa_dim_order"] + if "tosa_spatial_rank" in transpose_node.meta: + const_node.meta["tosa_spatial_rank"] = transpose_node.meta["tosa_spatial_rank"] + + # Replace all uses of the transpose node with the new constant + transpose_node.replace_all_uses_with(const_node) + + logger.debug( + f"FoldConstantTransposePass: Folded transpose {transpose_node.name} " + f"on constant {input_node.name} with perm={perm}" + ) + + return True + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + nodes_to_remove: Set[torch.fx.Node] = set() + input_nodes_to_check: Set[torch.fx.Node] = set() + + # Process TOSA transpose nodes + for node in list(graph_module.graph.nodes): + if node.target != exir_ops.backend.tosa.TRANSPOSE.default: + continue + + # Get permutation + perm = _get_transpose_perm(node) + if perm is None: + continue + + # Get input node + input_nodes = node.all_input_nodes + if len(input_nodes) == 0: + continue + input_node = input_nodes[0] + + # Check if input is a constant tensor + if not is_param_node(self.exported_program, input_node): + continue + + # Skip if input node has multiple users (other than this transpose) + # to avoid duplicating the constant + if len(input_node.users) > 1: + logger.debug( + f"FoldConstantTransposePass: Skipping {node.name} because input " + f"{input_node.name} has multiple users" + ) + continue + + # Fold the transpose + if self._fold_transpose(graph_module, node, input_node, perm): + modified = True + nodes_to_remove.add(node) + input_nodes_to_check.add(input_node) + + # Clean up removed transpose nodes + if modified: + graph_module.graph.eliminate_dead_code() + + # Try to clean up orphaned input nodes + for input_node in input_nodes_to_check: + if len(input_node.users) == 0: + try: + delete_constant_placeholder(self.exported_program, input_node) + except Exception as e: + logger.debug( + f"FoldConstantTransposePass: Could not delete orphaned " + f"placeholder {input_node.name}: {e}" + ) + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/fuse_consecutive_transposes_pass.py b/backends/arm/_passes/fuse_consecutive_transposes_pass.py new file mode 100644 index 00000000000..13ce096099e --- /dev/null +++ b/backends/arm/_passes/fuse_consecutive_transposes_pass.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Pass to fuse consecutive transpose/permute operations. + +This pass identifies chains of transpose/permute operations and either: +1. Removes both if they cancel out (result is identity permutation) +2. Fuses them into a single permute with combined dimensions + +This optimization reduces runtime overhead by eliminating redundant memory +movement operations. +""" + +import logging +from typing import List, Sequence, Set, Tuple, Type + +import executorch.backends.arm.tosa.dialect # noqa: F401 - loads TOSA dialect +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch._ops import OpOverload + +logger = logging.getLogger(__name__) + + +_PERMUTE_TARGETS: Tuple[OpOverload, ...] = ( + exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.backend.tosa.TRANSPOSE.default, +) + + +def _compose_permutations(perm1: Sequence[int], perm2: Sequence[int]) -> List[int]: + """Compose two permutations: result[i] = perm1[perm2[i]]. + + Given two consecutive permutations, computes the equivalent single permutation. + + Args: + perm1: First permutation (applied first to data) + perm2: Second permutation (applied second to data) + + Returns: + Combined permutation that has the same effect as applying perm1 then perm2 + """ + return [perm1[i] for i in perm2] + + +def _is_identity_permutation(perm: Sequence[int]) -> bool: + """Check if a permutation is the identity (no-op). + + Args: + perm: Permutation to check + + Returns: + True if perm[i] == i for all i (identity permutation) + """ + return list(perm) == list(range(len(perm))) + + +class FuseConsecutiveTransposesPass(ArmPass): + """Fuse consecutive transpose/permute operations. + + This pass looks for patterns like: + x -> permute(dims1) -> permute(dims2) -> y + + And transforms them to either: + x -> y (if dims1 and dims2 cancel out) + Or: + x -> permute(combined_dims) -> y (single fused permute) + + This is inspired by bolt/nn/espresso/transforms/fuse_ops.py:fuse_transposes + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + + # Iterate until no more fusions can be made + while True: + fused_this_iteration = False + + for node in list(graph_module.graph.nodes): + if node.op != "call_function": + continue + if node.target not in _PERMUTE_TARGETS: + continue + + # Check if input is also a permute + input_node = node.args[0] + if not isinstance(input_node, torch.fx.Node): + continue + if input_node.op != "call_function": + continue + if input_node.target not in _PERMUTE_TARGETS: + continue + + # We have permute -> permute pattern + permute1 = input_node + permute2 = node + + # Get the permutation dimensions + dims1 = permute1.args[1] + dims2 = permute2.args[1] + + if not isinstance(dims1, (list, tuple)): + continue + if not isinstance(dims2, (list, tuple)): + continue + + # Normalize to lists + dims1 = list(dims1) + dims2 = list(dims2) + + if len(dims1) != len(dims2): + # Permutations must have same rank + continue + + # Compose the permutations + combined_dims = _compose_permutations(dims1, dims2) + + if _is_identity_permutation(combined_dims): + # Two permutes cancel out - remove both + logger.debug( + f"Removing canceling permutes: " + f"permute({dims1}) -> permute({dims2})" + ) + permute2.replace_all_uses_with(permute1.args[0]) + else: + # Fuse into single permute + logger.debug( + f"Fusing permutes: " + f"permute({dims1}) -> permute({dims2}) => permute({combined_dims})" + ) + + with graph_module.graph.inserting_before(permute1): + new_permute = graph_module.graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(permute1.args[0], combined_dims), + ) + # Copy metadata from the output permute + new_permute.meta = permute2.meta.copy() + permute2.replace_all_uses_with(new_permute) + + fused_this_iteration = True + modified = True + break # Restart iteration after modification + + if not fused_this_iteration: + break + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/fuse_transpose_reshape_linear_pass.py b/backends/arm/_passes/fuse_transpose_reshape_linear_pass.py new file mode 100644 index 00000000000..7601e39b022 --- /dev/null +++ b/backends/arm/_passes/fuse_transpose_reshape_linear_pass.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2024-2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Fuse transpose -> reshape -> linear patterns. + +A common artifact from reordering dimensions is that transposes are +inserted before FC/Linear layers. Usually this looks something like: + Transpose -> Reshape -> Linear +where the Reshape flattens all dimensions except the batch dimension. + +This pass eliminates the transpose by applying the inverse of the transpose +to the Linear layer's weights instead. + +Inspired by bolt/nn/espresso/transforms/fuse_ops.py:fuse_transpose_reshape_fc +""" + +import logging +from typing import Optional, Set, Type + +import executorch.backends.arm.tosa.dialect # noqa: F401 - loads TOSA dialect +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch import fx + +logger = logging.getLogger(__name__) + + +# Supported permute/transpose targets +_PERMUTE_TARGETS = ( + torch.ops.aten.permute_copy.default, + exir_ops.backend.tosa.TRANSPOSE.default, +) + + +def _get_permute_dims(node: fx.Node) -> Optional[list[int]]: + """Extract the permutation dimensions from a permute node.""" + if node.target not in _PERMUTE_TARGETS: + return None + dims = node.args[1] + if isinstance(dims, (list, tuple)): + return list(dims) + return None + + +def _get_reshape_shape(node: fx.Node) -> Optional[list[int]]: + """Extract the shape from a view/reshape node.""" + if node.target not in ( + torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, + ): + return None + shape = node.args[1] + if isinstance(shape, (list, tuple)): + return list(shape) + return None + + +def _is_linear_node(node: fx.Node) -> bool: + """Check if a node is a linear operation.""" + if node.op != "call_function": + return False + return node.target in ( + torch.ops.aten.linear.default, + torch.ops.aten.mm.default, + torch.ops.aten.addmm.default, + ) + + +def _get_weight_node(node: fx.Node) -> Optional[fx.Node]: + """Get the weight tensor node from a linear operation.""" + if node.target == torch.ops.aten.linear.default: + if len(node.args) >= 2: + return node.args[1] + elif node.target == torch.ops.aten.mm.default: + if len(node.args) >= 2: + return node.args[1] + elif node.target == torch.ops.aten.addmm.default: + if len(node.args) >= 3: + return node.args[2] + return None + + +class FuseTransposeReshapeLinearPass(ArmPass): + """ + Fuses transpose -> reshape -> linear patterns by folding the transpose + into the linear layer's weights. + + This pass identifies patterns where: + 1. A permute is followed by a reshape + 2. Which is followed by a linear/mm operation + 3. The transpose does not modify the batch dimension + 4. The reshape flattens all non-batch dimensions + + Instead of transposing at runtime, the pass applies the inverse transpose + to the linear weights at compile time. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def __init__(self) -> None: + super().__init__() + self._graph_module: Optional[fx.GraphModule] = None + + def _find_patterns( + self, graph_module: fx.GraphModule + ) -> list[tuple[fx.Node, fx.Node, fx.Node]]: + """Find all transpose -> reshape -> linear patterns.""" + patterns = [] + graph = graph_module.graph + + for node in graph.nodes: + if node.op != "call_function": + continue + + dims = _get_permute_dims(node) + if dims is None: + continue + + transpose = node + + users = list(transpose.users.keys()) + if len(users) != 1: + continue + + reshape = users[0] + if reshape.op != "call_function": + continue + + reshape_shape = _get_reshape_shape(reshape) + if reshape_shape is None: + continue + + reshape_users = list(reshape.users.keys()) + if len(reshape_users) != 1: + continue + + linear = reshape_users[0] + if not _is_linear_node(linear): + continue + + patterns.append((transpose, reshape, linear)) + + return patterns + + def call(self, graph_module: fx.GraphModule) -> PassResult: + self._graph_module = graph_module + modified = False + + patterns = self._find_patterns(graph_module) + + for transpose, reshape, linear in patterns: + dims = _get_permute_dims(transpose) + reshape_shape = _get_reshape_shape(reshape) + + if dims is None or reshape_shape is None: + continue + + if dims[0] != 0: + logger.debug( + "Transpose modifies batch dimension, skipping fusion" + ) + continue + + if len(reshape_shape) != 2 or reshape_shape[0] not in (-1, 1): + logger.debug( + "Reshape does not flatten to 2D with batch dim, skipping" + ) + continue + + input_node = transpose.args[0] + if not isinstance(input_node, fx.Node): + continue + + input_val = input_node.meta.get("val") + if input_val is None: + continue + input_shape = list(input_val.shape) + + transpose_val = transpose.meta.get("val") + if transpose_val is None: + continue + + if input_shape[0] != transpose_val.shape[0]: + logger.debug("Batch dimension changed during transpose, skipping") + continue + + reshape_val = reshape.meta.get("val") + if reshape_val is None: + continue + + if reshape_val.shape[0] != transpose_val.shape[0]: + logger.debug("Batch dimension changed during reshape, skipping") + continue + + weight_node = _get_weight_node(linear) + if weight_node is None or not isinstance(weight_node, fx.Node): + logger.debug("Cannot find weight node for linear, skipping") + continue + + weight_val = weight_node.meta.get("val") + if weight_val is None: + continue + + inv_transpose = [dims.index(i) for i in range(len(dims))] + + inner_shape = transpose_val.shape[1:] + + try: + new_weight_shape = (-1,) + tuple(inner_shape) + weight_data = weight_val.reshape(new_weight_shape) + weight_data = weight_data.permute(inv_transpose) + weight_data = weight_data.reshape(weight_val.shape) + except (RuntimeError, ValueError) as e: + logger.debug(f"Cannot reshape weights: {e}, skipping") + continue + + logger.info( + f"Fusing transpose({dims}) into linear weights, " + f"inverse permutation: {inv_transpose}" + ) + + with graph_module.graph.inserting_before(transpose): + new_reshape = graph_module.graph.call_function( + torch.ops.aten.view_copy.default, + (input_node, reshape_shape), + ) + new_reshape.meta["val"] = reshape_val + + reshape.replace_all_uses_with(new_reshape) + + graph_module.graph.erase_node(reshape) + graph_module.graph.erase_node(transpose) + + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/fuse_transpose_reshape_transpose_pass.py b/backends/arm/_passes/fuse_transpose_reshape_transpose_pass.py new file mode 100644 index 00000000000..57c74ea74f9 --- /dev/null +++ b/backends/arm/_passes/fuse_transpose_reshape_transpose_pass.py @@ -0,0 +1,282 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2024-2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Fuse transpose -> reshape -> transpose patterns. + +When reordering the graph (e.g from NCHW -> NHWC), a reshape is considered +non-reorderable, and transposes are added before the input and after the output. +In certain situations, we can transform this (transpose -> reshape -> transpose) +pattern into a single transpose followed by a (different) reshape. + +Example: Consider a reshape on an NCHW tensor that reshapes the batch and channel +dimensions into the channel dimension: + (N, C, H, W) -> reshape -> (1, (N, C), H, W) + +If both input and output tensors are reordered to NHWC: + (N, H, W, C) + -> transpose -> (N, C, H, W) + -> reshape -> (1, (N, C), H, W) + -> transpose -> (1, H, W, (N, C)) + +This is equivalent to: + (N, H, W, C) -> transpose -> (H, W, N, C) -> reshape -> (1, H, W, (N, C)) + +Inspired by bolt/nn/espresso/transforms/fuse_ops.py:fuse_transpose_reshape_transpose +""" + +import logging +from typing import Optional, Set, Type + +import executorch.backends.arm.tosa.dialect # noqa: F401 - loads TOSA dialect +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch import fx + +logger = logging.getLogger(__name__) + + +# Supported permute/transpose targets +_PERMUTE_TARGETS = ( + torch.ops.aten.permute_copy.default, + exir_ops.backend.tosa.TRANSPOSE.default, +) + +# Supported reshape/view targets (both ATen and Edge dialects) +_RESHAPE_TARGETS = ( + torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, + exir_ops.edge.aten.view_copy.default, +) + + +def _get_permute_dims(node: fx.Node) -> Optional[list[int]]: + """Extract the permutation dimensions from a permute node.""" + if node.target not in _PERMUTE_TARGETS: + return None + dims = node.args[1] + if isinstance(dims, (list, tuple)): + return list(dims) + return None + + +def _get_reshape_shape(node: fx.Node) -> Optional[list[int]]: + """Extract the shape from a view/reshape node.""" + if node.target not in _RESHAPE_TARGETS: + return None + shape = node.args[1] + if isinstance(shape, (list, tuple)): + return list(shape) + return None + + +def _get_shape_indices( + input_shape: list[int], output_shape: list[int] +) -> Optional[list[list[int]]]: + """ + Compute which input dimensions map to each output dimension. + + For each output dimension, returns a list of input dimension indices + that were combined to create it. + + Returns None if the reshape is not a simple combination/split of dimensions. + """ + if not input_shape or not output_shape: + return None + + input_idx = 0 + result = [] + + for out_dim in output_shape: + if out_dim == -1: + return None + + indices = [] + accumulated = 1 + + while accumulated < out_dim and input_idx < len(input_shape): + indices.append(input_idx) + accumulated *= input_shape[input_idx] + input_idx += 1 + + if accumulated == out_dim: + if not indices and input_idx < len(input_shape): + if input_shape[input_idx] == 1: + indices.append(input_idx) + input_idx += 1 + elif input_shape[input_idx] == out_dim: + indices.append(input_idx) + input_idx += 1 + else: + return None + result.append(indices) + elif accumulated > out_dim: + return None + else: + return None + + if input_idx != len(input_shape): + return None + + return result + + +class FuseTransposeReshapeTransposePass(ArmPass): + """ + Fuses transpose -> reshape -> transpose patterns. + + This pass identifies patterns where: + 1. A permute is followed by a reshape/view + 2. Which is followed by another permute + + And transforms them into a single permute followed by a reshape with + the combined effect. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def __init__(self) -> None: + super().__init__() + self._graph_module: Optional[fx.GraphModule] = None + + def _find_patterns( + self, graph_module: fx.GraphModule + ) -> list[tuple[fx.Node, fx.Node, fx.Node]]: + """Find all transpose -> reshape -> transpose patterns.""" + patterns = [] + graph = graph_module.graph + + for node in graph.nodes: + if node.op != "call_function": + continue + + dims1 = _get_permute_dims(node) + if dims1 is None: + continue + + transpose1 = node + + users = list(transpose1.users.keys()) + if len(users) != 1: + continue + + reshape = users[0] + if reshape.op != "call_function": + continue + + reshape_shape = _get_reshape_shape(reshape) + if reshape_shape is None: + continue + + reshape_users = list(reshape.users.keys()) + if len(reshape_users) != 1: + continue + + transpose2 = reshape_users[0] + if transpose2.op != "call_function": + continue + + dims2 = _get_permute_dims(transpose2) + if dims2 is None: + continue + + patterns.append((transpose1, reshape, transpose2)) + + return patterns + + def call(self, graph_module: fx.GraphModule) -> PassResult: + self._graph_module = graph_module + modified = False + + patterns = self._find_patterns(graph_module) + + for transpose1, reshape, transpose2 in patterns: + dims1 = _get_permute_dims(transpose1) + dims2 = _get_permute_dims(transpose2) + reshape_shape = _get_reshape_shape(reshape) + + if dims1 is None or dims2 is None or reshape_shape is None: + continue + + input_node = transpose1.args[0] + if not isinstance(input_node, fx.Node): + continue + + input_val = input_node.meta.get("val") + if input_val is None: + continue + + input_shape = list(input_val.shape) + + transposed1_shape = [input_shape[i] for i in dims1] + + shape_indices = _get_shape_indices(transposed1_shape, reshape_shape) + if shape_indices is None: + logger.warning( + f"FuseTransposeReshapeTransposePass: Cannot compute shape indices for " + f"reshape {reshape.name}: transposed1_shape={transposed1_shape}, " + f"reshape_shape={reshape_shape}" + ) + continue + + original = list(range(len(dims1))) + transposed1 = [original[i] for i in dims1] + + reshaped = [tuple(transposed1[i] for i in s) for s in shape_indices] + + transposed2 = [reshaped[i] for i in dims2] + + new_transpose_axes = [i for s in transposed2 for i in s] + + if len(new_transpose_axes) != len(input_shape): + logger.debug( + "New transpose axes length mismatch, skipping fusion" + ) + continue + + output_val = transpose2.meta.get("val") + if output_val is None: + continue + output_shape = list(output_val.shape) + + logger.info( + f"Fusing transpose({dims1}) -> reshape({reshape_shape}) -> transpose({dims2}) " + f"into transpose({new_transpose_axes}) -> reshape({output_shape})" + ) + + with graph_module.graph.inserting_before(transpose1): + new_permute = graph_module.graph.call_function( + torch.ops.aten.permute_copy.default, + (input_node, new_transpose_axes), + ) + + if input_val is not None: + new_permute.meta["val"] = input_val.permute(new_transpose_axes) + + new_reshape = graph_module.graph.call_function( + torch.ops.aten.view_copy.default, + (new_permute, output_shape), + ) + + if output_val is not None: + new_reshape.meta["val"] = output_val + + transpose2.replace_all_uses_with(new_reshape) + + graph_module.graph.erase_node(transpose2) + graph_module.graph.erase_node(reshape) + graph_module.graph.erase_node(transpose1) + + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/fuse_transpose_sandwich_pass.py b/backends/arm/_passes/fuse_transpose_sandwich_pass.py new file mode 100644 index 00000000000..c554bd0978d --- /dev/null +++ b/backends/arm/_passes/fuse_transpose_sandwich_pass.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Pass to fuse transpose-op-transpose patterns where the middle op is layout-invariant. + +This pass identifies patterns like: + T(perm1) → LayoutInvariantOp → T(perm2) + +And if perm1 and perm2 cancel out (their composition is identity), removes both +transposes. This is particularly effective for TOSA graphs where ToTosaMemoryFormatPass +inserts NCHW↔NHWC transposes at operation boundaries. + +Layout-invariant operations include elementwise ops (add, mul, relu, sigmoid, etc.) +that don't care about data layout because they operate independently on each element. +""" + +import logging +from typing import List, Sequence, Set, Tuple, Type + +import executorch.backends.arm.tosa.dialect # noqa: F401 - loads TOSA dialect +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch._ops import OpOverload + +logger = logging.getLogger(__name__) + + +# Transpose/permute targets we can fuse +_PERMUTE_TARGETS: Tuple[OpOverload, ...] = ( + exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.backend.tosa.TRANSPOSE.default, +) + + +# Layout-invariant operations - these don't depend on data layout +# They operate element-wise and produce the same result regardless of memory format +_LAYOUT_INVARIANT_OPS: Set[OpOverload] = { + # Elementwise unary ops + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.log.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.sin.default, + exir_ops.edge.aten.cos.default, + exir_ops.edge.aten.floor.default, + exir_ops.edge.aten.ceil.default, + exir_ops.edge.aten.round.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.clamp.Tensor, + exir_ops.edge.aten.hardswish.default, + exir_ops.edge.aten.hardsigmoid.default, + exir_ops.edge.aten.leaky_relu.default, + exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.silu.default, + exir_ops.edge.aten.reciprocal.default, + # Elementwise binary ops (when both inputs have same layout) + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.minimum.default, +} + + +def _compose_permutations(perm1: Sequence[int], perm2: Sequence[int]) -> List[int]: + """Compose two permutations: result[i] = perm1[perm2[i]]. + + Given two consecutive permutations, computes the equivalent single permutation. + """ + return [perm1[i] for i in perm2] + + +def _is_identity_permutation(perm: Sequence[int]) -> bool: + """Check if a permutation is the identity (no-op).""" + return list(perm) == list(range(len(perm))) + + +def _inverse_permutation(perm: Sequence[int]) -> List[int]: + """Compute the inverse of a permutation.""" + inv = [0] * len(perm) + for i, p in enumerate(perm): + inv[p] = i + return inv + + +def _get_permutation(node: torch.fx.Node) -> List[int] | None: + """Extract permutation from a transpose/permute node.""" + if node.op != "call_function" or node.target not in _PERMUTE_TARGETS: + return None + if len(node.args) < 2: + return None + dims = node.args[1] + if isinstance(dims, (list, tuple)): + return list(dims) + return None + + +def _is_layout_invariant(node: torch.fx.Node) -> bool: + """Check if a node is a layout-invariant operation.""" + if node.op != "call_function": + return False + return node.target in _LAYOUT_INVARIANT_OPS + + +def _get_single_non_constant_input( + node: torch.fx.Node, +) -> torch.fx.Node | None: + """Get the single non-constant input to a node, or None if multiple/none.""" + non_const_inputs = [] + for arg in node.args: + if isinstance(arg, torch.fx.Node): + # Check if it's a constant (placeholder that's a param or buffer) + if arg.op == "get_attr": + continue + non_const_inputs.append(arg) + if len(non_const_inputs) == 1: + return non_const_inputs[0] + return None + + +class FuseTransposeSandwichPass(ArmPass): + """Fuse transpose-op-transpose patterns where permutations cancel out. + + This pass looks for patterns like: + x -> permute(perm1) -> layout_invariant_op -> permute(perm2) -> y + + And transforms them to: + x -> layout_invariant_op -> y (if perm1 and perm2 cancel out) + + This is effective for removing TOSA transposes inserted by ToTosaMemoryFormatPass + when they surround layout-invariant operations like ReLU, Add, etc. + + This pattern is common in TOSA graphs where: + - Input transpose: NCHW → NHWC for TOSA + - Layout-invariant op (relu, add, etc.) + - Output transpose: NHWC → NCHW back to PyTorch format + + Example: + Before: T([0,2,3,1]) -> ReLU -> T([0,3,1,2]) + After: ReLU (transposes removed, compose([0,2,3,1], [0,3,1,2]) = identity) + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + + # Iterate until no more fusions can be made + while True: + fused_this_iteration = False + + for node in list(graph_module.graph.nodes): + # Look for second transpose in pattern: T1 -> Op -> T2 + perm2 = _get_permutation(node) + if perm2 is None: + continue + + # Get the middle operation (input to T2) + middle_node = node.args[0] + if not isinstance(middle_node, torch.fx.Node): + continue + + # Check if middle is layout-invariant + if not _is_layout_invariant(middle_node): + continue + + # Check if middle has only one user (T2) + if len(middle_node.users) != 1: + continue + + # Get the input to the middle op that's a transpose + input_to_middle = _get_single_non_constant_input(middle_node) + if input_to_middle is None: + continue + + perm1 = _get_permutation(input_to_middle) + if perm1 is None: + continue + + # Check if T1 has only one user (the middle op) + if len(input_to_middle.users) != 1: + continue + + # Check if permutations have same rank + if len(perm1) != len(perm2): + continue + + # Compose permutations and check if they cancel + composed = _compose_permutations(perm1, perm2) + + if _is_identity_permutation(composed): + # Permutations cancel out - remove both transposes + logger.debug( + f"Removing sandwich pattern: " + f"T({perm1}) -> {middle_node.target} -> T({perm2})" + ) + + # Get original input (input to T1) + original_input = input_to_middle.args[0] + + # Rewire middle op to use original input + new_args = list(middle_node.args) + for i, arg in enumerate(new_args): + if arg is input_to_middle: + new_args[i] = original_input + middle_node.args = tuple(new_args) + + # Rewire users of T2 to use middle op directly + node.replace_all_uses_with(middle_node) + + fused_this_iteration = True + modified = True + break # Restart iteration after modification + + if not fused_this_iteration: + break + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/propagate_transposes_through_concat_pass.py b/backends/arm/_passes/propagate_transposes_through_concat_pass.py new file mode 100644 index 00000000000..7db36cecbb7 --- /dev/null +++ b/backends/arm/_passes/propagate_transposes_through_concat_pass.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Pass to propagate transposes through Concat operations. + +Concat is a layout-aware operation that concatenates tensors along a specific +dimension. When all inputs to a Concat have the same transpose permutation, +and the output is transposed with the inverse permutation, the transposes +can be eliminated by adjusting the concat dimension. + +Pattern targeted: + [T(perm1), T(perm1), ...] → Concat(dim=d) → T(perm2) + +If perm1 and perm2 compose to identity (i.e., perm2 is the inverse of perm1), +the pattern can be simplified to: + [inputs...] → Concat(dim=perm1[d]) → outputs + +This is particularly effective for TOSA graphs where: +- ToTosaMemoryFormatPass inserts NCHW↔NHWC transposes at graph boundaries +- Concat operations often have transposed inputs that are re-transposed after +- Example: T(NCHW→NHWC) → Concat → T(NHWC→NCHW) +""" + +import logging +from typing import List, Set, Tuple, Type + +import executorch.backends.arm.tosa.dialect # noqa: F401 - loads TOSA dialect +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch._ops import OpOverload + +logger = logging.getLogger(__name__) + + +# Transpose/permute targets we can propagate +_PERMUTE_TARGETS: Tuple[OpOverload, ...] = ( + exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.backend.tosa.TRANSPOSE.default, +) + +# Concat targets +_CONCAT_TARGETS: Tuple[OpOverload, ...] = ( + exir_ops.edge.aten.cat.default, +) + + +def _get_permutation(node: torch.fx.Node) -> List[int] | None: + """Extract permutation from a transpose/permute node.""" + if node.op != "call_function" or node.target not in _PERMUTE_TARGETS: + return None + if len(node.args) < 2: + return None + dims = node.args[1] + if isinstance(dims, (list, tuple)): + return list(dims) + return None + + +def _is_concat(node: torch.fx.Node) -> bool: + """Check if a node is a Concat operation.""" + if node.op != "call_function": + return False + return node.target in _CONCAT_TARGETS + + +def _compose_permutations(perm1: List[int], perm2: List[int]) -> List[int]: + """Compose two permutations: result[i] = perm1[perm2[i]].""" + return [perm1[i] for i in perm2] + + +def _is_identity_permutation(perm: List[int]) -> bool: + """Check if a permutation is identity.""" + return perm == list(range(len(perm))) + + +def _inverse_permutation(perm: List[int]) -> List[int]: + """Compute the inverse of a permutation.""" + inv = [0] * len(perm) + for i, p in enumerate(perm): + inv[p] = i + return inv + + +class PropagateTransposesThroughConcatPass(ArmPass): + """Propagate transposes through Concat operations. + + This pass looks for patterns like: + [T(perm), T(perm), ...] → Concat(dim=d) → T(inv_perm) + + Where all inputs to Concat have the same permutation perm, and the output + is transposed with the inverse permutation. In this case, the transposes + cancel out, and we can eliminate them by adjusting the concat dimension. + + Transformation: + Before: [T(perm)(x1), T(perm)(x2)] → Concat(dim=d) → T(inv_perm) → y + After: [x1, x2] → Concat(dim=perm[d]) → y + + The concat dimension is adjusted because: + - Original: Concat on transposed data along dim d + - New: Concat on original data along the dimension that maps to d after transpose + + Example for NCHW→NHWC (perm=[0,2,3,1]): + - Original: Concat(dim=3) on NHWC data → Transpose to NCHW + - After: Concat(dim=1) on NCHW data (since perm[3]=1, but we need inverse) + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + + # Iterate until no more fusions can be made + while True: + fused_this_iteration = False + + for node in list(graph_module.graph.nodes): + # Look for pattern: Concat → T (start from the output transpose) + output_perm = _get_permutation(node) + if output_perm is None: + continue + + # Get the input to the output transpose (should be Concat) + concat_node = node.args[0] + if not isinstance(concat_node, torch.fx.Node): + continue + + if not _is_concat(concat_node): + continue + + # Check if Concat has only one user (the output transpose) + if len(concat_node.users) != 1: + continue + + # Get the concat inputs and dimension + concat_inputs = concat_node.args[0] + if not isinstance(concat_inputs, (list, tuple)): + continue + + # Get concat dimension (default is 0) + concat_dim = concat_node.args[1] if len(concat_node.args) > 1 else 0 + + # Check if all concat inputs are transposes with the same permutation + input_perms = [] + input_sources = [] + all_valid = True + + for inp in concat_inputs: + if not isinstance(inp, torch.fx.Node): + all_valid = False + break + + inp_perm = _get_permutation(inp) + if inp_perm is None: + all_valid = False + break + + # Check single user (only this Concat) + if len(inp.users) != 1: + all_valid = False + break + + input_perms.append(inp_perm) + input_sources.append(inp.args[0]) + + if not all_valid: + continue + + if len(input_perms) == 0: + continue + + # Check all input permutations are the same + first_perm = input_perms[0] + if not all(perm == first_perm for perm in input_perms): + continue + + # Check if input and output permutations have the same rank + if len(first_perm) != len(output_perm): + continue + + # Check if input_perm → output_perm composes to identity + composed = _compose_permutations(first_perm, output_perm) + + if not _is_identity_permutation(composed): + # Permutations don't cancel out + continue + + # We have the pattern! All transposes can be eliminated. + # New concat dimension: when we remove the input transposes, + # the data is in the original layout. The concat dimension + # needs to be adjusted. + # + # Original: data → T(perm) → Concat(dim=d) → T(inv_perm) + # The concat happens on transposed data at dimension d. + # In original layout, this corresponds to dimension inv_perm[d]. + # + # But wait - we're removing BOTH the input and output transposes. + # So the new concat dimension should be: first_perm.index(concat_dim) + # i.e., find where concat_dim came from in the original layout. + # + # Actually, the inverse permutation tells us this: + # inv_perm[d] gives the original dimension that maps to d. + + inv_first_perm = _inverse_permutation(first_perm) + new_concat_dim = inv_first_perm[concat_dim] + + logger.debug( + f"Propagating transposes through Concat: " + f"T({first_perm}) x {len(input_perms)} inputs → Concat(dim={concat_dim}) → T({output_perm}) " + f"=> Concat(dim={new_concat_dim})" + ) + + # Create new concat node with adjusted dimension + with graph_module.graph.inserting_before(concat_node): + new_concat = graph_module.graph.call_function( + concat_node.target, + args=(list(input_sources), new_concat_dim), + kwargs=dict(concat_node.kwargs), + ) + new_concat.meta = node.meta.copy() + + # Replace the output transpose with the new concat + node.replace_all_uses_with(new_concat) + + fused_this_iteration = True + modified = True + break # Restart iteration after modification + + if not fused_this_iteration: + break + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/propagate_transposes_through_rescale_pass.py b/backends/arm/_passes/propagate_transposes_through_rescale_pass.py new file mode 100644 index 00000000000..bc11c1500c8 --- /dev/null +++ b/backends/arm/_passes/propagate_transposes_through_rescale_pass.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Pass to propagate transposes through TOSA Rescale operations. + +TOSA Rescale operations are elementwise (per-element scaling and zero-point +adjustment), meaning they are layout-invariant. This pass identifies patterns +where transposes surround a Rescale operation and propagates them through, +enabling subsequent passes to fuse or eliminate consecutive transposes. + +Pattern targeted: + T(perm1) → Rescale → ... → T(perm2) + +After propagation: + Rescale → T(perm1) → ... → T(perm2) + +This allows FuseConsecutiveTransposesPass to then merge T(perm1) and T(perm2) +if they are now adjacent or can be composed. + +This is particularly effective for TOSA graphs where: +- ToTosaMemoryFormatPass inserts NCHW↔NHWC transposes +- Rescale operations are inserted for quantization/dequantization +- The pattern Transpose → Rescale → Conv → Rescale → Transpose is common +""" + +import logging +from typing import List, Sequence, Set, Tuple, Type + +import executorch.backends.arm.tosa.dialect # noqa: F401 - loads TOSA dialect +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch._ops import OpOverload + +logger = logging.getLogger(__name__) + + +# Transpose/permute targets we can propagate +_PERMUTE_TARGETS: Tuple[OpOverload, ...] = ( + exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.backend.tosa.TRANSPOSE.default, +) + +# TOSA Rescale operation - elementwise and layout-invariant +_RESCALE_TARGETS: Tuple[OpOverload, ...] = (exir_ops.backend.tosa.RESCALE.default,) + +# Additional layout-invariant operations through which we can propagate transposes +_LAYOUT_INVARIANT_OPS: Set[OpOverload] = { + # Elementwise unary ops + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.log.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.hardswish.default, + exir_ops.edge.aten.hardsigmoid.default, + exir_ops.edge.aten.leaky_relu.default, + exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.silu.default, + exir_ops.edge.aten.reciprocal.default, +} + + +def _get_permutation(node: torch.fx.Node) -> List[int] | None: + """Extract permutation from a transpose/permute node.""" + if node.op != "call_function" or node.target not in _PERMUTE_TARGETS: + return None + if len(node.args) < 2: + return None + dims = node.args[1] + if isinstance(dims, (list, tuple)): + return list(dims) + return None + + +def _is_rescale(node: torch.fx.Node) -> bool: + """Check if a node is a TOSA Rescale operation.""" + if node.op != "call_function": + return False + return node.target in _RESCALE_TARGETS + + +def _is_layout_invariant(node: torch.fx.Node) -> bool: + """Check if a node is a layout-invariant operation.""" + if node.op != "call_function": + return False + return node.target in _LAYOUT_INVARIANT_OPS or node.target in _RESCALE_TARGETS + + +def _permute_shape(shape: List[int], perm: List[int]) -> List[int]: + """Apply permutation to a shape.""" + return [shape[i] for i in perm] + + +def _inverse_permutation(perm: Sequence[int]) -> List[int]: + """Compute the inverse of a permutation.""" + inv = [0] * len(perm) + for i, p in enumerate(perm): + inv[p] = i + return inv + + +def _get_single_user(node: torch.fx.Node) -> torch.fx.Node | None: + """Get the single user of a node, or None if multiple or no users.""" + users = list(node.users.keys()) + if len(users) == 1: + return users[0] + return None + + +class PropagateTransposesThroughRescalePass(ArmPass): + """Fuse T1 -> Rescale -> T2 patterns by removing Rescale from the middle. + + This pass looks for patterns like: + x -> permute(perm1) -> rescale -> permute(perm2) -> y + + Where Rescale is a layout-invariant operation, and transforms them by + composing the permutations. Since Rescale is elementwise (operates on + each element independently), it doesn't care about memory layout. + + This is different from simple consecutive transpose fusion because + the Rescale operation sits between the two transposes. + + Pattern: + Before: T1(perm1) -> Rescale -> T2(perm2) + After: Rescale -> T_combined(compose(perm1, perm2)) + + If compose(perm1, perm2) is identity, both transposes are eliminated: + Before: T1(perm1) -> Rescale -> T2(inverse(perm1)) + After: Rescale (transposes removed) + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + + # Iterate until no more fusions can be made + while True: + fused_this_iteration = False + + for node in list(graph_module.graph.nodes): + # Look for the pattern: T1 -> Rescale -> T2 + # Start by finding T2 (second transpose) + perm2 = _get_permutation(node) + if perm2 is None: + continue + + # Get the input to T2 (should be Rescale) + rescale_node = node.args[0] + if not isinstance(rescale_node, torch.fx.Node): + continue + + if not _is_rescale(rescale_node): + continue + + # Check if Rescale has only one user (T2) + if len(rescale_node.users) != 1: + continue + + # Get the input to Rescale (should be T1) + transpose1_node = rescale_node.args[0] + if not isinstance(transpose1_node, torch.fx.Node): + continue + + perm1 = _get_permutation(transpose1_node) + if perm1 is None: + continue + + # Check if T1 has only one user (Rescale) + if len(transpose1_node.users) != 1: + continue + + # Check if permutations have same rank + if len(perm1) != len(perm2): + continue + + # We have the pattern: T1(perm1) -> Rescale -> T2(perm2) + # Compose the permutations + composed = [perm1[i] for i in perm2] + + # Get the original input (input to T1) + original_input = transpose1_node.args[0] + + if composed == list(range(len(composed))): + # Identity permutation - remove both transposes + logger.debug( + f"Removing T({perm1}) -> Rescale -> T({perm2}) pattern " + f"(composes to identity)" + ) + + # Rewire: Rescale now takes the original input + new_args = list(rescale_node.args) + new_args[0] = original_input + rescale_node.args = tuple(new_args) + + # Rewire: users of T2 now use Rescale + node.replace_all_uses_with(rescale_node) + + else: + # Non-identity - replace with single transpose after Rescale + logger.debug( + f"Fusing T({perm1}) -> Rescale -> T({perm2}) " + f"=> Rescale -> T({composed})" + ) + + # Rewire: Rescale now takes the original input + new_args = list(rescale_node.args) + new_args[0] = original_input + rescale_node.args = tuple(new_args) + + # Create new combined transpose after Rescale + with graph_module.graph.inserting_after(rescale_node): + new_transpose = graph_module.graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(rescale_node, composed), + ) + # Copy metadata from T2 + new_transpose.meta = node.meta.copy() + + # Rewire: users of T2 now use new_transpose + node.replace_all_uses_with(new_transpose) + + fused_this_iteration = True + modified = True + break # Restart iteration after modification + + if not fused_this_iteration: + break + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index ba55f9bc81f..65a1584a53a 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -205,6 +205,21 @@ def get_batch_prod_dim(shape, spatial_rank): return (N_old != N_new) or (C_old != C_new) + @staticmethod + def _is_nop_transpose(shape, perm) -> bool: + """Return ``True`` when a transpose only permutes size-1 dimensions. + + A transpose is a NOP (no-operation) when the relative order of + all non-size-1 dimensions is unchanged — permuting size-1 dims + does not alter the physical byte layout. + + Example: ``[14, 72, 1, 1]`` with perm ``(0, 1, 3, 2)`` → True + (only the two trailing size-1 dims swap). + """ + old_order = [i for i, s in enumerate(shape) if s != 1] + new_order = [i for i, s in zip(perm, [shape[p] for p in perm]) if s != 1] + return old_order == new_order + @staticmethod def insert_input_transpose(node, input_node, graph_module): """Ensure an input tensor is converted to channels-last ordering by @@ -254,7 +269,7 @@ def insert_output_transpose(node, graph_module): # Guard: mem_format must be a true permutation for the current rank assert sorted(mem_format) == list( range(rank) - ), f"bad perm {mem_format} for rank {rank} in insert_input_transpose" + ), f"bad perm {mem_format} for rank {rank} in insert_output_transpose" with graph_module.graph.inserting_after(node): permute_node = create_node( @@ -279,6 +294,104 @@ def insert_output_transpose(node, graph_module): for user in users: user.replace_input_with(node, permute_node) + @staticmethod + def _get_shape_indices( + src_shape: list[int], tgt_shape: list[int] + ) -> list[list[int]] | None: + """Greedy dimension matching for reshape operations. + + For each target dimension, greedily consumes contiguous source + dimensions whose product equals the target size. Size-1 target + dimensions that do not correspond to any source dimension produce + empty index lists (inserted dims). + + Returns ``None`` when no valid mapping exists. + """ + src_idx = 0 + result: list[list[int]] = [] + + for tgt_dim in tgt_shape: + if tgt_dim <= 0: + return None + + indices: list[int] = [] + remaining = tgt_dim + + while src_idx < len(src_shape) and remaining % src_shape[src_idx] == 0: + indices.append(src_idx) + remaining //= src_shape[src_idx] + src_idx += 1 + if remaining == 1: + break + + if remaining != 1: + return None + + result.append(indices) + + if src_idx != len(src_shape): + return None + + return result + + @staticmethod + def _is_monotonic(indices: list[list[int]]) -> bool: + """Return ``True`` when all non-empty index groups are strictly + ordered — i.e. each group's indices follow the previous group's. + """ + last_max = -1 + for group in indices: + if not group: + continue + if group[0] <= last_max: + return False + last_max = group[-1] + return True + + @staticmethod + def _is_nhwc_safe_reshape( + input_shape, output_shape, input_sr, output_sr # noqa: ARG004 + ) -> bool: + """Detect whether a 4-D+ reshape can operate directly on NHWC data. + + By the time ``ToTosaMemoryFormatPass`` runs, 4-D tensor shapes in + ``meta["val"]`` are already in NHWC physical order (the channel + dimension sits at position ``rank - spatial_rank - 1``, not at + position 1 as in NCHW). We therefore check the shape indices on + the **raw** input/output shapes — no extra permutation is needed. + + Returns ``True`` when: + 1. The reshape has monotonic shape_indices (each output dim maps + to a contiguous, in-order group of input dims), AND + 2. The channel dimension is preserved alone (not merged with + spatial dims). + """ + rank_in = len(input_shape) + rank_out = len(output_shape) + if rank_in < 4 or rank_out < 4: + return False + + indices = ToTosaMemoryFormatPass._get_shape_indices( + list(input_shape), list(output_shape) + ) + if indices is None: + return False + + if not ToTosaMemoryFormatPass._is_monotonic(indices): + return False + + # In the TOSA pipeline the physical memory order is NHWC. + # The channel dimension in NHWC is always the **last** axis + # (position ``rank - 1``). It must appear *alone* in its + # output group — if it is merged with spatial dims the reshape + # would reorder channel data and the optimisation is invalid. + channel_idx = rank_in - 1 + for group in indices: + if channel_idx in group: + return len(group) == 1 + # Channel dim not consumed by any group — conservative reject. + return False + @staticmethod def _insert_view_transpose( input_shape, output_shape, node, input_node, graph_module @@ -300,6 +413,14 @@ def _insert_view_transpose( output_sr, ) + # When the NHWC-space reshape has monotonic shape_indices the + # view_copy can operate directly on NHWC data — no transposes + # are needed. + if channel_reshape and ToTosaMemoryFormatPass._is_nhwc_safe_reshape( + input_shape, output_shape, input_sr, output_sr + ): + return + if ( channel_reshape or nhwc_to_nchw ) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr): @@ -328,10 +449,44 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): - 1D/2D tensors """ - for node in graph_module.graph.nodes: + for node in list(graph_module.graph.nodes): if node.op != "call_function": continue + # Eliminate model-level permute_copy ops that are redundant + # with the tosa_dim_order annotation. When a permute_copy's + # permutation matches the channels-last order (or its + # inverse), the permute does the same NCHW↔NHWC conversion + # that tosa_dim_order already handles — keeping both would + # double-convert. Replace with view_copy (identity reshape). + if node.target in ( + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.permute.default, + ): + perm = list(node.args[1]) + rank = len(perm) + sr = node.meta.get("tosa_spatial_rank", 0) + + if rank >= 3 and sr >= 1: + cl_order = list( + self._channels_last_order(rank, sr) + ) + cl_inv = list( + self._channels_last_inverse_order(rank, sr) + ) + if perm == cl_order or perm == cl_inv: + input_node = node.args[0] + output_shape = list(node.meta["val"].shape) + with graph_module.graph.inserting_before(node): + view_node = graph_module.graph.call_function( + exir_ops.edge.aten.view_copy.default, + (input_node, output_shape), + ) + view_node.meta = dict(node.meta) + node.replace_all_uses_with(view_node) + graph_module.graph.erase_node(node) + continue + # Transpose views elif node.target == exir_ops.edge.aten.view_copy.default: input_node = node.args[0] diff --git a/backends/arm/arm_vela.py b/backends/arm/arm_vela.py index 76afcdd23f4..b729219dc51 100644 --- a/backends/arm/arm_vela.py +++ b/backends/arm/arm_vela.py @@ -63,7 +63,7 @@ def vela_bin_pack_io(prefix, data): def vela_compile( tosa_flatbuffer: bytes, args: List[str], - verbose: bool = False, + verbose: bool = True, intermediate_path: str | None = None, ): """Compile a TOSA graph to a binary stream for ArmBackendEthosU using diff --git a/backends/arm/test/passes/test_fuse_consecutive_transposes_pass.py b/backends/arm/test/passes/test_fuse_consecutive_transposes_pass.py new file mode 100644 index 00000000000..0e1db684e42 --- /dev/null +++ b/backends/arm/test/passes/test_fuse_consecutive_transposes_pass.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for FuseConsecutiveTransposesPass. +""" + +from typing import Tuple + +import torch +from executorch.backends.arm._passes import FuseConsecutiveTransposesPass +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor] + + +class DoublePermuteIdentityModule(torch.nn.Module): + """Two permutes that cancel each other out (identity).""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # NCHW -> NHWC -> NCHW = identity + x = x.permute(0, 2, 3, 1) # NCHW -> NHWC + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + return x + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 3, 4, 5),) + + +def test_fuse_consecutive_transposes_identity_removes_both(): + """Test that two canceling permutes are both removed.""" + module = DoublePermuteIdentityModule() + pipeline = PassPipeline[input_t]( + module, + DoublePermuteIdentityModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 2, + }, + ops_after_pass={}, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default", + ], + pass_list=[FuseConsecutiveTransposesPass], + ) + pipeline.run() + + +class DoublePermuteFusionModule(torch.nn.Module): + """Two permutes that don't cancel but can be fused into one.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # (0,1,2,3) -> (0,2,1,3) -> (0,3,2,1) + # Combined: (0,1,2,3)[i for i in (0,3,2,1)] applied to (0,2,1,3) + # = (0,2,1,3) permuted by (0,3,2,1) = [0,3,1,2] + x = x.permute(0, 2, 1, 3) + x = x.permute(0, 3, 2, 1) + return x + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 3, 4, 5),) + + +def test_fuse_consecutive_transposes_fuses_to_one(): + """Test that two non-canceling permutes are fused into one.""" + module = DoublePermuteFusionModule() + pipeline = PassPipeline[input_t]( + module, + DoublePermuteFusionModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 2, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + pass_list=[FuseConsecutiveTransposesPass], + ) + pipeline.run() + + +class TriplePermuteModule(torch.nn.Module): + """Three consecutive permutes that should be fused.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) # (0,1,2,3) -> (0,2,3,1) + x = x.permute(0, 2, 3, 1) # -> (0,3,1,2) + x = x.permute(0, 2, 3, 1) # -> (0,1,2,3) = identity! + return x + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 3, 4, 5),) + + +def test_fuse_consecutive_transposes_triple_identity(): + """Test that three permutes that form identity are all removed.""" + module = TriplePermuteModule() + pipeline = PassPipeline[input_t]( + module, + TriplePermuteModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 3, + }, + ops_after_pass={}, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default", + ], + pass_list=[FuseConsecutiveTransposesPass], + ) + pipeline.run() + + +class SinglePermuteModule(torch.nn.Module): + """Single permute that should not be affected.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(0, 2, 3, 1) + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 3, 4, 5),) + + +def test_fuse_consecutive_transposes_single_permute_unchanged(): + """Test that a single permute is not affected.""" + module = SinglePermuteModule() + pipeline = PassPipeline[input_t]( + module, + SinglePermuteModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + pass_list=[FuseConsecutiveTransposesPass], + ) + pipeline.run() + + +class PermuteWithOpBetweenModule(torch.nn.Module): + """Permutes separated by another op should not be fused.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) + x = x + 1.0 # Op between permutes + x = x.permute(0, 3, 1, 2) + return x + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 3, 4, 5),) + + +def test_fuse_consecutive_transposes_not_consecutive(): + """Test that permutes separated by other ops are not fused.""" + module = PermuteWithOpBetweenModule() + pipeline = PassPipeline[input_t]( + module, + PermuteWithOpBetweenModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 2, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 2, + }, + pass_list=[FuseConsecutiveTransposesPass], + ) + pipeline.run() + + +class Permute3DIdentityModule(torch.nn.Module): + """Two 3D permutes that cancel each other out.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) + x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) = identity + return x + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 3, 4),) + + +def test_fuse_consecutive_transposes_3d_identity(): + """Test that two 3D canceling permutes are removed.""" + module = Permute3DIdentityModule() + pipeline = PassPipeline[input_t]( + module, + Permute3DIdentityModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 2, + }, + ops_after_pass={}, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default", + ], + pass_list=[FuseConsecutiveTransposesPass], + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_transpose_reshape_linear_pass.py b/backends/arm/test/passes/test_fuse_transpose_reshape_linear_pass.py new file mode 100644 index 00000000000..04ea267b3b6 --- /dev/null +++ b/backends/arm/test/passes/test_fuse_transpose_reshape_linear_pass.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2024-2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for FuseTransposeReshapeLinearPass. +""" + +import unittest + +import torch +from executorch.backends.arm._passes.fuse_transpose_reshape_linear_pass import ( + FuseTransposeReshapeLinearPass, +) +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + + +class TransposeReshapeLinearModel(torch.nn.Module): + """Model with transpose -> reshape -> linear pattern.""" + + def __init__( + self, dims: list[int], in_features: int, out_features: int + ) -> None: + super().__init__() + self.dims = dims + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(self.dims) + batch_size = x.shape[0] + x = x.reshape(batch_size, -1) + x = self.linear(x) + return x + + +class TransposeReshapeLinearWithBiasModel(torch.nn.Module): + """Model with transpose -> reshape -> linear (with bias) pattern.""" + + def __init__( + self, dims: list[int], in_features: int, out_features: int + ) -> None: + super().__init__() + self.dims = dims + self.linear = torch.nn.Linear(in_features, out_features, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(self.dims) + batch_size = x.shape[0] + x = x.reshape(batch_size, -1) + x = self.linear(x) + return x + + +class SingleLinearModel(torch.nn.Module): + """Model with just a linear layer - should not be modified.""" + + def __init__(self, in_features: int, out_features: int) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +class TransposeModifiesBatchModel(torch.nn.Module): + """Model where transpose modifies batch dim - should NOT be fused.""" + + def __init__(self, in_features: int, out_features: int) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(1, 0, 2, 3) + x = x.reshape(x.shape[0], -1) + x = self.linear(x) + return x + + +def _count_permute_nodes(graph_module: torch.fx.GraphModule) -> int: + """Count the number of permute_copy nodes in the graph.""" + count = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.permute_copy.default: + count += 1 + return count + + +def _count_view_nodes(graph_module: torch.fx.GraphModule) -> int: + """Count the number of view_copy nodes in the graph.""" + count = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in ( + torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, + ): + count += 1 + return count + + +class TestFuseTransposeReshapeLinearPass(unittest.TestCase): + """Tests for FuseTransposeReshapeLinearPass.""" + + def test_basic_fusion(self) -> None: + """Test basic fusion of transpose -> reshape -> linear.""" + model = TransposeReshapeLinearModel( + dims=[0, 2, 3, 1], + in_features=512, + out_features=128, + ) + example_input = torch.randn(1, 64, 8, 1) + + pipeline = PassPipeline[torch.nn.Module](model, example_input) + permute_count_before = _count_permute_nodes(pipeline.graph_module) + + pipeline.run_passes([FuseTransposeReshapeLinearPass()]) + + permute_count_after = _count_permute_nodes(pipeline.graph_module) + + self.assertLess( + permute_count_after, + permute_count_before, + "Expected permute nodes to be reduced after fusion", + ) + + def test_transpose_modifies_batch_not_fused(self) -> None: + """Ensure transpose that modifies batch dim is NOT fused.""" + model = TransposeModifiesBatchModel( + in_features=512, + out_features=128, + ) + example_input = torch.randn(2, 1, 64, 8) + + pipeline = PassPipeline[torch.nn.Module](model, example_input) + permute_count_before = _count_permute_nodes(pipeline.graph_module) + + pipeline.run_passes([FuseTransposeReshapeLinearPass()]) + + permute_count_after = _count_permute_nodes(pipeline.graph_module) + + self.assertEqual( + permute_count_before, + permute_count_after, + "Should NOT fuse when transpose modifies batch dimension", + ) + + def test_single_linear_unchanged(self) -> None: + """Ensure single linear models are not modified.""" + model = SingleLinearModel(in_features=512, out_features=128) + example_input = torch.randn(1, 512) + + pipeline = PassPipeline[torch.nn.Module](model, example_input) + permute_count_before = _count_permute_nodes(pipeline.graph_module) + view_count_before = _count_view_nodes(pipeline.graph_module) + + pipeline.run_passes([FuseTransposeReshapeLinearPass()]) + + permute_count_after = _count_permute_nodes(pipeline.graph_module) + view_count_after = _count_view_nodes(pipeline.graph_module) + + self.assertEqual(permute_count_before, permute_count_after) + self.assertEqual(view_count_before, view_count_after) + + def test_nhwc_to_flatten_pattern(self) -> None: + """Test the common NHWC flatten pattern.""" + model = TransposeReshapeLinearModel( + dims=[0, 2, 3, 1], + in_features=512, + out_features=256, + ) + example_input = torch.randn(1, 64, 8, 1) + + pipeline = PassPipeline[torch.nn.Module](model, example_input) + pipeline.run_passes([FuseTransposeReshapeLinearPass()]) + + permute_count = _count_permute_nodes(pipeline.graph_module) + + self.assertEqual(permute_count, 0, "Expected all permutes to be fused") + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/arm/test/passes/test_fuse_transpose_reshape_transpose_pass.py b/backends/arm/test/passes/test_fuse_transpose_reshape_transpose_pass.py new file mode 100644 index 00000000000..ba04d34ee23 --- /dev/null +++ b/backends/arm/test/passes/test_fuse_transpose_reshape_transpose_pass.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2024-2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for FuseTransposeReshapeTransposePass. +""" + +import unittest + +import torch +from executorch.backends.arm._passes.fuse_transpose_reshape_transpose_pass import ( + FuseTransposeReshapeTransposePass, +) +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + + +class TransposeReshapeTransposeModel(torch.nn.Module): + """Model with transpose -> reshape -> transpose pattern.""" + + def __init__(self, dims1: list[int], shape: list[int], dims2: list[int]) -> None: + super().__init__() + self.dims1 = dims1 + self.shape = shape + self.dims2 = dims2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(self.dims1) + x = x.view(self.shape) + x = x.permute(self.dims2) + return x + + +class SingleTransposeModel(torch.nn.Module): + """Model with only a single transpose - should not be modified.""" + + def __init__(self, dims: list[int]) -> None: + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(self.dims) + + +class TransposeReshapeModel(torch.nn.Module): + """Model with transpose -> reshape but no second transpose.""" + + def __init__(self, dims: list[int], shape: list[int]) -> None: + super().__init__() + self.dims = dims + self.shape = shape + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(self.dims) + x = x.view(self.shape) + return x + + +def _count_permute_nodes(graph_module: torch.fx.GraphModule) -> int: + """Count the number of permute_copy nodes in the graph.""" + count = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.permute_copy.default: + count += 1 + return count + + +def _count_view_nodes(graph_module: torch.fx.GraphModule) -> int: + """Count the number of view_copy nodes in the graph.""" + count = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in ( + torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, + ): + count += 1 + return count + + +class TestFuseTransposeReshapeTransposePass(unittest.TestCase): + """Tests for FuseTransposeReshapeTransposePass.""" + + def test_basic_fusion(self) -> None: + """Test basic fusion of transpose -> reshape -> transpose.""" + model = TransposeReshapeTransposeModel( + dims1=[0, 2, 3, 1], + shape=[1, 8, 8, 64], + dims2=[0, 3, 1, 2], + ) + example_input = torch.randn(1, 64, 8, 8) + + pipeline = PassPipeline[torch.nn.Module](model, example_input) + pipeline.run_passes([FuseTransposeReshapeTransposePass()]) + graph_module = pipeline.graph_module + + permute_count = _count_permute_nodes(graph_module) + view_count = _count_view_nodes(graph_module) + + self.assertEqual(permute_count, 1, "Expected exactly 1 permute after fusion") + self.assertEqual(view_count, 1, "Expected exactly 1 view after fusion") + + def test_numerical_correctness(self) -> None: + """Verify that the pass produces numerically correct results.""" + model = TransposeReshapeTransposeModel( + dims1=[0, 2, 3, 1], + shape=[1, 8, 8, 64], + dims2=[0, 3, 1, 2], + ) + example_input = torch.randn(1, 64, 8, 8) + + expected_output = model(example_input) + + pipeline = PassPipeline[torch.nn.Module](model, example_input) + pipeline.run_passes([FuseTransposeReshapeTransposePass()]) + + actual_output = pipeline.graph_module(example_input) + + torch.testing.assert_close(actual_output, expected_output) + + def test_single_transpose_unchanged(self) -> None: + """Ensure single transpose models are not modified.""" + model = SingleTransposeModel(dims=[0, 2, 3, 1]) + example_input = torch.randn(1, 64, 8, 8) + + pipeline = PassPipeline[torch.nn.Module](model, example_input) + permute_count_before = _count_permute_nodes(pipeline.graph_module) + + pipeline.run_passes([FuseTransposeReshapeTransposePass()]) + + permute_count_after = _count_permute_nodes(pipeline.graph_module) + + self.assertEqual(permute_count_before, permute_count_after) + + def test_transpose_reshape_only_unchanged(self) -> None: + """Ensure transpose -> reshape (without second transpose) is not modified.""" + model = TransposeReshapeModel(dims=[0, 2, 3, 1], shape=[1, 512]) + example_input = torch.randn(1, 64, 8, 1) + + pipeline = PassPipeline[torch.nn.Module](model, example_input) + permute_count_before = _count_permute_nodes(pipeline.graph_module) + view_count_before = _count_view_nodes(pipeline.graph_module) + + pipeline.run_passes([FuseTransposeReshapeTransposePass()]) + + permute_count_after = _count_permute_nodes(pipeline.graph_module) + view_count_after = _count_view_nodes(pipeline.graph_module) + + self.assertEqual(permute_count_before, permute_count_after) + self.assertEqual(view_count_before, view_count_after) + + def test_nchw_to_nhwc_pattern(self) -> None: + """Test the NCHW -> NHWC reordering pattern from the docstring.""" + model = TransposeReshapeTransposeModel( + dims1=[0, 3, 1, 2], + shape=[1, 512, 8, 8], + dims2=[0, 2, 3, 1], + ) + example_input = torch.randn(1, 8, 8, 64) + + expected_output = model(example_input) + + pipeline = PassPipeline[torch.nn.Module](model, example_input) + pipeline.run_passes([FuseTransposeReshapeTransposePass()]) + + permute_count = _count_permute_nodes(pipeline.graph_module) + view_count = _count_view_nodes(pipeline.graph_module) + + self.assertEqual(permute_count, 1) + self.assertEqual(view_count, 1) + + actual_output = pipeline.graph_module(example_input) + torch.testing.assert_close(actual_output, expected_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/arm/test/passes/test_to_tosa_memory_format.py b/backends/arm/test/passes/test_to_tosa_memory_format.py index dfd57aa7e61..9c524642ec1 100644 --- a/backends/arm/test/passes/test_to_tosa_memory_format.py +++ b/backends/arm/test/passes/test_to_tosa_memory_format.py @@ -177,11 +177,76 @@ def get_inputs(self) -> input_t: return (torch.rand(4, 4, 4, 4),) +class NHWCSafeSpatialMerge(torch.nn.Module): + """Test-module with a 4D->4D reshape that merges spatial dims H*W while + preserving the last-dim channel. + + For models with view_copy shapes [1,2,14,72]->[1,28,1,72] where C=2 + sits at NCHW position 1 and the last dim (72) is the NHWC channel that gets + preserved. ``_is_nhwc_safe_reshape`` detects that shape_indices on the raw + shapes are monotonic with the last dim alone, so no transposes are inserted + around the view_copy. + + Setup: conv2d (forces NHWC, C=2) -> view_copy -> add (keeps in NHWC). + """ + + ops_before_pass: Dict[str, int] = {} + # Only the 2 I/O transposes for the conv, NO extra transposes from view_copy + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2 + } + ops_not_after_pass: List[str] = [] + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=2, out_channels=2, kernel_size=1, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) # forces NHWC path; output [1, 2, 14, 72] + x = x.view(1, 28, 1, 72) # spatial merge: H*W=2*14->28, last dim 72 preserved + return x + x # keep result 4-D in NHWC + + def get_inputs(self) -> input_t: + return (torch.randn(1, 2, 14, 72),) + + +class NHWCUnsafeChannelChange(torch.nn.Module): + """Test-module with a 4D->4D reshape that is NOT NHWC-safe because the + target shape cannot be produced by a monotonic merge of NHWC input dims. + The pass MUST still insert transposes around the view_copy. + """ + + ops_before_pass: Dict[str, int] = {} + # conv I/O transposes (2) + view_copy transposes (2) = 4 + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4 + } + ops_not_after_pass: List[str] = [] + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=72, out_channels=72, kernel_size=1, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) # output [1, 72, 2, 14] + x = x.view(1, 14, 2, 72) # not NHWC-safe (channels shuffled) + return x + x + + def get_inputs(self) -> input_t: + return (torch.randn(1, 72, 2, 14),) + + modules: Dict[str, ModuleMetadata] = { "no_nhwc": NoNHWC(), "parallel_clusters": ParallelClusters(), "serial_clusters": SerialClusters(), "reshapes": Reshapes(), + "nhwc_safe_spatial_merge": NHWCSafeSpatialMerge(), + "nhwc_unsafe_channel_change": NHWCUnsafeChannelChange(), } @@ -209,3 +274,92 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No module_nn = cast(torch.nn.Module, module) pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), []) pipeline.run() + + +# --- Direct unit tests for NHWC-safe reshape helpers --- + + +def test_get_shape_indices_spatial_merge(): + """[1,2,14,72] -> [1,28,1,72]: merge H*W, insert size-1 dim, preserve C.""" + indices = ToTosaMemoryFormatPass._get_shape_indices( + [1, 2, 14, 72], [1, 28, 1, 72] + ) + assert indices == [[0], [1, 2], [], [3]] + + +def test_get_shape_indices_identity(): + """Same shape => each dim maps to itself.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [2, 3, 4]) + assert indices == [[0], [1], [2]] + + +def test_get_shape_indices_full_merge(): + """[2, 3, 4] -> [24]: merge all dims into one.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [24]) + assert indices == [[0, 1, 2]] + + +def test_get_shape_indices_incompatible(): + """Sizes that don't divide => None.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 5], [6, 4]) + assert indices is None + + +def test_get_shape_indices_size_one_insert(): + """[6, 4] -> [6, 1, 4]: inserted size-1 dim in the middle.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([6, 4], [6, 1, 4]) + assert indices is not None + assert indices == [[0], [], [1]] + + +def test_is_monotonic_true(): + assert ToTosaMemoryFormatPass._is_monotonic([[0], [1, 2], [], [3]]) + assert ToTosaMemoryFormatPass._is_monotonic([[0], [], [1], [2, 3]]) + assert ToTosaMemoryFormatPass._is_monotonic([[], [0, 1, 2]]) + + +def test_is_monotonic_false(): + assert not ToTosaMemoryFormatPass._is_monotonic([[1], [0]]) + assert not ToTosaMemoryFormatPass._is_monotonic([[0, 2], [1]]) + + +def test_is_nhwc_safe_forward(): + """Shapes already NHWC by the time the pass runs. + [1,2,14,72] -> [1,28,1,72], sr=2 -> NHWC-safe (spatial merge, C=72 preserved). + """ + assert ToTosaMemoryFormatPass._is_nhwc_safe_reshape( + [1, 2, 14, 72], [1, 28, 1, 72], input_sr=2, output_sr=2 + ) + + +def test_is_nhwc_safe_non_4d(): + """Reshapes below rank 4 are never NHWC-safe.""" + assert not ToTosaMemoryFormatPass._is_nhwc_safe_reshape( + [6, 4], [24], input_sr=0, output_sr=0 + ) + + +def test_is_nop_transpose_size1_swap(): + """[14, 72, 1, 1] with perm (0, 1, 3, 2) only swaps trailing size-1 dims.""" + assert ToTosaMemoryFormatPass._is_nop_transpose([14, 72, 1, 1], (0, 1, 3, 2)) + + +def test_is_nop_transpose_real_reorder(): + """[14, 72, 1, 1] with perm (1, 0, 2, 3) swaps non-size-1 dims.""" + assert not ToTosaMemoryFormatPass._is_nop_transpose([14, 72, 1, 1], (1, 0, 2, 3)) + + +def test_is_nop_transpose_all_size1(): + """[1, 1, 1, 1] with any perm is always a NOP.""" + assert ToTosaMemoryFormatPass._is_nop_transpose([1, 1, 1, 1], (3, 2, 1, 0)) + + +def test_is_nop_transpose_identity(): + """Identity permutation is always a NOP.""" + assert ToTosaMemoryFormatPass._is_nop_transpose([2, 3, 4], (0, 1, 2)) + + +def test_is_nop_transpose_nhwc_on_size1_spatial(): + """[1, 28, 1, 72] with channels_last (0,2,3,1): non-size-1 dims 28,72 + change relative order (28→pos3, 72→pos2) → NOT a NOP.""" + assert not ToTosaMemoryFormatPass._is_nop_transpose([1, 28, 1, 72], (0, 2, 3, 1))