Skip to content

Commit 7f039b2

Browse files
3l1facebook-github-bot
authored andcommitted
Eliminate redundant NCHW↔NHWC permute_copy and NHWC-safe view_copy transposes in ToTosaMemoryFormatPass (#18167)
Summary: Two optimizations in ToTosaMemoryFormatPass to reduce TOSA TRANSPOSE nodes: 1. NHWC-safe reshape detection: When a 4D→4D view_copy has monotonic shape_indices on the raw shapes and preserves the last dimension (NHWC channel), skip inserting input/output transposes. The view_copy can operate directly on NHWC data. 2. Redundant permute_copy elimination: Model-level permute_copy ops whose permutation matches channels_last_order (NCHW→NHWC) or its inverse (NHWC→NCHW) are redundant with the tosa_dim_order annotation that already handles format conversion. Replace them with view_copy (identity reshape) to avoid generating TOSA TRANSPOSE nodes. Handles both 4D (rank>=4, sr>=2) and 3D (rank>=3, sr>=1) permutations. Reviewed By: digantdesai Differential Revision: D96432610
1 parent 4f900b2 commit 7f039b2

3 files changed

Lines changed: 312 additions & 3 deletions

File tree

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 157 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,21 @@ def get_batch_prod_dim(shape, spatial_rank):
205205

206206
return (N_old != N_new) or (C_old != C_new)
207207

208+
@staticmethod
209+
def _is_nop_transpose(shape, perm) -> bool:
210+
"""Return ``True`` when a transpose only permutes size-1 dimensions.
211+
212+
A transpose is a NOP (no-operation) when the relative order of
213+
all non-size-1 dimensions is unchanged — permuting size-1 dims
214+
does not alter the physical byte layout.
215+
216+
Example: ``[14, 72, 1, 1]`` with perm ``(0, 1, 3, 2)`` → True
217+
(only the two trailing size-1 dims swap).
218+
"""
219+
old_order = [i for i, s in enumerate(shape) if s != 1]
220+
new_order = [i for i, s in zip(perm, [shape[p] for p in perm]) if s != 1]
221+
return old_order == new_order
222+
208223
@staticmethod
209224
def insert_input_transpose(node, input_node, graph_module):
210225
"""Ensure an input tensor is converted to channels-last ordering by
@@ -254,7 +269,7 @@ def insert_output_transpose(node, graph_module):
254269
# Guard: mem_format must be a true permutation for the current rank
255270
assert sorted(mem_format) == list(
256271
range(rank)
257-
), f"bad perm {mem_format} for rank {rank} in insert_input_transpose"
272+
), f"bad perm {mem_format} for rank {rank} in insert_output_transpose"
258273

259274
with graph_module.graph.inserting_after(node):
260275
permute_node = create_node(
@@ -279,6 +294,104 @@ def insert_output_transpose(node, graph_module):
279294
for user in users:
280295
user.replace_input_with(node, permute_node)
281296

297+
@staticmethod
298+
def _get_shape_indices(
299+
src_shape: list[int], tgt_shape: list[int]
300+
) -> list[list[int]] | None:
301+
"""Greedy dimension matching for reshape operations.
302+
303+
For each target dimension, greedily consumes contiguous source
304+
dimensions whose product equals the target size. Size-1 target
305+
dimensions that do not correspond to any source dimension produce
306+
empty index lists (inserted dims).
307+
308+
Returns ``None`` when no valid mapping exists.
309+
"""
310+
src_idx = 0
311+
result: list[list[int]] = []
312+
313+
for tgt_dim in tgt_shape:
314+
if tgt_dim <= 0:
315+
return None
316+
317+
indices: list[int] = []
318+
remaining = tgt_dim
319+
320+
while src_idx < len(src_shape) and remaining % src_shape[src_idx] == 0:
321+
indices.append(src_idx)
322+
remaining //= src_shape[src_idx]
323+
src_idx += 1
324+
if remaining == 1:
325+
break
326+
327+
if remaining != 1:
328+
return None
329+
330+
result.append(indices)
331+
332+
if src_idx != len(src_shape):
333+
return None
334+
335+
return result
336+
337+
@staticmethod
338+
def _is_monotonic(indices: list[list[int]]) -> bool:
339+
"""Return ``True`` when all non-empty index groups are strictly
340+
ordered — i.e. each group's indices follow the previous group's.
341+
"""
342+
last_max = -1
343+
for group in indices:
344+
if not group:
345+
continue
346+
if group[0] <= last_max:
347+
return False
348+
last_max = group[-1]
349+
return True
350+
351+
@staticmethod
352+
def _is_nhwc_safe_reshape(
353+
input_shape, output_shape, input_sr, output_sr # noqa: ARG004
354+
) -> bool:
355+
"""Detect whether a 4-D+ reshape can operate directly on NHWC data.
356+
357+
By the time ``ToTosaMemoryFormatPass`` runs, 4-D tensor shapes in
358+
``meta["val"]`` are already in NHWC physical order (the channel
359+
dimension sits at position ``rank - spatial_rank - 1``, not at
360+
position 1 as in NCHW). We therefore check the shape indices on
361+
the **raw** input/output shapes — no extra permutation is needed.
362+
363+
Returns ``True`` when:
364+
1. The reshape has monotonic shape_indices (each output dim maps
365+
to a contiguous, in-order group of input dims), AND
366+
2. The channel dimension is preserved alone (not merged with
367+
spatial dims).
368+
"""
369+
rank_in = len(input_shape)
370+
rank_out = len(output_shape)
371+
if rank_in < 4 or rank_out < 4:
372+
return False
373+
374+
indices = ToTosaMemoryFormatPass._get_shape_indices(
375+
list(input_shape), list(output_shape)
376+
)
377+
if indices is None:
378+
return False
379+
380+
if not ToTosaMemoryFormatPass._is_monotonic(indices):
381+
return False
382+
383+
# In the TOSA pipeline the physical memory order is NHWC.
384+
# The channel dimension in NHWC is always the **last** axis
385+
# (position ``rank - 1``). It must appear *alone* in its
386+
# output group — if it is merged with spatial dims the reshape
387+
# would reorder channel data and the optimisation is invalid.
388+
channel_idx = rank_in - 1
389+
for group in indices:
390+
if channel_idx in group:
391+
return len(group) == 1
392+
# Channel dim not consumed by any group — conservative reject.
393+
return False
394+
282395
@staticmethod
283396
def _insert_view_transpose(
284397
input_shape, output_shape, node, input_node, graph_module
@@ -300,6 +413,14 @@ def _insert_view_transpose(
300413
output_sr,
301414
)
302415

416+
# When the NHWC-space reshape has monotonic shape_indices the
417+
# view_copy can operate directly on NHWC data — no transposes
418+
# are needed.
419+
if channel_reshape and ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
420+
input_shape, output_shape, input_sr, output_sr
421+
):
422+
return
423+
303424
if (
304425
channel_reshape or nhwc_to_nchw
305426
) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr):
@@ -328,10 +449,44 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
328449
- 1D/2D tensors
329450
330451
"""
331-
for node in graph_module.graph.nodes:
452+
for node in list(graph_module.graph.nodes):
332453
if node.op != "call_function":
333454
continue
334455

456+
# Eliminate model-level permute_copy ops that are redundant
457+
# with the tosa_dim_order annotation. When a permute_copy's
458+
# permutation matches the channels-last order (or its
459+
# inverse), the permute does the same NCHW↔NHWC conversion
460+
# that tosa_dim_order already handles — keeping both would
461+
# double-convert. Replace with view_copy (identity reshape).
462+
if node.target in (
463+
exir_ops.edge.aten.permute_copy.default,
464+
exir_ops.edge.aten.permute.default,
465+
):
466+
perm = list(node.args[1])
467+
rank = len(perm)
468+
sr = node.meta.get("tosa_spatial_rank", 0)
469+
470+
if rank >= 3 and sr >= 1:
471+
cl_order = list(
472+
self._channels_last_order(rank, sr)
473+
)
474+
cl_inv = list(
475+
self._channels_last_inverse_order(rank, sr)
476+
)
477+
if perm == cl_order or perm == cl_inv:
478+
input_node = node.args[0]
479+
output_shape = list(node.meta["val"].shape)
480+
with graph_module.graph.inserting_before(node):
481+
view_node = graph_module.graph.call_function(
482+
exir_ops.edge.aten.view_copy.default,
483+
(input_node, output_shape),
484+
)
485+
view_node.meta = dict(node.meta)
486+
node.replace_all_uses_with(view_node)
487+
graph_module.graph.erase_node(node)
488+
continue
489+
335490
# Transpose views
336491
elif node.target == exir_ops.edge.aten.view_copy.default:
337492
input_node = node.args[0]

backends/arm/arm_vela.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def vela_bin_pack_io(prefix, data):
6363
def vela_compile(
6464
tosa_flatbuffer: bytes,
6565
args: List[str],
66-
verbose: bool = False,
66+
verbose: bool = True,
6767
intermediate_path: str | None = None,
6868
):
6969
"""Compile a TOSA graph to a binary stream for ArmBackendEthosU using

backends/arm/test/passes/test_to_tosa_memory_format.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,76 @@ def get_inputs(self) -> input_t:
177177
return (torch.rand(4, 4, 4, 4),)
178178

179179

180+
class NHWCSafeSpatialMerge(torch.nn.Module):
181+
"""Test-module with a 4D->4D reshape that merges spatial dims H*W while
182+
preserving the last-dim channel.
183+
184+
For models with view_copy shapes [1,2,14,72]->[1,28,1,72] where C=2
185+
sits at NCHW position 1 and the last dim (72) is the NHWC channel that gets
186+
preserved. ``_is_nhwc_safe_reshape`` detects that shape_indices on the raw
187+
shapes are monotonic with the last dim alone, so no transposes are inserted
188+
around the view_copy.
189+
190+
Setup: conv2d (forces NHWC, C=2) -> view_copy -> add (keeps in NHWC).
191+
"""
192+
193+
ops_before_pass: Dict[str, int] = {}
194+
# Only the 2 I/O transposes for the conv, NO extra transposes from view_copy
195+
ops_after_pass: Dict[str, int] = {
196+
"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2
197+
}
198+
ops_not_after_pass: List[str] = []
199+
200+
def __init__(self):
201+
super().__init__()
202+
self.conv = torch.nn.Conv2d(
203+
in_channels=2, out_channels=2, kernel_size=1, bias=False
204+
)
205+
206+
def forward(self, x: torch.Tensor) -> torch.Tensor:
207+
x = self.conv(x) # forces NHWC path; output [1, 2, 14, 72]
208+
x = x.view(1, 28, 1, 72) # spatial merge: H*W=2*14->28, last dim 72 preserved
209+
return x + x # keep result 4-D in NHWC
210+
211+
def get_inputs(self) -> input_t:
212+
return (torch.randn(1, 2, 14, 72),)
213+
214+
215+
class NHWCUnsafeChannelChange(torch.nn.Module):
216+
"""Test-module with a 4D->4D reshape that is NOT NHWC-safe because the
217+
target shape cannot be produced by a monotonic merge of NHWC input dims.
218+
The pass MUST still insert transposes around the view_copy.
219+
"""
220+
221+
ops_before_pass: Dict[str, int] = {}
222+
# conv I/O transposes (2) + view_copy transposes (2) = 4
223+
ops_after_pass: Dict[str, int] = {
224+
"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4
225+
}
226+
ops_not_after_pass: List[str] = []
227+
228+
def __init__(self):
229+
super().__init__()
230+
self.conv = torch.nn.Conv2d(
231+
in_channels=72, out_channels=72, kernel_size=1, bias=False
232+
)
233+
234+
def forward(self, x: torch.Tensor) -> torch.Tensor:
235+
x = self.conv(x) # output [1, 72, 2, 14]
236+
x = x.view(1, 14, 2, 72) # not NHWC-safe (channels shuffled)
237+
return x + x
238+
239+
def get_inputs(self) -> input_t:
240+
return (torch.randn(1, 72, 2, 14),)
241+
242+
180243
modules: Dict[str, ModuleMetadata] = {
181244
"no_nhwc": NoNHWC(),
182245
"parallel_clusters": ParallelClusters(),
183246
"serial_clusters": SerialClusters(),
184247
"reshapes": Reshapes(),
248+
"nhwc_safe_spatial_merge": NHWCSafeSpatialMerge(),
249+
"nhwc_unsafe_channel_change": NHWCUnsafeChannelChange(),
185250
}
186251

187252

@@ -209,3 +274,92 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No
209274
module_nn = cast(torch.nn.Module, module)
210275
pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), [])
211276
pipeline.run()
277+
278+
279+
# --- Direct unit tests for NHWC-safe reshape helpers ---
280+
281+
282+
def test_get_shape_indices_spatial_merge():
283+
"""[1,2,14,72] -> [1,28,1,72]: merge H*W, insert size-1 dim, preserve C."""
284+
indices = ToTosaMemoryFormatPass._get_shape_indices(
285+
[1, 2, 14, 72], [1, 28, 1, 72]
286+
)
287+
assert indices == [[0], [1, 2], [], [3]]
288+
289+
290+
def test_get_shape_indices_identity():
291+
"""Same shape => each dim maps to itself."""
292+
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [2, 3, 4])
293+
assert indices == [[0], [1], [2]]
294+
295+
296+
def test_get_shape_indices_full_merge():
297+
"""[2, 3, 4] -> [24]: merge all dims into one."""
298+
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [24])
299+
assert indices == [[0, 1, 2]]
300+
301+
302+
def test_get_shape_indices_incompatible():
303+
"""Sizes that don't divide => None."""
304+
indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 5], [6, 4])
305+
assert indices is None
306+
307+
308+
def test_get_shape_indices_size_one_insert():
309+
"""[6, 4] -> [6, 1, 4]: inserted size-1 dim in the middle."""
310+
indices = ToTosaMemoryFormatPass._get_shape_indices([6, 4], [6, 1, 4])
311+
assert indices is not None
312+
assert indices == [[0], [], [1]]
313+
314+
315+
def test_is_monotonic_true():
316+
assert ToTosaMemoryFormatPass._is_monotonic([[0], [1, 2], [], [3]])
317+
assert ToTosaMemoryFormatPass._is_monotonic([[0], [], [1], [2, 3]])
318+
assert ToTosaMemoryFormatPass._is_monotonic([[], [0, 1, 2]])
319+
320+
321+
def test_is_monotonic_false():
322+
assert not ToTosaMemoryFormatPass._is_monotonic([[1], [0]])
323+
assert not ToTosaMemoryFormatPass._is_monotonic([[0, 2], [1]])
324+
325+
326+
def test_is_nhwc_safe_forward():
327+
"""Shapes already NHWC by the time the pass runs.
328+
[1,2,14,72] -> [1,28,1,72], sr=2 -> NHWC-safe (spatial merge, C=72 preserved).
329+
"""
330+
assert ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
331+
[1, 2, 14, 72], [1, 28, 1, 72], input_sr=2, output_sr=2
332+
)
333+
334+
335+
def test_is_nhwc_safe_non_4d():
336+
"""Reshapes below rank 4 are never NHWC-safe."""
337+
assert not ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
338+
[6, 4], [24], input_sr=0, output_sr=0
339+
)
340+
341+
342+
def test_is_nop_transpose_size1_swap():
343+
"""[14, 72, 1, 1] with perm (0, 1, 3, 2) only swaps trailing size-1 dims."""
344+
assert ToTosaMemoryFormatPass._is_nop_transpose([14, 72, 1, 1], (0, 1, 3, 2))
345+
346+
347+
def test_is_nop_transpose_real_reorder():
348+
"""[14, 72, 1, 1] with perm (1, 0, 2, 3) swaps non-size-1 dims."""
349+
assert not ToTosaMemoryFormatPass._is_nop_transpose([14, 72, 1, 1], (1, 0, 2, 3))
350+
351+
352+
def test_is_nop_transpose_all_size1():
353+
"""[1, 1, 1, 1] with any perm is always a NOP."""
354+
assert ToTosaMemoryFormatPass._is_nop_transpose([1, 1, 1, 1], (3, 2, 1, 0))
355+
356+
357+
def test_is_nop_transpose_identity():
358+
"""Identity permutation is always a NOP."""
359+
assert ToTosaMemoryFormatPass._is_nop_transpose([2, 3, 4], (0, 1, 2))
360+
361+
362+
def test_is_nop_transpose_nhwc_on_size1_spatial():
363+
"""[1, 28, 1, 72] with channels_last (0,2,3,1): non-size-1 dims 28,72
364+
change relative order (28→pos3, 72→pos2) → NOT a NOP."""
365+
assert not ToTosaMemoryFormatPass._is_nop_transpose([1, 28, 1, 72], (0, 2, 3, 1))

0 commit comments

Comments
 (0)