45 #ifndef KOKKOS_WORKGRAPHPOLICY_HPP 46 #define KOKKOS_WORKGRAPHPOLICY_HPP 48 #include <impl/Kokkos_AnalyzePolicy.hpp> 49 #include <Kokkos_Crs.hpp> 54 template <
class functor_type,
class execution_space,
class... policy_args>
62 template <
class... Properties>
63 class WorkGraphPolicy :
public Kokkos::Impl::PolicyTraits<Properties...> {
65 using execution_policy = WorkGraphPolicy<Properties...>;
66 using self_type = WorkGraphPolicy<Properties...>;
67 using traits = Kokkos::Impl::PolicyTraits<Properties...>;
68 using index_type =
typename traits::index_type;
69 using member_type = index_type;
70 using execution_space =
typename traits::execution_space;
71 using memory_space =
typename execution_space::memory_space;
88 graph_type
const m_graph;
91 KOKKOS_INLINE_FUNCTION
92 void push_work(
const std::int32_t w)
const noexcept {
93 const std::int32_t N = m_graph.numRows();
95 std::int32_t
volatile*
const ready_queue = &m_queue[0];
96 std::int32_t
volatile*
const end_hint = &m_queue[2 * N + 1];
99 const std::int32_t j = atomic_fetch_add(end_hint, 1);
101 if ((N <= j) || (END_TOKEN != atomic_exchange(ready_queue + j, w))) {
103 Kokkos::abort(
"WorkGraphPolicy push_work error");
124 KOKKOS_INLINE_FUNCTION
125 std::int32_t pop_work() const noexcept {
126 const std::int32_t N = m_graph.numRows();
128 std::int32_t
volatile*
const ready_queue = &m_queue[0];
129 std::int32_t
volatile*
const begin_hint = &m_queue[2 * N];
134 for (std::int32_t i = *begin_hint; i < N; ++i) {
135 const std::int32_t w = ready_queue[i];
137 if (w == END_TOKEN) {
141 if ((w != BEGIN_TOKEN) &&
142 (w == atomic_compare_exchange(ready_queue + i, w,
143 (std::int32_t)BEGIN_TOKEN))) {
146 atomic_increment(begin_hint);
152 return COMPLETED_TOKEN;
155 KOKKOS_INLINE_FUNCTION
156 void completed_work(std::int32_t w)
const noexcept {
157 Kokkos::memory_fence();
161 const std::int32_t N = m_graph.numRows();
163 std::int32_t
volatile*
const count_queue = &m_queue[N];
165 const std::int32_t B = m_graph.row_map(w);
166 const std::int32_t E = m_graph.row_map(w + 1);
168 for (std::int32_t i = B; i < E; ++i) {
169 const std::int32_t j = m_graph.entries(i);
170 if (1 == atomic_fetch_add(count_queue + j, -1)) {
186 KOKKOS_INLINE_FUNCTION
187 void operator()(
const TagInit,
int i)
const noexcept {
188 m_queue[i] = i < m_graph.numRows() ? END_TOKEN : 0;
191 KOKKOS_INLINE_FUNCTION
192 void operator()(
const TagCount,
int i)
const noexcept {
193 std::int32_t
volatile*
const count_queue = &m_queue[m_graph.numRows()];
195 atomic_increment(count_queue + m_graph.entries[i]);
198 KOKKOS_INLINE_FUNCTION
199 void operator()(
const TagReady,
int w)
const noexcept {
200 std::int32_t
const*
const count_queue = &m_queue[m_graph.numRows()];
202 if (0 == count_queue[w]) push_work(w);
205 execution_space space()
const {
return execution_space(); }
207 WorkGraphPolicy(
const graph_type& arg_graph)
208 : m_graph(arg_graph),
209 m_queue(view_alloc(
"queue", WithoutInitializing),
210 arg_graph.numRows() * 2 + 2) {
212 using policy_type = RangePolicy<std::int32_t, execution_space, TagInit>;
214 const closure_type closure(*
this, policy_type(0, m_queue.size()));
216 execution_space().fence();
220 using policy_type = RangePolicy<std::int32_t, execution_space, TagCount>;
222 const closure_type closure(*
this, policy_type(0, m_graph.entries.size()));
224 execution_space().fence();
228 using policy_type = RangePolicy<std::int32_t, execution_space, TagReady>;
230 const closure_type closure(*
this, policy_type(0, m_graph.numRows()));
232 execution_space().fence();
239 #ifdef KOKKOS_ENABLE_SERIAL 240 #include "impl/Kokkos_Serial_WorkGraphPolicy.hpp" 243 #ifdef KOKKOS_ENABLE_OPENMP 244 #include "OpenMP/Kokkos_OpenMP_WorkGraphPolicy.hpp" 247 #ifdef KOKKOS_ENABLE_CUDA 248 #include "Cuda/Kokkos_Cuda_WorkGraphPolicy.hpp" 251 #ifdef KOKKOS_ENABLE_HIP 252 #include "HIP/Kokkos_HIP_WorkGraphPolicy.hpp" 255 #ifdef KOKKOS_ENABLE_THREADS 256 #include "Threads/Kokkos_Threads_WorkGraphPolicy.hpp" 259 #ifdef KOKKOS_ENABLE_HPX 260 #include "HPX/Kokkos_HPX_WorkGraphPolicy.hpp"
Compressed row storage array.
Implementation of the ParallelFor operator that has a partial specialization for the device...