40
40
#include < sycl/nd_range.hpp> // for nd_range
41
41
#include < sycl/property_list.hpp> // for property_list
42
42
#include < sycl/range.hpp> // for range
43
+ #include < sycl/detail/kernel_name_based_cache_t.hpp>
43
44
44
45
#include < cstddef> // for size_t
45
46
#include < functional> // for function
@@ -2869,6 +2870,119 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
2869
2870
KernelNameBasedCachePtr);
2870
2871
}
2871
2872
2873
+ bool check_event_readiness (std::shared_ptr<detail::event_impl> EventImpl) const ;
2874
+ ur_device_handle_t get_device_ur_handle () const ;
2875
+ ur_context_handle_t get_context_ur_handle () const ;
2876
+ void extract_args_set_arg_value (ur_kernel_handle_t Kernel,
2877
+ size_t NextTrueIndex, int Size, const void *ArgPtr) const ;
2878
+ void extract_args_set_arg_pointer (ur_kernel_handle_t Kernel,
2879
+ size_t NextTrueIndex, const void *Ptr) const ;
2880
+
2881
+ detail::FastKernelCacheValPtr
2882
+ getKernel (ur_device_handle_t Device,
2883
+ ur_context_handle_t Context, detail::FastKernelSubcacheT *KernelSubcacheHint) const {
2884
+
2885
+ const detail::FastKernelSubcacheEntriesT &SubcacheEntries =
2886
+ KernelSubcacheHint->Entries ;
2887
+ detail::FastKernelSubcacheReadLockT SubcacheLock{KernelSubcacheHint->Mutex };
2888
+ const detail::FastKernelCacheKeyT RequiredKey (Device, Context);
2889
+ // Search for the kernel in the subcache.
2890
+ auto It = std::find_if (SubcacheEntries.begin (), SubcacheEntries.end (),
2891
+ [&](const detail::FastKernelEntryT &Entry) {
2892
+ return Entry.Key == RequiredKey;
2893
+ });
2894
+ if (It != SubcacheEntries.end ()) {
2895
+ // traceKernel("Kernel fetched.", KernelName, true);
2896
+ return It->Value ;
2897
+ }
2898
+
2899
+ return detail::FastKernelCacheValPtr ();
2900
+ }
2901
+
2902
+ void extract_args_from_lambda (ur_kernel_handle_t Kernel, void *KernelFuncPtr,
2903
+ const detail::KernelArgMask *EliminatedArgMask, int KernelNumParams,
2904
+ detail::kernel_param_desc_t (*KernelParamDescGetter)(int )) const {
2905
+ auto setFunc = [this , Kernel, KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
2906
+ size_t NextTrueIndex) {
2907
+ const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset ;
2908
+ switch (ParamDesc.kind ) {
2909
+ case detail::kernel_param_kind_t ::kind_std_layout: {
2910
+ int Size = ParamDesc.info ;
2911
+ extract_args_set_arg_value (Kernel, NextTrueIndex, Size, ArgPtr);
2912
+ break ;
2913
+ }
2914
+ case detail::kernel_param_kind_t ::kind_pointer: {
2915
+ const void *Ptr = *static_cast <const void *const *>(ArgPtr);
2916
+ extract_args_set_arg_pointer (Kernel, NextTrueIndex, Ptr);
2917
+ break ;
2918
+ }
2919
+ default :
2920
+ throw std::runtime_error (" Direct kernel argument copy failed." );
2921
+ }
2922
+ };
2923
+
2924
+ if (!EliminatedArgMask || EliminatedArgMask->size () == 0 ) {
2925
+ for (int I = 0 ; I < KernelNumParams; ++I) {
2926
+ const detail::kernel_param_desc_t &Param = KernelParamDescGetter (I);
2927
+ setFunc (Param, I);
2928
+ }
2929
+ } else {
2930
+ size_t NextTrueIndex = 0 ;
2931
+ for (int I = 0 ; I < KernelNumParams; ++I) {
2932
+ const detail::kernel_param_desc_t &Param = KernelParamDescGetter (I);
2933
+ if ((*EliminatedArgMask)[I])
2934
+ continue ;
2935
+ setFunc (Param, NextTrueIndex);
2936
+ ++NextTrueIndex;
2937
+ }
2938
+ }
2939
+ }
2940
+
2941
+ template <typename KernelName, typename KernelType, int Dims>
2942
+ void submit_no_handler_v2 (nd_range<Dims> Range, const KernelType &KernelFunc) {
2943
+
2944
+ using NameT =
2945
+ typename detail::get_kernel_name_t <KernelName, KernelType>::name;
2946
+
2947
+ KernelType KernelFuncLocal = KernelFunc;
2948
+ void *KernelFuncPtr = reinterpret_cast <void *>(&KernelFuncLocal);
2949
+ int KernelNumParams = detail::getKernelNumParams<NameT>();
2950
+ detail::kernel_param_desc_t (*KernelParamDescGetter)(int ) = &(detail::getKernelParamDesc<NameT>);
2951
+ bool IsKernelESIMD = detail::isKernelESIMD<NameT>();
2952
+ bool HasSpecialCapt = detail::hasSpecialCaptures<NameT>();
2953
+ detail::KernelNameBasedCacheT *KernelNameBasedCachePtr = detail::getKernelNameBasedCache<NameT>();
2954
+ ur_device_handle_t Device = get_device_ur_handle ();
2955
+ ur_context_handle_t Context = get_context_ur_handle ();
2956
+
2957
+ assert (HasSpecialCapt == false );
2958
+ assert (IsKernelESIMD == false );
2959
+
2960
+ detail::FastKernelCacheValPtr KernelCacheVal = getKernel (
2961
+ Device, Context, &KernelNameBasedCachePtr->FastKernelSubcache );
2962
+
2963
+ std::lock_guard<std::mutex> Lock{*MMutex};
2964
+
2965
+ bool SchedulerBypass = KernelCacheVal &&
2966
+ (!MLastEvent || (check_event_readiness (MLastEvent)));
2967
+
2968
+ if (SchedulerBypass) {
2969
+ extract_args_from_lambda (KernelCacheVal->MKernelHandle , KernelFuncPtr,
2970
+ KernelCacheVal->MKernelArgMask , KernelNumParams, KernelParamDescGetter);
2971
+ submit_no_handler_impl (Range, KernelCacheVal);
2972
+ MLastEvent = nullptr ;
2973
+ } else {
2974
+ const char *KernelN = detail::getKernelName<NameT>();
2975
+
2976
+ std::shared_ptr<detail::event_impl> EventImpl = submit_no_handler_impl_v2 (
2977
+ Range, KernelN, KernelFuncPtr, KernelNumParams, KernelParamDescGetter,
2978
+ KernelNameBasedCachePtr, MLastEvent);
2979
+
2980
+ if (is_in_order ()) {
2981
+ MLastEvent = EventImpl;
2982
+ }
2983
+ }
2984
+ }
2985
+
2872
2986
public:
2873
2987
// / single_task version not using handler
2874
2988
template <typename KernelName = detail::auto_name, typename KernelType>
@@ -2888,6 +3002,29 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
2888
3002
submit_no_handler<KernelName, KernelType, Dims>(Range, KernelFunc);
2889
3003
}
2890
3004
3005
+ template <typename KernelName = detail::auto_name, int Dims,
3006
+ typename KernelType>
3007
+ void parallel_for_no_handler_v2 (nd_range<Dims> Range, const KernelType &KernelFunc) {
3008
+
3009
+ kernel_parallel_for<KernelName, sycl::nd_item<Dims>, KernelType,
3010
+ ext::oneapi::experimental::empty_properties_t >(KernelFunc);
3011
+ submit_no_handler_v2<KernelName, KernelType, Dims>(Range, KernelFunc);
3012
+ }
3013
+
3014
+ template <typename FuncT>
3015
+ std::enable_if_t <detail::check_fn_signature<std::remove_reference_t <FuncT>,
3016
+ void ()>::value ||
3017
+ detail::check_fn_signature<std::remove_reference_t<FuncT>,
3018
+ void(interop_handle)>::value>
3019
+ host_task_no_handler(FuncT &&Func) {
3020
+ std::lock_guard<std::mutex> Lock{*MMutex};
3021
+
3022
+ std::shared_ptr<detail::event_impl> EventImpl =
3023
+ host_task_no_handler_impl (std::move (Func), MLastEvent);
3024
+ if (is_in_order () && EventImpl) {
3025
+ MLastEvent = EventImpl;
3026
+ }
3027
+ }
2891
3028
2892
3029
2893
3030
// / parallel_for version with a kernel represented as a lambda + range that
@@ -3673,6 +3810,9 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
3673
3810
std::shared_ptr<detail::queue_impl> impl;
3674
3811
queue (std::shared_ptr<detail::queue_impl> impl) : impl(impl) {}
3675
3812
3813
+ std::shared_ptr<detail::event_impl> MLastEvent;
3814
+ std::shared_ptr<std::mutex> MMutex;
3815
+
3676
3816
template <class Obj >
3677
3817
friend const decltype (Obj::impl) &
3678
3818
detail::getSyclObjImpl(const Obj &SyclObject);
@@ -3799,6 +3939,19 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
3799
3939
int KernelNumParams, detail::kernel_param_desc_t (*KernelParamDescGetter)(int ),
3800
3940
detail::KernelNameBasedCacheT *KernelNameBasedCachePtr) const ;
3801
3941
3942
+ template <int Dims>
3943
+ std::shared_ptr<detail::event_impl> submit_no_handler_impl_v2 (
3944
+ nd_range<Dims> Range, const char *KernelName, void *KernelFunc,
3945
+ int KernelNumParams, detail::kernel_param_desc_t (*KernelParamDescGetter)(int ),
3946
+ detail::KernelNameBasedCacheT *KernelNameBasedCachePtr,
3947
+ std::shared_ptr<detail::event_impl> LastEvent) const ;
3948
+
3949
+ template <int Dims>
3950
+ void submit_no_handler_impl (nd_range<Dims> Range, detail::FastKernelCacheValPtr &KernelCacheVal) const ;
3951
+
3952
+ std::shared_ptr<detail::event_impl> host_task_no_handler_impl (
3953
+ std::function<void ()> &&Func, std::shared_ptr<detail::event_impl> LastEventImpl);
3954
+
3802
3955
// / Submits a command group function object to the queue, in order to be
3803
3956
// / scheduled for execution on the device.
3804
3957
// /
0 commit comments