diff --git a/quest/src/comm/comm_routines.cpp b/quest/src/comm/comm_routines.cpp index cf6956454..60232046f 100644 --- a/quest/src/comm/comm_routines.cpp +++ b/quest/src/comm/comm_routines.cpp @@ -242,6 +242,56 @@ void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { } +void exchangeSubBufferChunks(Qureg qureg, const vector& pairRanks, const vector& recvTagBases, qindex sendTagBase, qindex chunkSize) { +#if QUEST_COMPILE_MPI + + if (pairRanks.empty()) + return; + + MPI_Comm mpiComm = comm_getMpiComm(); + + qindex sendInd = getSubBufferSendInd(qureg); + qindex recvInd = getBufferRecvInd(); + + auto [messageSize, numMessages] = dividePow2PayloadIntoMessages(chunkSize); + qindex maxTagBase = sendTagBase; + for (qindex tagBase : recvTagBases) + maxTagBase = std::max(maxTagBase, tagBase); + + qindex numTaggedMessages = numMessages * (maxTagBase + 1); + if (numTaggedMessages > getMaxNumMessages()) + error_commNumMessagesExceedTagMax(); + + qindex numRequests = 2 * numMessages * pairRanks.size(); + vector requests(numRequests, MPI_REQUEST_NULL); + + qindex reqInd = 0; + for (qindex c=0; c<(qindex) pairRanks.size(); c++) { + qindex chunkOffset = c * chunkSize; + + for (qindex m=0; m(recvTagBases[c]*numMessages + m); + int sendTag = static_cast(sendTagBase*numMessages + m); + qindex messageOffset = chunkOffset + m*messageSize; + + MPI_Irecv( + &qureg.cpuCommBuffer[recvInd + messageOffset], + messageSize, MPI_QCOMP, pairRanks[c], recvTag, mpiComm, &requests[reqInd++]); + + MPI_Isend( + &qureg.cpuCommBuffer[sendInd + messageOffset], + messageSize, MPI_QCOMP, pairRanks[c], sendTag, mpiComm, &requests[reqInd++]); + } + } + + MPI_Waitall(requests.size(), requests.data(), MPI_STATUSES_IGNORE); + +#else + error_commButEnvNotDistributed(); +#endif +} + + /* * PRIVATE ASYNC SEND AND RECEIVE @@ -533,6 +583,25 @@ void comm_exchangeSubBuffers(Qureg qureg, qindex numAmps, int pairRank) { } +void comm_exchangeSubBufferChunks(Qureg qureg, const vector& pairRanks, const vector& recvTagBases, qindex sendTagBase, qindex chunkSize) { + + qindex sendInd = getSubBufferSendInd(qureg); + qindex recvInd = getBufferRecvInd(); + qindex totalSize = chunkSize * pairRanks.size(); + + assert_commBoundsAreValid(qureg, sendInd, recvInd, totalSize); + assert_bufferSendRecvDoesNotOverlap(sendInd, recvInd, totalSize); + assert_commQuregIsDistributed(qureg); + if (pairRanks.size() != recvTagBases.size()) + error_commGivenInconsistentNumSubArraysANodes(); + + for (int pairRank : pairRanks) + assert_pairRankIsDistinct(qureg, pairRank); + + exchangeSubBufferChunks(qureg, pairRanks, recvTagBases, sendTagBase, chunkSize); +} + + void comm_asynchSendSubBuffer(Qureg qureg, qindex numElems, int pairRank) { auto [sendInd, recvInd] = getSubBufferSendRecvInds(qureg); diff --git a/quest/src/comm/comm_routines.hpp b/quest/src/comm/comm_routines.hpp index e75e889f6..97fcb08a2 100644 --- a/quest/src/comm/comm_routines.hpp +++ b/quest/src/comm/comm_routines.hpp @@ -31,6 +31,8 @@ void comm_exchangeAmpsToBuffers(Qureg qureg, int pairRank); void comm_exchangeSubBuffers(Qureg qureg, qindex numAmpsAndRecvInd, int pairRank); +void comm_exchangeSubBufferChunks(Qureg qureg, const vector& pairRanks, const vector& recvTagBases, qindex sendTagBase, qindex chunkSize); + void comm_asynchSendSubBuffer(Qureg qureg, qindex numElems, int pairRank); void comm_receiveArrayToBuffer(Qureg qureg, qindex numElems, int pairRank); @@ -81,4 +83,4 @@ vector comm_gatherStringsToRoot(char* localChars, int maxNumLocalCh -#endif // COMM_ROUTINES_HPP \ No newline at end of file +#endif // COMM_ROUTINES_HPP diff --git a/quest/src/core/localiser.cpp b/quest/src/core/localiser.cpp index 83a23b921..63759834d 100644 --- a/quest/src/core/localiser.cpp +++ b/quest/src/core/localiser.cpp @@ -24,8 +24,10 @@ #include "quest/src/core/localiser.hpp" #include "quest/src/core/accelerator.hpp" #include "quest/src/comm/comm_config.hpp" +#include "quest/src/comm/comm_indices.hpp" #include "quest/src/comm/comm_routines.hpp" #include "quest/src/cpu/cpu_config.hpp" +#include "quest/src/cpu/cpu_subroutines.hpp" #include "quest/src/gpu/gpu_config.hpp" #include @@ -893,6 +895,85 @@ void localiser_statevec_anyCtrlSwap(Qureg qureg, ConstList64 ctrls, ConstList64 */ +qindex getBitMaskOfQubitsInPattern(ConstList64 qubits, qindex pattern) { + + qindex mask = 0; + for (size_t i=0; i remotePatterns; + vector pairRanks; + + for (qindex pattern=0; pattern waveRanks; + vector recvTagBases; + + for (qindex c=0; c= 2 && + ctrls.empty() && + qureg.isDistributed && + !qureg.isGpuAccelerated + ) { + List64 prefixInds = lists_getEmptyList64(); + for (int prefixTarg : prefixTargs) + prefixInds.push_back(util_getPrefixInd(prefixTarg, qureg)); + + multiSwapBetweenPrefixAndSuffix(qureg, suffixTargs, prefixInds); + return; + } + + // otherwise, fall back to per-SWAP communication + for (size_t i=0; i qindex cpu_statevec_packAmpsIntoBuffer(Qureg qureg, ConstList64 qubitInds, ConstList64 qubitStates); +void cpu_statevec_packAmpsIntoBufferAtOffset(Qureg qureg, ConstList64 sortedQubits, qindex qubitStateMask, qindex bufferOffset); + +void cpu_statevec_unpackAmpsFromBufferAtOffset(Qureg qureg, ConstList64 sortedQubits, qindex qubitStateMask, qindex bufferOffset); + qindex cpu_statevec_packPairSummedAmpsIntoBuffer(Qureg qureg, int qubit1, int qubit2, int qubit3, int bit2); @@ -202,4 +206,4 @@ void cpu_statevec_initDebugState_sub(Qureg qureg); void cpu_statevec_initUnnormalisedUniformlyRandomPureStateAmps_sub(Qureg qureg); -#endif // CPU_SUBROUTINES_HPP \ No newline at end of file +#endif // CPU_SUBROUTINES_HPP