Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/arm/_passes/decompose_grouped_conv_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,10 @@ def call_operator(self, op, args, kwargs, meta):

input_node = args[0]
if DecomposeGroupedConvPass._is_depthwise_conv(input_node, groups, transposed):
# This is a depthwise convolution which is handled elsewhere
return super().call_operator(op, args, kwargs, meta)
# Conv2D depthwise maps to TOSA DEPTHWISE_CONV2D — handled in RewriteConvPass.
# Conv3D has no DEPTHWISE_CONV3D, so fall through and decompose like grouped conv.
if len(input_node.data.shape) != 5:
return super().call_operator(op, args, kwargs, meta)

weight_node = args[1]
bias_node = args[2]
Expand Down
10 changes: 5 additions & 5 deletions backends/arm/_passes/rewrite_conv_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ def _is_depthwise_conv2d(self, node: torch.fx.Node) -> bool:

def _is_conv3d(self, rank, groups) -> bool:
if rank == 5:
# A Conv3D is considered depthwise if Group == InChannels and
# Group * N == OutChannels, where N is a possitive integer.
# Currently we do not support depthwise or grouped conv3d.
# @TODO Add grouped/depthwise conv3d support or reject in partitioner.
# Both grouped and depthwise Conv3D are decomposed into groups==1
# convolutions by DecomposeGroupedConvPass before reaching here.
# This guard is defense-in-depth for paths that bypass that pass.
if groups != 1:
raise RuntimeError(
"CONV3D with groups != 1 is not supported in the Arm backend."
"CONV3D with groups != 1 reached unexpectedly; "
"DecomposeGroupedConvPass should have decomposed it first."
)
return True
return False
Expand Down
50 changes: 39 additions & 11 deletions backends/arm/test/ops/test_conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,32 @@ def forward(self, x):
return self.conv(x)


class GroupedConv3d(torch.nn.Module):
"""Non-depthwise grouped Conv3d (in_channels != groups).

Split into ``groups`` plain convolutions by DecomposeGroupedConvPass, so it
is delegated unlike the depthwise case.

"""

def __init__(self, dtype=torch.float):
super().__init__()
self.dtype = dtype
self.conv = torch.nn.Conv3d(
in_channels=4,
out_channels=4,
kernel_size=(3, 3, 3),
padding=1,
groups=2,
).to(dtype)

def get_inputs(self):
return (torch.randn(1, 4, 8, 8, 8).to(self.dtype),)

def forward(self, x):
return self.conv(x)


conv3d_2x2_3x2x14x14_nobias = Conv3d(
in_channels=2,
out_channels=3,
Expand Down Expand Up @@ -623,19 +649,21 @@ def test_convolution_3d_tosa_INT_multi_op():


def test_convolution_3d_tosa_FP_depthwise():
"""Depthwise or Grouped Conv3d should be rejected until grouped support
exists.
"""Depthwise Conv3d should be delegated, decomposed into groups==1
convolutions by DecomposeGroupedConvPass.
"""
model = DepthwiseConv3d()
pipeline = TosaPipelineFP[input_t](
model,
model.get_inputs(),
aten_op,
exir_op,
run_on_tosa_ref_model=False,
)
with pytest.raises(RuntimeError, match="CONV3D with groups != 1"):
pipeline.run()
pipeline = TosaPipelineFP[input_t](model, model.get_inputs(), aten_op, exir_op)
pipeline.run()


def test_convolution_3d_tosa_FP_grouped():
"""Non-depthwise grouped Conv3d should be delegated, decomposed into
groups==1 convolutions by DecomposeGroupedConvPass.
"""
model = GroupedConv3d()
pipeline = TosaPipelineFP[input_t](model, model.get_inputs(), aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", test_data_INT)
Expand Down
Loading