Migrate graph surgeries to onnx_ir + onnxscript rewriter (batch 1: infra + first surgeries)#2550
Migrate graph surgeries to onnx_ir + onnxscript rewriter (batch 1: infra + first surgeries)#2550justinchuby wants to merge 13 commits into
Conversation
Introduce a RewriteRuleSurgeon base class that lets graph surgeries be expressed as onnxscript rewrite rules over the ONNX IR model, instead of manual protobuf / OnnxDAG manipulation. Subclasses implement rules() returning a RewriteRuleSet; the base applies them via call_ir, so the rewriter handles operand commutativity, use-count bookkeeping, and dead-node cleanup. Port the first two pattern-based surgeries to this base: - ReciprocalMulToDiv: a * Reciprocal(x) -> Div(a, x) (commute=True covers both operand orders). - ReplaceErfWithTanh: Erf(x) -> Tanh(x * 605/503), emitting the scale as an initializer of the input's floating-point dtype. This is the first batch of an incremental migration of graph_surgeries.py off the protobuf/OnnxDAG approach; subsequent batches will port the remaining pattern-based surgeries (Gemm<->MatMul+Add, QDQ, RMSNorm variants, decompositions, ...) and move the whole-graph surgeries to plain onnx_ir. Update the ReplaceErfWithTanh test to read the scale via numpy_helper.to_array so it is agnostic to raw_data vs float_data tensor storage. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR is the first batch of an incremental refactor of GraphSurgeries patterns from manual ONNX proto/DAG manipulation to ONNX IR (onnx_ir) plus onnxscript’s pattern rewriter, with a small test adjustment to accommodate IR-emitted initializers.
Changes:
- Introduces a
RewriteRuleSurgeonbase class to implement local graph surgeries asonnxscript.rewriter.patternrewrite rule sets applied on the IR model. - Ports two surgeries to the rewrite-rule approach:
ReplaceErfWithTanhandReciprocalMulToDiv. - Updates the
ReplaceErfWithTanhunit test to read initializer values vianumpy_helper.to_array(works for bothraw_dataandfloat_datastorage).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
olive/passes/onnx/graph_surgeries.py |
Adds rewrite-rule surgeon infrastructure and migrates two surgeries to ONNX IR + onnxscript rewriter. |
test/passes/onnx/test_graph_surgeries.py |
Makes the scale-constant assertion robust to initializer serialization format (raw_data vs float_data). |
| # ir.DataType -> numpy dtype for the emitted scale initializer. | ||
| _DTYPE_MAP: ClassVar[dict] = { | ||
| ir.DataType.FLOAT: np.float32, | ||
| ir.DataType.FLOAT16: np.float16, | ||
| ir.DataType.DOUBLE: np.float64, |
There was a problem hiding this comment.
Restored BFLOAT16 support (via ml_dtypes), so the scale is emitted in the input's floating-point dtype as documented. Verified a bf16 Erf lowers to Mul+Tanh with a bf16 scale. Fixed in 838de67.
Convert RemoveGidxFromMatMulNBits from protobuf iteration to an onnx_ir call_ir implementation: drop a sorted (identity-permutation) g_idx input via resize_inputs and prune the now-unused g_idx initializer. Behavior unchanged. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
- InferShapes: delegate to onnx_ir ShapeInferencePass. - RemoveShapes: clear type/shape on intermediate values (empties value_info). - RemoveInputs: drop named graph inputs and their node references via onnx_ir, removing nodes left with no inputs. Behavior unchanged; verified by existing surgery tests. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Rebuild ZeroOutInput on the onnx_ir API: read the target input's shape/dtype from the IR value, emit a zero Constant, and rewire the node input. Update the test to read the constant via numpy_helper.to_array (IR stores tensors as raw_data). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Reimplement RemoveMemcpy on onnx_ir: bypass 1-in/1-out MemcpyToHost/MemcpyFromHost nodes via Value.replace_all_uses_with (which follows consumers into subgraphs), recurse into Loop/If/Scan subgraphs, preserve public output names on the output boundary, and re-order with TopologicalSortPass. Replaces ~185 lines of manual proto bypass/rename/topo-sort logic with ~40. Behavior verified by the 4 existing RemoveMemcpy tests. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Rewrite ReplaceAttentionMaskValue on onnx_ir: clamp below-threshold entries in float Constant/ConstantOfShape node values and initializers whose consumers are all mask-compatible ops. Behavior unchanged; verified by the existing test. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert RMSNorm, SimplifiedLayerNorm, and Pow/ReduceSum norm graph surgeries to mutate onnx_ir directly while preserving weight scaling, all-ones weights, and ReduceMean opset handling. Full graph surgery tests pass and lint is clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert the shape-dependent MatMul/Add and Gemm rewrites to the onnx_ir Surgeon path while preserving reshape, Relu, and transB handling. Targeted and full graph surgery tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert the GQA RoPE cache, attention-mask sequence length, and quantized-output exposure surgeries to operate through onnx_ir. Behavior is unchanged; graph surgery tests and lint pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert both graph surgeons to implement call_ir using onnx_ir while preserving their quantized initializer creation, shared-weight rewiring, output-name handling, and cleanup behavior. Verified with ad-hoc tiny ONNX models for both surgeries plus the existing graph_surgeries pytest module. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
| # value_info is emitted for intermediate values that carry a type/shape; | ||
| # clearing those on non-output node results empties graph.value_info. | ||
| graph_outputs = set(model.graph.outputs) | ||
| for node in model.graph: |
There was a problem hiding this comment.
use model.graph.all_nodes()
There was a problem hiding this comment.
Done — RemoveShapes now iterates model.graph.all_nodes() so value_info is cleared in subgraphs too, while preserving every graph's declared outputs. Fixed in 838de67.
| scale_initializer = quantized_node.inputs[1] | ||
| if scale_initializer is None or scale_initializer.name not in graph.initializers: | ||
| raise ValueError(f"Scale initializer '{quantized_node.inputs[1].name}' not found.") | ||
| scale_value = scale_initializer.const_value.numpy()[0] | ||
| self._add_scale(graph, scale_value) | ||
|
|
||
| zero_point_initializer = quantized_node.inputs[2] | ||
| if zero_point_initializer is None or zero_point_initializer.name not in graph.initializers: | ||
| raise ValueError(f"Zero point initializer '{quantized_node.inputs[2].name}' not found.") | ||
| zero_point_value = zero_point_initializer.const_value.numpy()[0] | ||
| self._add_zero_point(graph, zero_point_value, zero_point_initializer.dtype) | ||
| return model |
There was a problem hiding this comment.
Guarded against missing/None scale & zero-point inputs (and fewer than 3 inputs) instead of dereferencing .name on None. Fixed in 838de67.
| # remove the Greater -> If Nodes | ||
| if_node_name = dag.get_producer(cache_names["cos_cache"]) | ||
| greater_node_name = dag.get_parents(if_node_name)[0] | ||
| dag.remove_node(if_node_name) | ||
| dag.remove_node(greater_node_name) | ||
| if_node = cache_values["cos_cache"].producer() | ||
| greater_node = next(inp.producer() for inp in if_node.inputs if inp is not None and inp.producer() is not None) | ||
| graph.remove(if_node, safe=True) | ||
| graph.remove(greater_node, safe=True) |
There was a problem hiding this comment.
Now only removes the If-condition producer when it is confirmed to be a Greater node with no remaining consumers, avoiding removing an unrelated node or raising StopIteration. Fixed in 838de67.
| # Ensure com.microsoft opset is declared | ||
| dag.set_opset_import("com.microsoft", 1) | ||
| model.opset_imports[MSFT_DOMAIN] = 1 |
There was a problem hiding this comment.
No longer downgrades an existing com.microsoft opset — bumps it up to at least 1 with max(existing, 1). Fixed in 838de67.
| # Ensure com.microsoft opset is declared | ||
| dag.set_opset_import("com.microsoft", 1) | ||
| model.opset_imports[MSFT_DOMAIN] = 1 |
There was a problem hiding this comment.
Same fix as QuantizeEmbeddingInt8 — opset is bumped up to at least 1 without downgrading. Fixed in 838de67.
| input_ids = next(graph_input for graph_input in graph.inputs if graph_input.name == "input_ids") | ||
| batch_size = input_ids.shape[0] | ||
| seq_len_shapes = {"past_seq_len": [batch_size, 1], "total_seq_len": []} |
There was a problem hiding this comment.
Defaults the batch dim to 1 when input_ids is missing or its shape is unknown (dynamic models). Fixed in 838de67.
Add regression coverage for two previously untested migrated surgeries: - PowReduceSumPowDiv2LpNorm: Pow(2)->ReduceSum->Pow(0.5)->Div collapses to LpNormalization. - QuantizeEmbeddingInt8: an embed_tokens Gather over an FP16 weight becomes an INT8 GatherBlockQuantized with a uint8 quantized table. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
- RemoveShapes: iterate model.graph.all_nodes() so value_info is cleared in subgraphs too, preserving every graph's declared outputs. - ReplaceErfWithTanh: restore BFLOAT16 support (via ml_dtypes) so the scale is emitted in the input's floating-point dtype as documented. - ExposeQuantizedOutput: guard against missing/None scale/zero-point inputs instead of dereferencing .name on None (and assuming >=3 inputs). - RemoveRopeMultiCache: only remove the If-condition producer when it is a Greater node with no remaining consumers (avoid removing an unrelated node or raising StopIteration). - QuantizeEmbeddingInt8 / ShareEmbeddingLmHead: do not downgrade an existing com.microsoft opset version; bump up to at least 1. - AttentionMaskToSequenceLengths: default batch dim to 1 when input_ids is missing or its shape is unknown (dynamic models). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Describe your changes
First batch of an incremental migration of
graph_surgeries.pyoff theprotobuf /
OnnxDAGapproach onto the ONNX IR (onnx_ir) +onnxscriptrewriter.
Infrastructure
RewriteRuleSurgeon(Surgeon)base class. Subclasses implementrules()returning an
onnxscript.rewriter.pattern.RewriteRuleSet; the base applies itto the IR model via
call_ir. This lets local subgraph pattern replacements beexpressed declaratively, and the rewriter handles operand commutativity,
use-count bookkeeping, and dead-node cleanup for us.
First surgeries ported
ReciprocalMulToDiv:a * Reciprocal(x)→Div(a, x)(commute=Truecoversboth
Muloperand orders; a sharedReciprocalis preserved automatically).ReplaceErfWithTanh:Erf(x)→Tanh(x * 605/503), emitting the scale as aninitializer of the input's floating-point dtype; non-float inputs are skipped.
This trims ~120 lines of manual proto walking. Subsequent batches will port the
remaining pattern-based surgeries (
Gemm↔MatMul+Add, QDQ passes, RMSNormvariants, decompositions, ...) and move the whole-graph surgeries (rename/expose
I/O,
Non4D*, dedup,TieWordEmbeddings, ...) to plainonnx_ir.Test change
test_replace_erf_with_tanhnow reads the scale initializer vianumpy_helper.to_arrayinstead of.float_data, so it is agnostic to whetherthe tensor is stored as
raw_dataorfloat_data(IR emitsraw_data).No behavior change for users; the two ported surgeries produce equivalent graphs.
Checklist before requesting a review
test_graph_surgeries.py: 82 passed, 2 skipped)lintrunner -a(Optional) Issue link