Skip to content

[NFC][SYCL] Pass queue_impl by raw ptr in commands.hpp #19004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions sycl/source/detail/cg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,14 +725,10 @@ class CGHostTask : public CG {
std::shared_ptr<detail::context_impl> MContext;
std::vector<ArgDesc> MArgs;

CGHostTask(std::shared_ptr<HostTask> HostTask,
std::shared_ptr<detail::queue_impl> Queue,
CGHostTask(std::shared_ptr<HostTask> HostTask, detail::queue_impl *Queue,
std::shared_ptr<detail::context_impl> Context,
std::vector<ArgDesc> Args, CG::StorageInitHelper CGData,
CGType Type, detail::code_location loc = {})
: CG(Type, std::move(CGData), std::move(loc)),
MHostTask(std::move(HostTask)), MQueue(Queue), MContext(Context),
MArgs(std::move(Args)) {}
CGType Type, detail::code_location loc = {});
};

} // namespace detail
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class node_impl : public std::enable_shared_from_this<node_impl> {

return std::make_unique<sycl::detail::CGHostTask>(
sycl::detail::CGHostTask(
std::move(HostTaskSPtr), CommandGroupPtr->MQueue,
std::move(HostTaskSPtr), CommandGroupPtr->MQueue.get(),
CommandGroupPtr->MContext, std::move(NewArgs), std::move(Data),
CommandGroupPtr->getType(), Loc));
}
Expand Down
5 changes: 4 additions & 1 deletion sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,9 +650,12 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
// for in order ones.
void revisitUnenqueuedCommandsState(const EventImplPtr &CompletedHostTask);

static ContextImplPtr getContext(const QueueImplPtr &Queue) {
static ContextImplPtr getContext(queue_impl *Queue) {
return Queue ? Queue->getContextImplPtr() : nullptr;
}
static ContextImplPtr getContext(const QueueImplPtr &Queue) {
return getContext(Queue.get());
}

// Must be called under MMutex protection
void doUnenqueuedCommandCleanup(
Expand Down
118 changes: 62 additions & 56 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,14 @@ static unsigned long long getQueueID(const std::shared_ptr<queue_impl> &Queue) {
}
#endif

static context_impl *getContext(const QueueImplPtr &Queue) {
static context_impl *getContext(queue_impl *Queue) {
if (Queue)
return &Queue->getContextImpl();
return nullptr;
}
static context_impl *getContext(const std::shared_ptr<queue_impl> &Queue) {
return getContext(Queue.get());
}

#ifdef __SYCL_ENABLE_GNU_DEMANGLING
struct DemangleHandle {
Expand Down Expand Up @@ -510,7 +513,7 @@ void Command::waitForPreparedHostEvents() const {
HostEvent->waitInternal();
}

void Command::waitForEvents(QueueImplPtr Queue,
void Command::waitForEvents(queue_impl *Queue,
std::vector<EventImplPtr> &EventImpls,
ur_event_handle_t &Event) {
#ifndef NDEBUG
Expand Down Expand Up @@ -566,12 +569,12 @@ void Command::waitForEvents(QueueImplPtr Queue,
/// references to event_impl class members because Command
/// should not outlive the event connected to it.
Command::Command(
CommandType Type, QueueImplPtr Queue,
CommandType Type, queue_impl *Queue,
ur_exp_command_buffer_handle_t CommandBuffer,
const std::vector<ur_exp_command_buffer_sync_point_t> &SyncPoints)
: MQueue(std::move(Queue)),
MEvent(MQueue ? detail::event_impl::create_device_event(*MQueue)
: detail::event_impl::create_incomplete_host_event()),
: MQueue(Queue ? Queue->shared_from_this() : nullptr),
MEvent(Queue ? detail::event_impl::create_device_event(*Queue)
: detail::event_impl::create_incomplete_host_event()),
MPreparedDepsEvents(MEvent->getPreparedDepsEvents()),
MPreparedHostDepsEvents(MEvent->getPreparedHostDepsEvents()), MType(Type),
MCommandBuffer(CommandBuffer), MSyncPointDeps(SyncPoints) {
Expand Down Expand Up @@ -1034,7 +1037,7 @@ void Command::copySubmissionCodeLocation() {
#endif
}

AllocaCommandBase::AllocaCommandBase(CommandType Type, QueueImplPtr Queue,
AllocaCommandBase::AllocaCommandBase(CommandType Type, queue_impl *Queue,
Requirement Req,
AllocaCommandBase *LinkedAllocaCmd,
bool IsConst)
Expand Down Expand Up @@ -1077,10 +1080,10 @@ bool AllocaCommandBase::supportsPostEnqueueCleanup() const { return false; }

bool AllocaCommandBase::readyForCleanup() const { return false; }

AllocaCommand::AllocaCommand(QueueImplPtr Queue, Requirement Req,
AllocaCommand::AllocaCommand(queue_impl *Queue, Requirement Req,
bool InitFromUserData,
AllocaCommandBase *LinkedAllocaCmd, bool IsConst)
: AllocaCommandBase(CommandType::ALLOCA, std::move(Queue), std::move(Req),
: AllocaCommandBase(CommandType::ALLOCA, Queue, std::move(Req),
LinkedAllocaCmd, IsConst),
MInitFromUserData(InitFromUserData) {
// Node event must be created before the dependent edge is added to this
Expand Down Expand Up @@ -1115,7 +1118,7 @@ ur_result_t AllocaCommand::enqueueImp() {

if (!MQueue) {
// Do not need to make allocation if we have a linked device allocation
Command::waitForEvents(MQueue, EventImpls, UREvent);
Command::waitForEvents(MQueue.get(), EventImpls, UREvent);
MEvent->setHandle(UREvent);

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -1155,12 +1158,11 @@ void AllocaCommand::printDot(std::ostream &Stream) const {
}
}

AllocaSubBufCommand::AllocaSubBufCommand(QueueImplPtr Queue, Requirement Req,
AllocaSubBufCommand::AllocaSubBufCommand(queue_impl *Queue, Requirement Req,
AllocaCommandBase *ParentAlloca,
std::vector<Command *> &ToEnqueue,
std::vector<Command *> &ToCleanUp)
: AllocaCommandBase(CommandType::ALLOCA_SUB_BUF, std::move(Queue),
std::move(Req),
: AllocaCommandBase(CommandType::ALLOCA_SUB_BUF, Queue, std::move(Req),
/*LinkedAllocaCmd*/ nullptr, /*IsConst*/ false),
MParentAlloca(ParentAlloca) {
// Node event must be created before the dependent edge
Expand Down Expand Up @@ -1241,8 +1243,8 @@ void AllocaSubBufCommand::printDot(std::ostream &Stream) const {
}
}

ReleaseCommand::ReleaseCommand(QueueImplPtr Queue, AllocaCommandBase *AllocaCmd)
: Command(CommandType::RELEASE, std::move(Queue)), MAllocaCmd(AllocaCmd) {
ReleaseCommand::ReleaseCommand(queue_impl *Queue, AllocaCommandBase *AllocaCmd)
: Command(CommandType::RELEASE, Queue), MAllocaCmd(AllocaCmd) {
emitInstrumentationDataProxy();
}

Expand Down Expand Up @@ -1295,9 +1297,9 @@ ur_result_t ReleaseCommand::enqueueImp() {
}

if (NeedUnmap) {
const QueueImplPtr &Queue = CurAllocaIsHost
? MAllocaCmd->MLinkedAllocaCmd->getQueue()
: MAllocaCmd->getQueue();
queue_impl *Queue = CurAllocaIsHost
? MAllocaCmd->MLinkedAllocaCmd->getQueue()
: MAllocaCmd->getQueue();

assert(Queue);

Expand Down Expand Up @@ -1328,7 +1330,7 @@ ur_result_t ReleaseCommand::enqueueImp() {
}
ur_event_handle_t UREvent = nullptr;
if (SkipRelease)
Command::waitForEvents(MQueue, EventImpls, UREvent);
Command::waitForEvents(MQueue.get(), EventImpls, UREvent);
else {
if (auto Result = callMemOpHelper(
MemoryManager::release, getContext(MQueue),
Expand Down Expand Up @@ -1366,11 +1368,10 @@ bool ReleaseCommand::supportsPostEnqueueCleanup() const { return false; }
bool ReleaseCommand::readyForCleanup() const { return false; }

MapMemObject::MapMemObject(AllocaCommandBase *SrcAllocaCmd, Requirement Req,
void **DstPtr, QueueImplPtr Queue,
void **DstPtr, queue_impl *Queue,
access::mode MapMode)
: Command(CommandType::MAP_MEM_OBJ, std::move(Queue)),
MSrcAllocaCmd(SrcAllocaCmd), MSrcReq(std::move(Req)), MDstPtr(DstPtr),
MMapMode(MapMode) {
: Command(CommandType::MAP_MEM_OBJ, Queue), MSrcAllocaCmd(SrcAllocaCmd),
MSrcReq(std::move(Req)), MDstPtr(DstPtr), MMapMode(MapMode) {
emitInstrumentationDataProxy();
}

Expand Down Expand Up @@ -1430,9 +1431,9 @@ void MapMemObject::printDot(std::ostream &Stream) const {
}

UnMapMemObject::UnMapMemObject(AllocaCommandBase *DstAllocaCmd, Requirement Req,
void **SrcPtr, QueueImplPtr Queue)
: Command(CommandType::UNMAP_MEM_OBJ, std::move(Queue)),
MDstAllocaCmd(DstAllocaCmd), MDstReq(std::move(Req)), MSrcPtr(SrcPtr) {
void **SrcPtr, queue_impl *Queue)
: Command(CommandType::UNMAP_MEM_OBJ, Queue), MDstAllocaCmd(DstAllocaCmd),
MDstReq(std::move(Req)), MSrcPtr(SrcPtr) {
emitInstrumentationDataProxy();
}

Expand Down Expand Up @@ -1516,11 +1517,11 @@ MemCpyCommand::MemCpyCommand(Requirement SrcReq,
AllocaCommandBase *SrcAllocaCmd,
Requirement DstReq,
AllocaCommandBase *DstAllocaCmd,
QueueImplPtr SrcQueue, QueueImplPtr DstQueue)
: Command(CommandType::COPY_MEMORY, std::move(DstQueue)),
MSrcQueue(SrcQueue), MSrcReq(std::move(SrcReq)),
MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(DstReq)),
MDstAllocaCmd(DstAllocaCmd) {
queue_impl *SrcQueue, queue_impl *DstQueue)
: Command(CommandType::COPY_MEMORY, DstQueue),
MSrcQueue(SrcQueue ? SrcQueue->shared_from_this() : nullptr),
MSrcReq(std::move(SrcReq)), MSrcAllocaCmd(SrcAllocaCmd),
MDstReq(std::move(DstReq)), MDstAllocaCmd(DstAllocaCmd) {
if (MSrcQueue) {
MEvent->setContextImpl(MSrcQueue->getContextImplPtr());
}
Expand Down Expand Up @@ -1652,7 +1653,7 @@ ur_result_t UpdateHostRequirementCommand::enqueueImp() {
waitForPreparedHostEvents();
std::vector<EventImplPtr> EventImpls = MPreparedDepsEvents;
ur_event_handle_t UREvent = nullptr;
Command::waitForEvents(MQueue, EventImpls, UREvent);
Command::waitForEvents(MQueue.get(), EventImpls, UREvent);
MEvent->setHandle(UREvent);

assert(MSrcAllocaCmd && "Expected valid alloca command");
Expand Down Expand Up @@ -1689,11 +1690,11 @@ void UpdateHostRequirementCommand::printDot(std::ostream &Stream) const {
MemCpyCommandHost::MemCpyCommandHost(Requirement SrcReq,
AllocaCommandBase *SrcAllocaCmd,
Requirement DstReq, void **DstPtr,
QueueImplPtr SrcQueue,
QueueImplPtr DstQueue)
: Command(CommandType::COPY_MEMORY, std::move(DstQueue)),
MSrcQueue(SrcQueue), MSrcReq(std::move(SrcReq)),
MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(DstReq)), MDstPtr(DstPtr) {
queue_impl *SrcQueue, queue_impl *DstQueue)
: Command(CommandType::COPY_MEMORY, DstQueue),
MSrcQueue(SrcQueue ? SrcQueue->shared_from_this() : nullptr),
MSrcReq(std::move(SrcReq)), MSrcAllocaCmd(SrcAllocaCmd),
MDstReq(std::move(DstReq)), MDstPtr(DstPtr) {
if (MSrcQueue) {
MEvent->setContextImpl(MSrcQueue->getContextImplPtr());
}
Expand Down Expand Up @@ -1735,7 +1736,7 @@ ContextImplPtr MemCpyCommandHost::getWorkerContext() const {
}

ur_result_t MemCpyCommandHost::enqueueImp() {
const QueueImplPtr &Queue = MWorkerQueue;
queue_impl *Queue = MWorkerQueue.get();
waitForPreparedHostEvents();
std::vector<EventImplPtr> EventImpls = MPreparedDepsEvents;
std::vector<ur_event_handle_t> RawEvents = getUrEvents(EventImpls);
Expand Down Expand Up @@ -1774,7 +1775,7 @@ EmptyCommand::EmptyCommand() : Command(CommandType::EMPTY_TASK, nullptr) {
ur_result_t EmptyCommand::enqueueImp() {
waitForPreparedHostEvents();
ur_event_handle_t UREvent = nullptr;
waitForEvents(MQueue, MPreparedDepsEvents, UREvent);
waitForEvents(MQueue.get(), MPreparedDepsEvents, UREvent);
MEvent->setHandle(UREvent);
return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -1858,9 +1859,9 @@ void MemCpyCommandHost::printDot(std::ostream &Stream) const {
}

UpdateHostRequirementCommand::UpdateHostRequirementCommand(
QueueImplPtr Queue, Requirement Req, AllocaCommandBase *SrcAllocaCmd,
queue_impl *Queue, Requirement Req, AllocaCommandBase *SrcAllocaCmd,
void **DstPtr)
: Command(CommandType::UPDATE_REQUIREMENT, std::move(Queue)),
: Command(CommandType::UPDATE_REQUIREMENT, Queue),
MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(Req)), MDstPtr(DstPtr) {

emitInstrumentationDataProxy();
Expand Down Expand Up @@ -1956,11 +1957,10 @@ static std::string_view cgTypeToString(detail::CGType Type) {
}

ExecCGCommand::ExecCGCommand(
std::unique_ptr<detail::CG> CommandGroup, QueueImplPtr Queue,
std::unique_ptr<detail::CG> CommandGroup, queue_impl *Queue,
bool EventNeeded, ur_exp_command_buffer_handle_t CommandBuffer,
const std::vector<ur_exp_command_buffer_sync_point_t> &Dependencies)
: Command(CommandType::RUN_CG, std::move(Queue), CommandBuffer,
Dependencies),
: Command(CommandType::RUN_CG, Queue, CommandBuffer, Dependencies),
MEventNeeded(EventNeeded), MCommandGroup(std::move(CommandGroup)) {
if (MCommandGroup->getType() == detail::CGType::CodeplayHostTask) {
MEvent->setSubmittedQueue(
Expand Down Expand Up @@ -2777,20 +2777,18 @@ void enqueueImpKernel(
}
}

ur_result_t enqueueReadWriteHostPipe(const QueueImplPtr &Queue,
ur_result_t enqueueReadWriteHostPipe(queue_impl &Queue,
const std::string &PipeName, bool blocking,
void *ptr, size_t size,
std::vector<ur_event_handle_t> &RawEvents,
detail::event_impl *OutEventImpl,
bool read) {
assert(Queue &&
"ReadWrite host pipe submissions should have an associated queue");
detail::HostPipeMapEntry *hostPipeEntry =
ProgramManager::getInstance().getHostPipeEntry(PipeName);

ur_program_handle_t Program = nullptr;
device Device = Queue->get_device();
ContextImplPtr ContextImpl = Queue->getContextImplPtr();
device Device = Queue.get_device();
ContextImplPtr ContextImpl = Queue.getContextImplPtr();
std::optional<ur_program_handle_t> CachedProgram =
ContextImpl->getProgramForHostPipe(Device, hostPipeEntry);
if (CachedProgram)
Expand All @@ -2799,17 +2797,16 @@ ur_result_t enqueueReadWriteHostPipe(const QueueImplPtr &Queue,
// If there was no cached program, build one.
device_image_plain devImgPlain =
ProgramManager::getInstance().getDeviceImageFromBinaryImage(
hostPipeEntry->getDevBinImage(), Queue->get_context(),
Queue->get_device());
hostPipeEntry->getDevBinImage(), Queue.get_context(), Device);
device_image_plain BuiltImage = ProgramManager::getInstance().build(
std::move(devImgPlain), {std::move(Device)}, {});
Program = getSyclObjImpl(BuiltImage)->get_ur_program_ref();
}
assert(Program && "Program for this hostpipe is not compiled.");

const AdapterPtr &Adapter = Queue->getAdapter();
const AdapterPtr &Adapter = Queue.getAdapter();

ur_queue_handle_t ur_q = Queue->getHandleRef();
ur_queue_handle_t ur_q = Queue.getHandleRef();
ur_result_t Error;

ur_event_handle_t UREvent = nullptr;
Expand Down Expand Up @@ -3667,7 +3664,7 @@ ur_result_t ExecCGCommand::enqueueImpQueue() {
if (!EventImpl) {
EventImpl = MEvent.get();
}
return enqueueReadWriteHostPipe(MQueue, pipeName, blocking, hostPtr,
return enqueueReadWriteHostPipe(*MQueue, pipeName, blocking, hostPtr,
typeSize, RawEvents, EventImpl, read);
}
case CGType::ExecCommandBuffer: {
Expand Down Expand Up @@ -3802,7 +3799,7 @@ bool ExecCGCommand::readyForCleanup() const {
}

UpdateCommandBufferCommand::UpdateCommandBufferCommand(
QueueImplPtr Queue,
queue_impl *Queue,
ext::oneapi::experimental::detail::exec_graph_impl *Graph,
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
Nodes)
Expand All @@ -3813,7 +3810,7 @@ ur_result_t UpdateCommandBufferCommand::enqueueImp() {
waitForPreparedHostEvents();
std::vector<EventImplPtr> EventImpls = MPreparedDepsEvents;
ur_event_handle_t UREvent = nullptr;
Command::waitForEvents(MQueue, EventImpls, UREvent);
Command::waitForEvents(MQueue.get(), EventImpls, UREvent);
MEvent->setHandle(UREvent);

auto CheckAndFindAlloca = [](Requirement *Req, const DepDesc &Dep) {
Expand Down Expand Up @@ -3885,6 +3882,15 @@ void UpdateCommandBufferCommand::printDot(std::ostream &Stream) const {
void UpdateCommandBufferCommand::emitInstrumentationData() {}
bool UpdateCommandBufferCommand::producesPiEvent() const { return false; }

CGHostTask::CGHostTask(std::shared_ptr<HostTask> HostTask,
detail::queue_impl *Queue,
std::shared_ptr<detail::context_impl> Context,
std::vector<ArgDesc> Args, CG::StorageInitHelper CGData,
CGType Type, detail::code_location loc)
: CG(Type, std::move(CGData), std::move(loc)),
MHostTask(std::move(HostTask)),
MQueue(Queue ? Queue->shared_from_this() : nullptr), MContext(Context),
MArgs(std::move(Args)) {}
} // namespace detail
} // namespace _V1
} // namespace sycl
Loading