59 using Tpetra::global_size_t;
60 using Teuchos::ArrayView;
65 typedef Tpetra::Map<LocalOrdinal,GlobalOrdinal,Node> Map;
68 const global_size_t num_global_entries = map.getGlobalNumElements();
69 const size_t num_local_entries = map.getLocalNumElements();
71 ArrayView<const GlobalOrdinal> element_list =
72 map.getLocalElementList();
75 const global_size_t flat_num_global_entries = num_global_entries*block_size;
76 const size_t flat_num_local_entries = num_local_entries * block_size;
78 Array<GlobalOrdinal> flat_element_list(flat_num_local_entries);
79 for (
size_t i=0; i<num_local_entries; ++i)
81 flat_element_list[i*block_size+
j] = element_list[i]*block_size+
j;
85 rcp(
new Map(flat_num_global_entries, flat_element_list(),
86 flat_index_base, map.getComm()));
98 const Tpetra::CrsGraph<LocalOrdinal,GlobalOrdinal,Node>& graph,
99 Teuchos::RCP<
const Tpetra::Map<LocalOrdinal,GlobalOrdinal,Node> >& flat_domain_map,
100 Teuchos::RCP<
const Tpetra::Map<LocalOrdinal,GlobalOrdinal,Node> >& flat_range_map,
102 using Teuchos::ArrayRCP;
106 typedef Tpetra::Map<LocalOrdinal,GlobalOrdinal,Node> Map;
107 typedef Tpetra::CrsGraph<LocalOrdinal,GlobalOrdinal,Node> Graph;
110 if (flat_domain_map == Teuchos::null)
111 flat_domain_map =
create_flat_map(*(graph.getDomainMap()), block_size);
114 if (flat_range_map == Teuchos::null)
118 RCP<const Map> flat_col_map =
124 RCP<const Map> flat_row_map;
125 if (graph.getRangeMap() == graph.getRowMap())
126 flat_row_map = flat_range_map;
131 auto row_offsets = graph.getLocalRowPtrsHost();
132 auto col_indices = graph.getLocalIndicesHost();
133 const size_t num_row = graph.getLocalNumRows();
134 const size_t num_col_indices = col_indices.size();
135 ArrayRCP<size_t> flat_row_offsets(num_row*block_size+1);
136 ArrayRCP<LocalOrdinal> flat_col_indices(num_col_indices * block_size);
137 for (
size_t row=0; row<num_row; ++row) {
138 const size_t row_beg = row_offsets[row];
139 const size_t row_end = row_offsets[row+1];
140 const size_t num_col = row_end - row_beg;
142 const size_t flat_row = row*block_size +
j;
143 const size_t flat_row_beg = row_beg*block_size +
j*num_col;
144 flat_row_offsets[flat_row] = flat_row_beg;
145 for (
size_t entry=0; entry<num_col; ++entry) {
148 flat_col_indices[flat_row_beg+entry] = flat_col;
152 flat_row_offsets[num_row*block_size] = num_col_indices*block_size;
155 RCP<Graph> flat_graph =
156 rcp(
new Graph(flat_row_map, flat_col_map,
157 flat_row_offsets, flat_col_indices));
158 flat_graph->fillComplete(flat_domain_map, flat_range_map);
173 Teuchos::RCP<
const Tpetra::Map<
LocalOrdinal,
GlobalOrdinal,Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_domain_map,
174 Teuchos::RCP<
const Tpetra::Map<
LocalOrdinal,
GlobalOrdinal,Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_range_map,
176 using Teuchos::ArrayRCP;
180 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
181 typedef Tpetra::Map<LocalOrdinal,GlobalOrdinal,Node> Map;
182 typedef Tpetra::CrsGraph<LocalOrdinal,GlobalOrdinal,Node> Graph;
183 typedef typename Graph::local_graph_device_type::row_map_type::non_const_type RowPtrs;
184 typedef typename Graph::local_graph_device_type::entries_type::non_const_type LocalIndices;
187 if (flat_domain_map == Teuchos::null)
188 flat_domain_map =
create_flat_map(*(graph.getDomainMap()), block_size);
191 if (flat_range_map == Teuchos::null)
195 RCP<const Map> flat_col_map =
201 RCP<const Map> flat_row_map;
202 if (graph.getRangeMap() == graph.getRowMap())
203 flat_row_map = flat_range_map;
208 auto row_offsets = graph.getLocalRowPtrsHost();
209 auto col_indices = graph.getLocalIndicesHost();
210 const size_t num_row = graph.getLocalNumRows();
211 const size_t num_col_indices = col_indices.size();
212 RowPtrs flat_row_offsets(
"row_ptrs", num_row*block_size+1);
213 LocalIndices flat_col_indices(
"col_indices", num_col_indices * block_size);
216 for (
size_t row=0; row<num_row; ++row) {
217 const size_t row_beg = row_offsets[row];
218 const size_t row_end = row_offsets[row+1];
219 const size_t num_col = row_end - row_beg;
221 const size_t flat_row = row*block_size +
j;
222 const size_t flat_row_beg = row_beg*block_size +
j*num_col;
223 flat_row_offsets_host[flat_row] = flat_row_beg;
224 for (
size_t entry=0; entry<num_col; ++entry) {
227 flat_col_indices_host[flat_row_beg+entry] = flat_col;
231 flat_row_offsets_host[num_row*block_size] = num_col_indices*block_size;
236 RCP<Graph> flat_graph =
237 rcp(
new Graph(flat_row_map, flat_col_map,
238 flat_row_offsets, flat_col_indices));
239 flat_graph->fillComplete(flat_domain_map, flat_range_map);
254 const Teuchos::RCP<
const Tpetra::Map<LocalOrdinal,GlobalOrdinal,Node> >& flat_map) {
255 using Teuchos::ArrayRCP;
260 typedef typename Storage::value_type BaseScalar;
261 typedef Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> Vector;
262 typedef Tpetra::MultiVector<BaseScalar,LocalOrdinal,GlobalOrdinal,Node> FlatVector;
269 Vector& vec =
const_cast<Vector&
>(vec_const);
272 ArrayRCP<Scalar> vec_vals = vec.get1dViewNonConst();
273 const size_t vec_size = vec_vals.size();
276 BaseScalar *flat_vec_ptr =
277 reinterpret_cast<BaseScalar*
>(vec_vals.getRawPtr());
278 ArrayRCP<BaseScalar> flat_vec_vals =
279 Teuchos::arcp(flat_vec_ptr, 0, vec_size*mp_size,
false);
282 const size_t stride = vec.getStride();
283 const size_t flat_stride = stride * mp_size;
284 const size_t num_vecs = vec.getNumVectors();
285 RCP<FlatVector> flat_vec =
286 rcp(
new FlatVector(flat_map, flat_vec_vals, flat_stride, num_vecs));
319 const Teuchos::RCP<
const Tpetra::Map<LocalOrdinal,GlobalOrdinal,Node> >& flat_map) {
320 using Teuchos::ArrayRCP;
325 typedef typename Storage::value_type BaseScalar;
326 typedef Tpetra::MultiVector<BaseScalar,LocalOrdinal,GlobalOrdinal,Node> FlatVector;
332 ArrayRCP<Scalar> vec_vals = vec.get1dViewNonConst();
333 const size_t vec_size = vec_vals.size();
336 BaseScalar *flat_vec_ptr =
337 reinterpret_cast<BaseScalar*
>(vec_vals.getRawPtr());
338 ArrayRCP<BaseScalar> flat_vec_vals =
339 Teuchos::arcp(flat_vec_ptr, 0, vec_size*mp_size,
false);
342 const size_t stride = vec.getStride();
343 const size_t flat_stride = stride * mp_size;
344 const size_t num_vecs = vec.getNumVectors();
345 RCP<FlatVector> flat_vec =
346 rcp(
new FlatVector(flat_map, flat_vec_vals, flat_stride, num_vecs));
449 Kokkos::Compat::KokkosDeviceWrapperNode<Device> >& vec,
451 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_map) {
455 typedef typename Storage::value_type BaseScalar;
456 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
457 typedef Tpetra::MultiVector<BaseScalar,LocalOrdinal,GlobalOrdinal,Node> FlatVector;
458 typedef typename FlatVector::dual_view_type::t_dev flat_view_type;
464 typedef Tpetra::MultiVector<Sacado::MP::Vector<Storage>,
LocalOrdinal,
GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<Device> > mv_type;
465 mv_type& vec_nc =
const_cast<mv_type&
>(vec);
468 flat_view_type flat_vals = vec_nc.getLocalViewDevice(Tpetra::Access::ReadWrite);
471 RCP<FlatVector> flat_vec = rcp(
new FlatVector(flat_map, flat_vals));
486 Kokkos::Compat::KokkosDeviceWrapperNode<Device> >& vec,
488 Kokkos::Compat::KokkosDeviceWrapperNode<Device> > >& flat_map) {
492 typedef typename Storage::value_type BaseScalar;
493 typedef Kokkos::Compat::KokkosDeviceWrapperNode<Device>
Node;
494 typedef Tpetra::MultiVector<BaseScalar,LocalOrdinal,GlobalOrdinal,Node> FlatVector;
495 typedef typename FlatVector::dual_view_type::t_dev flat_view_type;
498 flat_view_type flat_vals = vec.getLocalViewDevice(Tpetra::Access::ReadWrite);
501 RCP<FlatVector> flat_vec = rcp(
new FlatVector(flat_map, flat_vals));
516 const Teuchos::RCP<
const Tpetra::CrsGraph<LocalOrdinal,GlobalOrdinal,Node> >& flat_graph,
518 using Teuchos::ArrayView;
519 using Teuchos::Array;
524 typedef typename Storage::value_type BaseScalar;
525 typedef Tpetra::CrsMatrix<Scalar,LocalOrdinal,GlobalOrdinal,Node> Matrix;
526 typedef Tpetra::CrsMatrix<BaseScalar,LocalOrdinal,GlobalOrdinal,Node> FlatMatrix;
529 RCP<FlatMatrix> flat_mat = rcp(
new FlatMatrix(flat_graph));
532 const size_t num_rows = mat.getLocalNumRows();
533 const size_t max_cols = mat.getLocalMaxNumRowEntries();
534 typename Matrix::local_inds_host_view_type indices, flat_indices;
535 typename Matrix::values_host_view_type values;
536 Array<BaseScalar> flat_values(max_cols);
537 for (
size_t row=0; row<num_rows; ++row) {
538 mat.getLocalRowView(row, indices, values);
539 const size_t num_col = mat.getNumEntriesInLocalRow(row);
542 for (
size_t j=0;
j<num_col; ++
j)
543 flat_values[
j] = values[
j].coeff(i);
544 flat_graph->getLocalRowView(flat_row, flat_indices);
545 flat_mat->replaceLocalValues(flat_row, Kokkos::Compat::getConstArrayView(flat_indices),
546 flat_values(0, num_col));
549 flat_mat->fillComplete(flat_graph->getDomainMap(),
550 flat_graph->getRangeMap());
Teuchos::RCP< const Tpetra::MultiVector< typename Storage::value_type, LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode< Device > > > create_flat_vector_view(const Tpetra::MultiVector< Sacado::UQ::PCE< Storage >, LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode< Device > > &vec, const Teuchos::RCP< const Tpetra::Map< LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode< Device > > > &flat_map)
Teuchos::RCP< Tpetra::CrsMatrix< typename Storage::value_type, LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode< Device > > > create_flat_matrix(const Tpetra::CrsMatrix< Sacado::UQ::PCE< Storage >, LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode< Device > > &mat, const Teuchos::RCP< const Tpetra::CrsGraph< LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode< Device > > > &flat_graph, const Teuchos::RCP< const Tpetra::CrsGraph< LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode< Device > > > &cijk_graph, const CijkType &cijk_dev)
Teuchos::RCP< Tpetra::CrsGraph< LocalOrdinal, GlobalOrdinal, Node > > create_flat_mp_graph(const Tpetra::CrsGraph< LocalOrdinal, GlobalOrdinal, Node > &graph, Teuchos::RCP< const Tpetra::Map< LocalOrdinal, GlobalOrdinal, Node > > &flat_domain_map, Teuchos::RCP< const Tpetra::Map< LocalOrdinal, GlobalOrdinal, Node > > &flat_range_map, const LocalOrdinal block_size)