diff --git a/core/include/detray/propagator/actor_chain.hpp b/core/include/detray/propagator/actor_chain.hpp index 2c50ecd073..27272c7386 100644 --- a/core/include/detray/propagator/actor_chain.hpp +++ b/core/include/detray/propagator/actor_chain.hpp @@ -11,6 +11,7 @@ #include "detray/definitions/containers.hpp" #include "detray/definitions/detail/qualifiers.hpp" #include "detray/propagator/base_actor.hpp" +#include "detray/propagator/composite_actor.hpp" #include "detray/utils/tuple.hpp" #include "detray/utils/tuple_helpers.hpp" @@ -28,7 +29,7 @@ namespace detray { /// The states of the actors need to be passed to the chain in an external tuple /// /// @tparam actors_t the types of the actors in the chain. -template +template class actor_chain { public: @@ -62,7 +63,8 @@ class actor_chain { DETRAY_HOST_DEVICE static constexpr auto make_default_actor_states() { // Only possible if each state is default initializable - if constexpr (std::default_initializable) { + if constexpr ((std::default_initializable && + ...)) { return state_tuple{}; } else { return std::nullopt; @@ -82,7 +84,7 @@ class actor_chain { /// @param actr the actor (might be a composite actor) /// @param states states of all actors (only bare actors) /// @param p_state the state of the propagator (stepper and navigator) - template DETRAY_HOST_DEVICE inline void run(const actor_t &actr, actor_states_t &states, @@ -133,7 +135,7 @@ class actor_chain<> { using state_ref_tuple = dtuple<>; /// Empty states replaces a real actor states container - struct state {}; + using state = state_tuple; /// Call to actors does nothing. /// @@ -145,11 +147,25 @@ class actor_chain<> { /*Do nothing*/ } + /// @returns the actor list + DETRAY_HOST_DEVICE constexpr const actor_tuple &actors() const { + return m_actors; + } + + /// @returns an empty state + DETRAY_HOST_DEVICE + static consteval state_tuple make_default_actor_states() { + return dtuple<>{}; + } + /// @returns an empty state DETRAY_HOST_DEVICE static constexpr state_ref_tuple setup_actor_states( const state_tuple &) { return {}; } + + private: + [[no_unique_address]] actor_tuple m_actors = {}; }; } // namespace detray diff --git a/core/include/detray/propagator/actors.hpp b/core/include/detray/propagator/actors.hpp index 0a3394844e..408336c40b 100644 --- a/core/include/detray/propagator/actors.hpp +++ b/core/include/detray/propagator/actors.hpp @@ -10,6 +10,6 @@ // Include all core actors #include "detray/propagator/actor_chain.hpp" #include "detray/propagator/actors/aborters.hpp" -#include "detray/propagator/actors/parameter_resetter.hpp" -#include "detray/propagator/actors/parameter_transporter.hpp" +#include "detray/propagator/actors/parameter_updater.hpp" #include "detray/propagator/actors/pointwise_material_interactor.hpp" +#include "detray/propagator/concepts.hpp" diff --git a/core/include/detray/propagator/actors/parameter_resetter.hpp b/core/include/detray/propagator/actors/parameter_resetter.hpp deleted file mode 100644 index 270f8e569e..0000000000 --- a/core/include/detray/propagator/actors/parameter_resetter.hpp +++ /dev/null @@ -1,45 +0,0 @@ -/** Detray library, part of the ACTS project (R&D line) - * - * (c) 2022-2024 CERN for the benefit of the ACTS project - * - * Mozilla Public License Version 2.0 - */ - -#pragma once - -// Project include(s). -#include "detray/definitions/algebra.hpp" -#include "detray/definitions/detail/qualifiers.hpp" -#include "detray/definitions/track_parametrization.hpp" -#include "detray/geometry/tracking_surface.hpp" -#include "detray/propagator/base_actor.hpp" -#include "detray/propagator/detail/jacobian_engine.hpp" - -namespace detray { - -template -struct parameter_resetter : actor { - - template - DETRAY_HOST_DEVICE void operator()(propagator_state_t& propagation) const { - - const auto& navigation = propagation._navigation; - auto& stepping = propagation._stepping; - - // Do covariance transport when the track is on surface - if (!(navigation.is_on_sensitive() || - navigation.encountered_sf_material())) { - return; - } - - // Update free params after bound params were changed by actors - const auto sf = navigation.get_surface(); - stepping() = sf.bound_to_free_vector(propagation._context, - stepping.bound_params()); - - // Reset jacobian transport to identity matrix - stepping.reset_transport_jacobian(); - } -}; - -} // namespace detray diff --git a/core/include/detray/propagator/actors/parameter_transporter.hpp b/core/include/detray/propagator/actors/parameter_transporter.hpp deleted file mode 100644 index cd32c67baf..0000000000 --- a/core/include/detray/propagator/actors/parameter_transporter.hpp +++ /dev/null @@ -1,131 +0,0 @@ -/** Detray library, part of the ACTS project (R&D line) - * - * (c) 2022-2024 CERN for the benefit of the ACTS project - * - * Mozilla Public License Version 2.0 - */ - -#pragma once - -// Project include(s). -#include "detray/definitions/algebra.hpp" -#include "detray/definitions/detail/qualifiers.hpp" -#include "detray/definitions/track_parametrization.hpp" -#include "detray/geometry/tracking_surface.hpp" -#include "detray/propagator/base_actor.hpp" -#include "detray/propagator/detail/jacobian_engine.hpp" - -namespace detray { - -template -struct parameter_transporter : actor { - - /// @name Type definitions for the struct - /// @{ - using scalar_type = dscalar; - // Transformation matching this struct - using transform3_type = dtransform3D; - // bound matrix type - using bound_matrix_t = bound_matrix; - // Matrix type for bound to free jacobian - using bound_to_free_matrix_t = bound_to_free_matrix; - /// @} - - struct get_full_jacobian_kernel { - - template - DETRAY_HOST_DEVICE inline bound_matrix_t operator()( - const mask_group_t& /*mask_group*/, const index_t& /*index*/, - const transform3_type& trf3, - const bound_to_free_matrix_t& bound_to_free_jacobian, - const material* vol_mat_ptr, - const stepper_state_t& stepping) const { - - using frame_t = typename mask_group_t::value_type::shape:: - template local_frame_type; - - using jacobian_engine_t = detail::jacobian_engine; - - using free_matrix_t = free_matrix; - using free_to_bound_matrix_t = - typename jacobian_engine_t::free_to_bound_matrix_type; - - // Free to bound jacobian at the destination surface - const free_to_bound_matrix_t free_to_bound_jacobian = - jacobian_engine_t::free_to_bound_jacobian(trf3, stepping()); - - // Path correction factor - const free_matrix_t path_correction = - jacobian_engine_t::path_correction( - stepping().pos(), stepping().dir(), stepping.dtds(), - stepping.dqopds(vol_mat_ptr), trf3); - - const free_matrix_t correction_term = - matrix::identity() + path_correction; - - return free_to_bound_jacobian * correction_term * - stepping.transport_jacobian() * bound_to_free_jacobian; - } - }; - - template - DETRAY_HOST_DEVICE void operator()(propagator_state_t& propagation) const { - auto& stepping = propagation._stepping; - const auto& navigation = propagation._navigation; - - // Do covariance transport when the track is on surface - if (!(navigation.is_on_sensitive() || - navigation.encountered_sf_material())) { - return; - } - - // Geometry context for this track - const auto& gctx = propagation._context; - - // Current Surface - const auto sf = navigation.get_surface(); - - // Bound track params of departure surface - auto& bound_params = stepping.bound_params(); - - // Covariance is transported only when the previous surface is an - // actual tracking surface. (i.e. This disables the covariance transport - // from curvilinear frame) - if (!bound_params.surface_link().is_invalid()) { - - // Previous surface - tracking_surface prev_sf{navigation.detector(), - bound_params.surface_link()}; - - const bound_to_free_matrix_t bound_to_free_jacobian = - prev_sf.bound_to_free_jacobian(gctx, bound_params); - - auto vol = navigation.get_volume(); - const auto vol_mat_ptr = - vol.has_material() ? vol.material_parameters(stepping().pos()) - : nullptr; - stepping.set_full_jacobian( - sf.template visit_mask( - sf.transform(gctx), bound_to_free_jacobian, vol_mat_ptr, - propagation._stepping)); - - // Calculate surface-to-surface covariance transport - const bound_matrix_t new_cov = - stepping.full_jacobian() * bound_params.covariance() * - matrix::transpose(stepping.full_jacobian()); - - stepping.bound_params().set_covariance(new_cov); - } - - // Convert free to bound vector - bound_params.set_parameter_vector( - sf.free_to_bound_vector(gctx, stepping())); - - // Set surface link - bound_params.set_surface_link(sf.barcode()); - } - -}; // namespace detray - -} // namespace detray diff --git a/core/include/detray/propagator/actors/parameter_updater.hpp b/core/include/detray/propagator/actors/parameter_updater.hpp new file mode 100644 index 0000000000..a56bdcf7de --- /dev/null +++ b/core/include/detray/propagator/actors/parameter_updater.hpp @@ -0,0 +1,243 @@ +/** Detray library, part of the ACTS project (R&D line) + * + * (c) 2022-2024 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +// Project include(s). +#include "detray/definitions/algebra.hpp" +#include "detray/definitions/detail/qualifiers.hpp" +#include "detray/definitions/track_parametrization.hpp" +#include "detray/geometry/tracking_surface.hpp" +#include "detray/propagator/base_actor.hpp" +#include "detray/propagator/composite_actor.hpp" +#include "detray/propagator/detail/jacobian_engine.hpp" +#include "detray/utils/curvilinear_frame.hpp" + +namespace detray { + +template +struct parameter_transporter : actor { + + /// @name Type definitions for the struct + /// @{ + using scalar_type = dscalar; + // Transformation matching this struct + using transform3_type = dtransform3D; + // The track parameters bound to the current sensitive/material surface + using bound_track_parameters_type = bound_track_parameters; + // bound matrix type + using bound_matrix_type = bound_matrix; + // Matrix type for bound to free jacobian + using bound_to_free_matrix_type = bound_to_free_matrix; + /// @} + + struct state { + + friend parameter_transporter; + + state() = default; + + /// Start from free track parameters + DETRAY_HOST_DEVICE + explicit state(const free_track_parameters& free_params) { + init(free_params); + } + + /// Start from bound track parameters + DETRAY_HOST_DEVICE + explicit state(const bound_track_parameters_type& bound_params) + : m_bound_params{bound_params} {} + + /// @returns bound track parameters - const access + DETRAY_HOST_DEVICE + bound_track_parameters_type& bound_params() { return m_bound_params; } + + /// Initialize the state from bound track parameters + DETRAY_HOST_DEVICE + void init(const bound_track_parameters_type& bound_params) { + m_bound_params = bound_params; + } + + /// Initialize the state from free track parameters + DETRAY_HOST_DEVICE + void init(const free_track_parameters& free_params) { + + curvilinear_frame cf(free_params); + + // Set bound track parameters + m_bound_params.set_parameter_vector(cf.m_bound_vec); + + // A dummy covariance - should not be used + m_bound_params.set_covariance( + matrix::identity()); + + // An invalid barcode - should not be used + m_bound_params.set_surface_link(geometry::barcode{}); + } + + /// @returns bound track parameters. + DETRAY_HOST_DEVICE + const bound_track_parameters_type& bound_params() const { + return m_bound_params; + } + + /// @returns the current full Jacbian. + DETRAY_HOST_DEVICE + inline const bound_matrix_type& full_jacobian() const { + return m_full_jacobian; + } + + private: + /// Set new full Jacbian. + DETRAY_HOST_DEVICE + inline void set_full_jacobian(const bound_matrix_type& jac) { + m_full_jacobian = jac; + } + + /// Full jacobian + bound_matrix_type m_full_jacobian = + matrix::identity(); + + /// bound covariance + bound_track_parameters_type m_bound_params{}; + }; + + struct get_full_jacobian_kernel { + + template + DETRAY_HOST_DEVICE inline bound_matrix_type operator()( + const mask_group_t& /*mask_group*/, const index_t& /*index*/, + const transform3_type& trf3, + const bound_to_free_matrix_type& bound_to_free_jacobian, + const material* vol_mat_ptr, + const stepper_state_t& stepping) const { + + using frame_t = typename mask_group_t::value_type::shape:: + template local_frame_type; + + using jacobian_engine_t = detail::jacobian_engine; + + using free_matrix_t = free_matrix; + using free_to_bound_matrix_t = + typename jacobian_engine_t::free_to_bound_matrix_type; + + // Free to bound jacobian at the destination surface + const free_to_bound_matrix_t free_to_bound_jacobian = + jacobian_engine_t::free_to_bound_jacobian(trf3, stepping()); + + // Path correction factor + const free_matrix_t path_correction = + jacobian_engine_t::path_correction( + stepping().pos(), stepping().dir(), stepping.dtds(), + stepping.dqopds(vol_mat_ptr), trf3); + + const free_matrix_t correction_term = + matrix::identity() + path_correction; + + return free_to_bound_jacobian * correction_term * + stepping.transport_jacobian() * bound_to_free_jacobian; + } + }; + + template + DETRAY_HOST_DEVICE void operator()(state& transporter_state, + propagator_state_t& propagation) const { + auto& stepping = propagation._stepping; + const auto& navigation = propagation._navigation; + + // Do covariance transport when the track is on surface + if (!(navigation.is_on_sensitive() || + navigation.encountered_sf_material())) { + return; + } + + // Geometry context for this track + const auto& gctx = propagation._context; + + // Current Surface + const auto sf = navigation.get_surface(); + + // Bound track params of departure surface + auto& bound_params = transporter_state.bound_params(); + + // Covariance is transported only when the previous surface is an + // actual tracking surface. (i.e. This disables the covariance transport + // from curvilinear frame) + if (!bound_params.surface_link().is_invalid()) { + + // Previous surface + tracking_surface prev_sf{navigation.detector(), + bound_params.surface_link()}; + + const bound_to_free_matrix_type bound_to_free_jacobian = + prev_sf.bound_to_free_jacobian(gctx, bound_params); + + auto vol = navigation.get_volume(); + const auto vol_mat_ptr = + vol.has_material() ? vol.material_parameters(stepping().pos()) + : nullptr; + transporter_state.set_full_jacobian( + sf.template visit_mask( + sf.transform(gctx), bound_to_free_jacobian, vol_mat_ptr, + propagation._stepping)); + + // Calculate surface-to-surface covariance transport + const bound_matrix_type new_cov = + transporter_state.full_jacobian() * bound_params.covariance() * + matrix::transpose(transporter_state.full_jacobian()); + + bound_params.set_covariance(new_cov); + } + + // Convert free to bound vector + bound_params.set_parameter_vector( + sf.free_to_bound_vector(gctx, stepping())); + + // Set surface link + bound_params.set_surface_link(sf.barcode()); + } +}; + +/// Update the stepper state after the bound track parameters were updated +template +struct parameter_resetter : actor { + + template + DETRAY_HOST_DEVICE void operator()( + const parameter_transporter::state& transporter_state, + propagator_state_t& propagation) const { + + const auto& navigation = propagation._navigation; + auto& stepping = propagation._stepping; + + // Do covariance transport when the track is on surface + if (!(navigation.is_on_sensitive() || + navigation.encountered_sf_material())) { + return; + } + + typename propagator_state_t::detector_type::geometry_context ctx{}; + + // Update free params after bound params were changed by actors + const auto sf = navigation.get_surface(); + stepping() = + sf.bound_to_free_vector(ctx, transporter_state.bound_params()); + + // Reset jacobian transport to identity matrix + stepping.reset_transport_jacobian(); + } +}; + +/// Call actors that depend on the bound track parameters safely together with +/// the parameter transporter and parameter resetter +template +using parameter_updater = + composite_actor, transporter_observers..., + parameter_resetter>; + +} // namespace detray diff --git a/core/include/detray/propagator/actors/pointwise_material_interactor.hpp b/core/include/detray/propagator/actors/pointwise_material_interactor.hpp index 5643bad10d..dd5ac503a9 100644 --- a/core/include/detray/propagator/actors/pointwise_material_interactor.hpp +++ b/core/include/detray/propagator/actors/pointwise_material_interactor.hpp @@ -15,6 +15,7 @@ #include "detray/materials/detail/concepts.hpp" #include "detray/materials/detail/material_accessor.hpp" #include "detray/materials/interaction.hpp" +#include "detray/propagator/actors/parameter_updater.hpp" #include "detray/propagator/base_actor.hpp" #include "detray/tracks/bound_track_parameters.hpp" #include "detray/utils/ranges.hpp" @@ -127,7 +128,9 @@ struct pointwise_material_interactor : actor { template DETRAY_HOST_DEVICE inline void operator()( - state &interactor_state, propagator_state_t &prop_state) const { + state &interactor_state, + parameter_transporter::state &transporter_state, + propagator_state_t &prop_state) const { interactor_state.reset(); @@ -139,7 +142,7 @@ struct pointwise_material_interactor : actor { auto &stepping = prop_state._stepping; this->update(prop_state._context, stepping.particle_hypothesis(), - stepping.bound_params(), interactor_state, + transporter_state.bound_params(), interactor_state, static_cast(navigation.direction()), navigation.get_surface()); } diff --git a/core/include/detray/propagator/base_actor.hpp b/core/include/detray/propagator/base_actor.hpp index 3144d5a44d..f239b1f516 100644 --- a/core/include/detray/propagator/base_actor.hpp +++ b/core/include/detray/propagator/base_actor.hpp @@ -7,15 +7,8 @@ #pragma once -// Propagate include(s) -#include "detray/definitions/containers.hpp" -#include "detray/definitions/detail/qualifiers.hpp" -#include "detray/utils/tuple_helpers.hpp" - // System include(s) -#include #include -#include namespace detray { @@ -28,167 +21,4 @@ struct actor { struct state {}; }; -namespace detail { -/// Extrac the tuple of actor states from an actor type -/// @{ -// Simple actor: No observers -template -struct get_state_tuple { - private: - using state_t = typename actor_t::state; - - // Remove empty default state of base actor type from tuple - using principal = std::conditional_t, - dtuple<>, dtuple>; - using principal_ref = - std::conditional_t, dtuple<>, - dtuple>; - - public: - using type = principal; - using ref_type = principal_ref; -}; - -// Composite actor: Has observers -template -requires(!std::same_as::observer_states, - void>) struct get_state_tuple { - private: - using principal_actor_t = typename actor_t::actor_type; - - using principal = typename get_state_tuple::type; - using principal_ref = typename get_state_tuple::ref_type; - - using observers = typename actor_t::observer_states; - using observer_refs = typename actor_t::observer_state_refs; - - public: - using type = detail::tuple_cat_t; - using ref_type = detail::tuple_cat_t; -}; - -/// Tuple of state types -template -using state_tuple_t = get_state_tuple::type; - -/// Tuple of references -template -using state_ref_tuple_t = get_state_tuple::ref_type; -/// @} - -} // namespace detail - -/// Composition of actors -/// -/// The composition represents an actor together with its observers. In -/// addition to running its own implementation, it notifies its observing actors -/// -/// @tparam principal_actor_t the actor the compositions implements itself. -/// @tparam observers a pack of observing actors that get called on the updated -/// actor state of the compositions actor implementation. -template -class composite_actor final : public principal_actor_t { - - public: - /// Tag whether this is a composite type (hides the def in the actor) - struct is_comp_actor : public std::true_type {}; - - /// The composite is an actor in itself. - using actor_type = principal_actor_t; - using state = typename actor_type::state; - - /// Tuple of states of observing actors - using observer_states = - detail::tuple_cat_t...>; - using observer_state_refs = - detail::tuple_cat_t...>; - - /// Call to the implementation of the actor (the actor possibly being an - /// observer itself) - /// - /// First runs its own implementation, then passes the updated state to its - /// observers. - /// - /// @param states the states of all actors in the chain - /// @param p_state the state of the propagator (stepper and navigator) - /// @param subject_state the state of the actor this actor observes. Uses - /// a dummy type if this is not an observing actor. - template - DETRAY_HOST_DEVICE void operator()( - actor_states_t &states, propagator_state_t &p_state, - subj_state_t &&subject_state = {}) const { - - // State of the primary actor that is implement by this composite actor - auto &actor_state = detail::get(states); - - // Do your own work ... - // Two cases: This is a simple actor or observing actor (pass on its - // subject's state) - if constexpr (std::is_same_v) { - actor_type::operator()(actor_state, p_state); - } else { - actor_type::operator()(actor_state, p_state, - std::forward(subject_state)); - } - - // ... then run the observers on the updated state - notify(m_observers, states, actor_state, p_state, - std::make_index_sequence{}); - } - - private: - /// Notifies the observing actors for composite and simple actor case. - /// - /// @param observer one of the observers - /// @param states the states of all actors in the chain - /// @param actor_state the state of this compositions actor as the subject - /// to all of its observers - /// @param p_state the state of the propagator (stepper and navigator) - template - DETRAY_HOST_DEVICE inline void notify(const observer_t &observer, - actor_states_t &states, - actor_impl_state_t &actor_state, - propagator_state_t &p_state) const { - // Two cases: observer is a simple actor or a composite actor - if constexpr (!typename observer_t::is_comp_actor()) { - // No actor state defined (empty) - if constexpr (std::same_as) { - observer(actor_state, p_state); - } else { - observer(detail::get(states), - actor_state, p_state); - } - } else { - observer(states, actor_state, p_state); - } - } - - /// Resolve the observer notification. - /// - /// Unrolls the observer types and runs the notification for each of them. - /// - /// @param observer_list all observers of the actor - /// @param states the states of all actors in the chain - /// @param actor_state the state of this compositions actor as the subject - /// to all of its observers - /// @param p_state the state of the propagator (stepper and navigator) - template - DETRAY_HOST_DEVICE inline void notify( - const dtuple &observer_list, actor_states_t &states, - actor_impl_state_t &actor_state, propagator_state_t &p_state, - std::index_sequence /*ids*/) const { - - (notify(detail::get(observer_list), states, actor_state, - p_state), - ...); - } - - /// Keep the observers (might be composites again) - [[no_unique_address]] dtuple m_observers = {}; -}; - } // namespace detray diff --git a/core/include/detray/propagator/base_stepper.hpp b/core/include/detray/propagator/base_stepper.hpp index df90ebdd6e..cd154103e0 100644 --- a/core/include/detray/propagator/base_stepper.hpp +++ b/core/include/detray/propagator/base_stepper.hpp @@ -18,7 +18,6 @@ #include "detray/propagator/constrained_step.hpp" #include "detray/propagator/stepping_config.hpp" #include "detray/tracks/tracks.hpp" -#include "detray/utils/curvilinear_frame.hpp" namespace detray { @@ -58,31 +57,17 @@ class base_stepper { /// @note It has to cast into a const track via the call operation. struct state { - /// Sets track parameters. - DETRAY_HOST_DEVICE - explicit state(const free_track_parameters_type &free_params) - : m_track(free_params) { - - curvilinear_frame cf(free_params); - - // Set bound track parameters - m_bound_params.set_parameter_vector(cf.m_bound_vec); - - // A dummy covariance - should not be used - m_bound_params.set_covariance( - matrix::identity()); - - // An invalid barcode - should not be used - m_bound_params.set_surface_link(geometry::barcode{}); - } + /// Construct state from free track parameters. + DETRAY_HOST_DEVICE explicit state( + const free_track_parameters_type &free_params) + : m_track{free_params} {} - /// Sets track parameters from bound track parameter. + /// Sets track parameters from bound track parameters. template DETRAY_HOST_DEVICE state( const bound_track_parameters_type &bound_params, const detector_t &det, - const typename detector_t::geometry_context &ctx) - : m_bound_params(bound_params) { + const typename detector_t::geometry_context &ctx) { // Departure surface const auto sf = tracking_surface{det, bound_params.surface_link()}; @@ -99,16 +84,6 @@ class base_stepper { DETRAY_HOST_DEVICE const free_track_parameters_type &operator()() const { return m_track; } - /// @returns bound track parameters - const access - DETRAY_HOST_DEVICE - bound_track_parameters_type &bound_params() { return m_bound_params; } - - /// @returns bound track parameters - non-const access - DETRAY_HOST_DEVICE - const bound_track_parameters_type &bound_params() const { - return m_bound_params; - } - /// Get stepping direction DETRAY_HOST_DEVICE inline step::direction direction() const { @@ -181,18 +156,6 @@ class base_stepper { return m_jac_transport; } - /// @returns the current full Jacbian. - DETRAY_HOST_DEVICE - inline const bound_matrix_type &full_jacobian() const { - return m_full_jacobian; - } - - /// Set new full Jacbian. - DETRAY_HOST_DEVICE - inline void set_full_jacobian(const bound_matrix_type &jac) { - m_full_jacobian = jac; - } - /// Reset transport Jacbian. DETRAY_HOST_DEVICE inline void reset_transport_jacobian() { @@ -230,13 +193,6 @@ class base_stepper { /// Jacobian transport matrix free_matrix_type m_jac_transport = matrix::identity(); - /// Full jacobian - bound_matrix_type m_full_jacobian = - matrix::identity(); - - /// Bound covariance - bound_track_parameters_type m_bound_params; - /// Free track parameters free_track_parameters_type m_track; diff --git a/core/include/detray/propagator/composite_actor.hpp b/core/include/detray/propagator/composite_actor.hpp new file mode 100644 index 0000000000..65311b5e57 --- /dev/null +++ b/core/include/detray/propagator/composite_actor.hpp @@ -0,0 +1,138 @@ +/** Detray library, part of the ACTS project (R&D line) + * + * (c) 2022-2025 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +// Propagate include(s) +#include "detray/definitions/containers.hpp" +#include "detray/definitions/detail/qualifiers.hpp" +#include "detray/propagator/base_actor.hpp" +#include "detray/propagator/concepts.hpp" +#include "detray/propagator/detail/type_traits.hpp" +#include "detray/utils/tuple_helpers.hpp" + +// System include(s) +#include +#include +#include + +namespace detray { +/// Composition of actors +/// +/// The composition represents an actor together with its observers. In +/// addition to running its own implementation, it notifies its observing actors +/// +/// @tparam principal_actor_t the actor the compositions implements itself. +/// @tparam observers a pack of observing actors that get called on the updated +/// actor state of the compositions actor implementation. +template +class composite_actor final : public principal_actor_t { + + public: + /// Tag whether this is a composite type (hides the def in the actor) + struct is_comp_actor : public std::true_type {}; + + /// The composite is an actor in itself. + using actor_type = principal_actor_t; + using state = typename actor_type::state; + + /// Tuple of states of observing actors + using observer_states = + detail::tuple_cat_t...>; + using observer_state_refs = + detail::tuple_cat_t...>; + + /// Call to the implementation of the actor (the actor possibly being an + /// observer itself) + /// + /// First runs its own implementation, then passes the updated state to its + /// observers. + /// + /// @param states the states of all actors in the chain + /// @param p_state the state of the propagator (stepper and navigator) + /// @param subject_state the state of the actor this actor observes. Uses + /// a dummy type if this is not an observing actor. + template + DETRAY_HOST_DEVICE void operator()( + actor_states_t &states, propagator_state_t &p_state, + subj_state_t &&subject_state = {}) const { + + // State of the primary actor that is implement by this composite actor + auto &actor_state = detail::get(states); + + // Do your own work ... + // Two cases: This is a simple actor or observing actor (pass on its + // subject's state) + if constexpr (std::same_as) { + actor_type::operator()(actor_state, p_state); + } else { + actor_type::operator()(actor_state, p_state, + std::forward(subject_state)); + } + + // ... then run the observers on the updated state + notify(m_observers, states, actor_state, p_state, + std::make_index_sequence{}); + } + + private: + /// Notifies the observing actors for composite and simple actor case. + /// + /// @param observer one of the observers + /// @param states the states of all actors in the chain + /// @param actor_state the state of this compositions actor as the subject + /// to all of its observers + /// @param p_state the state of the propagator (stepper and navigator) + template + DETRAY_HOST_DEVICE inline void notify(const observer_t &observer, + actor_states_t &states, + state &actor_state, + propagator_state_t &p_state) const { + // Two cases: observer is a simple actor or a composite actor + if constexpr (!concepts::composite_actor) { + // No actor state defined (empty) + if constexpr (std::same_as) { + observer(actor_state, p_state); + } else { + observer(detail::get(states), + actor_state, p_state); + } + } else { + observer(states, actor_state, p_state); + } + } + + /// Resolve the observer notification. + /// + /// Unrolls the observer types and runs the notification for each of them. + /// + /// @param observer_list all observers of the actor + /// @param states the states of all actors in the chain + /// @param actor_state the state of this compositions actor as the subject + /// to all of its observers + /// @param p_state the state of the propagator (stepper and navigator) + template + DETRAY_HOST_DEVICE inline void notify( + const dtuple &observer_list, actor_states_t &states, + state &actor_state, propagator_state_t &p_state, + std::index_sequence /*ids*/) const { + + (notify(detail::get(observer_list), states, actor_state, + p_state), + ...); + } + + /// Keep the observers (might be composites again) + [[no_unique_address]] dtuple m_observers = {}; +}; + +} // namespace detray diff --git a/core/include/detray/propagator/concepts.hpp b/core/include/detray/propagator/concepts.hpp new file mode 100644 index 0000000000..ed1803c56f --- /dev/null +++ b/core/include/detray/propagator/concepts.hpp @@ -0,0 +1,63 @@ +/** Detray library, part of the ACTS project (R&D line) + * + * (c) 2025 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +// Include all core actors +#include "detray/propagator/base_actor.hpp" +#include "detray/utils/concepts.hpp" +#include "detray/utils/tuple_helpers.hpp" + +// System include(s) +#include +#include + +namespace detray::concepts { + +/// Concept for a simple actor +template +concept actor = std::derived_from &&requires(const A a) { + typename A::state; +}; + +/// Concept for an actor including observing actors +template +concept composite_actor = + actor &&A::is_comp_actor::value &&requires(const A ca) { + typename A::observer_states; + typename A::observer_state_refs; +}; + +/// Concept for the actor chain that is being run in the propagator +template +concept actor_chain = requires(const A c, typename A::state_tuple s) { + typename A::actor_tuple; + typename A::state_tuple; + typename A::state_ref_tuple; + + { c.actors() } + ->std::same_as; + + { c.make_default_actor_states() } + ->detray::concepts::any_of; + + { c.setup_actor_states(s) } + ->std::same_as; +}; + +/// Check if a state type belongs to an actor or actor chain +template +concept is_state_of = (actor_chain && + (detail::is_permutation_v, + typename T::state_ref_tuple> || + detail::is_permutation_v, + typename T::state_tuple>)) || + + (actor && + std::same_as, typename T::state>); + +} // namespace detray::concepts diff --git a/core/include/detray/propagator/detail/type_traits.hpp b/core/include/detray/propagator/detail/type_traits.hpp new file mode 100644 index 0000000000..3d1943e5c1 --- /dev/null +++ b/core/include/detray/propagator/detail/type_traits.hpp @@ -0,0 +1,67 @@ +/** Detray library, part of the ACTS project (R&D line) + * + * (c) 2025 CERN for the benefit of the ACTS project + * + * Mozilla Public License Version 2.0 + */ + +#pragma once + +// Include all core actors +#include "detray/propagator/base_actor.hpp" +#include "detray/utils/tuple.hpp" +#include "detray/utils/tuple_helpers.hpp" + +// System include(s) +#include + +namespace detray::detail { + +/// Extrac the tuple of actor states from an actor type +/// @{ +// Simple actor: No observers +template +struct get_state_tuple { + private: + using state_t = typename actor_t::state; + + // Remove empty default state of base actor type from tuple + using principal = std::conditional_t, + dtuple<>, dtuple>; + using principal_ref = + std::conditional_t, dtuple<>, + dtuple>; + + public: + using type = principal; + using ref_type = principal_ref; +}; + +// Composite actor: Has observers +template +requires(!std::same_as::observer_states, + void>) struct get_state_tuple { + private: + using principal_actor_t = typename actor_t::actor_type; + + using principal = typename get_state_tuple::type; + using principal_ref = typename get_state_tuple::ref_type; + + using observers = typename actor_t::observer_states; + using observer_refs = typename actor_t::observer_state_refs; + + public: + using type = detail::tuple_cat_t; + using ref_type = detail::tuple_cat_t; +}; + +/// Tuple of state types +template +using state_tuple_t = get_state_tuple::type; + +/// Tuple of references +template +using state_ref_tuple_t = get_state_tuple::ref_type; +/// @} + +} // namespace detray::detail diff --git a/core/include/detray/propagator/propagator.hpp b/core/include/detray/propagator/propagator.hpp index 2373f60364..dc205a550d 100644 --- a/core/include/detray/propagator/propagator.hpp +++ b/core/include/detray/propagator/propagator.hpp @@ -14,6 +14,7 @@ #include "detray/navigation/navigator.hpp" #include "detray/propagator/actor_chain.hpp" #include "detray/propagator/base_stepper.hpp" +#include "detray/propagator/concepts.hpp" #include "detray/propagator/propagation_config.hpp" #include "detray/tracks/tracks.hpp" @@ -27,7 +28,8 @@ namespace detray { /// /// @tparam stepper_t for the transport /// @tparam navigator_t for the navigation -template +template struct propagator { using stepper_type = stepper_t; @@ -48,7 +50,7 @@ struct propagator { navigator_t m_navigator; /// Register the actor types - const actor_chain_t run_actors{}; + const actor_chain_type run_actors{}; /// Construct from a propagator configuration DETRAY_HOST_DEVICE @@ -151,8 +153,9 @@ struct propagator { /// @note If the return value of this function is true, a propagation step /// can be taken afterwards. template - DETRAY_HOST_DEVICE void propagate_init( - state &propagation, actor_states_t actor_state_refs) const { + requires concepts::is_state_of + DETRAY_HOST_DEVICE void propagate_init( + state &propagation, actor_states_t actor_state_refs) const { auto &navigation = propagation._navigation; auto &stepping = propagation._stepping; auto &context = propagation._context; @@ -180,9 +183,10 @@ struct propagator { /// @note If the return value of this function is true, another step can /// be taken afterwards. template - DETRAY_HOST_DEVICE bool propagate_step( - state &propagation, bool is_init, - actor_states_t actor_state_refs) const { + requires concepts::is_state_of + DETRAY_HOST_DEVICE bool propagate_step( + state &propagation, bool is_init, + actor_states_t actor_state_refs) const { auto &navigation = propagation._navigation; auto &stepping = propagation._stepping; auto &context = propagation._context; @@ -244,9 +248,10 @@ struct propagator { /// /// @return propagation success. template - DETRAY_HOST_DEVICE bool propagate( - state &propagation, - actor_states_t actor_state_refs = dtuple<>{}) const { + requires concepts::is_state_of + DETRAY_HOST_DEVICE bool propagate( + state &propagation, + actor_states_t actor_state_refs = dtuple<>{}) const { propagate_init(propagation, actor_state_refs); bool is_init = true; @@ -281,8 +286,9 @@ struct propagator { /// /// @return propagation success. template - DETRAY_HOST_DEVICE bool propagate_sync( - state &propagation, actor_states_t actor_state_refs) const { + requires concepts::is_state_of + DETRAY_HOST_DEVICE bool propagate_sync( + state &propagation, actor_states_t actor_state_refs) const { propagate_init(propagation, actor_state_refs); bool is_init = true; diff --git a/core/include/detray/utils/concepts.hpp b/core/include/detray/utils/concepts.hpp index 6b59926d4d..8de6508aa4 100644 --- a/core/include/detray/utils/concepts.hpp +++ b/core/include/detray/utils/concepts.hpp @@ -18,6 +18,9 @@ namespace detray::concepts { +template +concept any_of = std::disjunction_v...>; + /// Arithmetic types template concept arithmetic = std::is_arithmetic_v; diff --git a/core/include/detray/utils/tuple_helpers.hpp b/core/include/detray/utils/tuple_helpers.hpp index 7d1fe27340..b299ffdb43 100644 --- a/core/include/detray/utils/tuple_helpers.hpp +++ b/core/include/detray/utils/tuple_helpers.hpp @@ -199,4 +199,41 @@ template using tuple_cat_t = typename tuple_cat_type::type; /// @} +/// Check for equality of tuple types modulo permutation +/// @{ +template +struct is_permutation : public std::false_type {}; + +template <> +struct is_permutation, dtuple<>> : public std::true_type {}; + +template +struct is_permutation, dtuple> { + + using T1 = dtuple; + using T2 = dtuple; + + template + static consteval bool compare() { + if constexpr (detray::detail::has_type_v) { + if constexpr (sizeof...(U) == 0u) { + return true; + } else { + return compare(); + } + } else { + return false; + } + } + + static constexpr bool value = + ((detray::detail::tuple_size_v == + detray::detail::tuple_size_v)&&compare() && + compare()); +}; + +template +constexpr bool is_permutation_v = is_permutation::value; +/// @} + } // namespace detray::detail diff --git a/tests/benchmarks/cpu/propagation.cpp b/tests/benchmarks/cpu/propagation.cpp index c422726ed9..f99ebc10b6 100644 --- a/tests/benchmarks/cpu/propagation.cpp +++ b/tests/benchmarks/cpu/propagation.cpp @@ -46,10 +46,8 @@ int main(int argc, char** argv) { using field_t = bfield::const_field_t; using stepper_t = rk_stepper; using empty_chain_t = actor_chain<>; - using default_chain = - actor_chain, - pointwise_material_interactor, - parameter_resetter>; + using default_chain = actor_chain>>; vecmem::host_memory_resource host_mr; @@ -113,8 +111,10 @@ int main(int argc, char** argv) { dtuple<> empty_state{}; pointwise_material_interactor::state interactor_state{}; + parameter_updater::state transporter_state{}; - auto actor_states = detail::make_tuple(interactor_state); + auto actor_states = + detail::make_tuple(transporter_state, interactor_state); // // Register benchmarks diff --git a/tests/benchmarks/cuda/propagation.cpp b/tests/benchmarks/cuda/propagation.cpp index 219928154e..18077f907a 100644 --- a/tests/benchmarks/cuda/propagation.cpp +++ b/tests/benchmarks/cuda/propagation.cpp @@ -108,8 +108,10 @@ int main(int argc, char** argv) { dtuple<> empty_state{}; pointwise_material_interactor::state interactor_state{}; + parameter_updater::state transporter_state{}; - auto actor_states = detail::make_tuple(interactor_state); + auto actor_states = + detail::make_tuple(transporter_state, interactor_state); // // Register benchmarks diff --git a/tests/benchmarks/include/detray/benchmarks/cpu/propagation_benchmark.hpp b/tests/benchmarks/include/detray/benchmarks/cpu/propagation_benchmark.hpp index e4469ade1e..dff03b616b 100644 --- a/tests/benchmarks/include/detray/benchmarks/cpu/propagation_benchmark.hpp +++ b/tests/benchmarks/include/detray/benchmarks/cpu/propagation_benchmark.hpp @@ -104,6 +104,17 @@ struct host_propagation_bm : public benchmark_base { &track) { // Fresh copy of actor states actor_states_t actor_states(*input_actor_states); + + // Init the parameter transport, if present + using transporter_state_t = parameter_transporter::state; + if constexpr (detail::has_type_v) { + // @TODO: Make non-owning to avoid the copy + auto &transporter_state = + detail::get(actor_states); + transporter_state.init(track); + } + // Tuple of references to pass to the propagator typename actor_chain_t::state_ref_tuple actor_state_refs = actor_chain_t::setup_actor_states(actor_states); diff --git a/tests/benchmarks/include/detray/benchmarks/device/cuda/propagation_benchmark.cu b/tests/benchmarks/include/detray/benchmarks/device/cuda/propagation_benchmark.cu index ba45520ffd..34f8358480 100644 --- a/tests/benchmarks/include/detray/benchmarks/device/cuda/propagation_benchmark.cu +++ b/tests/benchmarks/include/detray/benchmarks/device/cuda/propagation_benchmark.cu @@ -43,9 +43,19 @@ __global__ void __launch_bounds__(256, 4) propagator_benchmark_kernel( propagator_device_t p{cfg}; // Create the actor states on a fresh copy - typename actor_chain_t::state_tuple actor_states = *device_actor_state_ptr; + using actor_states_t = typename actor_chain_t::state_tuple; + actor_states_t actor_states = *device_actor_state_ptr; auto actor_state_refs = actor_chain_t::setup_actor_states(actor_states); + // Init the parameter transport, if present + using transporter_state_t = parameter_transporter::state; + if constexpr (detail::has_type_v) { + // @TODO: Make non-owning to avoid the copy + auto &transporter_state = + detail::get(actor_states); + transporter_state.init(tracks.at(gid)); + } + // Create the propagator state // The track gets copied into the stepper state, so that the diff --git a/tests/benchmarks/include/detray/benchmarks/device/cuda/propagation_benchmark.hpp b/tests/benchmarks/include/detray/benchmarks/device/cuda/propagation_benchmark.hpp index 10fe03e8b1..d9c4df0e5c 100644 --- a/tests/benchmarks/include/detray/benchmarks/device/cuda/propagation_benchmark.hpp +++ b/tests/benchmarks/include/detray/benchmarks/device/cuda/propagation_benchmark.hpp @@ -47,9 +47,8 @@ template using empty_chain = actor_chain<>; template -using default_chain = actor_chain, - pointwise_material_interactor, - parameter_resetter>; +using default_chain = actor_chain< + parameter_updater>>; using const_field_t = bfield::const_bknd_t; diff --git a/tests/include/detray/test/device/cuda/material_validation.cu b/tests/include/detray/test/device/cuda/material_validation.cu index 7ae7ddb775..289803d58e 100644 --- a/tests/include/detray/test/device/cuda/material_validation.cu +++ b/tests/include/detray/test/device/cuda/material_validation.cu @@ -34,14 +34,14 @@ __global__ void material_validation_kernel( using navigator_t = navigator; // Propagator with full covariance transport, pathlimit aborter and // material tracer + using pathlimit_aborter_t = pathlimit_aborter; + using parameter_updater_t = + parameter_updater>; using material_tracer_t = material_validator::material_tracer; - using pathlimit_aborter_t = pathlimit_aborter; - using actor_chain_t = - actor_chain, - parameter_resetter, - pointwise_material_interactor, - material_tracer_t>; + + using actor_chain_t = actor_chain; using propagator_t = propagator; detector_device_t det(det_data); @@ -62,11 +62,13 @@ __global__ void material_validation_kernel( // Create the actor states typename pathlimit_aborter_t::state aborter_state{cfg.stepping.path_limit}; + typename parameter_transporter::state transporter_state{ + tracks[trk_id]}; typename pointwise_material_interactor::state interactor_state{}; typename material_tracer_t::state mat_tracer_state{mat_steps.at(trk_id)}; - auto actor_states = - ::detray::tie(aborter_state, interactor_state, mat_tracer_state); + auto actor_states = ::detray::tie(aborter_state, transporter_state, + interactor_state, mat_tracer_state); // Run propagation typename navigator_t::state::view_type nav_view{}; diff --git a/tests/include/detray/test/device/propagator_test.hpp b/tests/include/detray/test/device/propagator_test.hpp index 38c7d5f7d2..b7292caa84 100644 --- a/tests/include/detray/test/device/propagator_test.hpp +++ b/tests/include/detray/test/device/propagator_test.hpp @@ -76,16 +76,14 @@ struct propagator_test_config { using step_tracer_host_t = step_tracer; using step_tracer_device_t = step_tracer; using pathlimit_aborter_t = pathlimit_aborter; +using parameter_updater_t = + parameter_updater>; + using actor_chain_host_t = - actor_chain, - pointwise_material_interactor, - parameter_resetter>; + actor_chain; using actor_chain_device_t = - actor_chain, - pointwise_material_interactor, - parameter_resetter>; + actor_chain; /// Precompute the tracks template > @@ -134,9 +132,10 @@ inline auto run_propagation_host(vecmem::memory_resource *mr, tracer_state.collect_only_on_surface(true); typename pathlimit_aborter_t::state pathlimit_state{ cfg.stepping.path_limit}; + parameter_transporter::state transporter_state{trk}; pointwise_material_interactor::state interactor_state{}; - auto actor_states = - detray::tie(tracer_state, pathlimit_state, interactor_state); + auto actor_states = detray::tie(tracer_state, pathlimit_state, + transporter_state, interactor_state); typename propagator_host_t::state state(trk, field, det); diff --git a/tests/include/detray/test/utils/simulation/random_scatterer.hpp b/tests/include/detray/test/utils/simulation/random_scatterer.hpp index 710eb2b821..de642a19db 100644 --- a/tests/include/detray/test/utils/simulation/random_scatterer.hpp +++ b/tests/include/detray/test/utils/simulation/random_scatterer.hpp @@ -16,6 +16,7 @@ #include "detray/geometry/tracking_surface.hpp" #include "detray/materials/detail/concepts.hpp" #include "detray/materials/interaction.hpp" +#include "detray/propagator/actors/parameter_updater.hpp" #include "detray/propagator/base_actor.hpp" #include "detray/tracks/bound_track_parameters.hpp" #include "detray/utils/axis_rotation.hpp" @@ -128,8 +129,10 @@ struct random_scatterer : actor { }; template - DETRAY_HOST inline void operator()(state& simulator_state, - propagator_state_t& prop_state) const { + DETRAY_HOST inline void operator()( + state& simulator_state, + parameter_transporter::state& transporter_state, + propagator_state_t& prop_state) const { // @Todo: Make context part of propagation state using detector_type = typename propagator_state_t::detector_type; @@ -143,7 +146,7 @@ struct random_scatterer : actor { auto& stepping = prop_state._stepping; const auto& ptc = stepping.particle_hypothesis(); - auto& bound_params = stepping.bound_params(); + auto& bound_params = transporter_state.bound_params(); const auto sf = navigation.get_surface(); const scalar_type cos_inc_angle{ sf.cos_angle(geo_context_type{}, bound_params.dir(), @@ -168,8 +171,8 @@ struct random_scatterer : actor { simulator_state.generator); // Update Phi and Theta - stepping.bound_params().set_phi(vector::phi(new_dir)); - stepping.bound_params().set_theta(vector::theta(new_dir)); + bound_params.set_phi(vector::phi(new_dir)); + bound_params.set_theta(vector::theta(new_dir)); // Flag renavigation of the current candidate prop_state._navigation.set_high_trust(); diff --git a/tests/include/detray/test/validation/material_validation_utils.hpp b/tests/include/detray/test/validation/material_validation_utils.hpp index 8b93d62ca0..1103f99223 100644 --- a/tests/include/detray/test/validation/material_validation_utils.hpp +++ b/tests/include/detray/test/validation/material_validation_utils.hpp @@ -148,26 +148,43 @@ struct material_tracer : detray::actor { vector_t> m_mat_steps{}; }; + /// Run as observer to the parameter transporter (covariance transportsa) + template + DETRAY_HOST_DEVICE void operator()( + state &tracer, + parameter_transporter::state &transporter_state, + const propagator_state_t &prop_state) const { + + record_track_dir(tracer, prop_state); + + // Only count material if navigator encountered it + const auto &navigation = prop_state._navigation; + if (!navigation.encountered_sf_material()) { + return; + } + + // For now use default context + typename propagator_state_t::detector_type::geometry_context gctx{}; + + const auto &track_param = transporter_state.bound_params(); + dpoint2D loc_pos = track_param.bound_local(); + + record_mat_step(tracer, gctx, navigation.get_surface(), loc_pos, + track_param.dir()); + } + + /// Run in a propagation chain without covariance transport template DETRAY_HOST_DEVICE void operator()( state &tracer, const propagator_state_t &prop_state) const { using algebra_t = typename propagator_state_t::detector_type::algebra_type; - using vector3_t = dvector3D; - using point2_t = dpoint2D; - const auto &navigation = prop_state._navigation; - - // Record the initial track direction - vector3_t glob_dir = prop_state._stepping().dir(); - if (detray::detail::is_invalid_value(tracer.m_mat_record.eta) && - detray::detail::is_invalid_value(tracer.m_mat_record.phi)) { - tracer.m_mat_record.eta = vector::eta(glob_dir); - tracer.m_mat_record.phi = vector::phi(glob_dir); - } + record_track_dir(tracer, prop_state); // Only count material if navigator encountered it + const auto &navigation = prop_state._navigation; if (!navigation.encountered_sf_material()) { return; } @@ -179,20 +196,39 @@ struct material_tracer : detray::actor { const auto sf = navigation.get_surface(); // Track direction and bound position on current surface - point2_t loc_pos{}; - - // Get the local track position from the bound track parameters, - // if covariance transport is enabled in the propagation - if constexpr (detail::has_type_v, - typename propagator_state_t:: - actor_chain_type::actor_tuple>) { - const auto &track_param = prop_state._stepping.bound_params(); - loc_pos = track_param.bound_local(); - } else { - const auto &track_param = prop_state._stepping(); - glob_dir = track_param.dir(); - loc_pos = sf.global_to_bound(gctx, track_param.pos(), glob_dir); + const auto &track_param = prop_state._stepping(); + dvector3D glob_dir = track_param.dir(); + dpoint2D loc_pos = + sf.global_to_bound(gctx, track_param.pos(), glob_dir); + + record_mat_step(tracer, gctx, sf, loc_pos, glob_dir); + } + + private: + /// Record the track direction + template + DETRAY_HOST_DEVICE inline auto record_track_dir( + state &tracer, const propagator_state_t &prop_state) const { + using algebra_t = + typename propagator_state_t::detector_type::algebra_type; + using vector3_t = dvector3D; + + // Record the initial track direction + vector3_t glob_dir = prop_state._stepping().dir(); + if (detray::detail::is_invalid_value(tracer.m_mat_record.eta) && + detray::detail::is_invalid_value(tracer.m_mat_record.phi)) { + tracer.m_mat_record.eta = vector::eta(glob_dir); + tracer.m_mat_record.phi = vector::phi(glob_dir); } + } + + /// Record the data for a material step + template + DETRAY_HOST_DEVICE inline auto record_mat_step( + state &tracer, const typename detector_t::geometry_context &gctx, + const tracking_surface sf, + const dpoint2D &loc_pos, + const dvector3D &glob_dir) const { // Fetch the material parameters and pathlength through the material const auto mat_params = sf.template visit_material( @@ -237,11 +273,11 @@ inline auto record_material( using material_tracer_t = material_validator::material_tracer; using pathlimit_aborter_t = pathlimit_aborter; - using actor_chain_t = - actor_chain, - parameter_resetter, - pointwise_material_interactor, - material_tracer_t>; + using parameter_updater_t = + parameter_updater>; + + using actor_chain_t = actor_chain; using propagator_t = propagator; // Propagator @@ -250,11 +286,12 @@ inline auto record_material( // Build actor and propagator states typename pathlimit_aborter_t::state pathlimit_aborter_state{ cfg.stepping.path_limit}; + typename parameter_transporter::state transporter_state{track}; typename pointwise_material_interactor::state interactor_state{}; typename material_tracer_t::state mat_tracer_state{*host_mr}; - auto actor_states = detray::tie(pathlimit_aborter_state, interactor_state, - mat_tracer_state); + auto actor_states = detray::tie(pathlimit_aborter_state, transporter_state, + interactor_state, mat_tracer_state); typename propagator_t::state propagation{track, det, cfg.context}; diff --git a/tests/integration_tests/cpu/material/material_interaction.cpp b/tests/integration_tests/cpu/material/material_interaction.cpp index 8e6af4972d..011bf5b8c7 100644 --- a/tests/integration_tests/cpu/material/material_interaction.cpp +++ b/tests/integration_tests/cpu/material/material_interaction.cpp @@ -40,6 +40,9 @@ using test_algebra = test::algebra; using scalar = test::scalar; using covariance_t = typename bound_track_parameters::covariance_type; +using interactor_t = pointwise_material_interactor; + +static_assert(detray::concepts::actor); // Test is done for muon namespace { @@ -72,10 +75,10 @@ GTEST_TEST(detray_material, telescope_geometry_energy_loss) { using navigator_t = navigator; using stepper_t = line_stepper; using interactor_t = pointwise_material_interactor; + using parameter_updater_t = parameter_updater; using pathlimit_aborter_t = pathlimit_aborter; - using actor_chain_t = - actor_chain, - interactor_t, parameter_resetter>; + + using actor_chain_t = actor_chain; using propagator_t = propagator; // Propagator is built from the stepper and navigator @@ -97,10 +100,12 @@ GTEST_TEST(detray_material, telescope_geometry_energy_loss) { geometry::barcode{}.set_index(0u), bound_vector, bound_cov); pathlimit_aborter_t::state aborter_state{}; + parameter_transporter::state bound_updater{bound_param}; interactor_t::state interactor_state{}; // Create actor states tuples - auto actor_states = detray::tie(aborter_state, interactor_state); + auto actor_states = + detray::tie(aborter_state, bound_updater, interactor_state); propagator_t::state state(bound_param, det); state.do_debug = true; @@ -110,7 +115,7 @@ GTEST_TEST(detray_material, telescope_geometry_energy_loss) { << state.debug_stream.str() << std::endl; // new momentum - const scalar newP{state._stepping.bound_params().p(ptc.charge())}; + const scalar newP{bound_updater.bound_params().p(ptc.charge())}; // mass const auto mass = ptc.mass(); @@ -123,7 +128,7 @@ GTEST_TEST(detray_material, telescope_geometry_energy_loss) { // New qop variance const scalar new_var_qop{ - getter::element(state._stepping.bound_params().covariance(), + getter::element(bound_updater.bound_params().covariance(), e_bound_qoverp, e_bound_qoverp)}; // Interaction object @@ -192,10 +197,10 @@ GTEST_TEST(detray_material, telescope_geometry_scattering_angle) { using navigator_t = navigator; using stepper_t = line_stepper; using simulator_t = random_scatterer; + using parameter_updater_t = parameter_updater; using pathlimit_aborter_t = pathlimit_aborter; - using actor_chain_t = - actor_chain, - simulator_t, parameter_resetter>; + + using actor_chain_t = actor_chain; using propagator_t = propagator; // Propagator is built from the stepper and navigator @@ -224,12 +229,14 @@ GTEST_TEST(detray_material, telescope_geometry_scattering_angle) { for (std::size_t i = 0u; i < n_samples; i++) { pathlimit_aborter_t::state aborter_state{}; + parameter_transporter::state bound_updater{bound_param}; // Seed = sample id simulator_t::state simulator_state{i}; simulator_state.do_energy_loss = false; // Create actor states tuples - auto actor_states = detray::tie(aborter_state, simulator_state); + auto actor_states = + detray::tie(aborter_state, bound_updater, simulator_state); propagator_t::state state(bound_param, det); state.do_debug = true; @@ -238,11 +245,11 @@ GTEST_TEST(detray_material, telescope_geometry_scattering_angle) { ASSERT_TRUE(p.propagate(state, actor_states)) << state.debug_stream.str() << std::endl; - const auto& final_param = state._stepping.bound_params(); + const auto& final_param = bound_updater.bound_params(); // Updated phi and theta variance if (i == 0u) { - pointwise_material_interactor{}.update_angle_variance( + interactor_t{}.update_angle_variance( bound_cov, traj.dir(), simulator_state.projected_scattering_angle); } diff --git a/tests/integration_tests/cpu/propagator/backward_propagation.cpp b/tests/integration_tests/cpu/propagator/backward_propagation.cpp index 5f98210bcb..2356a98c8c 100644 --- a/tests/integration_tests/cpu/propagator/backward_propagation.cpp +++ b/tests/integration_tests/cpu/propagator/backward_propagation.cpp @@ -67,10 +67,10 @@ TEST_P(BackwardPropagation, backward_propagation) { using navigator_t = navigator; using rk_stepper_t = rk_stepper; - using actor_chain_t = - actor_chain, - pointwise_material_interactor, - parameter_resetter>; + using parameter_updater_t = + parameter_updater>; + using actor_chain_t = actor_chain; using propagator_t = propagator; // Particle hypothesis @@ -90,6 +90,7 @@ TEST_P(BackwardPropagation, backward_propagation) { geometry::barcode{}.set_index(0u), bound_vector, bound_cov); // Actors + parameter_transporter::state bound_updater{bound_param0}; pointwise_material_interactor::state interactor{}; propagation::config prop_cfg{}; @@ -104,13 +105,13 @@ TEST_P(BackwardPropagation, backward_propagation) { fw_state.do_debug = true; // Run propagator - p.propagate(fw_state, detray::tie(interactor)); + p.propagate(fw_state, detray::tie(bound_updater, interactor)); // Print the debug stream // std::cout << fw_state.debug_stream.str() << std::endl; - // Bound state after propagation - const auto& bound_param1 = fw_state._stepping.bound_params(); + // Bound state after propagation (snapshot) + const auto bound_param1 = bound_updater.bound_params(); // Check if the track reaches the final surface EXPECT_EQ(bound_param0.surface_link().volume(), 4095u); @@ -127,13 +128,13 @@ TEST_P(BackwardPropagation, backward_propagation) { bw_state._navigation.set_direction(navigation::direction::e_backward); // Run propagator - p.propagate(bw_state, detray::tie(interactor)); + p.propagate(bw_state, detray::tie(bound_updater, interactor)); // Print the debug stream // std::cout << bw_state.debug_stream.str() << std::endl; - // Bound state after propagation - const auto& bound_param2 = bw_state._stepping.bound_params(); + // Bound state after propagation (snapshot) + const auto bound_param2 = bound_updater.bound_params(); // Check if the track reaches the initial surface EXPECT_EQ(bound_param2.surface_link().volume(), 0u); diff --git a/tests/integration_tests/cpu/propagator/jacobian_validation.cpp b/tests/integration_tests/cpu/propagator/jacobian_validation.cpp index 2fe14a4e98..58e471e103 100644 --- a/tests/integration_tests/cpu/propagator/jacobian_validation.cpp +++ b/tests/integration_tests/cpu/propagator/jacobian_validation.cpp @@ -11,8 +11,7 @@ #include "detray/detectors/bfield.hpp" #include "detray/navigation/intersection/helix_intersector.hpp" #include "detray/navigation/navigator.hpp" -#include "detray/propagator/actors/parameter_resetter.hpp" -#include "detray/propagator/actors/parameter_transporter.hpp" +#include "detray/propagator/actors/parameter_updater.hpp" #include "detray/propagator/propagator.hpp" #include "detray/propagator/rk_stepper.hpp" @@ -396,8 +395,10 @@ struct bound_getter : actor { }; template - DETRAY_HOST_DEVICE void operator()(state& actor_state, - propagator_state_t& propagation) const { + DETRAY_HOST_DEVICE void operator()( + state& actor_state, + const parameter_transporter::state& transporter_state, + propagator_state_t& propagation) const { auto& navigation = propagation._navigation; auto& stepping = propagation._stepping; @@ -428,7 +429,7 @@ struct bound_getter : actor { if ((navigation.is_on_sensitive() || navigation.is_on_passive()) && navigation.barcode().index() == 0u) { - actor_state.m_param_departure = stepping.bound_params(); + actor_state.m_param_departure = transporter_state.bound_params(); } // Get the bound track parameters and jacobian at the destination // surface @@ -437,8 +438,8 @@ struct bound_getter : actor { actor_state.m_path_length = stepping.path_length(); actor_state.m_abs_path_length = stepping.abs_path_length(); - actor_state.m_param_destination = stepping.bound_params(); - actor_state.m_jacobi = stepping.full_jacobian(); + actor_state.m_param_destination = transporter_state.bound_params(); + actor_state.m_jacobi = transporter_state.full_jacobian(); // Stop navigation if the destination surface found propagation._heartbeat &= navigation.exit(); @@ -476,7 +477,7 @@ bound_getter::state evaluate_bound_param( bound_getter::state bound_getter_state{}; bound_getter_state.track_ID = trk_count; bound_getter_state.m_min_path_length = detector_length * 0.75f; - auto actor_states = detray::tie(bound_getter_state); + auto actor_states = detray::tie(transporter_state, bound_getter_state); // Init propagator states for the reference track typename propagator_t::state state(initial_param, field, det); @@ -521,11 +522,12 @@ bound_param_vector_type get_displaced_bound_vector( typename propagator_t::state dstate(dparam, field, det); // Actor states + parameter_transporter::state transporter_state{dparam}; bound_getter::state bound_getter_state{}; bound_getter_state.track_ID = trk_count; bound_getter_state.m_min_path_length = detector_length * 0.75f; - auto actor_states = detray::tie(bound_getter_state); + auto actor_states = detray::tie(transporter_state, bound_getter_state); dstate.set_particle(ptc); dstate._stepping .template set_constraint( @@ -1565,9 +1567,9 @@ int main(int argc, char** argv) { const inhom_bfield_t inhom_bfield = bfield::create_inhom_field(); // Actor chain type - using actor_chain_t = actor_chain, - bound_getter, - parameter_resetter>; + using parameter_updater_t = + parameter_updater>; + using actor_chain_t = actor_chain; // Iterate over reference (pilot) tracks for a rectangular telescope // geometry and Jacobian calculation diff --git a/tests/integration_tests/cpu/propagator/propagator.cpp b/tests/integration_tests/cpu/propagator/propagator.cpp index e3a02eaff9..4810b4b908 100644 --- a/tests/integration_tests/cpu/propagator/propagator.cpp +++ b/tests/integration_tests/cpu/propagator/propagator.cpp @@ -57,19 +57,18 @@ struct helix_inspector : actor { /// Check that the stepper remains on the right helical track for its pos. template DETRAY_HOST_DEVICE void operator()( - state& inspector_state, const propagator_state_t& prop_state) const { + state& inspector_state, + parameter_transporter::state& transporter_state, + const propagator_state_t& prop_state) const { const auto& navigation = prop_state._navigation; const auto& stepping = prop_state._stepping; + const auto& bound_params = transporter_state.bound_params(); typename propagator_state_t::detector_type::geometry_context ctx{}; // Update inspector state inspector_state._nav_status.push_back(navigation.status()); - // The propagation does not start on a surface, skipp the inital path - if (!stepping.bound_params().surface_link().is_invalid()) { - inspector_state.path_from_surface += stepping.step_size(); - } // Nothing has happened yet (first call of actor chain) if (stepping.path_length() < tol || @@ -77,16 +76,16 @@ struct helix_inspector : actor { return; } - if (stepping.bound_params().surface_link().is_invalid()) { + if (bound_params.surface_link().is_invalid()) { return; } // Surface - const auto sf = tracking_surface{ - navigation.detector(), stepping.bound_params().surface_link()}; + const auto sf = tracking_surface{navigation.detector(), + bound_params.surface_link()}; const free_track_parameters free_params = - sf.bound_to_free_vector(ctx, stepping.bound_params()); + sf.bound_to_free_vector(ctx, bound_params); const auto last_pos = free_params.pos(); @@ -113,6 +112,10 @@ struct helix_inspector : actor { inspector_state.path_from_surface * tol * 10.f); } } + // The propagation does not start on a surface, skipp the inital path + if (!bound_params.surface_link().is_invalid()) { + inspector_state.path_from_surface += stepping.step_size(); + } // Reset path from surface if (navigation.is_on_sensitive() || navigation.encountered_sf_material()) { @@ -203,12 +206,16 @@ TEST_P(PropagatorWithRkStepper, rk4_propagator_const_bfield) { using policy_t = stepper_rk_policy; using stepper_t = rk_stepper; - // Include helix actor to check track position/covariance + + // Parameter update scheme: Include material interaction and helix + // inspector to check track position/covariance + using parameter_updater_t = + parameter_updater, + helix_inspector>; + using actor_chain_t = - actor_chain, - parameter_transporter, - pointwise_material_interactor, - parameter_resetter>; + actor_chain, parameter_updater_t>; using propagator_t = propagator; // Build detector @@ -313,12 +320,11 @@ TEST_P(PropagatorWithRkStepper, rk4_propagator_inhom_bfield) { using policy_t = stepper_rk_policy; using stepper_t = rk_stepper; - // Include helix actor to check track position/covariance + using parameter_updater_t = + parameter_updater>; using actor_chain_t = - actor_chain, - parameter_transporter, - pointwise_material_interactor, - parameter_resetter>; + actor_chain, parameter_updater_t>; using propagator_t = propagator; // Build detector and magnetic field @@ -341,13 +347,16 @@ TEST_P(PropagatorWithRkStepper, rk4_propagator_inhom_bfield) { // Build actor states: the helix inspector can be shared pathlimit_aborter::state unlimted_aborter_state{}; pathlimit_aborter::state pathlimit_aborter_state{path_limit}; + parameter_transporter::state transporter_state{track}; + parameter_transporter::state transporter_state_lim{ + lim_track}; pointwise_material_interactor::state interactor_state{}; // Create actor states tuples - auto actor_states = - detray::tie(unlimted_aborter_state, interactor_state); - auto lim_actor_states = - detray::tie(pathlimit_aborter_state, interactor_state); + auto actor_states = detray::tie(unlimted_aborter_state, + transporter_state, interactor_state); + auto lim_actor_states = detray::tie( + pathlimit_aborter_state, transporter_state_lim, interactor_state); // Init propagator states propagator_t::state state(track, bfield, det); diff --git a/tests/integration_tests/device/cuda/propagator_cuda_kernel.cu b/tests/integration_tests/device/cuda/propagator_cuda_kernel.cu index 5f6951ab9a..0b10cc1aff 100644 --- a/tests/integration_tests/device/cuda/propagator_cuda_kernel.cu +++ b/tests/integration_tests/device/cuda/propagator_cuda_kernel.cu @@ -51,11 +51,12 @@ __global__ void propagator_test_kernel( step_tracer_device_t::state tracer_state(steps.at(gid)); tracer_state.collect_only_on_surface(true); pathlimit_aborter_t::state aborter_state{cfg.stepping.path_limit}; + parameter_transporter::state transporter_state{tracks[gid]}; pointwise_material_interactor::state interactor_state{}; // Create the actor states - auto actor_states = - ::detray::tie(tracer_state, aborter_state, interactor_state); + auto actor_states = ::detray::tie(tracer_state, aborter_state, + transporter_state, interactor_state); // Create the propagator state typename propagator_device_t::state state(tracks[gid], field_data, det); diff --git a/tests/integration_tests/device/sycl/propagator_kernel.sycl b/tests/integration_tests/device/sycl/propagator_kernel.sycl index d421ff8127..b373ec5cda 100644 --- a/tests/integration_tests/device/sycl/propagator_kernel.sycl +++ b/tests/integration_tests/device/sycl/propagator_kernel.sycl @@ -64,14 +64,16 @@ void propagator_test( // Create actor states step_tracer_device_t::state tracer_state(steps.at(gid)); tracer_state.collect_only_on_surface(true); - pathlimit_aborter_t::state aborter_state{ - cfg.stepping.path_limit}; + pathlimit_aborter_t::state aborter_state{cfg.stepping.path_limit}; + parameter_transporter::state transporter_state{ + tracks[gid]}; pointwise_material_interactor::state interactor_state{}; // Create the actor states - auto actor_states = ::detray::tie( - tracer_state, aborter_state, interactor_state); + auto actor_states = + ::detray::tie(tracer_state, aborter_state, + transporter_state, interactor_state); // Create the propagator state typename propagator_device_t::state state(tracks[gid], field_data, dev_det); diff --git a/tests/tools/src/cpu/propagation_benchmark.cpp b/tests/tools/src/cpu/propagation_benchmark.cpp index 1dbc7be0d8..9f43f7f345 100644 --- a/tests/tools/src/cpu/propagation_benchmark.cpp +++ b/tests/tools/src/cpu/propagation_benchmark.cpp @@ -8,11 +8,7 @@ // Project include(s) #include "detray/detectors/bfield.hpp" #include "detray/navigation/navigator.hpp" -#include "detray/propagator/actor_chain.hpp" -#include "detray/propagator/actors/aborters.hpp" -#include "detray/propagator/actors/parameter_resetter.hpp" -#include "detray/propagator/actors/parameter_transporter.hpp" -#include "detray/propagator/actors/pointwise_material_interactor.hpp" +#include "detray/propagator/actors.hpp" #include "detray/propagator/rk_stepper.hpp" #include "detray/tracks/tracks.hpp" #include "detray/utils/type_list.hpp" @@ -61,10 +57,8 @@ int main(int argc, char** argv) { using field_t = bfield::const_field_t; using stepper_t = rk_stepper; using empty_chain_t = actor_chain<>; - using default_chain = - actor_chain, - pointwise_material_interactor, - parameter_resetter>; + using default_chain = actor_chain>>; // Host memory resource vecmem::host_memory_resource host_mr; @@ -150,8 +144,10 @@ int main(int argc, char** argv) { dtuple<> empty_state{}; pointwise_material_interactor::state interactor_state{}; + parameter_updater::state transporter_state{}; - auto actor_states = detail::make_tuple(interactor_state); + auto actor_states = + detail::make_tuple(transporter_state, interactor_state); // // Register benchmarks diff --git a/tests/tools/src/cpu/propagation_scaling.cpp b/tests/tools/src/cpu/propagation_scaling.cpp index a781d93eb9..c65fd776ff 100644 --- a/tests/tools/src/cpu/propagation_scaling.cpp +++ b/tests/tools/src/cpu/propagation_scaling.cpp @@ -8,11 +8,7 @@ // Project include(s) #include "detray/detectors/bfield.hpp" #include "detray/navigation/navigator.hpp" -#include "detray/propagator/actor_chain.hpp" -#include "detray/propagator/actors/aborters.hpp" -#include "detray/propagator/actors/parameter_resetter.hpp" -#include "detray/propagator/actors/parameter_transporter.hpp" -#include "detray/propagator/actors/pointwise_material_interactor.hpp" +#include "detray/propagator/actors.hpp" #include "detray/propagator/rk_stepper.hpp" #include "detray/tracks/tracks.hpp" #include "detray/utils/type_list.hpp" @@ -62,10 +58,8 @@ int main(int argc, char** argv) { using field_t = bfield::const_field_t; using stepper_t = rk_stepper; using empty_chain_t = actor_chain<>; - using default_chain = - actor_chain, - pointwise_material_interactor, - parameter_resetter>; + using default_chain = actor_chain>>; // Code for the openMP scheduling scheme, // @see https://www.openmp.org/spec-html/5.0/openmpsu121.html @@ -178,8 +172,10 @@ int main(int argc, char** argv) { dtuple<> empty_state{}; pointwise_material_interactor::state interactor_state{}; + parameter_updater::state transporter_state{}; - auto actor_states = detail::make_tuple(interactor_state); + auto actor_states = + detail::make_tuple(transporter_state, interactor_state); // // Register benchmarks diff --git a/tests/tools/src/cuda/propagation_benchmark_cuda.cpp b/tests/tools/src/cuda/propagation_benchmark_cuda.cpp index 48534f2a31..a32b1cbd2f 100644 --- a/tests/tools/src/cuda/propagation_benchmark_cuda.cpp +++ b/tests/tools/src/cuda/propagation_benchmark_cuda.cpp @@ -8,11 +8,7 @@ // Project include(s) #include "detray/detectors/bfield.hpp" #include "detray/navigation/navigator.hpp" -#include "detray/propagator/actor_chain.hpp" -#include "detray/propagator/actors/aborters.hpp" -#include "detray/propagator/actors/parameter_resetter.hpp" -#include "detray/propagator/actors/parameter_transporter.hpp" -#include "detray/propagator/actors/pointwise_material_interactor.hpp" +#include "detray/propagator/actors.hpp" #include "detray/propagator/rk_stepper.hpp" #include "detray/tracks/tracks.hpp" #include "detray/utils/type_list.hpp" @@ -145,8 +141,10 @@ int main(int argc, char** argv) { dtuple<> empty_state{}; pointwise_material_interactor::state interactor_state{}; + parameter_updater::state transporter_state{}; - auto actor_states = detail::make_tuple(interactor_state); + auto actor_states = + detail::make_tuple(transporter_state, interactor_state); // // Register benchmarks diff --git a/tests/unit_tests/cpu/propagator/actor_chain.cpp b/tests/unit_tests/cpu/propagator/actor_chain.cpp index e67dcbcab3..b488d6e4d1 100644 --- a/tests/unit_tests/cpu/propagator/actor_chain.cpp +++ b/tests/unit_tests/cpu/propagator/actor_chain.cpp @@ -10,6 +10,7 @@ #include "detray/definitions/units.hpp" #include "detray/propagator/base_actor.hpp" +#include "detray/propagator/concepts.hpp" // GTest include(s). #include @@ -27,6 +28,12 @@ struct print_actor : detray::actor { struct state { std::stringstream stream{}; + state() = default; + state(state &&) noexcept = default; + state(const state &) = delete; + state &operator=(state const &) = delete; + state &operator=(state &&) noexcept = default; + std::string to_string() const { return stream.str(); } }; @@ -114,6 +121,26 @@ using chain = composite_actor; // Test the actor chain on some dummy actor types GTEST_TEST(detray_propagator, actor_chain) { + static_assert(detray::concepts::actor); + static_assert(detray::concepts::actor); + static_assert(detray::concepts::actor); + static_assert(detray::concepts::actor); + static_assert(detray::concepts::actor); + static_assert(detray::concepts::actor); + static_assert(detray::concepts::actor); + static_assert(detray::concepts::actor); + static_assert(detray::concepts::actor); + + static_assert(!detray::concepts::composite_actor); + static_assert(!detray::concepts::composite_actor); + static_assert(detray::concepts::composite_actor); + static_assert(detray::concepts::composite_actor); + static_assert(detray::concepts::composite_actor); + static_assert(detray::concepts::composite_actor); + static_assert(detray::concepts::composite_actor); + static_assert(detray::concepts::composite_actor); + static_assert(detray::concepts::composite_actor); + // The actor states (can be reused between actors) example_actor_t::state example_state{}; print_actor::state printer_state{}; @@ -128,6 +155,10 @@ GTEST_TEST(detray_propagator, actor_chain) { // Chain of actors using actor_chain_t = actor_chain; + + static_assert(detray::concepts::actor_chain>); + static_assert(detray::concepts::actor_chain); + // Run actor_chain_t run_actors{}; run_actors(actor_states, prop_state); diff --git a/tests/unit_tests/cpu/propagator/covariance_transport.cpp b/tests/unit_tests/cpu/propagator/covariance_transport.cpp index de7257409d..b1da29e745 100644 --- a/tests/unit_tests/cpu/propagator/covariance_transport.cpp +++ b/tests/unit_tests/cpu/propagator/covariance_transport.cpp @@ -13,8 +13,8 @@ #include "detray/geometry/shapes/unbounded.hpp" #include "detray/navigation/navigator.hpp" #include "detray/propagator/actor_chain.hpp" -#include "detray/propagator/actors/parameter_resetter.hpp" -#include "detray/propagator/actors/parameter_transporter.hpp" +#include "detray/propagator/actors/parameter_updater.hpp" +#include "detray/propagator/concepts.hpp" #include "detray/propagator/line_stepper.hpp" #include "detray/propagator/propagator.hpp" #include "detray/tracks/ray.hpp" @@ -54,6 +54,9 @@ constexpr scalar tol{1e-6f}; GTEST_TEST(detray_propagator, covariance_transport) { + static_assert(detray::concepts::actor>); + static_assert(detray::concepts::actor>); + vecmem::host_memory_resource host_mr; // Build in x-direction from given module positions @@ -70,8 +73,7 @@ GTEST_TEST(detray_propagator, covariance_transport) { using navigator_t = navigator; using cline_stepper_t = line_stepper; - using actor_chain_t = actor_chain, - parameter_resetter>; + using actor_chain_t = actor_chain>; using propagator_t = propagator; @@ -92,16 +94,19 @@ GTEST_TEST(detray_propagator, covariance_transport) { const bound_track_parameters bound_param0( geometry::barcode{}.set_index(0u), bound_vector, bound_cov); + // Actors + parameter_transporter::state bound_updater{bound_param0}; + propagation::config prop_cfg{}; prop_cfg.navigation.overstep_tolerance = -100.f * unit::um; propagator_t p{prop_cfg}; propagator_t::state propagation(bound_param0, det, prop_cfg.context); // Run propagator - p.propagate(propagation); + p.propagate(propagation, detray::tie(bound_updater)); // Bound state after one turn propagation - const auto& bound_param1 = propagation._stepping.bound_params(); + const auto& bound_param1 = bound_updater.bound_params(); // Check if the track reaches the final surface EXPECT_EQ(bound_param0.surface_link().volume(), 4095u); diff --git a/tests/unit_tests/cpu/utils/tuple_helpers.cpp b/tests/unit_tests/cpu/utils/tuple_helpers.cpp index 31f2b289af..c73c14f1c2 100644 --- a/tests/unit_tests/cpu/utils/tuple_helpers.cpp +++ b/tests/unit_tests/cpu/utils/tuple_helpers.cpp @@ -66,4 +66,25 @@ GTEST_TEST(detray_utils, tuple_helpers) { EXPECT_FLOAT_EQ(detail::get(d_tuple), 1.0f); EXPECT_EQ(detail::get(d_tuple), 2UL); EXPECT_EQ(detail::get(d_tuple), std::string("detray::tuple")); + + // Check type concatenation + static_assert( + std::same_as< + detail::tuple_cat_t, std::tuple, + std::tuple<>, std::tuple>, + std::tuple>); + + static_assert( + std::same_as, dtuple, + dtuple<>, dtuple>, + dtuple>); + + // Permutation check + static_assert(detail::is_permutation_v, dtuple<>>); + static_assert(detail::is_permutation_v, + dtuple>); + static_assert(!detail::is_permutation_v, + dtuple>); + static_assert(!detail::is_permutation_v, + dtuple>); } diff --git a/tutorials/src/device/cuda/propagation.hpp b/tutorials/src/device/cuda/propagation.hpp index f9efc923e4..5a5441b14b 100644 --- a/tutorials/src/device/cuda/propagation.hpp +++ b/tutorials/src/device/cuda/propagation.hpp @@ -54,10 +54,14 @@ using device_field_t = using stepper_t = rk_stepper; // Actors + +// Add the material interaction to the bound track parameter update scheme +using parameter_updater_t = parameter_updater< + detray::tutorial::algebra_t, + pointwise_material_interactor>; +// Make actor call chain using actor_chain_t = - actor_chain, parameter_transporter, - pointwise_material_interactor, - parameter_resetter>; + actor_chain, parameter_updater_t>; // Propagator using propagator_t = propagator; diff --git a/tutorials/src/device/cuda/propagation_kernel.cu b/tutorials/src/device/cuda/propagation_kernel.cu index 5b653ab4b1..d77878fb93 100644 --- a/tutorials/src/device/cuda/propagation_kernel.cu +++ b/tutorials/src/device/cuda/propagation_kernel.cu @@ -40,9 +40,13 @@ __global__ void propagation_kernel( // Create actor states detray::pathlimit_aborter::state aborter_state{path_limit}; - detray::pointwise_material_interactor::state interactor_state{}; + detray::parameter_transporter::state + transporter_state{tracks[gid]}; + detray::pointwise_material_interactor::state + interactor_state{}; - auto actor_states = detray::tie(aborter_state, interactor_state); + auto actor_states = + detray::tie(aborter_state, transporter_state, interactor_state); // Create the propagator state for the track detray::tutorial::propagator_t::state state(tracks[gid], field_data, det);