diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index d1cdf7c5dd..996b194036 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -712,10 +712,21 @@ def _apply_to_graph_or_function( initializers = graph_or_function.initializers for initializer in delta.new_initializers: if initializer.name in initializers: + existing = initializers[initializer.name] + if existing is initializer: + # Same Value already registered, skip. + continue + # Name conflict with a different Value: generate a unique name. + base_name = initializer.name + counter = 1 + while f"{base_name}_{counter}" in initializers: + counter += 1 if verbose: - print(f"Initializer {initializer.name} already exists.") - continue - for initializer in delta.new_initializers: + print( + f"Initializer '{initializer.name}' already exists. " + f"Renaming to '{base_name}_{counter}'." + ) + initializer.name = f"{base_name}_{counter}" initializers[initializer.name] = initializer # type: ignore[index] # TODO: This does not yet handle the problem of determining the correct insertion point # for inserted nodes in the case of patterns with multiple output-nodes. The following diff --git a/onnxscript/rewriter/rules/common/_basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py index 67ebdaa495..8476f64715 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -554,6 +554,48 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg): self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) self.assertRegex(tracer_match.match_result.reason, error_msg) + def test_reshape_reshape_shared_shape_no_clobbering(self): + """Two Reshape->Reshape chains sharing the same shape with -1 should not clobber initializers.""" + import onnx_ir + + model = onnx_ir.from_onnx_text( + """\ + +test (float[2, 3] input1, float[2, 6] input2) => (float[2, 3] out1, float[2, 6] out2) { + mid1 = Reshape (input1, shape_mid_a) + mid2 = Reshape (input2, shape_mid_b) + out1 = Reshape (mid1, shared_shape) + out2 = Reshape (mid2, shared_shape) +}""", + initializers=[ + ir.Tensor(np.array([6], dtype=np.int64), name="shape_mid_a"), + ir.Tensor(np.array([12], dtype=np.int64), name="shape_mid_b"), + ir.Tensor(np.array([2, -1], dtype=np.int64), name="shared_shape"), + ], + ) + + count = _basic_rules.reshape_reshape_rule.apply_to_model(model) + self.assertEqual(count, 2) + + # All Reshape shape inputs must be registered initializers + for node in model.graph: + if node.op_type == "Reshape": + shape_val = node.inputs[1] + self.assertIn( + shape_val.name, + model.graph.initializers, + f"Shape '{shape_val.name}' not registered as initializer", + ) + self.assertIs( + model.graph.initializers[shape_val.name], + shape_val, + f"Shape '{shape_val.name}' maps to a different Value object", + ) + + # Model must serialize and pass onnx checker + proto = onnx_ir.to_proto(model) + onnx.checker.check_model(proto, full_check=True) + class Flatten2ReshapeTest(unittest.TestCase): @staticmethod