45 #ifndef KOKKOS_KOKKOS_GRAPHNODE_HPP 46 #define KOKKOS_KOKKOS_GRAPHNODE_HPP 48 #include <Kokkos_Macros.hpp> 50 #include <impl/Kokkos_Error.hpp> 52 #include <Kokkos_Core_fwd.hpp> 53 #include <Kokkos_Graph_fwd.hpp> 54 #include <impl/Kokkos_GraphImpl_fwd.hpp> 55 #include <Kokkos_Parallel_Reduce.hpp> 56 #include <impl/Kokkos_GraphImpl_Utilities.hpp> 57 #include <impl/Kokkos_GraphImpl.hpp> 64 template <
class ExecutionSpace,
class Kernel ,
74 #ifndef KOKKOS_COMPILER_IBM 76 std::is_same<Predecessor, TypeErasedTag>::value ||
77 Kokkos::Impl::is_specialization_of<Predecessor, GraphNodeRef>::value,
78 "Invalid predecessor template parameter given to GraphNodeRef");
82 Kokkos::is_execution_space<ExecutionSpace>::value,
83 "Invalid execution space template parameter given to GraphNodeRef");
85 static_assert(std::is_same<Predecessor, TypeErasedTag>::value ||
86 Kokkos::Impl::is_graph_kernel<Kernel>::value,
87 "Invalid kernel template parameter given to GraphNodeRef");
89 static_assert(!Kokkos::Impl::is_more_type_erased<Kernel, Predecessor>::value,
90 "The kernel of a graph node can't be more type-erased than the " 100 using execution_space = ExecutionSpace;
101 using graph_kernel = Kernel;
102 using graph_predecessor = Predecessor;
111 template <
class,
class,
class>
112 friend class GraphNodeRef;
113 friend struct Kokkos::Impl::GraphAccess;
121 using graph_impl_t = Kokkos::Impl::GraphImpl<ExecutionSpace>;
122 std::weak_ptr<graph_impl_t> m_graph_impl;
131 Kokkos::Impl::GraphNodeImpl<ExecutionSpace, Kernel, Predecessor>;
132 std::shared_ptr<node_impl_t> m_node_impl;
141 node_impl_t& get_node_impl()
const {
return *m_node_impl.get(); }
142 std::shared_ptr<node_impl_t>
const& get_node_ptr() const& {
145 std::shared_ptr<node_impl_t> get_node_ptr() && {
146 return std::move(m_node_impl);
148 std::weak_ptr<graph_impl_t> get_graph_weak_ptr()
const {
157 template <
class NextKernelDeduced>
158 auto _then_kernel(NextKernelDeduced&& arg_kernel)
const {
162 static_assert(Kokkos::Impl::is_specialization_of<
163 Kokkos::Impl::remove_cvref_t<NextKernelDeduced>,
164 Kokkos::Impl::GraphNodeKernelImpl>::value,
165 "Kokkos internal error");
167 auto graph_ptr = m_graph_impl.lock();
168 KOKKOS_EXPECTS(
bool(graph_ptr))
170 using next_kernel_t = Kokkos::Impl::remove_cvref_t<NextKernelDeduced>;
172 using return_t = GraphNodeRef<ExecutionSpace, next_kernel_t, GraphNodeRef>;
174 auto rv = Kokkos::Impl::GraphAccess::make_graph_node_ref(
176 Kokkos::Impl::GraphAccess::make_node_shared_ptr<
177 typename return_t::node_impl_t>(
178 m_node_impl->execution_space_instance(),
179 Kokkos::Impl::_graph_node_kernel_ctor_tag{},
180 (NextKernelDeduced &&) arg_kernel,
182 Kokkos::Impl::_graph_node_predecessor_ctor_tag{}, *
this));
186 graph_ptr->add_node(rv.m_node_impl);
189 graph_ptr->add_predecessor(rv.m_node_impl, *
this);
190 KOKKOS_ENSURES(
bool(rv.m_node_impl))
197 GraphNodeRef(std::weak_ptr<graph_impl_t> arg_graph_impl,
198 std::shared_ptr<node_impl_t> arg_node_impl)
199 : m_graph_impl(std::move(arg_graph_impl)),
200 m_node_impl(std::move(arg_node_impl)) {}
213 GraphNodeRef() noexcept = default;
214 GraphNodeRef(GraphNodeRef const&) = default;
215 GraphNodeRef(GraphNodeRef&&) noexcept = default;
216 GraphNodeRef& operator=(GraphNodeRef const&) = default;
217 GraphNodeRef& operator=(GraphNodeRef&&) noexcept = default;
218 ~GraphNodeRef() = default;
227 class OtherKernel, class OtherPredecessor,
228 typename std::enable_if_t<
230 !std::is_same<GraphNodeRef, GraphNodeRef<execution_space, OtherKernel,
231 OtherPredecessor>>::value &&
233 Kokkos::Impl::is_compatible_type_erasure<OtherKernel,
234 graph_kernel>::value &&
236 Kokkos::Impl::is_compatible_type_erasure<
237 OtherPredecessor, graph_predecessor>::value,
241 GraphNodeRef<execution_space, OtherKernel, OtherPredecessor> const& other)
242 : m_graph_impl(other.m_graph_impl), m_node_impl(other.m_node_impl) {}
258 class Policy,
class Functor,
259 typename std::enable_if<
262 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
265 auto then_parallel_for(std::string arg_name, Policy&& arg_policy,
266 Functor&& functor)
const {
268 KOKKOS_EXPECTS(!m_graph_impl.expired())
269 KOKKOS_EXPECTS(
bool(m_node_impl))
277 using policy_t = Kokkos::Impl::remove_cvref_t<Policy>;
280 std::is_same<
typename policy_t::execution_space,
281 execution_space>::value,
284 "Execution Space mismatch between execution policy and graph");
286 auto policy = Experimental::require((Policy &&) arg_policy,
287 Kokkos::Impl::KernelInGraphProperty{});
289 using next_policy_t = decltype(policy);
290 using next_kernel_t =
291 Kokkos::Impl::GraphNodeKernelImpl<ExecutionSpace, next_policy_t,
292 std::decay_t<Functor>,
293 Kokkos::ParallelForTag>;
294 return this->_then_kernel(next_kernel_t{std::move(arg_name), policy.space(),
295 (Functor &&) functor,
296 (Policy &&) policy});
300 class Policy,
class Functor,
301 typename std::enable_if<
304 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
307 auto then_parallel_for(Policy&& policy, Functor&& functor)
const {
309 return this->then_parallel_for(
"", (Policy &&) policy,
310 (Functor &&) functor);
313 template <
class Functor>
314 auto then_parallel_for(std::string name, std::size_t n,
315 Functor&& functor)
const {
317 return this->then_parallel_for(std::move(name),
319 (Functor &&) functor);
322 template <
class Functor>
323 auto then_parallel_for(std::size_t n, Functor&& functor)
const {
325 return this->then_parallel_for(
"", n, (Functor &&) functor);
335 class Policy,
class Functor,
class ReturnType,
336 typename std::enable_if<
339 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
342 auto then_parallel_reduce(std::string arg_name, Policy&& arg_policy,
345 auto graph_impl_ptr = m_graph_impl.lock();
346 KOKKOS_EXPECTS(
bool(graph_impl_ptr))
347 KOKKOS_EXPECTS(
bool(m_node_impl))
356 using policy_t =
typename std::remove_cv<
357 typename std::remove_reference<Policy>::type>::type;
359 std::is_same<
typename policy_t::execution_space,
360 execution_space>::value,
363 "Execution Space mismatch between execution policy and graph");
371 if (Kokkos::Impl::parallel_reduce_needs_fence(
372 graph_impl_ptr->get_execution_space(), return_value)) {
373 Kokkos::Impl::throw_runtime_exception(
374 "Parallel reductions in graphs can't operate on Reducers that " 375 "reference a scalar because they can't complete synchronously. Use a " 376 "Kokkos::View instead and keep in mind the result will only be " 377 "available once the graph is submitted (or in tasks that depend on " 383 using return_type_remove_cvref =
typename std::remove_cv<
384 typename std::remove_reference<ReturnType>::type>::type;
385 static_assert(Kokkos::is_view<return_type_remove_cvref>::value ||
386 Kokkos::is_reducer<return_type_remove_cvref>::value,
387 "Output argument to parallel reduce in a graph must be a " 388 "View or a Reducer");
391 std::conditional_t<Kokkos::is_reducer<return_type_remove_cvref>::value,
392 return_type_remove_cvref,
393 const return_type_remove_cvref>;
394 using functor_type = Kokkos::Impl::remove_cvref_t<Functor>;
397 using return_value_adapter =
398 Kokkos::Impl::ParallelReduceReturnValue<void, return_type,
400 using functor_adaptor = Kokkos::Impl::ParallelReduceFunctorType<
401 functor_type, Policy,
typename return_value_adapter::value_type,
406 auto policy = Experimental::require((Policy &&) arg_policy,
407 Kokkos::Impl::KernelInGraphProperty{});
409 using next_policy_t = decltype(policy);
410 using next_kernel_t = Kokkos::Impl::GraphNodeKernelImpl<
411 ExecutionSpace, next_policy_t,
typename functor_adaptor::functor_type,
412 Kokkos::ParallelReduceTag,
typename return_value_adapter::reducer_type>;
414 return this->_then_kernel(next_kernel_t{
415 std::move(arg_name), graph_impl_ptr->get_execution_space(),
416 (Functor &&) functor, (Policy &&) policy,
417 return_value_adapter::return_value(return_value, functor)});
421 class Policy,
class Functor,
class ReturnType,
422 typename std::enable_if<
425 is_execution_policy<Kokkos::Impl::remove_cvref_t<Policy>>::value,
428 auto then_parallel_reduce(Policy&& arg_policy, Functor&& functor,
430 return this->then_parallel_reduce(
"", (Policy &&) arg_policy,
431 (Functor &&) functor,
435 template <
class Functor,
class ReturnType>
436 auto then_parallel_reduce(std::string label,
437 typename execution_space::size_type idx_end,
440 return this->then_parallel_reduce(
442 (Functor &&) functor, (
ReturnType &&) return_value);
445 template <
class Functor,
class ReturnType>
446 auto then_parallel_reduce(
typename execution_space::size_type idx_end,
449 return this->then_parallel_reduce(
"", idx_end, (Functor &&) functor,
462 #endif // KOKKOS_KOKKOS_GRAPHNODE_HPP
Execution policy for work over a range of an integral type.