Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_iallreduce.hpp
Go to the documentation of this file.
1// @HEADER
2// ***********************************************************************
3//
4// Tpetra: Templated Linear Algebra Services Package
5// Copyright (2008) Sandia Corporation
6//
7// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8// the U.S. Government retains certain rights in this software.
9//
10// Redistribution and use in source and binary forms, with or without
11// modification, are permitted provided that the following conditions are
12// met:
13//
14// 1. Redistributions of source code must retain the above copyright
15// notice, this list of conditions and the following disclaimer.
16//
17// 2. Redistributions in binary form must reproduce the above copyright
18// notice, this list of conditions and the following disclaimer in the
19// documentation and/or other materials provided with the distribution.
20//
21// 3. Neither the name of the Corporation nor the names of the
22// contributors may be used to endorse or promote products derived from
23// this software without specific prior written permission.
24//
25// THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
26// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
28// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
29// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36//
37// Questions? Contact Michael A. Heroux (maherou@sandia.gov)
38//
39// ************************************************************************
40// @HEADER
41
42#ifndef TPETRA_DETAILS_IALLREDUCE_HPP
43#define TPETRA_DETAILS_IALLREDUCE_HPP
44
60
61#include "TpetraCore_config.h"
62#include "Teuchos_EReductionType.hpp"
63#ifdef HAVE_TPETRACORE_MPI
66#endif // HAVE_TPETRACORE_MPI
67#include "Tpetra_Details_temporaryViewUtils.hpp"
69#include "Kokkos_Core.hpp"
70#include <memory>
71#include <stdexcept>
72#include <type_traits>
73#include <functional>
74
75#ifndef DOXYGEN_SHOULD_SKIP_THIS
76namespace Teuchos {
77 // forward declaration of Comm
78 template<class OrdinalType> class Comm;
79} // namespace Teuchos
80#endif // NOT DOXYGEN_SHOULD_SKIP_THIS
81
82namespace Tpetra {
83namespace Details {
84
85#ifdef HAVE_TPETRACORE_MPI
86std::string getMpiErrorString (const int errCode);
87#endif
88
96public:
98 virtual ~CommRequest () {}
99
104 virtual void wait () {}
105
109 virtual void cancel () {}
110};
111
112// Don't rely on anything in this namespace.
113namespace Impl {
114
116std::shared_ptr<CommRequest>
117emptyCommRequest ();
118
119#ifdef HAVE_TPETRACORE_MPI
120#if MPI_VERSION >= 3
121template<typename InputViewType, typename OutputViewType, typename ResultViewType>
122struct MpiRequest : public CommRequest
123{
124 MpiRequest(const InputViewType& send, const OutputViewType& recv, const ResultViewType& result, MPI_Request req_)
125 : sendBuf(send), recvBuf(recv), resultBuf(result), req(req_)
126 {}
127
128 ~MpiRequest()
129 {
130 //this is a no-op if wait() or cancel() have already been called
131 cancel();
132 }
133
138 void wait () override
139 {
140 if (req != MPI_REQUEST_NULL) {
141 const int err = MPI_Wait (&req, MPI_STATUS_IGNORE);
142 TEUCHOS_TEST_FOR_EXCEPTION
143 (err != MPI_SUCCESS, std::runtime_error,
144 "MpiCommRequest::wait: MPI_Wait failed with error \""
145 << getMpiErrorString (err));
146 // MPI_Wait should set the MPI_Request to MPI_REQUEST_NULL on
147 // success. We'll do it here just to be conservative.
148 req = MPI_REQUEST_NULL;
149 //Since recvBuf contains the result, copy it to the user's resultBuf.
150 Kokkos::deep_copy(resultBuf, recvBuf);
151 }
152 }
153
157 void cancel () override
158 {
159 //BMK: Per https://www.mpi-forum.org/docs/mpi-3.1/mpi31-report/node126.htm,
160 //MPI_Cancel cannot be used for collectives like iallreduce.
161 req = MPI_REQUEST_NULL;
162 }
163
164private:
165 InputViewType sendBuf;
166 OutputViewType recvBuf;
167 ResultViewType resultBuf;
168 //This request is active if and only if req != MPI_REQUEST_NULL.
169 MPI_Request req;
170};
171
174MPI_Request
175iallreduceRaw (const void* sendbuf,
176 void* recvbuf,
177 const int count,
178 MPI_Datatype mpiDatatype,
179 const Teuchos::EReductionType op,
180 MPI_Comm comm);
181#endif
182
184void
185allreduceRaw (const void* sendbuf,
186 void* recvbuf,
187 const int count,
188 MPI_Datatype mpiDatatype,
189 const Teuchos::EReductionType op,
190 MPI_Comm comm);
191
192template<class InputViewType, class OutputViewType>
193std::shared_ptr<CommRequest>
194iallreduceImpl (const InputViewType& sendbuf,
195 const OutputViewType& recvbuf,
196 const ::Teuchos::EReductionType op,
197 const ::Teuchos::Comm<int>& comm)
198{
199 using Packet = typename InputViewType::non_const_value_type;
200 if(comm.getSize() == 1)
201 {
202 Kokkos::deep_copy(recvbuf, sendbuf);
203 return emptyCommRequest();
204 }
205 Packet examplePacket;
206 MPI_Datatype mpiDatatype = sendbuf.extent(0) ?
207 MpiTypeTraits<Packet>::getType (examplePacket) :
208 MPI_BYTE;
209 bool datatypeNeedsFree = MpiTypeTraits<Packet>::needsFree;
210 MPI_Comm rawComm = ::Tpetra::Details::extractMpiCommFromTeuchos (comm);
211 //Note BMK: Nonblocking collectives like iallreduce cannot use GPU buffers.
212 //See https://www.open-mpi.org/faq/?category=runcuda#mpi-cuda-support
213 auto sendMPI = Tpetra::Details::TempView::toMPISafe<InputViewType, false>(sendbuf);
214 auto recvMPI = Tpetra::Details::TempView::toMPISafe<OutputViewType, false>(recvbuf);
215 std::shared_ptr<CommRequest> req;
216 //Next, if input/output alias and comm is an intercomm, make a deep copy of input.
217 //Not possible to do in-place allreduce for intercomm.
218 if(isInterComm(comm) && sendMPI.data() == recvMPI.data())
219 {
220 //Can't do in-place collective on an intercomm,
221 //so use a separate 1D copy as the input.
222 Kokkos::View<Packet*, Kokkos::HostSpace> tempInput(Kokkos::ViewAllocateWithoutInitializing("tempInput"), sendMPI.extent(0));
223 for(size_t i = 0; i < sendMPI.extent(0); i++)
224 tempInput(i) = sendMPI.data()[i];
225#if MPI_VERSION >= 3
226 //MPI 3+: use async allreduce
227 MPI_Request mpiReq = iallreduceRaw((const void*) tempInput.data(), (void*) recvMPI.data(), tempInput.extent(0), mpiDatatype, op, rawComm);
228 req = std::shared_ptr<CommRequest>(new MpiRequest<decltype(tempInput), decltype(recvMPI), OutputViewType>(tempInput, recvMPI, recvbuf, mpiReq));
229#else
230 //Older MPI: Iallreduce not available. Instead do blocking all-reduce and return empty request.
231 allreduceRaw((const void*) sendMPI.data(), (void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
232 Kokkos::deep_copy(recvbuf, recvMPI);
233 req = emptyCommRequest();
234#endif
235 }
236 else
237 {
238#if MPI_VERSION >= 3
239 //MPI 3+: use async allreduce
240 MPI_Request mpiReq = iallreduceRaw((const void*) sendMPI.data(), (void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
241 req = std::shared_ptr<CommRequest>(new MpiRequest<decltype(sendMPI), decltype(recvMPI), OutputViewType>(sendMPI, recvMPI, recvbuf, mpiReq));
242#else
243 //Older MPI: Iallreduce not available. Instead do blocking all-reduce and return empty request.
244 allreduceRaw((const void*) sendMPI.data(), (void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
245 Kokkos::deep_copy(recvbuf, recvMPI);
246 req = emptyCommRequest();
247#endif
248 }
249 if(datatypeNeedsFree)
250 MPI_Type_free(&mpiDatatype);
251 return req;
252}
253
254#else
255
256//No MPI: reduction is always the same as input.
257template<class InputViewType, class OutputViewType>
258std::shared_ptr<CommRequest>
259iallreduceImpl (const InputViewType& sendbuf,
260 const OutputViewType& recvbuf,
261 const ::Teuchos::EReductionType,
262 const ::Teuchos::Comm<int>&)
263{
264 Kokkos::deep_copy(recvbuf, sendbuf);
265 return emptyCommRequest();
266}
267
268#endif // HAVE_TPETRACORE_MPI
269
270} // namespace Impl
271
272//
273// SKIP DOWN TO HERE
274//
275
301template<class InputViewType, class OutputViewType>
302std::shared_ptr<CommRequest>
303iallreduce (const InputViewType& sendbuf,
304 const OutputViewType& recvbuf,
305 const ::Teuchos::EReductionType op,
306 const ::Teuchos::Comm<int>& comm)
307{
308 static_assert (Kokkos::is_view<InputViewType>::value,
309 "InputViewType must be a Kokkos::View specialization.");
310 static_assert (Kokkos::is_view<OutputViewType>::value,
311 "OutputViewType must be a Kokkos::View specialization.");
312 constexpr int rank = static_cast<int> (OutputViewType::rank);
313 static_assert (static_cast<int> (InputViewType::rank) == rank,
314 "InputViewType and OutputViewType must have the same rank.");
315 static_assert (rank == 0 || rank == 1,
316 "InputViewType and OutputViewType must both have "
317 "rank 0 or rank 1.");
318 typedef typename OutputViewType::non_const_value_type packet_type;
319 static_assert (std::is_same<typename OutputViewType::value_type,
320 packet_type>::value,
321 "OutputViewType must be a nonconst Kokkos::View.");
322 static_assert (std::is_same<typename InputViewType::non_const_value_type,
323 packet_type>::value,
324 "InputViewType and OutputViewType must be Views "
325 "whose entries have the same type.");
326 //Make sure layouts are contiguous (don't accept strided 1D view)
327 static_assert (!std::is_same<typename InputViewType::array_layout, Kokkos::LayoutStride>::value,
328 "Input/Output views must be contiguous (not LayoutStride)");
329 static_assert (!std::is_same<typename OutputViewType::array_layout, Kokkos::LayoutStride>::value,
330 "Input/Output views must be contiguous (not LayoutStride)");
331
332 return Impl::iallreduceImpl<InputViewType, OutputViewType> (sendbuf, recvbuf, op, comm);
333}
334
335std::shared_ptr<CommRequest>
336iallreduce (const int localValue,
337 int& globalValue,
338 const ::Teuchos::EReductionType op,
339 const ::Teuchos::Comm<int>& comm);
340
341} // namespace Details
342} // namespace Tpetra
343
344#endif // TPETRA_DETAILS_IALLREDUCE_HPP
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
Add specializations of Teuchos::Details::MpiTypeTraits for Kokkos::complex<float> and Kokkos::complex...
Declaration of Tpetra::Details::extractMpiCommFromTeuchos.
Base class for the request (more or less a future) representing a pending nonblocking MPI operation.
virtual ~CommRequest()
Destructor (virtual for memory safety of derived classes).
virtual void cancel()
Cancel the pending communication request.
virtual void wait()
Wait on this communication request to complete.
Implementation details of Tpetra.
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator.
Namespace Tpetra contains the class and methods constituting the Tpetra library.