diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6f5ae85ca9a6..d6fb2bd047a0 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -3369,6 +3369,24 @@ Example:: auto processGroupXCCL = intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupXCCL>( module, "ProcessGroupXCCL", backend) + .def( + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size, + c10::intrusive_ptr<::c10d::ProcessGroupXCCL::Options> + options) { + // gil_scoped_release is not safe as a call_guard in init. + // https://github.com/pybind/pybind11/issues/5473 + py::gil_scoped_release nogil{}; + + return c10::make_intrusive<::c10d::ProcessGroupXCCL>( + store, rank, size, std::move(options)); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::arg("options"), + R"(Create a new ProcessGroupXCCL instance.)") .def( py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, @@ -3377,12 +3395,50 @@ Example:: // https://github.com/pybind/pybind11/issues/5473 py::gil_scoped_release nogil{}; + auto options = ::c10d::ProcessGroupXCCL::Options::create(); + options->is_high_priority_stream = false; return c10::make_intrusive<::c10d::ProcessGroupXCCL>( - store, rank, size); + store, rank, size, options); }), py::arg("store"), py::arg("rank"), - py::arg("size")); + py::arg("size"), + R"(Create a new ProcessGroupXCCL instance.)") + .def( + "comm_split_count", + &::c10d::ProcessGroupXCCL::getCommSplitCounter) + .def_property_readonly( + "options", + &::c10d::ProcessGroupXCCL::getOptions, + R"(Return the options used to create this ProcessGroupXCCL instance.)") + .def_property( + "bound_device_id", + &::c10d::ProcessGroupXCCL::getBoundDeviceId, + &::c10d::ProcessGroupXCCL::setBoundDeviceId, + R"(Return the bound device id.)") + .def( + "perform_nocolor_split", + &::c10d::ProcessGroupXCCL::performNocolorSplit) + .def( + "_is_initialized", + &::c10d::ProcessGroupXCCL::isInitialized, + py::call_guard()); + intrusive_ptr_class_<::c10d::ProcessGroupXCCL::Options>( + processGroupXCCL, "Options", backendOptions) + .def(py::init(), py::arg("is_high_priority_stream") = false) + .def_readwrite("config", &::c10d::ProcessGroupXCCL::Options::config) + .def_readwrite( + "is_high_priority_stream", + &::c10d::ProcessGroupXCCL::Options::is_high_priority_stream) + .def_readwrite( + "split_from", &::c10d::ProcessGroupXCCL::Options::split_from) + .def_readwrite( + "split_color", &::c10d::ProcessGroupXCCL::Options::split_color) + .def_readwrite( + "global_ranks_in_group", + &::c10d::ProcessGroupXCCL::Options::global_ranks_in_group) + .def_readwrite( + "group_name", &::c10d::ProcessGroupXCCL::Options::group_name); #endif #ifdef USE_C10D_UCC diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 866658515a74..bfbf1ee73436 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -2033,8 +2033,18 @@ def _new_process_group_helper( elif backend_str == Backend.XCCL: if not is_xccl_available(): raise RuntimeError("Distributed package doesn't have XCCL built in") + if backend_options is not None: + assert isinstance(backend_options, ProcessGroupXCCL.Options), ( + "Expected backend_options argument to be of type ProcessGroupXCCL.Options" + ) + else: + # default backend_options for XCCL + backend_options = ProcessGroupXCCL.Options() + backend_options.is_high_priority_stream = False + backend_options.global_ranks_in_group = global_ranks_in_group + backend_options.group_name = group_name backend_class = ProcessGroupXCCL( - backend_prefix_store, group_rank, group_size + backend_prefix_store, group_rank, group_size, backend_options ) backend_type = ProcessGroup.BackendType.XCCL else: @@ -5042,7 +5052,7 @@ def split_group( ) parent_group_rank = parent_global_to_group_ranks[global_rank] - parent_backend = parent_pg._get_backend(torch.device("cuda")) + parent_backend = parent_pg._get_backend(device_id) # if the parent backend does not support splitting, raise error # currently this API only support NCCL backend @@ -5123,6 +5133,15 @@ def split_group( backend_class = ProcessGroupNCCL( prefix_store, group_rank, len(my_group), pg_options ) + elif parent_backend_str == Backend.XCCL: + backend_type = ProcessGroup.BackendType.XCCL + if not isinstance(pg_options, ProcessGroupXCCL.Options): + raise RuntimeError( + "Expected pg_options argument to be of type ProcessGroupXCCL.Options" + ) + backend_class = ProcessGroupXCCL( + prefix_store, group_rank, len(my_group), pg_options + ) else: assert parent_backend_str.upper() in Backend._plugins, ( f"Unknown c10d backend type {parent_backend_str.upper()}" @@ -5143,7 +5162,9 @@ def split_group( pg._set_default_backend(backend_type) backend_class._set_sequence_number_for_group() - pg._register_backend(torch.device("cuda"), backend_type, backend_class) + pg._register_backend( + torch.accelerator.current_accelerator(), backend_type, backend_class + ) # set group_name and group_desc to backend assert group_name is not None