Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_TsqrAdaptor.hpp
Go to the documentation of this file.
1// @HEADER
2// ***********************************************************************
3//
4// Tpetra: Templated Linear Algebra Services Package
5// Copyright (2008) Sandia Corporation
6//
7// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8// the U.S. Government retains certain rights in this software.
9//
10// Redistribution and use in source and binary forms, with or without
11// modification, are permitted provided that the following conditions are
12// met:
13//
14// 1. Redistributions of source code must retain the above copyright
15// notice, this list of conditions and the following disclaimer.
16//
17// 2. Redistributions in binary form must reproduce the above copyright
18// notice, this list of conditions and the following disclaimer in the
19// documentation and/or other materials provided with the distribution.
20//
21// 3. Neither the name of the Corporation nor the names of the
22// contributors may be used to endorse or promote products derived from
23// this software without specific prior written permission.
24//
25// THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
26// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
28// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
29// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36//
37// ************************************************************************
38// @HEADER
39
40#ifndef TPETRA_TSQRADAPTOR_HPP
41#define TPETRA_TSQRADAPTOR_HPP
42
46
47#include "Tpetra_ConfigDefs.hpp"
48
49#ifdef HAVE_TPETRA_TSQR
50# include "Tsqr_NodeTsqrFactory.hpp" // create intranode TSQR object
51# include "Tsqr.hpp" // full (internode + intranode) TSQR
52# include "Tsqr_DistTsqr.hpp" // internode TSQR
53// Subclass of TSQR::MessengerBase, implemented using Teuchos
54// communicator template helper functions
55# include "Tsqr_TeuchosMessenger.hpp"
56# include "Tpetra_MultiVector.hpp"
57# include "Teuchos_ParameterListAcceptorDefaultBase.hpp"
58# include <stdexcept>
59
60namespace Tpetra {
61
83 template<class MV>
84 class TsqrAdaptor : public Teuchos::ParameterListAcceptorDefaultBase {
85 public:
86 using scalar_type = typename MV::scalar_type;
87 using ordinal_type = typename MV::local_ordinal_type;
88 using dense_matrix_type =
89 Teuchos::SerialDenseMatrix<ordinal_type, scalar_type>;
90 using magnitude_type =
91 typename Teuchos::ScalarTraits<scalar_type>::magnitudeType;
92
93 private:
94 using node_tsqr_factory_type =
95 TSQR::NodeTsqrFactory<scalar_type, ordinal_type,
96 typename MV::device_type>;
97 using node_tsqr_type = TSQR::NodeTsqr<ordinal_type, scalar_type>;
98 using dist_tsqr_type = TSQR::DistTsqr<ordinal_type, scalar_type>;
99 using tsqr_type = TSQR::Tsqr<ordinal_type, scalar_type>;
100
101 TSQR::MatView<ordinal_type, scalar_type>
102 get_mat_view(MV& X)
103 {
104 TEUCHOS_ASSERT( ! tsqr_.is_null() );
105 // FIXME (mfh 18 Oct 2010, 22 Dec 2019) Check Teuchos::Comm<int>
106 // object in Q to make sure it is the same communicator as the
107 // one we are using in our dist_tsqr_type implementation.
108
109 const ordinal_type lclNumRows(X.getLocalLength());
110 const ordinal_type numCols(X.getNumVectors());
111 scalar_type* X_ptr = nullptr;
112 // LAPACK and BLAS functions require "LDA" >= 1, even if the
113 // corresponding matrix dimension is zero.
114 ordinal_type X_stride = 1;
115 if(tsqr_->wants_device_memory()) {
116 auto X_view = X.getLocalViewDevice(Access::ReadWrite);
117 X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
118 X_stride = static_cast<ordinal_type>(X_view.stride(1));
119 if(X_stride == 0) {
120 X_stride = ordinal_type(1); // see note above
121 }
122 }
123 else {
124 auto X_view = X.getLocalViewHost(Access::ReadWrite);
125 X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
126 X_stride = static_cast<ordinal_type>(X_view.stride(1));
127 if(X_stride == 0) {
128 X_stride = ordinal_type(1); // see note above
129 }
130 }
131 using mat_view_type = TSQR::MatView<ordinal_type, scalar_type>;
132 return mat_view_type(lclNumRows, numCols, X_ptr, X_stride);
133 }
134
135 public:
142 TsqrAdaptor(const Teuchos::RCP<Teuchos::ParameterList>& plist) :
143 nodeTsqr_(node_tsqr_factory_type::getNodeTsqr()),
144 distTsqr_(new dist_tsqr_type),
145 tsqr_(new tsqr_type(nodeTsqr_, distTsqr_))
146 {
147 setParameterList(plist);
148 }
149
151 TsqrAdaptor() :
152 nodeTsqr_(node_tsqr_factory_type::getNodeTsqr()),
153 distTsqr_(new dist_tsqr_type),
154 tsqr_(new tsqr_type(nodeTsqr_, distTsqr_))
155 {
156 setParameterList(Teuchos::null);
157 }
158
160 Teuchos::RCP<const Teuchos::ParameterList>
161 getValidParameters() const
162 {
163 if(defaultParams_.is_null()) {
164 auto params = Teuchos::parameterList("TSQR implementation");
165 params->set("NodeTsqr", *(nodeTsqr_->getValidParameters()));
166 params->set("DistTsqr", *(distTsqr_->getValidParameters()));
167 defaultParams_ = params;
168 }
169 return defaultParams_;
170 }
171
197 void
198 setParameterList(const Teuchos::RCP<Teuchos::ParameterList>& plist)
199 {
200 auto params = plist.is_null() ?
201 Teuchos::parameterList(*getValidParameters()) : plist;
202 using Teuchos::sublist;
203 nodeTsqr_->setParameterList(sublist(params, "NodeTsqr"));
204 distTsqr_->setParameterList(sublist(params, "DistTsqr"));
205
206 this->setMyParamList(params);
207 }
208
230 void
231 factorExplicit(MV& A,
232 MV& Q,
233 dense_matrix_type& R,
234 const bool forceNonnegativeDiagonal=false)
235 {
236 TEUCHOS_TEST_FOR_EXCEPTION
237 (! A.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
238 "factorExplicit: Input MultiVector A must have constant stride.");
239 TEUCHOS_TEST_FOR_EXCEPTION
240 (! Q.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
241 "factorExplicit: Input MultiVector Q must have constant stride.");
242 prepareTsqr(Q); // Finish initializing TSQR.
243 TEUCHOS_ASSERT( ! tsqr_.is_null() );
244
245 auto A_view = get_mat_view(A);
246 auto Q_view = get_mat_view(Q);
247 constexpr bool contiguousCacheBlocks = false;
248 tsqr_->factorExplicitRaw(A_view.extent(0),
249 A_view.extent(1),
250 A_view.data(), A_view.stride(1),
251 Q_view.data(), Q_view.stride(1),
252 R.values(), R.stride(),
253 contiguousCacheBlocks,
254 forceNonnegativeDiagonal);
255 }
256
287 int
288 revealRank(MV& Q,
289 dense_matrix_type& R,
290 const magnitude_type& tol)
291 {
292 TEUCHOS_TEST_FOR_EXCEPTION
293 (! Q.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
294 "revealRank: Input MultiVector Q must have constant stride.");
295 prepareTsqr(Q); // Finish initializing TSQR.
296
297 auto Q_view = get_mat_view(Q);
298 constexpr bool contiguousCacheBlocks = false;
299 return tsqr_->revealRankRaw(Q_view.extent(0),
300 Q_view.extent(1),
301 Q_view.data(), Q_view.stride(1),
302 R.values(), R.stride(),
303 tol, contiguousCacheBlocks);
304 }
305
306 private:
308 Teuchos::RCP<node_tsqr_type> nodeTsqr_;
309
311 Teuchos::RCP<dist_tsqr_type> distTsqr_;
312
314 Teuchos::RCP<tsqr_type> tsqr_;
315
317 mutable Teuchos::RCP<const Teuchos::ParameterList> defaultParams_;
318
320 bool ready_ = false;
321
342 void
343 prepareTsqr(const MV& mv)
344 {
345 if(! ready_) {
346 prepareDistTsqr(mv);
347 ready_ = true;
348 }
349 }
350
357 void
358 prepareDistTsqr(const MV& mv)
359 {
360 using Teuchos::RCP;
361 using Teuchos::rcp_implicit_cast;
362 using mess_type = TSQR::TeuchosMessenger<scalar_type>;
363 using base_mess_type = TSQR::MessengerBase<scalar_type>;
364
365 auto comm = mv.getMap()->getComm();
366 RCP<mess_type> mess(new mess_type(comm));
367 auto messBase = rcp_implicit_cast<base_mess_type>(mess);
368 distTsqr_->init(messBase);
369 }
370 };
371
372} // namespace Tpetra
373
374#endif // HAVE_TPETRA_TSQR
375
376#endif // TPETRA_TSQRADAPTOR_HPP
Namespace Tpetra contains the class and methods constituting the Tpetra library.