47 #ifndef MUELU_AGGREGATES_KOKKOS_DEF_HPP 48 #define MUELU_AGGREGATES_KOKKOS_DEF_HPP 50 #include <Xpetra_Map.hpp> 51 #include <Xpetra_Vector.hpp> 52 #include <Xpetra_MultiVectorFactory.hpp> 53 #include <Xpetra_VectorFactory.hpp> 55 #include "MueLu_LWGraph_kokkos.hpp" 61 template <
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
62 Aggregates_kokkos<LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::
63 Aggregates_kokkos(LWGraph_kokkos graph) {
66 vertex2AggId_ = LOVectorFactory::Build(graph.GetImportMap());
69 procWinner_ = LOVectorFactory::Build(graph.GetImportMap());
73 Kokkos::deep_copy(isRoot_,
false);
76 aggregatesIncludeGhosts_ =
true;
79 template <
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
80 Aggregates_kokkos<LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::
81 Aggregates_kokkos(
const RCP<const Map>& map) {
84 vertex2AggId_ = LOVectorFactory::Build(map);
87 procWinner_ = LOVectorFactory::Build(map);
91 Kokkos::deep_copy(isRoot_,
false);
94 aggregatesIncludeGhosts_ =
true;
97 template <
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
98 typename Aggregates_kokkos<LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::aggregates_sizes_type::const_type
99 Aggregates_kokkos<LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::ComputeAggregateSizes(
bool forceRecompute)
const {
100 if (aggregateSizes_.size() && !forceRecompute) {
101 return aggregateSizes_;
105 aggregates_sizes_type aggregateSizes(
"aggregates", numAggregates_);
107 int myPID = GetMap()->getComm()->getRank();
109 auto vertex2AggId = vertex2AggId_->getDeviceLocalView(Xpetra::Access::ReadOnly);
110 auto procWinner = procWinner_ ->getDeviceLocalView(Xpetra::Access::ReadOnly);
112 typename AppendTrait<decltype(aggregateSizes_), Kokkos::Atomic>::type aggregateSizesAtomic = aggregateSizes;
113 Kokkos::parallel_for(
"MueLu:Aggregates:ComputeAggregateSizes:for", range_type(0,procWinner.size()),
114 KOKKOS_LAMBDA(
const LO i) {
115 if (procWinner(i, 0) == myPID)
116 aggregateSizesAtomic(vertex2AggId(i, 0))++;
119 aggregateSizes_ = aggregateSizes;
121 return aggregateSizes;
126 template <
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
127 typename Aggregates_kokkos<LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::local_graph_type
128 Aggregates_kokkos<LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::GetGraph()
const {
129 using row_map_type =
typename local_graph_type::row_map_type;
130 using entries_type =
typename local_graph_type::entries_type;
131 using size_type =
typename local_graph_type::size_type;
133 auto numAggregates = numAggregates_;
135 if (static_cast<LO>(graph_.numRows()) == numAggregates)
138 auto vertex2AggId = vertex2AggId_->getDeviceLocalView(Xpetra::Access::ReadOnly);
139 auto procWinner = procWinner_ ->getDeviceLocalView(Xpetra::Access::ReadOnly);
140 auto sizes = ComputeAggregateSizes();
143 typename row_map_type::non_const_type rows(
"Agg_rows", numAggregates+1);
146 Kokkos::parallel_scan(
"MueLu:Aggregates:GetGraph:compute_rows", range_type(0, numAggregates),
147 KOKKOS_LAMBDA(
const LO i, LO& update,
const bool& final_pass) {
153 decltype(rows) offsets(Kokkos::ViewAllocateWithoutInitializing(
"Agg_offsets"), numAggregates+1);
154 Kokkos::deep_copy(offsets, rows);
156 int myPID = GetMap()->getComm()->getRank();
162 Kokkos::deep_copy(numNNZ_host, numNNZ_device);
163 numNNZ = numNNZ_host();
165 typename entries_type::non_const_type cols(Kokkos::ViewAllocateWithoutInitializing(
"Agg_cols"), numNNZ);
167 Kokkos::parallel_reduce(
"MueLu:Aggregates:GetGraph:compute_cols", range_type(0, procWinner.size()),
168 KOKKOS_LAMBDA(
const LO i,
size_t& nnz) {
169 if (procWinner(i, 0) == myPID) {
170 typedef typename std::remove_reference< decltype( offsets(0) ) >::type atomic_incr_type;
171 auto idx = Kokkos::atomic_fetch_add( &offsets(vertex2AggId(i,0)), atomic_incr_type(1));
176 TEUCHOS_TEST_FOR_EXCEPTION(realnnz != numNNZ, Exceptions::RuntimeError,
177 "MueLu: Internal error: Something is wrong with aggregates graph construction: numNNZ = " << numNNZ <<
" != " << realnnz <<
" = realnnz");
179 graph_ = local_graph_type(cols, rows);
184 template <
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
185 std::string Aggregates_kokkos<LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::description()
const {
189 template <
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
190 void Aggregates_kokkos<LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::print(Teuchos::FancyOStream& out,
const Teuchos::EVerbosityLevel verbLevel)
const {
194 out0 <<
"Global number of aggregates: " << GetNumGlobalAggregates() << std::endl;
197 template <
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
198 GlobalOrdinal Aggregates_kokkos<LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::GetNumGlobalAggregates()
const {
199 LO nAggregates = GetNumAggregates();
200 GO nGlobalAggregates;
201 MueLu_sumAll(vertex2AggId_->getMap()->getComm(), (GO)nAggregates, nGlobalAggregates);
202 return nGlobalAggregates;
205 template <
class LocalOrdinal,
class GlobalOrdinal,
class DeviceType>
206 const RCP<const Xpetra::Map<LocalOrdinal,GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType>> >
207 Aggregates_kokkos<LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType>>::GetMap()
const {
208 return vertex2AggId_->getMap();
213 #endif // MUELU_AGGREGATES_KOKKOS_DEF_HPP
#define MueLu_sumAll(rcpComm, in, out)
std::string toString(const T &what)
Little helper function to convert non-string types to strings.
Namespace for MueLu classes and methods.
MueLu::DefaultGlobalOrdinal GlobalOrdinal
#define MUELU_UNAGGREGATED
#define MUELU_DESCRIBE
Helper macro for implementing Describable::describe() for BaseClass objects.
virtual std::string description() const
Return a simple one-line description of this object.