Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 192 additions & 3 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
import os
import platform
import warnings
from typing import Any, Collection, List, Optional, Sequence, Union
from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple, Union

import sympy
import torch
from torch.export import ExportedProgram
from torch.export.graph_signature import InputKind
from torch.fx.node import Target
from torch.utils._sympy.numbers import int_oo
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt._features import ENABLED_FEATURES, needs_cross_compile
Expand Down Expand Up @@ -407,6 +410,7 @@ def cross_compile_for_windows(
trt_arg_inputs,
trt_kwarg_inputs,
settings,
graph_signature=exported_program.graph_signature,
)
return trt_gm

Expand Down Expand Up @@ -793,7 +797,12 @@ def compile(
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
)
trt_gm = compile_module(
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
gm,
trt_arg_inputs,
trt_kwarg_inputs,
settings,
engine_cache,
graph_signature=exported_program.graph_signature,
)
return trt_gm

Expand Down Expand Up @@ -898,6 +907,170 @@ def _insert_complex_io_adapters(
partitioned_module.recompile()


def _build_user_symbol_bounds(
gm: torch.fx.GraphModule,
sample_arg_inputs: Sequence[Input],
sample_kwarg_inputs: dict[Any, Any],
graph_signature: torch.export.graph_signature.ExportGraphSignature,
) -> Dict[sympy.Symbol, Tuple[int, int]]:
"""Map ``sympy.Symbol -> (min, max)`` from dynamic ``Input``s, used to
fill ``Dim.DYNAMIC`` upper bounds without mutating ``ShapeEnv``.

Validates against finite exporter bounds: ``user_max > exp_max`` and
``user_min < exp_min`` raise (TRT would reject those shapes at runtime);
a strict subset narrows the engine profile to the user's bounds (info
log only); the ``user_min=1, exp_min=2`` case warns -- it's PyTorch's
0/1 specialization artifact, not a user error.
"""
all_placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]
placeholders = [
p
for p, s in zip(all_placeholders, graph_signature.input_specs)
if s.kind == InputKind.USER_INPUT
]

in_spec = getattr(gm, "_in_spec", None)
assert in_spec is not None, "Exported graph module missing _in_spec"
flat_inputs = in_spec.flatten_up_to((list(sample_arg_inputs), sample_kwarg_inputs))

user_symbol_bounds: Dict[sympy.Symbol, Tuple[int, int]] = {}

for node, inp in zip(placeholders, flat_inputs):
if not (isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC):
continue
fake_val = node.meta.get("val")
if not isinstance(fake_val, torch.Tensor):
continue

min_shape = inp.shape["min_shape"]
max_shape = inp.shape["max_shape"]

if len(fake_val.size()) != len(min_shape):
raise ValueError(
f"Input '{node.target}' has {len(fake_val.size())} dimensions in "
f"the exported program, but the provided Input specifies "
f"{len(min_shape)} dimensions. Ensure Input.min_shape, "
f"Input.opt_shape, and Input.max_shape each have "
f"{len(fake_val.size())} entries."
)

for d, dim in enumerate(fake_val.size()):
if not isinstance(dim, torch.SymInt):
if min_shape[d] != dim or max_shape[d] != dim:
raise ValueError(
f"Input '{node.target}' dim {d} is static (size={int(dim)}) "
f"in the exported program, but the provided Input has "
f"min_shape[{d}]={min_shape[d]}, max_shape[{d}]={max_shape[d]}. "
f"Static dimensions must be fixed."
)
continue
expr = dim.node.expr
# Composite exprs (e.g. ``2*s0``) are recomputed by
# ``ShapeEnv.bound_sympy``; overriding them directly would lie.
if not isinstance(expr, sympy.Symbol):
logger.debug(
"Input '%s' dim %d is a composite symbolic expression (%s) "
"bounded by another dynamic dimension; its range will be "
"derived from constituent symbols via bound_sympy.",
node.target,
d,
expr,
)
continue
if expr in user_symbol_bounds:
continue
user_min = int(min_shape[d])
user_max = int(max_shape[d])
user_symbol_bounds[expr] = (user_min, user_max)
logger.debug(
"Recorded user-supplied bounds for %s: [%d, %d]",
expr,
user_min,
user_max,
)

# The exported program may already bound this symbol to a finite
# range (e.g. Dim("batch", min=10, max=20)). The compiled TRT
# engine's optimization profile follows that range; any shape
# outside it is rejected by TensorRT at runtime
# (IExecutionContext::setInputShape "satisfyProfile" check).
# Validate the user's Input range against it here -- at compile
# time -- before they hit that opaque runtime error on a shape
# they explicitly declared in Input.min_shape / Input.max_shape.
shape_env = getattr(dim.node, "shape_env", None)
if shape_env is None:
continue
exp_range = shape_env.var_to_range.get(expr)
if exp_range is None:
continue
exp_lower = exp_range.lower
exp_upper = exp_range.upper
exp_max_unbounded = exp_upper is int_oo or exp_upper == sympy.oo
if exp_max_unbounded:
# Dim.DYNAMIC: user fills the gap (intended use).
continue
try:
exp_min = int(exp_lower)
exp_max = int(exp_upper)
except (TypeError, ValueError):
continue
if user_min == exp_min and user_max == exp_max:
continue

mismatch = (
f"Dynamic dimension '{expr}': "
f"Input range [{user_min}, {user_max}] vs "
f"exported program range [{exp_min}, {exp_max}]."
)

if user_max > exp_max:
raise ValueError(
f"{mismatch} Input.max_shape ({user_max}) exceeds the "
f"exported program's max ({exp_max}). The program was "
f"exported with this dimension bounded to "
f"[{exp_min}, {exp_max}], so the compiled TensorRT engine "
f"cannot accept shapes above {exp_max}. Either re-export "
f"with Dim('{expr}', max={user_max}) or set "
f"Input.max_shape <= {exp_max}."
)

if user_min < exp_min:
# 1->2 is the 0/1 specialization artifact, not a user error.
if user_min == 1 and exp_min == 2:
logger.warning(
"%s Input.min_shape=1 but the exported program's min "
"is 2 (PyTorch 0/1 specialization -- Dim(min=1) is "
"recorded as min=2). The compiled engine's min will "
"be 2.",
mismatch,
)
continue
raise ValueError(
f"{mismatch} Input.min_shape ({user_min}) is below the "
f"exported program's min ({exp_min}). The program was "
f"exported with this dimension bounded to "
f"[{exp_min}, {exp_max}], so the compiled TensorRT engine "
f"cannot accept shapes below {exp_min}. Either re-export "
f"with Dim('{expr}', min={user_min}) or set "
f"Input.min_shape >= {exp_min}."
)

# Strict subset: engine profile narrows to the user's bounds
# (applied in ``extract_var_range_info``). Not a warning -- the
# user got exactly what they asked for.
logger.info(
"%s Narrowing engine profile to user bounds [%d, %d] "
"(exported program range was [%d, %d]).",
mismatch,
user_min,
user_max,
exp_min,
exp_max,
)

return user_symbol_bounds


@fn_supports_debugger # type: ignore[misc]
def compile_module(
gm: torch.fx.GraphModule,
Expand All @@ -906,6 +1079,9 @@ def compile_module(
settings: CompilationSettings = CompilationSettings(),
engine_cache: Optional[BaseEngineCache] = None,
*,
graph_signature: Optional[
torch.export.graph_signature.ExportGraphSignature
] = None,
_debugger_config: Optional[DebuggerConfig] = None,
) -> torch.fx.GraphModule:
"""Compile a traced FX module
Expand All @@ -929,6 +1105,17 @@ def compile_module(
if sample_kwarg_inputs is None:
sample_kwarg_inputs = {}

# Forwarded to the partitioner to fill Dim.DYNAMIC upper bounds.
# Read-only w.r.t. ShapeEnv so range_constraints survive save/re-export.
# graph_signature is None on the torch.compile path, which has no ExportedProgram.
user_symbol_bounds = (
_build_user_symbol_bounds(
gm, sample_arg_inputs, sample_kwarg_inputs, graph_signature
)
if graph_signature is not None
else {}
)

# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(settings)

Expand Down Expand Up @@ -1110,7 +1297,9 @@ def preserve_module_specs(
)

# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.construct_submodule_inputs(submodule)
submodule_inputs = partitioning.construct_submodule_inputs(
submodule, user_symbol_bounds=user_symbol_bounds
)

assert submodule_inputs is not None

Expand Down
38 changes: 32 additions & 6 deletions py/torch_tensorrt/dynamo/partitioning/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
from typing import Any, Dict, Optional, Sequence, Set, Tuple

import sympy
import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily

from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.utils import (
COMPLEX_TO_REAL_DTYPE,
Expand All @@ -20,11 +20,14 @@ def construct_dynamic_input(
input_dtype: torch.dtype,
name: str = "",
is_shape_tensor: bool = False,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Input:
"""
Constructs a torch_tensorrt.Input based on a symbolic input
Args:
input_shape: A symbolic shape / regular shape of a tensor (which can have a mix of SymInt nodes and static values)
user_symbol_bounds: Optional ``{sym: (min, max)}`` map; forwarded to
:func:`extract_var_range_info` to fill unbounded exporter uppers.
Returns:
A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input.
"""
Expand All @@ -33,7 +36,9 @@ def construct_dynamic_input(
max_shape = []
for d, dim in enumerate(input_shape):
if isinstance(dim, torch.SymInt):
min_max_opt = extract_var_range_info(dim)
min_max_opt = extract_var_range_info(
dim, user_symbol_bounds=user_symbol_bounds
)
unwrapped_min_max_opt: Dict[str, int] = {}
if "min" not in min_max_opt or min_max_opt["min"] is None:
logger.warning(
Expand Down Expand Up @@ -85,9 +90,12 @@ def get_input(
dtype: torch.dtype,
name: str = "",
is_shape_tensor: bool = False,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Input:
"""
Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs
Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs.

``user_symbol_bounds`` is forwarded to :func:`construct_dynamic_input`.
"""
if dtype in COMPLEX_TO_REAL_DTYPE:
real_dtype = COMPLEX_TO_REAL_DTYPE[dtype]
Expand All @@ -106,19 +114,25 @@ def get_input(
dtype,
name=name,
is_shape_tensor=is_shape_tensor,
user_symbol_bounds=user_symbol_bounds,
)
else:
return Input(
shape=input_shape, dtype=dtype, name=name, is_shape_tensor=is_shape_tensor
)


def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
def construct_submodule_inputs(
module: torch.fx.GraphModule,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Sequence[Input]:
"""
Construct torch_tensorrt Inputs based on the module inputs.
The module inputs will have meta data which has the shape and dtype info
Args:
module: Input FX GraphModule
user_symbol_bounds: Optional ``{sym: (min, max)}`` map; forwarded to
:func:`get_input` to fill unbounded exporter uppers.
Returns:
Sequence of torch_tensorrt.Input's representing inputs to given module
"""
Expand All @@ -134,7 +148,12 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
if isinstance(input_meta, (FakeTensor, torch.Tensor)):
input_shape = input_meta.size()
torchtrt_inputs.append(
get_input(input_shape, input_meta.dtype, name=input.name)
get_input(
input_shape,
input_meta.dtype,
name=input.name,
user_symbol_bounds=user_symbol_bounds,
)
)
elif isinstance(input_meta, torch.SymInt):
# Assuming sym_integers | shape inputs always have torch.int64 dtype
Expand All @@ -144,6 +163,7 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
torch.int64,
name=input.name,
is_shape_tensor=True,
user_symbol_bounds=user_symbol_bounds,
)
)
elif isinstance(input_meta, torch.SymFloat):
Expand All @@ -153,6 +173,7 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
torch.float32,
name=input.name,
is_shape_tensor=False, # Only SymInt inputs are treated as shape tensors
user_symbol_bounds=user_symbol_bounds,
)
)
else:
Expand All @@ -164,7 +185,12 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
input_meta = input.meta["tensor_meta"]
input_shape = input_meta.shape
torchtrt_inputs.append(
get_input(input_shape, input_meta.dtype, name=input.name)
get_input(
input_shape,
input_meta.dtype,
name=input.name,
user_symbol_bounds=user_symbol_bounds,
)
)
else:
raise AssertionError(
Expand Down
Loading
Loading