42#ifndef TPETRA_DETAILS_IALLREDUCE_HPP
43#define TPETRA_DETAILS_IALLREDUCE_HPP
61#include "TpetraCore_config.h"
62#include "Teuchos_EReductionType.hpp"
63#ifdef HAVE_TPETRACORE_MPI
67#include "Tpetra_Details_temporaryViewUtils.hpp"
69#include "Kokkos_Core.hpp"
75#ifndef DOXYGEN_SHOULD_SKIP_THIS
78 template<
class OrdinalType>
class Comm;
85#ifdef HAVE_TPETRACORE_MPI
86std::string getMpiErrorString (
const int errCode);
116std::shared_ptr<CommRequest>
119#ifdef HAVE_TPETRACORE_MPI
121template<
typename InputViewType,
typename OutputViewType,
typename ResultViewType>
138 void wait ()
override
140 if (req != MPI_REQUEST_NULL) {
141 const int err = MPI_Wait (&req, MPI_STATUS_IGNORE);
142 TEUCHOS_TEST_FOR_EXCEPTION
143 (err != MPI_SUCCESS, std::runtime_error,
144 "MpiCommRequest::wait: MPI_Wait failed with error \""
145 << getMpiErrorString (err));
148 req = MPI_REQUEST_NULL;
150 Kokkos::deep_copy(resultBuf, recvBuf);
157 void cancel ()
override
161 req = MPI_REQUEST_NULL;
165 InputViewType sendBuf;
166 OutputViewType recvBuf;
167 ResultViewType resultBuf;
175iallreduceRaw (
const void* sendbuf,
178 MPI_Datatype mpiDatatype,
179 const Teuchos::EReductionType op,
185allreduceRaw (
const void* sendbuf,
188 MPI_Datatype mpiDatatype,
189 const Teuchos::EReductionType op,
192template<
class InputViewType,
class OutputViewType>
193std::shared_ptr<CommRequest>
194iallreduceImpl (
const InputViewType& sendbuf,
195 const OutputViewType& recvbuf,
196 const ::Teuchos::EReductionType op,
197 const ::Teuchos::Comm<int>& comm)
199 using Packet =
typename InputViewType::non_const_value_type;
200 if(comm.getSize() == 1)
202 Kokkos::deep_copy(recvbuf, sendbuf);
203 return emptyCommRequest();
205 Packet examplePacket;
206 MPI_Datatype mpiDatatype = sendbuf.extent(0) ?
207 MpiTypeTraits<Packet>::getType (examplePacket) :
209 bool datatypeNeedsFree = MpiTypeTraits<Packet>::needsFree;
210 MPI_Comm rawComm = ::Tpetra::Details::extractMpiCommFromTeuchos (comm);
213 auto sendMPI = Tpetra::Details::TempView::toMPISafe<InputViewType, false>(sendbuf);
214 auto recvMPI = Tpetra::Details::TempView::toMPISafe<OutputViewType, false>(recvbuf);
215 std::shared_ptr<CommRequest> req;
218 if(
isInterComm(comm) && sendMPI.data() == recvMPI.data())
222 Kokkos::View<Packet*, Kokkos::HostSpace> tempInput(Kokkos::ViewAllocateWithoutInitializing(
"tempInput"), sendMPI.extent(0));
223 for(
size_t i = 0; i < sendMPI.extent(0); i++)
224 tempInput(i) = sendMPI.data()[i];
227 MPI_Request mpiReq = iallreduceRaw((
const void*) tempInput.data(), (
void*) recvMPI.data(), tempInput.extent(0), mpiDatatype, op, rawComm);
228 req = std::shared_ptr<CommRequest>(
new MpiRequest<
decltype(tempInput),
decltype(recvMPI), OutputViewType>(tempInput, recvMPI, recvbuf, mpiReq));
231 allreduceRaw((
const void*) sendMPI.data(), (
void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
232 Kokkos::deep_copy(recvbuf, recvMPI);
233 req = emptyCommRequest();
240 MPI_Request mpiReq = iallreduceRaw((
const void*) sendMPI.data(), (
void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
241 req = std::shared_ptr<CommRequest>(
new MpiRequest<
decltype(sendMPI),
decltype(recvMPI), OutputViewType>(sendMPI, recvMPI, recvbuf, mpiReq));
244 allreduceRaw((
const void*) sendMPI.data(), (
void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
245 Kokkos::deep_copy(recvbuf, recvMPI);
246 req = emptyCommRequest();
249 if(datatypeNeedsFree)
250 MPI_Type_free(&mpiDatatype);
257template<
class InputViewType,
class OutputViewType>
258std::shared_ptr<CommRequest>
259iallreduceImpl (
const InputViewType& sendbuf,
260 const OutputViewType& recvbuf,
261 const ::Teuchos::EReductionType,
262 const ::Teuchos::Comm<int>&)
264 Kokkos::deep_copy(recvbuf, sendbuf);
265 return emptyCommRequest();
301template<
class InputViewType,
class OutputViewType>
302std::shared_ptr<CommRequest>
305 const ::Teuchos::EReductionType op,
306 const ::Teuchos::Comm<int>& comm)
308 static_assert (Kokkos::is_view<InputViewType>::value,
309 "InputViewType must be a Kokkos::View specialization.");
310 static_assert (Kokkos::is_view<OutputViewType>::value,
311 "OutputViewType must be a Kokkos::View specialization.");
312 constexpr int rank =
static_cast<int> (OutputViewType::rank);
313 static_assert (
static_cast<int> (InputViewType::rank) ==
rank,
314 "InputViewType and OutputViewType must have the same rank.");
315 static_assert (
rank == 0 ||
rank == 1,
316 "InputViewType and OutputViewType must both have "
317 "rank 0 or rank 1.");
318 typedef typename OutputViewType::non_const_value_type packet_type;
319 static_assert (std::is_same<
typename OutputViewType::value_type,
321 "OutputViewType must be a nonconst Kokkos::View.");
322 static_assert (std::is_same<
typename InputViewType::non_const_value_type,
324 "InputViewType and OutputViewType must be Views "
325 "whose entries have the same type.");
327 static_assert (!std::is_same<typename InputViewType::array_layout, Kokkos::LayoutStride>::value,
328 "Input/Output views must be contiguous (not LayoutStride)");
329 static_assert (!std::is_same<typename OutputViewType::array_layout, Kokkos::LayoutStride>::value,
330 "Input/Output views must be contiguous (not LayoutStride)");
332 return Impl::iallreduceImpl<InputViewType, OutputViewType> (
sendbuf,
recvbuf, op, comm);
335std::shared_ptr<CommRequest>
338 const ::Teuchos::EReductionType op,
339 const ::Teuchos::Comm<int>& comm);
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
Add specializations of Teuchos::Details::MpiTypeTraits for Kokkos::complex<float> and Kokkos::complex...
Struct that holds views of the contents of a CrsMatrix.
Base class for the request (more or less a future) representing a pending nonblocking MPI operation.
virtual ~CommRequest()
Destructor (virtual for memory safety of derived classes).
virtual void cancel()
Cancel the pending communication request.
virtual void wait()
Wait on this communication request to complete.
Implementation details of Tpetra.
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator.
Namespace Tpetra contains the class and methods constituting the Tpetra library.