Skip to content

Commit 409f079

Browse files
authored
[pjrt] ensure client destruction on process exit (#1999)
`torch_xla` doesn't call `PJRT_ClientDestroy` properly. This means that we are not closing the devices properly. Recently, this started causing hangs on `n300` boards on subsequent execution of tests. This PR introduces a global singleton object which will ensure that we are properly destroying the client instance on process shutdown. The singleton serves as a fallback mechanism if the framework doesn't call `PJRT_ClientDestroy` - like in the case of `torch_xla`. Additionally, optimizer sub-meshes are now closed after each compilation; this previously was needed to avoid hangs, but now it causes them. Leaving the mechanism of persisting optimizer submesh in the code base, so that we can play with it if needed. Obviously, we need to dig deep into these issues to fix them properly. NOTE: this does not solve the case when the process terminates abruptly, e.g. in case of `SIGSEGV` (segmentation fault). For this, ideally we would want a fix on `tt-metal` side. Closes #1824
1 parent 21df083 commit 409f079

File tree

3 files changed

+86
-10
lines changed

3 files changed

+86
-10
lines changed

pjrt_implementation/inc/api/client_instance.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,32 @@ namespace module_builder {
3434
class ModuleBuilder;
3535
}
3636

37+
// Singleton class that wraps the PJRT Client Instance.
38+
// Ensures that we properly destroy the client instance on process termination.
39+
//
40+
// NOTE: This is needed since `torch_xla` implementation doesn't call
41+
// `PJRT_Client_Destroy` API properly and `tt-metal` currently cannot recover
42+
// (on n300 boards) if we do not properly close all previously opened devices -
43+
// which is done on client destruction.
44+
//
45+
// NOTE: This serves only as a fallback option if `PJRT_Client_Destroy` is not
46+
// called by the framework.
47+
class GlobalClientInstanceSingleton {
48+
public:
49+
static ClientInstance *getClientInstance();
50+
static PJRT_Error *initClient();
51+
static void destroyClient();
52+
53+
private:
54+
GlobalClientInstanceSingleton(std::unique_ptr<ClientInstance> client_instance)
55+
: m_client_instance(std::move(client_instance)) {}
56+
57+
bool isInitialized() const { return m_client_instance != nullptr; }
58+
59+
static GlobalClientInstanceSingleton &getInstance();
60+
std::unique_ptr<ClientInstance> m_client_instance;
61+
};
62+
3763
// Represents PJRT_Client structure and the functionality around it.
3864
class ClientInstance {
3965

@@ -96,6 +122,9 @@ class ClientInstance {
96122
tt::runtime::Device
97123
getOrCreateOptimizerSubmesh(const std::vector<uint32_t> &target_mesh_shape);
98124

125+
// Closes the currently opened optimizer submesh device, if any.
126+
void closeOptimizerSubmesh();
127+
99128
// Compiles given mlir program.
100129
tt_pjrt_status compileMlirProgram(
101130
const PJRT_Program *mlir_program,

pjrt_implementation/src/api/client_instance.cc

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,37 @@
3737

3838
namespace tt::pjrt {
3939

40+
PJRT_Error *GlobalClientInstanceSingleton::initClient() {
41+
std::unique_ptr<ClientInstance> client = std::make_unique<ClientInstance>();
42+
PJRT_Error *error = client->initialize();
43+
if (error) {
44+
return error;
45+
}
46+
47+
GlobalClientInstanceSingleton &singleton_instance = getInstance();
48+
singleton_instance.m_client_instance = std::move(client);
49+
50+
return nullptr;
51+
}
52+
53+
void GlobalClientInstanceSingleton::destroyClient() {
54+
GlobalClientInstanceSingleton &singleton_instance = getInstance();
55+
if (singleton_instance.isInitialized()) {
56+
singleton_instance.m_client_instance.reset();
57+
}
58+
}
59+
60+
GlobalClientInstanceSingleton &GlobalClientInstanceSingleton::getInstance() {
61+
static GlobalClientInstanceSingleton singleton =
62+
GlobalClientInstanceSingleton(nullptr);
63+
return singleton;
64+
}
65+
66+
ClientInstance *GlobalClientInstanceSingleton::getClientInstance() {
67+
auto &singleton = GlobalClientInstanceSingleton::getInstance();
68+
return singleton.m_client_instance.get();
69+
}
70+
4071
ClientInstance::ClientInstance()
4172
: m_system_descriptor(nullptr),
4273
m_module_builder(std::make_unique<module_builder::ModuleBuilder>()),
@@ -58,9 +89,7 @@ ClientInstance::ClientInstance()
5889
ClientInstance::~ClientInstance() {
5990
DLOG_F(LOG_DEBUG, "ClientInstance::~ClientInstance");
6091

61-
if (m_optimizer_submesh.has_value()) {
62-
tt::runtime::releaseSubMeshDevice(*m_optimizer_submesh);
63-
}
92+
closeOptimizerSubmesh();
6493

6594
if (m_parent_mesh.has_value()) {
6695
tt::runtime::closeMeshDevice(*m_parent_mesh);
@@ -446,6 +475,13 @@ ClientInstance::openMeshDevice(const std::vector<uint32_t> &mesh_shape) {
446475
return tt::runtime::openMeshDevice(options);
447476
}
448477

478+
void ClientInstance::closeOptimizerSubmesh() {
479+
if (m_optimizer_submesh.has_value()) {
480+
tt::runtime::releaseSubMeshDevice(*m_optimizer_submesh);
481+
m_optimizer_submesh.reset();
482+
}
483+
}
484+
449485
tt::runtime::Device ClientInstance::getOrCreateOptimizerSubmesh(
450486
const std::vector<uint32_t> &target_mesh_shape) {
451487

@@ -487,23 +523,29 @@ PJRT_Error *onClientCreate(PJRT_Client_Create_Args *args) {
487523
args->create_options[i].name);
488524
}
489525

490-
std::unique_ptr<ClientInstance> client = std::make_unique<ClientInstance>();
491-
PJRT_Error *error = client->initialize();
526+
PJRT_Error *error = GlobalClientInstanceSingleton::initClient();
527+
492528
if (error) {
529+
DLOG_F(ERROR, "Failed to initialize PJRT client instance");
493530
return error;
494531
}
495532

496-
// Successful return.
497-
args->client = reinterpret_cast<PJRT_Client *>(client.release());
533+
ClientInstance *client_instance =
534+
GlobalClientInstanceSingleton::getClientInstance();
535+
assert(client_instance != nullptr);
536+
args->client = reinterpret_cast<PJRT_Client *>(client_instance);
498537

499538
return nullptr;
500539
}
501540

502541
PJRT_Error *onClientDestroy(PJRT_Client_Destroy_Args *args) {
503542
DLOG_F(LOG_DEBUG, "ClientInstance::PJRT_Client_Destroy");
504543

505-
delete ClientInstance::unwrap(args->client);
506-
544+
ClientInstance *client_instance = ClientInstance::unwrap(args->client);
545+
ClientInstance *global_client_instance =
546+
GlobalClientInstanceSingleton::getClientInstance();
547+
assert(client_instance == global_client_instance);
548+
GlobalClientInstanceSingleton::destroyClient();
507549
return nullptr;
508550
}
509551

pjrt_implementation/src/api/module_builder/module_builder.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,12 @@ tt_pjrt_status ModuleBuilder::convertFromTTIRToTTNN(
861861
enableVerboseIRPrinting(ttir_to_ttnn_pm);
862862

863863
// Run the pass manager.
864-
if (mlir::failed(ttir_to_ttnn_pm.run(mlir_module.get()))) {
864+
mlir::LogicalResult mlir_result = ttir_to_ttnn_pm.run(mlir_module.get());
865+
866+
// Close the optimizer submesh now that the compilation is complete.
867+
client_instance->closeOptimizerSubmesh();
868+
869+
if (mlir::failed(mlir_result)) {
865870
DLOG_F(ERROR, "Failed to convert from TTIR to TTNN module");
866871
return tt_pjrt_status::kInternal;
867872
}

0 commit comments

Comments
 (0)