From d33ee66d85d31c78844f5cbf23d6541e91f5ffb6 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 4 Jun 2026 13:44:34 +0100 Subject: [PATCH] compiler: Fix DefFunction (re)construction --- devito/symbolics/extended_sympy.py | 9 ++++++++- tests/test_symbolics.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 4523fbd2f4..41e6ed7547 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -9,6 +9,7 @@ import sympy from sympy import Expr, Function, Number, Tuple, cacheit, sympify from sympy.core.decorators import call_highest_priority +from sympy.core.function import Application from sympy.logic.boolalg import BooleanFunction from devito.finite_differences.elementary import Max, Min @@ -718,7 +719,13 @@ def __new__(cls, name, arguments=None, template=None, **kwargs): if _template: args.append(Tuple(*_template)) - obj = Function.__new__(cls, *args) + # `Function.__new__` and `Application.__new__` are both cached by + # SymPy. DefFunction subclasses may attach reconstruction kwargs as + # side attributes after this base constructor returns; going through + # the cached route could then alias a previous object and mutate it + # during reconstruction. Call Application's uncached constructor + # explicitly instead of using super()/Function.__new__. + obj = Application.__new__.__wrapped__(cls, *args) obj._name = name obj._arguments = tuple(_arguments) obj._template = tuple(_template) diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 48709d7cd1..5af78117a2 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -934,6 +934,36 @@ def __new__(cls, name=None, arguments=None, p0=None, p1=None, p2=None): assert func1.p1 == (g,) assert func1.p2 == 'bar' + def test_custom_def_function_reconstruction_no_aliasing(self): + + class MyDefFunction(DefFunction): + __rargs__ = ('name', 'arguments') + __rkwargs__ = ('p0',) + + def __new__(cls, name=None, arguments=None, p0=None): + obj = super().__new__(cls, name=name, arguments=arguments) + obj.p0 = p0 + return obj + + def _hashable_content(self): + return super()._hashable_content() + (self.p0,) + + grid = Grid(shape=(4, 4)) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + + func0 = MyDefFunction(name='foo', arguments=f.indexify(), p0=f) + h0 = hash(func0) + + func1 = func0.func(p0=g) + + assert func1 is not func0 + assert func1 != func0 + assert hash(func0) == h0 + assert func0.p0 is f + assert func1.p0 is g + def test_reduce_to_number(self): grid = Grid(shape=(4, 4)) x, _ = grid.dimensions