diff --git a/quest/src/gpu/gpu_thrust.cuh b/quest/src/gpu/gpu_thrust.cuh index 864cca5f8..4ab9dd6f1 100644 --- a/quest/src/gpu/gpu_thrust.cuh +++ b/quest/src/gpu/gpu_thrust.cuh @@ -508,72 +508,64 @@ struct functor_getFidelityTerm { }; -template struct functor_projectStateVec { - // this functor multiplies an amp with zero or a - // renormalisation codfficient, depending on whether + // this functor multiplies an amp with zero or a + // renormalisation coefficient, depending on whether // the basis state of the amp has qubits in a particular // configuration. This is used to project statevector - // qubits into a particular measurement outcome - - int* targetsPtr; - int numTargets, rank; - qindex retainValue; + // qubits into a particular measurement outcome. + // + // The projected qubits and their target outcomes are encoded as bitmasks + // (qubitMask = the projected qubit positions, valueMask = those positions + // set to their measured outcomes), so the test reduces to a single masked + // comparison. This needs no device-side qubit list, avoiding a per-call + // host->device copy that dominates runtime for small Quregs (#749). It is + // also branch/loop-free, so no compile-time NumTargets unrolling is needed. + + qindex qubitMask, valueMask; qreal renorm; - functor_projectStateVec( - int* targetsPtr, int numTargets, - qindex retainValue, qreal renorm - ) : - targetsPtr(targetsPtr), numTargets(numTargets), - retainValue(retainValue), renorm(renorm) - { - assert_numTargsMatchesTemplateParam(numTargets, NumTargets); - } + functor_projectStateVec(qindex qubitMask, qindex valueMask, qreal renorm) : + qubitMask(qubitMask), valueMask(valueMask), renorm(renorm) + { } __host__ __device__ gpu_qcomp operator()(qindex n, gpu_qcomp amp) { - // use the compile-time value if possible, to auto-unroll the getValueOfBits() loop below - SET_VAR_AT_COMPILE_TIME(int, numBits, NumTargets, numTargets); - - // return amp scaled by zero or renorm, depending on whether n has projected substate - qindex val = getValueOfBits(n, targetsPtr, numBits); - qreal fac = renorm * (val == retainValue); + // keep amp (scaled by renorm) iff its projected qubits match the outcome + qreal fac = renorm * ((n & qubitMask) == valueMask); return fac * amp; } }; -template struct functor_projectDensMatr { - // this functor multiplies an amp with zero or a + // this functor multiplies an amp with zero or a // renormalisation coefficient, depending on whether // the basis state of the amp has qubits in a particular // configuration. This is used to project density matrix - // qubits into a particular measurement outcome - - int* targetsPtr; - int numTargets, rank, numQuregQubits; - qindex logNumAmpsPerNode, retainValue; + // qubits into a particular measurement outcome. + // + // Like functor_projectStateVec, the projected qubits and their outcomes are + // encoded as bitmasks (qubitMask, valueMask), so no device-side qubit list + // is needed - avoiding a per-call host->device copy at small Quregs (#749). + + qindex qubitMask, valueMask; + int rank, numQuregQubits; + qindex logNumAmpsPerNode; qreal renorm; functor_projectDensMatr( - int* targetsPtr, int numTargets, int rank, int numQuregQubits, - qindex logNumAmpsPerNode, qindex retainValue, qreal renorm + qindex qubitMask, qindex valueMask, int rank, int numQuregQubits, + qindex logNumAmpsPerNode, qreal renorm ) : - targetsPtr(targetsPtr), numTargets(numTargets), rank(rank), numQuregQubits(numQuregQubits), - logNumAmpsPerNode(logNumAmpsPerNode), retainValue(retainValue), renorm(renorm) - { - assert_numTargsMatchesTemplateParam(numTargets, NumTargets); - } + qubitMask(qubitMask), valueMask(valueMask), rank(rank), numQuregQubits(numQuregQubits), + logNumAmpsPerNode(logNumAmpsPerNode), renorm(renorm) + { } __host__ __device__ gpu_qcomp operator()(qindex n, gpu_qcomp amp) { - // use the compile-time value if possible, to auto-unroll the getValueOfBits() loop below - SET_VAR_AT_COMPILE_TIME(int, numBits, NumTargets, numTargets); - // i = global index of nth local amp qindex i = concatenateBits(rank, n, logNumAmpsPerNode); @@ -581,11 +573,10 @@ struct functor_projectDensMatr { qindex r = getBitsRightOfIndex(i, numQuregQubits); qindex c = getBitsLeftOfIndex(i, numQuregQubits-1); - qindex v1 = getValueOfBits(r, targetsPtr, numBits); - qindex v2 = getValueOfBits(c, targetsPtr, numBits); - - // multiply amp with renorm or zero if values disagree with given outcomes - qreal fac = renorm * (v1 == v2) * (retainValue == v1); + // keep amp (scaled by renorm) iff both row and column basis states have + // the projected qubits in the measured outcome + bool match = ((r & qubitMask) == valueMask) && ((c & qubitMask) == valueMask); + qreal fac = renorm * match; return fac * amp; } }; @@ -1018,10 +1009,12 @@ gpu_qcomp thrust_densmatr_calcExpecFullStateDiagMatr_sub(Qureg qureg, FullStateD template void thrust_statevec_multiQubitProjector_sub(Qureg qureg, ConstList64 qubits, ConstList64 outcomes, qreal renorm) { - devints devQubits = getDevInts(qubits); - qindex retainValue = getIntegerFromBits(outcomes.data(), outcomes.size()); - auto projFunctor = functor_projectStateVec( - getPtr(devQubits), qubits.size(), retainValue, renorm); + // encode the projected qubits and their outcomes as host-side bitmasks, + // avoiding a per-call host->device copy of the qubit list (#749). NumQubits + // is retained only for dispatch uniformity; the masked functor needs no unroll. + qindex qubitMask = util_getBitMask(qubits); + qindex valueMask = util_getBitMask(qubits, outcomes); + auto projFunctor = functor_projectStateVec(qubitMask, valueMask, renorm); auto indIter = thrust::make_counting_iterator(QINDEX_ZERO); auto ampIter = getStartPtr(qureg); @@ -1034,11 +1027,14 @@ void thrust_statevec_multiQubitProjector_sub(Qureg qureg, ConstList64 qubits, Co template void thrust_densmatr_multiQubitProjector_sub(Qureg qureg, ConstList64 qubits, ConstList64 outcomes, qreal renorm) { - devints devQubits = getDevInts(qubits); - qindex retainValue = getIntegerFromBits(outcomes.data(), outcomes.size()); - auto projFunctor = functor_projectDensMatr( - getPtr(devQubits), qubits.size(), qureg.rank, qureg.numQubits, - qureg.logNumAmpsPerNode, retainValue, renorm); + // encode the projected qubits and outcomes as host-side bitmasks, avoiding a + // per-call host->device copy of the qubit list (#749). NumQubits is retained + // only for dispatch uniformity; the masked functor needs no unroll. + qindex qubitMask = util_getBitMask(qubits); + qindex valueMask = util_getBitMask(qubits, outcomes); + auto projFunctor = functor_projectDensMatr( + qubitMask, valueMask, qureg.rank, qureg.numQubits, + qureg.logNumAmpsPerNode, renorm); auto indIter = thrust::make_counting_iterator(QINDEX_ZERO); auto ampIter = getStartPtr(qureg);