diff --git a/tests/test_kernel.py b/tests/test_kernel.py index e381c11..5f6cebd 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -4,7 +4,9 @@ # ########################################### # import cffi +import pytest import sysconfig +from pathlib import Path import numpy as np diff --git a/xobjects/context_cpu.py b/xobjects/context_cpu.py index 90952bf..4bf53a2 100644 --- a/xobjects/context_cpu.py +++ b/xobjects/context_cpu.py @@ -150,8 +150,7 @@ def __init__(self, omp_num_threads=0): """ super().__init__() self.omp_num_threads = omp_num_threads - if omp_num_threads == 0: - self.allow_prebuilt_kernels = True + self.allow_prebuilt_kernels = True def __str__(self): if not self.openmp_enabled: @@ -415,8 +414,18 @@ def compile_kernel( log.debug(f"cffi def {pyname} {signature}") if self.openmp_enabled: - ffi_interface.cdef("void omp_set_num_threads(int);") - ffi_interface.cdef("int omp_get_max_threads();") + # The wrapper is needed to ensure that the omp functions are linked + ffi_interface.cdef("void xo_omp_set_num_threads(int);") + ffi_interface.cdef("int xo_omp_get_max_threads();") + specialized_source += """ + void xo_omp_set_num_threads(int num_threads) { + omp_set_num_threads(num_threads); + } + + int xo_omp_get_max_threads(void) { + return omp_get_max_threads(); + } + """ # Compile xtr_compile_args = ["-std=c99", "-DXO_CONTEXT_CPU"] @@ -528,8 +537,8 @@ def _load_kernel_module( spec.loader.exec_module(module) if self.openmp_enabled: - self.omp_set_num_threads = module.lib.omp_set_num_threads - self.omp_get_max_threads = module.lib.omp_get_max_threads + self.omp_set_num_threads = module.lib.xo_omp_set_num_threads + self.omp_get_max_threads = module.lib.xo_omp_get_max_threads return module @@ -768,10 +777,8 @@ def update_from_nplike(self, offset, dest_dtype, value): value = nplike_to_numpy(value) if dest_dtype != value.dtype: value = value.astype(dtype=dest_dtype) # make a copy - src = value.view("int8") - self.buffer[offset : offset + src.nbytes] = value.flatten().view( - "int8" - ) + src = value.flatten().view("int8") + self.buffer[offset : offset + src.nbytes] = src def to_bytearray(self, offset, nbytes): """copy in byte array: used in update_from_xbuffer""" diff --git a/xobjects/struct.py b/xobjects/struct.py index f11c547..67fb062 100644 --- a/xobjects/struct.py +++ b/xobjects/struct.py @@ -518,7 +518,6 @@ def compile_class_kernels( extra_classes=(), extra_compile_args=(), ): - if context.allow_prebuilt_kernels: _print_state = Print.suppress Print.suppress = True @@ -532,6 +531,7 @@ def compile_class_kernels( config={}, tracker_element_classes=[], classes=list(extra_classes) + [cls], + context=context, ) except ImportError: kernel_info = None