Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 50 additions & 54 deletions quest/src/gpu/gpu_thrust.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -508,84 +508,75 @@ struct functor_getFidelityTerm {
};


template <int NumTargets>
struct functor_projectStateVec {

// this functor multiplies an amp with zero or a
// renormalisation codfficient, depending on whether
// this functor multiplies an amp with zero or a
// renormalisation coefficient, depending on whether
// the basis state of the amp has qubits in a particular
// configuration. This is used to project statevector
// qubits into a particular measurement outcome

int* targetsPtr;
int numTargets, rank;
qindex retainValue;
// qubits into a particular measurement outcome.
//
// The projected qubits and their target outcomes are encoded as bitmasks
// (qubitMask = the projected qubit positions, valueMask = those positions
// set to their measured outcomes), so the test reduces to a single masked
// comparison. This needs no device-side qubit list, avoiding a per-call
// host->device copy that dominates runtime for small Quregs (#749). It is
// also branch/loop-free, so no compile-time NumTargets unrolling is needed.

qindex qubitMask, valueMask;
qreal renorm;

functor_projectStateVec(
int* targetsPtr, int numTargets,
qindex retainValue, qreal renorm
) :
targetsPtr(targetsPtr), numTargets(numTargets),
retainValue(retainValue), renorm(renorm)
{
assert_numTargsMatchesTemplateParam(numTargets, NumTargets);
}
functor_projectStateVec(qindex qubitMask, qindex valueMask, qreal renorm) :
qubitMask(qubitMask), valueMask(valueMask), renorm(renorm)
{ }

__host__ __device__ gpu_qcomp operator()(qindex n, gpu_qcomp amp) {

// use the compile-time value if possible, to auto-unroll the getValueOfBits() loop below
SET_VAR_AT_COMPILE_TIME(int, numBits, NumTargets, numTargets);

// return amp scaled by zero or renorm, depending on whether n has projected substate
qindex val = getValueOfBits(n, targetsPtr, numBits);
qreal fac = renorm * (val == retainValue);
// keep amp (scaled by renorm) iff its projected qubits match the outcome
qreal fac = renorm * ((n & qubitMask) == valueMask);
return fac * amp;
}
};


template <int NumTargets>
struct functor_projectDensMatr {

// this functor multiplies an amp with zero or a
// this functor multiplies an amp with zero or a
// renormalisation coefficient, depending on whether
// the basis state of the amp has qubits in a particular
// configuration. This is used to project density matrix
// qubits into a particular measurement outcome

int* targetsPtr;
int numTargets, rank, numQuregQubits;
qindex logNumAmpsPerNode, retainValue;
// qubits into a particular measurement outcome.
//
// Like functor_projectStateVec, the projected qubits and their outcomes are
// encoded as bitmasks (qubitMask, valueMask), so no device-side qubit list
// is needed - avoiding a per-call host->device copy at small Quregs (#749).

qindex qubitMask, valueMask;
int rank, numQuregQubits;
qindex logNumAmpsPerNode;
qreal renorm;

functor_projectDensMatr(
int* targetsPtr, int numTargets, int rank, int numQuregQubits,
qindex logNumAmpsPerNode, qindex retainValue, qreal renorm
qindex qubitMask, qindex valueMask, int rank, int numQuregQubits,
qindex logNumAmpsPerNode, qreal renorm
) :
targetsPtr(targetsPtr), numTargets(numTargets), rank(rank), numQuregQubits(numQuregQubits),
logNumAmpsPerNode(logNumAmpsPerNode), retainValue(retainValue), renorm(renorm)
{
assert_numTargsMatchesTemplateParam(numTargets, NumTargets);
}
qubitMask(qubitMask), valueMask(valueMask), rank(rank), numQuregQubits(numQuregQubits),
logNumAmpsPerNode(logNumAmpsPerNode), renorm(renorm)
{ }

__host__ __device__ gpu_qcomp operator()(qindex n, gpu_qcomp amp) {

// use the compile-time value if possible, to auto-unroll the getValueOfBits() loop below
SET_VAR_AT_COMPILE_TIME(int, numBits, NumTargets, numTargets);

// i = global index of nth local amp
qindex i = concatenateBits(rank, n, logNumAmpsPerNode);

// r, c = global row and column indices of nth local amp
qindex r = getBitsRightOfIndex(i, numQuregQubits);
qindex c = getBitsLeftOfIndex(i, numQuregQubits-1);

qindex v1 = getValueOfBits(r, targetsPtr, numBits);
qindex v2 = getValueOfBits(c, targetsPtr, numBits);

// multiply amp with renorm or zero if values disagree with given outcomes
qreal fac = renorm * (v1 == v2) * (retainValue == v1);
// keep amp (scaled by renorm) iff both row and column basis states have
// the projected qubits in the measured outcome
bool match = ((r & qubitMask) == valueMask) && ((c & qubitMask) == valueMask);
qreal fac = renorm * match;
return fac * amp;
}
};
Expand Down Expand Up @@ -1018,10 +1009,12 @@ gpu_qcomp thrust_densmatr_calcExpecFullStateDiagMatr_sub(Qureg qureg, FullStateD
template <int NumQubits>
void thrust_statevec_multiQubitProjector_sub(Qureg qureg, ConstList64 qubits, ConstList64 outcomes, qreal renorm) {

devints devQubits = getDevInts(qubits);
qindex retainValue = getIntegerFromBits(outcomes.data(), outcomes.size());
auto projFunctor = functor_projectStateVec<NumQubits>(
getPtr(devQubits), qubits.size(), retainValue, renorm);
// encode the projected qubits and their outcomes as host-side bitmasks,
// avoiding a per-call host->device copy of the qubit list (#749). NumQubits
// is retained only for dispatch uniformity; the masked functor needs no unroll.
qindex qubitMask = util_getBitMask(qubits);
qindex valueMask = util_getBitMask(qubits, outcomes);
auto projFunctor = functor_projectStateVec(qubitMask, valueMask, renorm);

auto indIter = thrust::make_counting_iterator(QINDEX_ZERO);
auto ampIter = getStartPtr(qureg);
Expand All @@ -1034,11 +1027,14 @@ void thrust_statevec_multiQubitProjector_sub(Qureg qureg, ConstList64 qubits, Co
template <int NumQubits>
void thrust_densmatr_multiQubitProjector_sub(Qureg qureg, ConstList64 qubits, ConstList64 outcomes, qreal renorm) {

devints devQubits = getDevInts(qubits);
qindex retainValue = getIntegerFromBits(outcomes.data(), outcomes.size());
auto projFunctor = functor_projectDensMatr<NumQubits>(
getPtr(devQubits), qubits.size(), qureg.rank, qureg.numQubits,
qureg.logNumAmpsPerNode, retainValue, renorm);
// encode the projected qubits and outcomes as host-side bitmasks, avoiding a
// per-call host->device copy of the qubit list (#749). NumQubits is retained
// only for dispatch uniformity; the masked functor needs no unroll.
qindex qubitMask = util_getBitMask(qubits);
qindex valueMask = util_getBitMask(qubits, outcomes);
auto projFunctor = functor_projectDensMatr(
qubitMask, valueMask, qureg.rank, qureg.numQubits,
qureg.logNumAmpsPerNode, renorm);

auto indIter = thrust::make_counting_iterator(QINDEX_ZERO);
auto ampIter = getStartPtr(qureg);
Expand Down