Skip to content

Commit e1b265d

Browse files
committed
[WIP] No handler submit POC - header based version
1 parent a1971a3 commit e1b265d

File tree

12 files changed

+438
-20
lines changed

12 files changed

+438
-20
lines changed

sycl/source/detail/kernel_name_based_cache_t.hpp renamed to sycl/include/sycl/detail/kernel_name_based_cache_t.hpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
//===----------------------------------------------------------------------===//
88
#pragma once
99

10-
#include <detail/hashers.hpp>
11-
#include <detail/kernel_arg_mask.hpp>
10+
//#include <detail/hashers.hpp>
11+
//#include <sycl/detail/kernel_arg_mask.hpp>
1212
#include <emhash/hash_table8.hpp>
1313
#include <sycl/detail/spinlock.hpp>
1414
#include <sycl/detail/ur.hpp>
@@ -20,6 +20,7 @@ namespace sycl {
2020
inline namespace _V1 {
2121
namespace detail {
2222
using FastKernelCacheKeyT = std::pair<ur_device_handle_t, ur_context_handle_t>;
23+
using KernelArgMask = std::vector<bool>;
2324

2425
struct FastKernelCacheVal {
2526
ur_kernel_handle_t MKernelHandle; /* UR kernel handle pointer. */
@@ -29,25 +30,27 @@ struct FastKernelCacheVal {
2930
const KernelArgMask *MKernelArgMask; /* Eliminated kernel argument mask. */
3031
ur_program_handle_t MProgramHandle; /* UR program handle corresponding to
3132
this kernel. */
32-
const Adapter &MAdapterPtr; /* We can keep reference to the adapter
33+
/*const Adapter &MAdapterPtr;*/ /* We can keep reference to the adapter
3334
because during 2-stage shutdown the kernel
3435
cache is destroyed deliberately before the
3536
adapter. */
3637

3738
FastKernelCacheVal(ur_kernel_handle_t KernelHandle, std::mutex *Mutex,
3839
const KernelArgMask *KernelArgMask,
39-
ur_program_handle_t ProgramHandle,
40-
const Adapter &AdapterPtr)
40+
ur_program_handle_t ProgramHandle)
41+
//const Adapter &AdapterPtr)
4142
: MKernelHandle(KernelHandle), MMutex(Mutex),
42-
MKernelArgMask(KernelArgMask), MProgramHandle(ProgramHandle),
43-
MAdapterPtr(AdapterPtr) {}
43+
MKernelArgMask(KernelArgMask), MProgramHandle(ProgramHandle)
44+
/*MAdapterPtr(AdapterPtr)*/ {}
4445

4546
~FastKernelCacheVal() {
47+
/*
4648
if (MKernelHandle)
4749
MAdapterPtr.call<sycl::detail::UrApiKind::urKernelRelease>(MKernelHandle);
4850
if (MProgramHandle)
4951
MAdapterPtr.call<sycl::detail::UrApiKind::urProgramRelease>(
5052
MProgramHandle);
53+
*/
5154
MKernelHandle = nullptr;
5255
MMutex = nullptr;
5356
MKernelArgMask = nullptr;

sycl/include/sycl/khr/free_function_commands.hpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,28 +150,28 @@ void launch_grouped(handler &h, range<3> r, range<3> size,
150150
}
151151

152152
template <typename KernelType>
153-
void launch_grouped(const queue &q, range<1> r, range<1> size,
153+
void launch_grouped(queue &q, range<1> r, range<1> size,
154154
const KernelType &k,
155155
const sycl::detail::code_location &codeLoc =
156156
sycl::detail::code_location::current()) {
157157
(void)codeLoc;
158-
q.parallel_for_no_handler(nd_range<1>(r, size), k);
158+
q.parallel_for_no_handler_v2(nd_range<1>(r, size), k);
159159
}
160160
template <typename KernelType>
161-
void launch_grouped(const queue &q, range<2> r, range<2> size,
161+
void launch_grouped(queue &q, range<2> r, range<2> size,
162162
const KernelType &k,
163163
const sycl::detail::code_location &codeLoc =
164164
sycl::detail::code_location::current()) {
165165
(void)codeLoc;
166-
q.parallel_for_no_handler(nd_range<2>(r, size), k);
166+
q.parallel_for_no_handler_v2(nd_range<2>(r, size), k);
167167
}
168168
template <typename KernelType>
169-
void launch_grouped(const queue &q, range<3> r, range<3> size,
169+
void launch_grouped(queue &q, range<3> r, range<3> size,
170170
const KernelType &k,
171171
const sycl::detail::code_location &codeLoc =
172172
sycl::detail::code_location::current()) {
173173
(void)codeLoc;
174-
q.parallel_for_no_handler(nd_range<3>(r, size), k);
174+
q.parallel_for_no_handler_v2(nd_range<3>(r, size), k);
175175
}
176176

177177
template <typename... Args>
@@ -297,6 +297,11 @@ void launch_task(const queue &q, const kernel &k, Args &&...args) {
297297
[&](handler &h) { launch_task(h, k, std::forward<Args>(args)...); });
298298
}
299299

300+
template <typename FuncT>
301+
void launch_host_task(queue &q, FuncT &&Func) {
302+
q.host_task_no_handler(std::move(Func));
303+
}
304+
300305
inline void memcpy(handler &h, void *dest, const void *src, size_t numBytes) {
301306
h.memcpy(dest, src, numBytes);
302307
}

sycl/include/sycl/queue.hpp

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include <sycl/nd_range.hpp> // for nd_range
4141
#include <sycl/property_list.hpp> // for property_list
4242
#include <sycl/range.hpp> // for range
43+
#include <sycl/detail/kernel_name_based_cache_t.hpp>
4344

4445
#include <cstddef> // for size_t
4546
#include <functional> // for function
@@ -2869,6 +2870,119 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
28692870
KernelNameBasedCachePtr);
28702871
}
28712872

2873+
bool check_event_readiness(std::shared_ptr<detail::event_impl> EventImpl) const;
2874+
ur_device_handle_t get_device_ur_handle() const;
2875+
ur_context_handle_t get_context_ur_handle() const;
2876+
void extract_args_set_arg_value(ur_kernel_handle_t Kernel,
2877+
size_t NextTrueIndex, int Size, const void *ArgPtr) const;
2878+
void extract_args_set_arg_pointer(ur_kernel_handle_t Kernel,
2879+
size_t NextTrueIndex, const void *Ptr) const;
2880+
2881+
detail::FastKernelCacheValPtr
2882+
getKernel(ur_device_handle_t Device,
2883+
ur_context_handle_t Context, detail::FastKernelSubcacheT *KernelSubcacheHint) const {
2884+
2885+
const detail::FastKernelSubcacheEntriesT &SubcacheEntries =
2886+
KernelSubcacheHint->Entries;
2887+
detail::FastKernelSubcacheReadLockT SubcacheLock{KernelSubcacheHint->Mutex};
2888+
const detail::FastKernelCacheKeyT RequiredKey(Device, Context);
2889+
// Search for the kernel in the subcache.
2890+
auto It = std::find_if(SubcacheEntries.begin(), SubcacheEntries.end(),
2891+
[&](const detail::FastKernelEntryT &Entry) {
2892+
return Entry.Key == RequiredKey;
2893+
});
2894+
if (It != SubcacheEntries.end()) {
2895+
//traceKernel("Kernel fetched.", KernelName, true);
2896+
return It->Value;
2897+
}
2898+
2899+
return detail::FastKernelCacheValPtr();
2900+
}
2901+
2902+
void extract_args_from_lambda(ur_kernel_handle_t Kernel, void *KernelFuncPtr,
2903+
const detail::KernelArgMask *EliminatedArgMask, int KernelNumParams,
2904+
detail::kernel_param_desc_t (*KernelParamDescGetter)(int)) const {
2905+
auto setFunc = [this, Kernel, KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
2906+
size_t NextTrueIndex) {
2907+
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset;
2908+
switch (ParamDesc.kind) {
2909+
case detail::kernel_param_kind_t::kind_std_layout: {
2910+
int Size = ParamDesc.info;
2911+
extract_args_set_arg_value(Kernel, NextTrueIndex, Size, ArgPtr);
2912+
break;
2913+
}
2914+
case detail::kernel_param_kind_t::kind_pointer: {
2915+
const void *Ptr = *static_cast<const void *const *>(ArgPtr);
2916+
extract_args_set_arg_pointer(Kernel, NextTrueIndex, Ptr);
2917+
break;
2918+
}
2919+
default:
2920+
throw std::runtime_error("Direct kernel argument copy failed.");
2921+
}
2922+
};
2923+
2924+
if (!EliminatedArgMask || EliminatedArgMask->size() == 0) {
2925+
for (int I = 0; I < KernelNumParams; ++I) {
2926+
const detail::kernel_param_desc_t &Param = KernelParamDescGetter(I);
2927+
setFunc(Param, I);
2928+
}
2929+
} else {
2930+
size_t NextTrueIndex = 0;
2931+
for (int I = 0; I < KernelNumParams; ++I) {
2932+
const detail::kernel_param_desc_t &Param = KernelParamDescGetter(I);
2933+
if ((*EliminatedArgMask)[I])
2934+
continue;
2935+
setFunc(Param, NextTrueIndex);
2936+
++NextTrueIndex;
2937+
}
2938+
}
2939+
}
2940+
2941+
template <typename KernelName, typename KernelType, int Dims>
2942+
void submit_no_handler_v2(nd_range<Dims> Range, const KernelType &KernelFunc) {
2943+
2944+
using NameT =
2945+
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
2946+
2947+
KernelType KernelFuncLocal = KernelFunc;
2948+
void *KernelFuncPtr = reinterpret_cast<void *>(&KernelFuncLocal);
2949+
int KernelNumParams = detail::getKernelNumParams<NameT>();
2950+
detail::kernel_param_desc_t (*KernelParamDescGetter)(int) = &(detail::getKernelParamDesc<NameT>);
2951+
bool IsKernelESIMD = detail::isKernelESIMD<NameT>();
2952+
bool HasSpecialCapt = detail::hasSpecialCaptures<NameT>();
2953+
detail::KernelNameBasedCacheT *KernelNameBasedCachePtr = detail::getKernelNameBasedCache<NameT>();
2954+
ur_device_handle_t Device = get_device_ur_handle();
2955+
ur_context_handle_t Context = get_context_ur_handle();
2956+
2957+
assert(HasSpecialCapt == false);
2958+
assert(IsKernelESIMD == false);
2959+
2960+
detail::FastKernelCacheValPtr KernelCacheVal = getKernel(
2961+
Device, Context, &KernelNameBasedCachePtr->FastKernelSubcache);
2962+
2963+
std::lock_guard<std::mutex> Lock{*MMutex};
2964+
2965+
bool SchedulerBypass = KernelCacheVal &&
2966+
(!MLastEvent || (check_event_readiness(MLastEvent)));
2967+
2968+
if (SchedulerBypass) {
2969+
extract_args_from_lambda(KernelCacheVal->MKernelHandle, KernelFuncPtr,
2970+
KernelCacheVal->MKernelArgMask, KernelNumParams, KernelParamDescGetter);
2971+
submit_no_handler_impl(Range, KernelCacheVal);
2972+
MLastEvent = nullptr;
2973+
} else {
2974+
const char *KernelN = detail::getKernelName<NameT>();
2975+
2976+
std::shared_ptr<detail::event_impl> EventImpl = submit_no_handler_impl_v2(
2977+
Range, KernelN, KernelFuncPtr, KernelNumParams, KernelParamDescGetter,
2978+
KernelNameBasedCachePtr, MLastEvent);
2979+
2980+
if (is_in_order()) {
2981+
MLastEvent = EventImpl;
2982+
}
2983+
}
2984+
}
2985+
28722986
public:
28732987
/// single_task version not using handler
28742988
template <typename KernelName = detail::auto_name, typename KernelType>
@@ -2888,6 +3002,29 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
28883002
submit_no_handler<KernelName, KernelType, Dims>(Range, KernelFunc);
28893003
}
28903004

3005+
template <typename KernelName = detail::auto_name, int Dims,
3006+
typename KernelType>
3007+
void parallel_for_no_handler_v2(nd_range<Dims> Range, const KernelType &KernelFunc) {
3008+
3009+
kernel_parallel_for<KernelName, sycl::nd_item<Dims>, KernelType,
3010+
ext::oneapi::experimental::empty_properties_t>(KernelFunc);
3011+
submit_no_handler_v2<KernelName, KernelType, Dims>(Range, KernelFunc);
3012+
}
3013+
3014+
template <typename FuncT>
3015+
std::enable_if_t<detail::check_fn_signature<std::remove_reference_t<FuncT>,
3016+
void()>::value ||
3017+
detail::check_fn_signature<std::remove_reference_t<FuncT>,
3018+
void(interop_handle)>::value>
3019+
host_task_no_handler(FuncT &&Func) {
3020+
std::lock_guard<std::mutex> Lock{*MMutex};
3021+
3022+
std::shared_ptr<detail::event_impl> EventImpl =
3023+
host_task_no_handler_impl(std::move(Func), MLastEvent);
3024+
if (is_in_order() && EventImpl) {
3025+
MLastEvent = EventImpl;
3026+
}
3027+
}
28913028

28923029

28933030
/// parallel_for version with a kernel represented as a lambda + range that
@@ -3673,6 +3810,9 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
36733810
std::shared_ptr<detail::queue_impl> impl;
36743811
queue(std::shared_ptr<detail::queue_impl> impl) : impl(impl) {}
36753812

3813+
std::shared_ptr<detail::event_impl> MLastEvent;
3814+
std::shared_ptr<std::mutex> MMutex;
3815+
36763816
template <class Obj>
36773817
friend const decltype(Obj::impl) &
36783818
detail::getSyclObjImpl(const Obj &SyclObject);
@@ -3799,6 +3939,19 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
37993939
int KernelNumParams, detail::kernel_param_desc_t (*KernelParamDescGetter)(int),
38003940
detail::KernelNameBasedCacheT *KernelNameBasedCachePtr) const;
38013941

3942+
template<int Dims>
3943+
std::shared_ptr<detail::event_impl> submit_no_handler_impl_v2(
3944+
nd_range<Dims> Range, const char *KernelName, void *KernelFunc,
3945+
int KernelNumParams, detail::kernel_param_desc_t (*KernelParamDescGetter)(int),
3946+
detail::KernelNameBasedCacheT *KernelNameBasedCachePtr,
3947+
std::shared_ptr<detail::event_impl> LastEvent) const;
3948+
3949+
template<int Dims>
3950+
void submit_no_handler_impl(nd_range<Dims> Range, detail::FastKernelCacheValPtr &KernelCacheVal) const;
3951+
3952+
std::shared_ptr<detail::event_impl> host_task_no_handler_impl(
3953+
std::function<void()> &&Func, std::shared_ptr<detail::event_impl> LastEventImpl);
3954+
38023955
/// Submits a command group function object to the queue, in order to be
38033956
/// scheduled for execution on the device.
38043957
///

sycl/source/detail/global_handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include <detail/adapter.hpp>
1515
#include <detail/config.hpp>
1616
#include <detail/global_handler.hpp>
17-
#include <detail/kernel_name_based_cache_t.hpp>
17+
#include <sycl/detail/kernel_name_based_cache_t.hpp>
1818
#include <detail/platform_impl.hpp>
1919
#include <detail/program_manager/program_manager.hpp>
2020
#include <detail/scheduler/scheduler.hpp>

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
#include "sycl/exception.hpp"
1212
#include <detail/config.hpp>
1313
#include <detail/kernel_arg_mask.hpp>
14-
#include <detail/kernel_name_based_cache_t.hpp>
14+
#include <detail/hashers.hpp>
15+
#include <sycl/detail/kernel_name_based_cache_t.hpp>
1516
#include <detail/platform_impl.hpp>
1617
#include <detail/unordered_multimap.hpp>
1718
#include <sycl/detail/common.hpp>

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,7 @@ FastKernelCacheValPtr ProgramManager::getOrCreateKernel(
11861186
// nullptr for the mutex.
11871187
auto [Kernel, ArgMask] = BuildF();
11881188
return std::make_shared<FastKernelCacheVal>(
1189-
Kernel, nullptr, ArgMask, Program, *ContextImpl.getAdapter().get());
1189+
Kernel, nullptr, ArgMask, Program); //*ContextImpl.getAdapter().get());
11901190
}
11911191

11921192
auto BuildResult = Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
@@ -1195,7 +1195,7 @@ FastKernelCacheValPtr ProgramManager::getOrCreateKernel(
11951195
const KernelArgMaskPairT &KernelArgMaskPair = BuildResult->Val;
11961196
auto ret_val = std::make_shared<FastKernelCacheVal>(
11971197
KernelArgMaskPair.first, &(BuildResult->MBuildResultMutex),
1198-
KernelArgMaskPair.second, Program, *ContextImpl.getAdapter().get());
1198+
KernelArgMaskPair.second, Program); //*ContextImpl.getAdapter().get());
11991199
// If caching is enabled, one copy of the kernel handle will be
12001200
// stored in FastKernelCacheVal, and one is in
12011201
// KernelProgramCache::MKernelsPerProgramCache. To cover

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <detail/device_global_map_entry.hpp>
1313
#include <detail/host_pipe_map_entry.hpp>
1414
#include <detail/kernel_arg_mask.hpp>
15-
#include <detail/kernel_name_based_cache_t.hpp>
15+
#include <sycl/detail/kernel_name_based_cache_t.hpp>
1616
#include <detail/spec_constant_impl.hpp>
1717
#include <sycl/detail/cg_types.hpp>
1818
#include <sycl/detail/common.hpp>

0 commit comments

Comments
 (0)