Skip to content

Migrate graph surgeries to onnx_ir + onnxscript rewriter (batch 1: infra + first surgeries)#2550

Open
justinchuby wants to merge 13 commits into
mainfrom
justinchu/graph-surgeries-ir-rewriter
Open

Migrate graph surgeries to onnx_ir + onnxscript rewriter (batch 1: infra + first surgeries)#2550
justinchuby wants to merge 13 commits into
mainfrom
justinchu/graph-surgeries-ir-rewriter

Conversation

@justinchuby

Copy link
Copy Markdown
Contributor

Describe your changes

First batch of an incremental migration of graph_surgeries.py off the
protobuf / OnnxDAG approach onto the ONNX IR (onnx_ir) + onnxscript
rewriter.

Infrastructure

  • Add a RewriteRuleSurgeon(Surgeon) base class. Subclasses implement rules()
    returning an onnxscript.rewriter.pattern.RewriteRuleSet; the base applies it
    to the IR model via call_ir. This lets local subgraph pattern replacements be
    expressed declaratively, and the rewriter handles operand commutativity,
    use-count bookkeeping, and dead-node cleanup for us.

First surgeries ported

  • ReciprocalMulToDiv: a * Reciprocal(x)Div(a, x) (commute=True covers
    both Mul operand orders; a shared Reciprocal is preserved automatically).
  • ReplaceErfWithTanh: Erf(x)Tanh(x * 605/503), emitting the scale as an
    initializer of the input's floating-point dtype; non-float inputs are skipped.

This trims ~120 lines of manual proto walking. Subsequent batches will port the
remaining pattern-based surgeries (GemmMatMul+Add, QDQ passes, RMSNorm
variants, decompositions, ...) and move the whole-graph surgeries (rename/expose
I/O, Non4D*, dedup, TieWordEmbeddings, ...) to plain onnx_ir.

Test change

  • test_replace_erf_with_tanh now reads the scale initializer via
    numpy_helper.to_array instead of .float_data, so it is agnostic to whether
    the tensor is stored as raw_data or float_data (IR emits raw_data).

No behavior change for users; the two ported surgeries produce equivalent graphs.

Checklist before requesting a review

  • Add unit tests for this change. (existing surgery tests cover both; all pass)
  • Make sure all tests can pass. (test_graph_surgeries.py: 82 passed, 2 skipped)
  • Update documents if necessary.
  • Lint and apply fixes to your code by running lintrunner -a
  • Is this a user-facing change? No — internal refactor, equivalent output.

(Optional) Issue link

Introduce a RewriteRuleSurgeon base class that lets graph surgeries be expressed
as onnxscript rewrite rules over the ONNX IR model, instead of manual protobuf /
OnnxDAG manipulation. Subclasses implement rules() returning a RewriteRuleSet;
the base applies them via call_ir, so the rewriter handles operand commutativity,
use-count bookkeeping, and dead-node cleanup.

Port the first two pattern-based surgeries to this base:
- ReciprocalMulToDiv: a * Reciprocal(x) -> Div(a, x) (commute=True covers both
  operand orders).
- ReplaceErfWithTanh: Erf(x) -> Tanh(x * 605/503), emitting the scale as an
  initializer of the input's floating-point dtype.

This is the first batch of an incremental migration of graph_surgeries.py off the
protobuf/OnnxDAG approach; subsequent batches will port the remaining pattern-based
surgeries (Gemm<->MatMul+Add, QDQ, RMSNorm variants, decompositions, ...) and move
the whole-graph surgeries to plain onnx_ir.

Update the ReplaceErfWithTanh test to read the scale via numpy_helper.to_array so
it is agnostic to raw_data vs float_data tensor storage.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Copilot AI review requested due to automatic review settings July 1, 2026 18:01

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR is the first batch of an incremental refactor of GraphSurgeries patterns from manual ONNX proto/DAG manipulation to ONNX IR (onnx_ir) plus onnxscript’s pattern rewriter, with a small test adjustment to accommodate IR-emitted initializers.

Changes:

  • Introduces a RewriteRuleSurgeon base class to implement local graph surgeries as onnxscript.rewriter.pattern rewrite rule sets applied on the IR model.
  • Ports two surgeries to the rewrite-rule approach: ReplaceErfWithTanh and ReciprocalMulToDiv.
  • Updates the ReplaceErfWithTanh unit test to read initializer values via numpy_helper.to_array (works for both raw_data and float_data storage).

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
olive/passes/onnx/graph_surgeries.py Adds rewrite-rule surgeon infrastructure and migrates two surgeries to ONNX IR + onnxscript rewriter.
test/passes/onnx/test_graph_surgeries.py Makes the scale-constant assertion robust to initializer serialization format (raw_data vs float_data).

Comment on lines +266 to +270
# ir.DataType -> numpy dtype for the emitted scale initializer.
_DTYPE_MAP: ClassVar[dict] = {
ir.DataType.FLOAT: np.float32,
ir.DataType.FLOAT16: np.float16,
ir.DataType.DOUBLE: np.float64,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restored BFLOAT16 support (via ml_dtypes), so the scale is emitted in the input's floating-point dtype as documented. Verified a bf16 Erf lowers to Mul+Tanh with a bf16 scale. Fixed in 838de67.

justinchuby and others added 9 commits July 1, 2026 19:14
Convert RemoveGidxFromMatMulNBits from protobuf iteration to an onnx_ir call_ir
implementation: drop a sorted (identity-permutation) g_idx input via
resize_inputs and prune the now-unused g_idx initializer. Behavior unchanged.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
- InferShapes: delegate to onnx_ir ShapeInferencePass.
- RemoveShapes: clear type/shape on intermediate values (empties value_info).
- RemoveInputs: drop named graph inputs and their node references via onnx_ir,
  removing nodes left with no inputs.

Behavior unchanged; verified by existing surgery tests.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Rebuild ZeroOutInput on the onnx_ir API: read the target input's shape/dtype from
the IR value, emit a zero Constant, and rewire the node input. Update the test to
read the constant via numpy_helper.to_array (IR stores tensors as raw_data).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Reimplement RemoveMemcpy on onnx_ir: bypass 1-in/1-out MemcpyToHost/MemcpyFromHost
nodes via Value.replace_all_uses_with (which follows consumers into subgraphs),
recurse into Loop/If/Scan subgraphs, preserve public output names on the output
boundary, and re-order with TopologicalSortPass. Replaces ~185 lines of manual
proto bypass/rename/topo-sort logic with ~40. Behavior verified by the 4 existing
RemoveMemcpy tests.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Rewrite ReplaceAttentionMaskValue on onnx_ir: clamp below-threshold entries in
float Constant/ConstantOfShape node values and initializers whose consumers are
all mask-compatible ops. Behavior unchanged; verified by the existing test.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert RMSNorm, SimplifiedLayerNorm, and Pow/ReduceSum norm graph surgeries to mutate onnx_ir directly while preserving weight scaling, all-ones weights, and ReduceMean opset handling. Full graph surgery tests pass and lint is clean.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert the shape-dependent MatMul/Add and Gemm rewrites to the onnx_ir Surgeon path while preserving reshape, Relu, and transB handling. Targeted and full graph surgery tests pass.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert the GQA RoPE cache, attention-mask sequence length,
and quantized-output exposure surgeries to operate through onnx_ir.
Behavior is unchanged; graph surgery tests and lint pass.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Convert both graph surgeons to implement call_ir using onnx_ir while preserving their quantized initializer creation, shared-weight rewiring, output-name handling, and cleanup behavior. Verified with ad-hoc tiny ONNX models for both surgeries plus the existing graph_surgeries pytest module.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Comment thread olive/passes/onnx/graph_surgeries.py Outdated
# value_info is emitted for intermediate values that carry a type/shape;
# clearing those on non-output node results empties graph.value_info.
graph_outputs = set(model.graph.outputs)
for node in model.graph:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use model.graph.all_nodes()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — RemoveShapes now iterates model.graph.all_nodes() so value_info is cleared in subgraphs too, while preserving every graph's declared outputs. Fixed in 838de67.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 1 out of 2 changed files in this pull request and generated 5 comments.

Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment on lines +447 to +458
scale_initializer = quantized_node.inputs[1]
if scale_initializer is None or scale_initializer.name not in graph.initializers:
raise ValueError(f"Scale initializer '{quantized_node.inputs[1].name}' not found.")
scale_value = scale_initializer.const_value.numpy()[0]
self._add_scale(graph, scale_value)

zero_point_initializer = quantized_node.inputs[2]
if zero_point_initializer is None or zero_point_initializer.name not in graph.initializers:
raise ValueError(f"Zero point initializer '{quantized_node.inputs[2].name}' not found.")
zero_point_value = zero_point_initializer.const_value.numpy()[0]
self._add_zero_point(graph, zero_point_value, zero_point_initializer.dtype)
return model

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guarded against missing/None scale & zero-point inputs (and fewer than 3 inputs) instead of dereferencing .name on None. Fixed in 838de67.

Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment on lines +1559 to +1563
# remove the Greater -> If Nodes
if_node_name = dag.get_producer(cache_names["cos_cache"])
greater_node_name = dag.get_parents(if_node_name)[0]
dag.remove_node(if_node_name)
dag.remove_node(greater_node_name)
if_node = cache_values["cos_cache"].producer()
greater_node = next(inp.producer() for inp in if_node.inputs if inp is not None and inp.producer() is not None)
graph.remove(if_node, safe=True)
graph.remove(greater_node, safe=True)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now only removes the If-condition producer when it is confirmed to be a Greater node with no remaining consumers, avoiding removing an unrelated node or raising StopIteration. Fixed in 838de67.

Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment on lines +2431 to +2432
# Ensure com.microsoft opset is declared
dag.set_opset_import("com.microsoft", 1)
model.opset_imports[MSFT_DOMAIN] = 1

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer downgrades an existing com.microsoft opset — bumps it up to at least 1 with max(existing, 1). Fixed in 838de67.

Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment on lines +2561 to +2562
# Ensure com.microsoft opset is declared
dag.set_opset_import("com.microsoft", 1)
model.opset_imports[MSFT_DOMAIN] = 1

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same fix as QuantizeEmbeddingInt8 — opset is bumped up to at least 1 without downgrading. Fixed in 838de67.

Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment on lines 1615 to 1617
input_ids = next(graph_input for graph_input in graph.inputs if graph_input.name == "input_ids")
batch_size = input_ids.shape[0]
seq_len_shapes = {"past_seq_len": [batch_size, 1], "total_seq_len": []}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defaults the batch dim to 1 when input_ids is missing or its shape is unknown (dynamic models). Fixed in 838de67.

justinchuby and others added 3 commits July 1, 2026 21:18
Add regression coverage for two previously untested migrated surgeries:
- PowReduceSumPowDiv2LpNorm: Pow(2)->ReduceSum->Pow(0.5)->Div collapses to LpNormalization.
- QuantizeEmbeddingInt8: an embed_tokens Gather over an FP16 weight becomes an INT8
  GatherBlockQuantized with a uint8 quantized table.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
- RemoveShapes: iterate model.graph.all_nodes() so value_info is cleared in
  subgraphs too, preserving every graph's declared outputs.
- ReplaceErfWithTanh: restore BFLOAT16 support (via ml_dtypes) so the scale is
  emitted in the input's floating-point dtype as documented.
- ExposeQuantizedOutput: guard against missing/None scale/zero-point inputs
  instead of dereferencing .name on None (and assuming >=3 inputs).
- RemoveRopeMultiCache: only remove the If-condition producer when it is a
  Greater node with no remaining consumers (avoid removing an unrelated node
  or raising StopIteration).
- QuantizeEmbeddingInt8 / ShareEmbeddingLmHead: do not downgrade an existing
  com.microsoft opset version; bump up to at least 1.
- AttentionMaskToSequenceLengths: default batch dim to 1 when input_ids is
  missing or its shape is unknown (dynamic models).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants