diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index 3f87496c62..8ef479955e 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -415,7 +415,7 @@ def _print_DefFunction(self, expr): return f"{expr.name}{template}({args})" def _print_SizeOf(self, expr): - return f'sizeof({self._print(expr.intype)}{self._print(expr.stars)})' + return f'sizeof({self._print(expr.ctype)}{self._print(expr.stars)})' def _print_MathFunction(self, expr): return f"{self.ns(expr)}{self._print_DefFunction(expr)}" diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 41e6ed7547..86979c8bbd 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -980,7 +980,6 @@ class SizeOf(DefFunction): __rargs__ = ('intype', 'stars') def __new__(cls, intype, stars=None, **kwargs): - stars = stars or '' if not isinstance(intype, (str, ReservedWord)): ctype = dtype_to_ctype(intype) for k, v in ctypes_vector_mapper.items(): @@ -990,15 +989,40 @@ def __new__(cls, intype, stars=None, **kwargs): else: intype = ctypes_to_cstr(ctype) - newobj = super().__new__(cls, 'sizeof', arguments=f'{intype}{stars}', **kwargs) - newobj.stars = stars - newobj.intype = intype + stars = stars or '' + if not all(c == '*' for c in str(stars)): + raise ValueError("`stars` must be a string of zero or more `*` characters") + + if not isinstance(intype, (str, ReservedWord)): + intype = ctypes_vector_mapper[intype].__name__ - return newobj + return super().__new__(cls, 'sizeof', arguments=(intype, stars), **kwargs) @property def args(self): - return super().args[1] + return self._arguments + + @property + def intype(self): + return self.arguments[0] + + @cached_property + def ctype(self): + for v in ctypes_vector_mapper.values(): + if str(self.intype) == v.__name__: + return v + return self.intype + + @property + def stars(self): + return self.arguments[1] + + def __str__(self): + try: + intype = ctypes_to_cstr(self.ctype) + except TypeError: + intype = str(self.ctype) + return f"sizeof({intype}{self.stars})" def rfunc(func, item, *args): diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 5af78117a2..aab8502a20 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -1,4 +1,4 @@ -from ctypes import c_void_p +from ctypes import c_uint64, c_void_p import numpy as np import pytest @@ -1262,6 +1262,30 @@ def test_print_div(): assert cstr == 'sizeof(int)/sizeof(long)' +def test_sizeof(): + sizeof_ctype = SizeOf(c_uint64) + str_pointer0 = SizeOf('float', stars='*') + str_pointer1 = SizeOf('float', '*') + str_simple = SizeOf('int') + complex_size = SizeOf(np.complex64) + + assert sizeof_ctype.arguments == (ReservedWord('unsigned long'), ReservedWord('')) + assert str_pointer0.arguments == (ReservedWord('float'), ReservedWord('*')) + assert complex_size.arguments == (ReservedWord('c_complex'), ReservedWord('')) + + # Printing + assert ccode(sizeof_ctype) == 'sizeof(unsigned long)' + assert ccode(str_pointer0) == ccode(str_pointer1) == 'sizeof(float*)' + assert ccode(str_simple) == 'sizeof(int)' + assert str(complex_size) == 'sizeof(c_complex)' + assert ccode(complex_size) == 'sizeof(float _Complex)' + + # Reconstruction + assert ccode(str_pointer0.func(*str_pointer0.args)) == 'sizeof(float*)' + assert ccode(str_pointer0.func('int', stars='**')) == 'sizeof(int**)' + assert ccode(complex_size.func(*complex_size.args)) == 'sizeof(float _Complex)' + + def test_customdtype_complex(): """ Test that `CustomDtype` doesn't brak is_imag