Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def torch_lib_onnx_functions_from_registry() -> Generator[onnxscript.OnnxFunctio

class TestDeduceTypeConstraints(unittest.TestCase):
_SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN = (
"_aten_as_strided_onnx",
"_aten_unfold_onnx",
"_aten_embedding_bag_onnx",
"_aten_embedding_bag_1d_padding_idx_onnx",
Expand Down
130 changes: 69 additions & 61 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,72 +817,80 @@ def aten_argwhere(self: TensorType) -> TensorType:

@torch_op("aten::as_strided", trace_only=True)
def aten_as_strided(
self: TTensor, size: INT64, stride: Sequence[int], storage_offset: int = 0
self: TTensor,
size: Sequence[INT64],
stride: Sequence[INT64],
storage_offset: Optional[INT64] = None,
) -> TTensor:
"""as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)"""

rank = len(stride)
return _aten_as_strided_onnx(self, size, stride, storage_offset, rank)


@torch_op("aten::as_strided", private=True)
def _aten_as_strided_onnx(
self: TTensor, size: INT64, stride: INT64, storage_offset: int = 0, rank: int = 0
) -> TTensor:
# e.g. when size=[2,3,4], stride=[2,1,3], indices=[0]
# i = 0
# indices=[0], add_value=[0,3,6,9]
# expand(shape=[4]) to [0,0,0,0]
# then + add_value = [0,3,6,9]
# i = 1
# indices=[0,3,6,9], add_value=[0,1,2]
# expand(shape=[3,4] to [[0,3,6,9],[0,3,6,9],[0,3,6,9]]
# indices + add_value = [[0,3,6,9],[1,3,7,10],[2,5,8,11]]
# i = 2
# indices = [[0,3,6,9],[1,3,7,10],[2,5,8,11]], add_value=[0,2]
# expand(shape=[2,3,4]) to [[[0,3,6,9],[1,3,7,10],[2,5,8,11]]],[[0,3,6,9],[1,3,7,10],[2,5,8,11]]]
# indices + add_value = [[[0,3,6,9],[1,3,7,10],[2,5,8,11]]],[[2,5,8,11],[3,5,9,12],[4,7,10,13]]]
neg_1 = op.Constant(value_ints=[-1])
rank_tensor = op.Reshape(rank, neg_1) # should be 3
# The final indices for op.Gather(data, indices), will be continually changed during the loop
indices = op.Constant(value_int=0)
one_seq = op.SequenceEmpty()
for i in range(rank):
# Get the index from back to front, should be 2,1,0 when to i=0,1,2
j = rank - i - 1
j_tensor = op.Reshape(j, neg_1)
# Get size according to index_j, should be 4,3,2 when i=0,1,2
size_dim_j = op.Gather(size, j_tensor, axis=0)
# Get right size according to index_j, should be [4],[3,4],[2,3,4] when i=0,1,2
size_after_j = op.Slice(size, j_tensor, rank_tensor)
# Get stride according to index_j, should be 3,1,2 when i=0,1,2
stride_dim_j = op.Gather(stride, j_tensor, axis=0)
indices = op.Expand(indices, size_after_j)
# When size[j]=4, stride[j]=3, then add_value = [0,1,2,3] * 3 = [0,3,6,9]
# When size[j]=3, stride[j]=1, then add_value = [0,1,2] * 1 = [0,1,2]
# When size[j]=2, stride[j]=2, then add_value = [0,1] * 2 = [0,2]
add_value = op.Range(0, size_dim_j, 1) * stride_dim_j
# Compute the shape for add_value for correct broadcasting
if i == 0:
# shape = [dim_size]
shape = size_dim_j
else:
# shape = [dim_size, 1, 1, ...], the count of 1 euqal to i
ones = op.ConcatFromSequence(one_seq, axis=0)
shape = op.Concat(op.Cast(size_dim_j, to=FLOAT.dtype), ones, axis=0)
shape = op.Cast(shape, to=INT64.dtype)

add_value = op.Reshape(add_value, shape)
# Broadcasting add value to indices according to size and stride value
indices = indices + add_value
# Dims after dim_size to reshape(add_value), should be [1],[1,1],[1,1,1] when i=0,1,2
one_seq = op.SequenceInsert(one_seq, op.Constant(value_floats=[1.0]))

# torch.as_strided produces a view of `self`'s underlying contiguous storage
# with the requested `size` (the output shape) and `stride` (the step, in
# elements of storage, taken along each output dimension), starting at
# `storage_offset` elements into the storage. For an output element at
# position (i_0, ..., i_{n-1}) the element read from storage lives at the flat
# index storage_offset + sum_d i_d * stride[d]. So if we flatten `self` to 1-D
# and gather it with a tensor of those flat indices shaped like the output, we
# reproduce the view as a single Gather. This avoids the hard-to-fold loop of
# the previous implementation.
rank = len(size)
# `self_flatten` is the contiguous storage as a 1-D tensor; Gather indexes into it.
self_flatten = op.Reshape(self, op.Constant(value_ints=[-1]))
indices = op.Add(indices, storage_offset)
result = op.Gather(self_flatten, indices)

return result
# A missing storage_offset means "start at the beginning of the storage".
if storage_offset is None:
storage_offset = 0

if (
all(isinstance(s, int) for s in size)
and all(isinstance(s, int) for s in stride)
and isinstance(storage_offset, int)
):
# Static fast path: every size/stride/offset is known at trace time, so we
# compute the full index tensor with NumPy and emit it as a single
# constant that downstream passes can fold trivially.
# Start from the storage_offset; the per-dimension contributions are added in.
indices = np.array(storage_offset, dtype=np.int64)
for dim, (dim_size, dim_stride) in enumerate(zip(size, stride)):
# Contribution of dimension `dim`: index i_dim contributes i_dim * stride[dim].
add_value = np.arange(dim_size, dtype=np.int64) * dim_stride
# Reshape that 1-D contribution so it broadcasts along `dim` only
# (length dim_size at position `dim`, length 1 everywhere else), which
# lets the running sum build the full n-D index grid.
broadcast_shape = [1] * rank
broadcast_shape[dim] = dim_size
indices = indices + add_value.reshape(broadcast_shape)
# `indices` now has shape `size`; gathering yields the strided view.
return op.Gather(self_flatten, op.Constant(value=ir.tensor(indices)))

# Dynamic path: at least one SymInt is a runtime value, so the index tensor
# cannot be folded to a constant. We build it with ONNX ops, mirroring the
# NumPy math above. The per-dimension loop is unrolled at trace time because
# the rank is always static, so no Loop/Scan is emitted.
zero = op.Constant(value_int=0)
one = op.Constant(value_int=1)
# `empty_shape` reshapes a value to a 0-D scalar (shape []).
empty_shape = op.Constant(value=ir.tensor(np.array([], dtype=np.int64)))
# Start the running index from storage_offset as an INT64 scalar; SymInt
# runtime values are assumed to be INT64.
indices = op.Reshape(storage_offset, empty_shape)
for dim in range(rank):
# Reshape this dimension's size and stride to INT64 scalars.
dim_size = op.Reshape(size[dim], empty_shape)
dim_stride = op.Reshape(stride[dim], empty_shape)
Comment thread
justinchuby marked this conversation as resolved.
# add_value = arange(dim_size) * dim_stride, a 1-D tensor of length dim_size
# holding the storage offsets contributed by index 0..dim_size-1 along `dim`.
add_value = op.Mul(op.Range(zero, dim_size, one), dim_stride)
# Insert singleton axes everywhere except `dim` so this 1-D contribution
# broadcasts along dimension `dim` only when added to the running index,
# matching the NumPy `reshape(broadcast_shape)` in the static path.
unsqueeze_axes = [axis for axis in range(rank) if axis != dim]
if unsqueeze_axes:
add_value = op.Unsqueeze(add_value, op.Constant(value_ints=unsqueeze_axes))
indices = op.Add(indices, add_value)

# `indices` now has shape `size`; gathering yields the strided view.
return op.Gather(self_flatten, indices)


def aten_as_strided_copy(
Expand Down
74 changes: 74 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,80 @@ def forward(self, x):
got = onnx_program.call_reference({"x": inputs[0]})
torch.testing.assert_close(expected, got[0])

def test_aten_as_strided_static_multi_dim(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.as_strided(x, (2, 3), (4, 1), 2)

model = Model()
x = torch.arange(24, dtype=torch.float32).reshape(4, 6)
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_aten_as_strided_static_single_dim(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.as_strided(x, (4,), (2,))

model = Model()
x = torch.arange(12, dtype=torch.float32)
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_aten_as_strided_static_overlapping(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.as_strided(x, (3, 3), (1, 1))

model = Model()
x = torch.arange(10, dtype=torch.float32)
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_aten_as_strided_static_scalar(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.as_strided(x, (), (), 3)

model = Model()
x = torch.arange(12, dtype=torch.float32)
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_aten_as_strided_dynamic_size(self):
class Model(torch.nn.Module):
def forward(self, x):
n = x.shape[0] - 1
return torch.as_strided(x, (n, 2), (1, 1))

model = Model()
x = torch.arange(12, dtype=torch.float32)
onnx_program = torch.onnx.export(
model,
(x,),
dynamic_shapes=({0: "length"},),
dynamo=True,
verbose=False,
)
_testing.assert_onnx_program(onnx_program)

def test_aten_as_strided_dynamic_size_with_offset(self):
class Model(torch.nn.Module):
def forward(self, x):
n = x.shape[0] - 2
return torch.as_strided(x, (n,), (1,), 1)

model = Model()
x = torch.arange(12, dtype=torch.float32)
onnx_program = torch.onnx.export(
model,
(x,),
dynamic_shapes=({0: "length"},),
dynamo=True,
verbose=False,
)
_testing.assert_onnx_program(onnx_program)


if __name__ == "__main__":
unittest.main()
Loading