Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions backends/xnnpack/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,11 @@ def define_node(
stride = cast(List[int], node.args[3])
padding = cast(List[int], node.args[4])
dilation = cast(List[int], node.args[5])
output_padding = cast(List[int], node.args[7])
if len(padding) == 1:
padding = padding + padding

# args[7] = output padding
check_or_raise(
all(out_pad == 0 for out_pad in cast(List[int], node.args[7])),
"XNNPACK does not support output padding",
)
if len(output_padding) == 1:
output_padding = output_padding + output_padding

check_or_raise(
len(stride) == 2, "XNNPACK currently only supports 2D convolution"
Expand All @@ -165,8 +162,8 @@ def define_node(
kwargs["group_input_channels"] = group_input_channels
kwargs["group_output_channels"] = group_output_channels
kwargs["groups"] = groups
kwargs["adjustment_height"] = 0
kwargs["adjustment_width"] = 0
kwargs["adjustment_height"] = output_padding[0]
kwargs["adjustment_width"] = output_padding[1]
kwargs["flags"] = 0

if is_depthwise_conv:
Expand Down
12 changes: 0 additions & 12 deletions backends/xnnpack/partition/config/gemm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,18 +382,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
)
return False

# XNNPACK does not support non-zero output padding in transposed
# convolutions.
if is_transpose and any(
out_pad != 0 for out_pad in cast(List[int], node.args[7])
):
why(
node,
"XNNPACK does not support transposed convolutions with"
"non-zero output padding",
)
return False

if (
is_transpose
and weight_quant_params is not None
Expand Down
194 changes: 176 additions & 18 deletions backends/xnnpack/test/ops/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,17 +656,20 @@ def get_inputs(self):
conv_count=1,
)

def test_padded_output_tconv(self):
class TConv2d(torch.nn.Module):
def __init__(self):
def test_fp32_tconv_output_padding(self):
"""Test transposed convolution with non-zero output padding."""

class TConv2dOutputPadding(torch.nn.Module):
def __init__(self, output_padding):
super().__init__()
self.transpose = True
self.conv = torch.nn.ConvTranspose2d(
in_channels=2,
out_channels=1,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
output_padding=(0, 1),
output_padding=output_padding,
dilation=(1, 1),
groups=1,
bias=True,
Expand All @@ -675,26 +678,181 @@ def __init__(self):
def forward(self, x):
return self.conv(x)

m = TConv2d()
inputs = (torch.randn(1, 2, 8, 8),)
tester = Tester(m.eval(), inputs)
def get_inputs(self):
return (torch.randn(1, 2, 8, 8),)

conv_count: int = 1
op = "torch.ops.aten.conv_transpose2d"
# Test asymmetric output padding (0, 1)
self._test(TConv2dOutputPadding(output_padding=(0, 1)))

(tester.export().check_count({op: conv_count}).to_edge_transform_and_lower())
# Test symmetric output padding (1, 1)
self._test(TConv2dOutputPadding(output_padding=(1, 1)))

# tconv should not be offloaded to XNNPack, since output padding is not supported
(
tester.check(
["executorch_exir_dialects_edge__ops_aten_convolution_default"]
def test_qs8_tconv_output_padding(self):
"""Test quantized transposed convolution with non-zero output padding."""

class TConv2dOutputPadding(torch.nn.Module):
def __init__(self, output_padding):
super().__init__()
self.transpose = True
self.conv = torch.nn.ConvTranspose2d(
in_channels=2,
out_channels=1,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
output_padding=output_padding,
dilation=(1, 1),
groups=1,
bias=True,
).to(torch.float)

def forward(self, x):
return self.conv(x)

def get_inputs(self):
return (torch.randn(1, 2, 8, 8),)

# Test asymmetric output padding (0, 1) with quantization
self._test(
TConv2dOutputPadding(output_padding=(0, 1)),
quant_config=get_symmetric_quantization_config(),
)

# Test symmetric output padding (1, 1) with quantization
self._test(
TConv2dOutputPadding(output_padding=(1, 1)),
quant_config=get_symmetric_quantization_config(),
)

def test_fp32_tconv_output_padding_large_stride(self):
"""Test transposed convolution with larger output padding and stride values."""

class TConv2dLargeOutputPadding(torch.nn.Module):
def __init__(self, stride, output_padding):
super().__init__()
self.transpose = True
self.conv = torch.nn.ConvTranspose2d(
in_channels=8,
out_channels=16,
kernel_size=(5, 5),
stride=stride,
padding=(2, 2),
output_padding=output_padding,
dilation=(1, 1),
groups=1,
bias=True,
).to(torch.float)

def forward(self, x):
return self.conv(x)

def get_inputs(self):
return (torch.randn(2, 8, 16, 16),)

# Test with stride=4 and output_padding=(3, 3) - maximum valid for stride 4
self._test(TConv2dLargeOutputPadding(stride=(4, 4), output_padding=(3, 3)))

# Test with stride=3 and asymmetric output_padding=(2, 1)
self._test(TConv2dLargeOutputPadding(stride=(3, 3), output_padding=(2, 1)))

# Test with asymmetric stride and output_padding
self._test(TConv2dLargeOutputPadding(stride=(4, 3), output_padding=(3, 2)))

def test_fp32_tconv_output_padding_various_shapes(self):
"""Test transposed convolution with output padding on various input shapes."""

class TConv2dVariousShapes(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
height,
width,
):
super().__init__()
self.transpose = True
self.height = height
self.width = width
self.in_channels = in_channels
self.conv = torch.nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
dilation=(1, 1),
groups=1,
bias=True,
).to(torch.float)

def forward(self, x):
return self.conv(x)

def get_inputs(self):
return (torch.randn(1, self.in_channels, self.height, self.width),)

# Test with larger kernel (7x7), stride=2, output_padding=1
self._test(
TConv2dVariousShapes(
in_channels=3,
out_channels=32,
kernel_size=(7, 7),
stride=(2, 2),
padding=(3, 3),
output_padding=(1, 1),
height=32,
width=32,
)
.check_not(["torch.ops.higher_order.executorch_call_delegate"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs(qtol=1)
)

# Test with rectangular kernel and asymmetric output padding
self._test(
TConv2dVariousShapes(
in_channels=16,
out_channels=8,
kernel_size=(3, 5),
stride=(2, 3),
padding=(1, 2),
output_padding=(1, 2),
height=24,
width=32,
)
)

# Test with small spatial dimensions but larger output padding
self._test(
TConv2dVariousShapes(
in_channels=4,
out_channels=4,
kernel_size=(4, 4),
stride=(4, 4),
padding=(0, 0),
output_padding=(3, 3),
height=4,
width=4,
)
)

# Test with batch size > 1 and asymmetric everything
model = TConv2dVariousShapes(
in_channels=6,
out_channels=12,
kernel_size=(5, 3),
stride=(3, 2),
padding=(2, 1),
output_padding=(2, 1),
height=20,
width=15,
)
# Override get_inputs to use larger batch
model.get_inputs = lambda: (torch.randn(4, 6, 20, 15),)
self._test(model)

def test_dq_conv2d(self) -> None:
model = Conv2d(
in_channels=3,
Expand Down
Loading