Skip to content

[NFC][SYCL] Use raw context_impl & in event_impl::[set|get]Context #19007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions sycl/source/detail/event_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ void event_impl::initContextIfNeeded() {
return;

const device SyclDevice;
this->setContextImpl(
detail::queue_impl::getDefaultOrNew(*detail::getSyclObjImpl(SyclDevice)));
MIsHostEvent = false;
MContext =
detail::queue_impl::getDefaultOrNew(*detail::getSyclObjImpl(SyclDevice));
assert(MContext);
}

event_impl::~event_impl() {
Expand Down Expand Up @@ -140,9 +142,10 @@ void event_impl::setHandle(const ur_event_handle_t &UREvent) {
MEvent.store(UREvent);
}

const ContextImplPtr &event_impl::getContextImpl() {
context_impl &event_impl::getContextImpl() {
initContextIfNeeded();
return MContext;
assert(MContext && "Trying to get context from a host event!");
return *MContext;
}

const AdapterPtr &event_impl::getAdapter() {
Expand All @@ -152,9 +155,9 @@ const AdapterPtr &event_impl::getAdapter() {

void event_impl::setStateIncomplete() { MState = HES_NotComplete; }

void event_impl::setContextImpl(const ContextImplPtr &Context) {
MIsHostEvent = Context == nullptr;
MContext = Context;
void event_impl::setContextImpl(context_impl &Context) {
MIsHostEvent = false;
MContext = Context.shared_from_this();
}

event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext,
Expand All @@ -178,7 +181,7 @@ event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext,
event_impl::event_impl(queue_impl &Queue, private_tag)
: MQueue{Queue.weak_from_this()},
MIsProfilingEnabled{Queue.MIsProfilingEnabled} {
this->setContextImpl(Queue.getContextImplPtr());
this->setContextImpl(Queue.getContextImpl());
MState.store(HES_Complete);
}

Expand Down
10 changes: 3 additions & 7 deletions sycl/source/detail/event_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,21 +173,17 @@ class event_impl : public std::enable_shared_from_this<event_impl> {
void setHandle(const ur_event_handle_t &UREvent);

/// Returns context that is associated with this event.
///
/// \return a shared pointer to a valid context_impl.
const ContextImplPtr &getContextImpl();
context_impl &getContextImpl();

/// \return the Adapter associated with the context of this event.
/// Should be called when this is not a Host Event.
const AdapterPtr &getAdapter();

/// Associate event with the context.
///
/// Provided UrContext inside ContextImplPtr must be associated
/// Provided UrContext inside Context must be associated
/// with the UrEvent object stored in this class
///
/// @param Context is a shared pointer to an instance of valid context_impl.
void setContextImpl(const ContextImplPtr &Context);
void setContextImpl(context_impl &Context);

/// Clear the event state
void setStateIncomplete();
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue,

auto CreateNewEvent([&]() {
auto NewEvent = sycl::detail::event_impl::create_device_event(Queue);
NewEvent->setContextImpl(Queue.getContextImplPtr());
NewEvent->setContextImpl(Queue.getContextImpl());
NewEvent->setStateIncomplete();
return NewEvent;
});
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/queue_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ queue_impl::get_backend_info<info::device::backend_version>() const {
static event prepareSYCLEventAssociatedWithQueue(
const std::shared_ptr<detail::queue_impl> &QueueImpl) {
auto EventImpl = detail::event_impl::create_device_event(*QueueImpl);
EventImpl->setContextImpl(detail::getSyclObjImpl(QueueImpl->get_context()));
EventImpl->setContextImpl(QueueImpl->getContextImpl());
EventImpl->setStateIncomplete();
return detail::createSyclObjFromImpl<event>(EventImpl);
}
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ __SYCL_EXPORT void
addCounterInit(handler &CGH, std::shared_ptr<sycl::detail::queue_impl> &Queue,
std::shared_ptr<int> &Counter) {
auto EventImpl = detail::event_impl::create_device_event(*Queue);
EventImpl->setContextImpl(detail::getSyclObjImpl(Queue->get_context()));
EventImpl->setContextImpl(Queue->getContextImpl());
EventImpl->setStateIncomplete();
ur_event_handle_t UREvent = nullptr;
MemoryManager::fill_usm(Counter.get(), *Queue, sizeof(int), {0}, {},
Expand Down
18 changes: 8 additions & 10 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,10 +533,8 @@ void Command::waitForEvents(queue_impl *Queue,
RequiredEventsPerContext;

for (const EventImplPtr &Event : EventImpls) {
ContextImplPtr Context = Event->getContextImpl();
assert(Context.get() &&
"Only non-host events are expected to be waited for here");
RequiredEventsPerContext[Context.get()].push_back(Event);
context_impl &Context = Event->getContextImpl();
RequiredEventsPerContext[&Context].push_back(Event);
}

for (auto &CtxWithEvents : RequiredEventsPerContext) {
Expand Down Expand Up @@ -576,7 +574,7 @@ Command::Command(
MEvent->setSubmittedQueue(MWorkerQueue);
MEvent->setCommand(this);
if (MQueue)
MEvent->setContextImpl(MQueue->getContextImplPtr());
MEvent->setContextImpl(MQueue->getContextImpl());
MEvent->setStateIncomplete();
MEnqueueStatus = EnqueueResultT::SyclEnqueueReady;

Expand Down Expand Up @@ -781,9 +779,9 @@ Command *Command::processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep,

Command *ConnectionCmd = nullptr;

ContextImplPtr DepEventContext = DepEvent->getContextImpl();
context_impl &DepEventContext = DepEvent->getContextImpl();
// If contexts don't match we'll connect them using host task
if (DepEventContext != WorkerContext && WorkerContext) {
if (&DepEventContext != WorkerContext.get() && WorkerContext) {
Scheduler::GraphBuilder &GB = Scheduler::getInstance().MGraphBuilder;
ConnectionCmd = GB.connectDepEvent(this, DepEvent, Dep, ToCleanUp);
} else
Expand Down Expand Up @@ -1298,7 +1296,7 @@ ur_result_t ReleaseCommand::enqueueImp() {

std::shared_ptr<event_impl> UnmapEventImpl =
event_impl::create_device_event(*Queue);
UnmapEventImpl->setContextImpl(Queue->getContextImplPtr());
UnmapEventImpl->setContextImpl(Queue->getContextImpl());
UnmapEventImpl->setStateIncomplete();
ur_event_handle_t UREvent = nullptr;

Expand Down Expand Up @@ -1516,7 +1514,7 @@ MemCpyCommand::MemCpyCommand(Requirement SrcReq,
MSrcReq(std::move(SrcReq)), MSrcAllocaCmd(SrcAllocaCmd),
MDstReq(std::move(DstReq)), MDstAllocaCmd(DstAllocaCmd) {
if (MSrcQueue) {
MEvent->setContextImpl(MSrcQueue->getContextImplPtr());
MEvent->setContextImpl(MSrcQueue->getContextImpl());
}

MWorkerQueue = !MQueue ? MSrcQueue : MQueue;
Expand Down Expand Up @@ -1689,7 +1687,7 @@ MemCpyCommandHost::MemCpyCommandHost(Requirement SrcReq,
MSrcReq(std::move(SrcReq)), MSrcAllocaCmd(SrcAllocaCmd),
MDstReq(std::move(DstReq)), MDstPtr(DstPtr) {
if (MSrcQueue) {
MEvent->setContextImpl(MSrcQueue->getContextImplPtr());
MEvent->setContextImpl(MSrcQueue->getContextImpl());
}

MWorkerQueue = !MQueue ? MSrcQueue : MQueue;
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/scheduler/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1221,7 +1221,7 @@ void Scheduler::GraphBuilder::removeRecordForMemObj(SYCLMemObjI *MemObject) {
Command *Scheduler::GraphBuilder::connectDepEvent(
Command *const Cmd, const EventImplPtr &DepEvent, const DepDesc &Dep,
std::vector<Command *> &ToCleanUp) {
assert(Cmd->getWorkerContext() != DepEvent->getContextImpl());
assert(Cmd->getWorkerContext().get() != &DepEvent->getContextImpl());

// construct Host Task type command manually and make it depend on DepEvent
ExecCGCommand *ConnectCmd = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/scheduler/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ bool Scheduler::CheckEventReadiness(context_impl &Context,
return SyclEventImplPtr->isCompleted();
}
// Cross-context dependencies can't be passed to the backend directly.
if (SyclEventImplPtr->getContextImpl().get() != &Context)
if (&SyclEventImplPtr->getContextImpl() != &Context)
return false;

// A nullptr here means that the commmand does not produce a UR event or it
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ event handler::finalize() {
detail::queue_impl &Queue = impl->get_queue();
LastEventImpl->setQueue(Queue);
LastEventImpl->setWorkerQueue(Queue.weak_from_this());
LastEventImpl->setContextImpl(impl->get_context().shared_from_this());
LastEventImpl->setContextImpl(impl->get_context());
LastEventImpl->setStateIncomplete();
LastEventImpl->setSubmissionTime();

Expand Down
4 changes: 2 additions & 2 deletions sycl/unittests/scheduler/QueueFlushing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ TEST_F(SchedulerTest, QueueFlushing) {
access::mode::read_write};
std::shared_ptr<detail::event_impl> DepEvent =
detail::event_impl::create_device_event(QueueImplB);
DepEvent->setContextImpl(QueueImplB.getContextImplPtr());
DepEvent->setContextImpl(QueueImplB.getContextImpl());

ur_event_handle_t UREvent = mock::createDummyHandle<ur_event_handle_t>();

Expand All @@ -170,7 +170,7 @@ TEST_F(SchedulerTest, QueueFlushing) {
queue TempQueue{Ctx, default_selector_v};
detail::queue_impl &TempQueueImpl = *detail::getSyclObjImpl(TempQueue);
DepEvent = detail::event_impl::create_device_event(TempQueueImpl);
DepEvent->setContextImpl(TempQueueImpl.getContextImplPtr());
DepEvent->setContextImpl(TempQueueImpl.getContextImpl());

ur_event_handle_t UREvent = mock::createDummyHandle<ur_event_handle_t>();

Expand Down