From 2bf6dbe47d999dc5be90c322d6ce02ec52fbd161 Mon Sep 17 00:00:00 2001 From: opluss Date: Mon, 22 Jun 2026 18:09:56 +0800 Subject: [PATCH] fix: resolve initializer name clobbering in rewrite rule registration When multiple rewrite matches produce initializers with the same name, the second registration would silently overwrite the first in the graph.initializers dict. The original Value became an unregistered dangling reference, causing toposort validation failure on serialization. The root cause was two independent for-loops: the first checked for duplicates but its `continue` only affected itself (dead code), while the second unconditionally overwrote all entries. Fix: merge into a single loop that detects name conflicts and auto-generates a unique suffix (_1, _2, ...) before registering. Signed-off-by: opluss --- onnxscript/rewriter/_rewrite_rule.py | 17 ++++++-- .../rules/common/_basic_rules_test.py | 42 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index d1cdf7c5dd..996b194036 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -712,10 +712,21 @@ def _apply_to_graph_or_function( initializers = graph_or_function.initializers for initializer in delta.new_initializers: if initializer.name in initializers: + existing = initializers[initializer.name] + if existing is initializer: + # Same Value already registered, skip. + continue + # Name conflict with a different Value: generate a unique name. + base_name = initializer.name + counter = 1 + while f"{base_name}_{counter}" in initializers: + counter += 1 if verbose: - print(f"Initializer {initializer.name} already exists.") - continue - for initializer in delta.new_initializers: + print( + f"Initializer '{initializer.name}' already exists. " + f"Renaming to '{base_name}_{counter}'." + ) + initializer.name = f"{base_name}_{counter}" initializers[initializer.name] = initializer # type: ignore[index] # TODO: This does not yet handle the problem of determining the correct insertion point # for inserted nodes in the case of patterns with multiple output-nodes. The following diff --git a/onnxscript/rewriter/rules/common/_basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py index 67ebdaa495..8476f64715 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -554,6 +554,48 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg): self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) self.assertRegex(tracer_match.match_result.reason, error_msg) + def test_reshape_reshape_shared_shape_no_clobbering(self): + """Two Reshape->Reshape chains sharing the same shape with -1 should not clobber initializers.""" + import onnx_ir + + model = onnx_ir.from_onnx_text( + """\ + +test (float[2, 3] input1, float[2, 6] input2) => (float[2, 3] out1, float[2, 6] out2) { + mid1 = Reshape (input1, shape_mid_a) + mid2 = Reshape (input2, shape_mid_b) + out1 = Reshape (mid1, shared_shape) + out2 = Reshape (mid2, shared_shape) +}""", + initializers=[ + ir.Tensor(np.array([6], dtype=np.int64), name="shape_mid_a"), + ir.Tensor(np.array([12], dtype=np.int64), name="shape_mid_b"), + ir.Tensor(np.array([2, -1], dtype=np.int64), name="shared_shape"), + ], + ) + + count = _basic_rules.reshape_reshape_rule.apply_to_model(model) + self.assertEqual(count, 2) + + # All Reshape shape inputs must be registered initializers + for node in model.graph: + if node.op_type == "Reshape": + shape_val = node.inputs[1] + self.assertIn( + shape_val.name, + model.graph.initializers, + f"Shape '{shape_val.name}' not registered as initializer", + ) + self.assertIs( + model.graph.initializers[shape_val.name], + shape_val, + f"Shape '{shape_val.name}' maps to a different Value object", + ) + + # Model must serialize and pass onnx checker + proto = onnx_ir.to_proto(model) + onnx.checker.check_model(proto, full_check=True) + class Flatten2ReshapeTest(unittest.TestCase): @staticmethod