diff --git a/quest/src/core/bitwise.hpp b/quest/src/core/bitwise.hpp index f5266afa4..7ce211428 100644 --- a/quest/src/core/bitwise.hpp +++ b/quest/src/core/bitwise.hpp @@ -14,6 +14,11 @@ #include #endif +#if defined(__BMI2__) && (defined(__x86_64__) || defined(_M_X64)) && !defined(__NVCC__) && !defined(__HIP__) + #include + #define QUEST_HAVE_BMI2_INTRINSICS 1 +#endif + #include "quest/include/types.h" #include "quest/src/core/inliner.hpp" @@ -166,6 +171,24 @@ INLINE int getBitMaskParity(qindex mask) { INLINE qindex insertBits(qindex number, const int* bitIndices, int numIndices, int bitValue) { // bitIndices must be strictly increasing + +#ifdef QUEST_HAVE_BMI2_INTRINSICS + // Smaller fixed-size cases are usually unrolled by callers already. + if (numIndices > 5) { + unsigned long long insertionMask = 0; + + for (int i=0; i(number), ~insertionMask); + + if (bitValue) + inserted |= insertionMask; + + return static_cast(inserted); + } +#endif + for (int i=0; i 5) { + unsigned long long extractionMask = 0; + bool indicesAreIncreasing = true; + int prevInd = -1; + + for (int i=0; i prevInd; + prevInd = bitInd; + } + + // PEXT returns bits in ascending mask order, matching this API only for sorted indices. + if (indicesAreIncreasing) + return static_cast(_pext_u64(static_cast(number), extractionMask)); + } +#endif + qindex value = 0; for (int i=0; i + +#include "quest/src/core/bitwise.hpp" +#include "tests/utils/macros.hpp" + +#include + +using std::vector; + + +/* + * UTILITIES + */ + +#define TEST_CATEGORY \ + LABEL_UNIT_TAG "[bitwise]" + + +static qindex getRefInsertedBits(qindex number, const vector& bitIndices, int bitValue) { + + qindex out = 0; + int srcInd = 0; + int nextIns = 0; + + for (int dstInd=0; dstInd<63; dstInd++) { + if (nextIns < static_cast(bitIndices.size()) && dstInd == bitIndices[nextIns]) { + if (bitValue) + out |= QINDEX_ONE << dstInd; + nextIns++; + } else { + if ((number >> srcInd) & QINDEX_ONE) + out |= QINDEX_ONE << dstInd; + srcInd++; + } + } + + return out; +} + + +static qindex getRefValueOfBits(qindex number, const vector& bitIndices) { + + qindex out = 0; + + for (int i=0; i(bitIndices.size()); i++) + if ((number >> bitIndices[i]) & QINDEX_ONE) + out |= QINDEX_ONE << i; + + return out; +} + + +/** + * TESTS + * + * @ingroup unitbitwise + * @{ + */ + + +TEST_CASE( "insertBits", TEST_CATEGORY ) { + + SECTION( LABEL_CORRECTNESS ) { + vector numbers = {0, 1, 2, 5, 21, 0x12345, 0x6DB6DB}; + vector> bitIndices = { + {}, + {0}, + {1}, + {0, 1}, + {1, 3, 5}, + {0, 2, 6, 9}, + {4, 8, 12, 20} + }; + + for (auto number: numbers) + for (auto& inds: bitIndices) + for (int bitValue: {0, 1}) + REQUIRE( insertBits(number, inds.data(), static_cast(inds.size()), bitValue) == getRefInsertedBits(number, inds, bitValue) ); + } + + SECTION( LABEL_VALIDATION ) { + + // no validation! + SUCCEED( ); + } +} + + +TEST_CASE( "getValueOfBits", TEST_CATEGORY ) { + + SECTION( LABEL_CORRECTNESS ) { + vector numbers = {0, 1, 2, 5, 0x12345, 0xAAAAAAAA, 0x55555555}; + vector> bitIndices = { + {}, + {0}, + {1}, + {0, 1, 2, 3}, + {3, 1, 4, 0}, + {20, 0, 16, 8, 4} + }; + + for (auto number: numbers) + for (auto& inds: bitIndices) + REQUIRE( getValueOfBits(number, inds.data(), static_cast(inds.size())) == getRefValueOfBits(number, inds) ); + } + + SECTION( LABEL_VALIDATION ) { + + // no validation! + SUCCEED( ); + } +} + + +/** @} (end defgroup) */