1+ // SPDX-License-Identifier: MIT
2+ // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved
3+ #pragma once
4+ #include " kernels/hybrid_ep_backend_configs.hpp"
5+ #include " kernels/hybrid_ep_backend.cuh"
6+ #include < ATen/cuda/CUDAContext.h>
7+ #include < c10/util/Optional.h>
8+ #include < torch/torch.h>
9+ #include < pybind11/functional.h>
10+ #include < pybind11/pybind11.h>
11+ #include < pybind11/stl.h>
12+ #include < vector>
13+ #include < algorithm>
14+
15+ inline std::string type_to_string (TOKEN_DATA_TYPE token_data_type) {
16+ switch (token_data_type) {
17+ case TOKEN_DATA_TYPE::UINT16:
18+ return " uint16_t" ;
19+ case TOKEN_DATA_TYPE::UINT8:
20+ return " uint8_t" ;
21+ default :
22+ return " unknown" ;
23+ }
24+ }
25+
26+ union MemHandleInner{
27+ cudaIpcMemHandle_t cuda_ipc_mem_handle;
28+ CUmemFabricHandle cu_mem_fabric_handle;
29+ };
30+
31+ struct MemHandle {
32+ MemHandleInner inner;
33+ size_t size;
34+ };
35+
36+ // Utility function to get token data type size
37+ inline size_t get_token_data_type_size (TOKEN_DATA_TYPE data_type) {
38+ switch (data_type) {
39+ case TOKEN_DATA_TYPE::UINT8:
40+ return sizeof (uint8_t );
41+ case TOKEN_DATA_TYPE::UINT16:
42+ return sizeof (uint16_t );
43+ default :
44+ throw std::runtime_error (" Invalid token data type:" + std::to_string (static_cast <int >(data_type)));
45+ }
46+ }
47+
48+ // Round-up allocation size to fabric granularity.
49+ inline size_t get_size_align_to_granularity (size_t size_raw, size_t granularity){
50+ size_t size = (size_raw + granularity - 1 ) & ~(granularity - 1 );
51+ if (size == 0 ) size = granularity;
52+ return size;
53+ }
54+
55+ // Device memory allocator, allocate local device memory. Support both normal cudaMalloc and fabric allocator.
56+ inline void device_mem_malloc (void ** ptr, size_t size_raw, bool enable_fabric){
57+ if (enable_fabric){
58+ CUdevice device;
59+ CU_CHECK (cuCtxGetDevice (&device));
60+
61+ CUmemAllocationProp prop = {};
62+ prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
63+ prop.location .type = CU_MEM_LOCATION_TYPE_DEVICE;
64+ prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;
65+ prop.location .id = device;
66+
67+ size_t granularity = 0 ;
68+ CU_CHECK (cuMemGetAllocationGranularity (&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
69+
70+ size_t size = get_size_align_to_granularity (size_raw, granularity);
71+
72+ CUmemGenericAllocationHandle handle;
73+ CU_CHECK (cuMemCreate (&handle, size, &prop, 0 ));
74+
75+ CU_CHECK (cuMemAddressReserve ((CUdeviceptr*)ptr, size, granularity, 0 , 0 ));
76+ CU_CHECK (cuMemMap ((CUdeviceptr)*ptr, size, 0 , handle, 0 ));
77+ CUmemAccessDesc access_desc = {};
78+ access_desc.location .type = CU_MEM_LOCATION_TYPE_DEVICE;
79+ access_desc.location .id = device;
80+ access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
81+ CU_CHECK (cuMemSetAccess ((CUdeviceptr)*ptr, size, &access_desc, 1 ));
82+ }else {
83+ CUDA_CHECK (cudaMalloc (ptr, size_raw));
84+ }
85+ }
86+
87+ // Get sharable memory handle of local device memory for remote ranks to access. Support both IPC handle and fabric handle.
88+ inline void get_device_mem_handle (MemHandle* mem_handle, void * ptr, bool enable_fabric){
89+ size_t size = 0 ;
90+ CU_CHECK (cuMemGetAddressRange (NULL , &size, (CUdeviceptr)ptr));
91+
92+ mem_handle->size = size;
93+
94+ if (enable_fabric){
95+ CUmemGenericAllocationHandle handle;
96+ CU_CHECK (cuMemRetainAllocationHandle (&handle, ptr));
97+ CU_CHECK (cuMemExportToShareableHandle (&mem_handle->inner .cu_mem_fabric_handle , handle, CU_MEM_HANDLE_TYPE_FABRIC, 0 ));
98+ }else {
99+ CUDA_CHECK (cudaIpcGetMemHandle (&mem_handle->inner .cuda_ipc_mem_handle , ptr));
100+ }
101+ }
102+
103+ // Open sharable memory handle from other remote ranks and map it for local device to access. Support both IPC handle and fabric handle.
104+ inline void open_device_mem_handle (void ** ptr, MemHandle* mem_handle, bool enable_fabric){
105+ if (enable_fabric){
106+ CUdevice device;
107+ CU_CHECK (cuCtxGetDevice (&device));
108+ size_t size = mem_handle->size ;
109+
110+ CUmemGenericAllocationHandle handle;
111+ CU_CHECK (cuMemImportFromShareableHandle (&handle, &mem_handle->inner .cu_mem_fabric_handle , CU_MEM_HANDLE_TYPE_FABRIC));
112+
113+ CU_CHECK (cuMemAddressReserve ((CUdeviceptr*)ptr, size, 0 , 0 , 0 ));
114+ CU_CHECK (cuMemMap ((CUdeviceptr)*ptr, size, 0 , handle, 0 ));
115+ CUmemAccessDesc access_desc = {};
116+ access_desc.location .type = CU_MEM_LOCATION_TYPE_DEVICE;
117+ access_desc.location .id = device;
118+ access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
119+ CU_CHECK (cuMemSetAccess ((CUdeviceptr)*ptr, size, &access_desc, 1 ));
120+ }else {
121+ CUDA_CHECK (cudaIpcOpenMemHandle (ptr, mem_handle->inner .cuda_ipc_mem_handle , cudaIpcMemLazyEnablePeerAccess));
122+ }
123+ }
124+
125+ // Close and unmap sharable memory handle from other remote ranks. Support both IPC handle and fabric handle.
126+ inline void close_device_mem_handle (void * ptr, bool enable_fabric){
127+ if (enable_fabric){
128+ size_t size = 0 ;
129+ CU_CHECK (cuMemGetAddressRange (NULL , &size, (CUdeviceptr)ptr));
130+
131+ CU_CHECK (cuMemUnmap ((CUdeviceptr)ptr, size));
132+ CU_CHECK (cuMemAddressFree ((CUdeviceptr)ptr, size));
133+ }else {
134+ CUDA_CHECK (cudaIpcCloseMemHandle (ptr));
135+ }
136+ }
137+
138+ // Free local device memory allocated by device_mem_malloc.
139+ inline void device_mem_free (void * ptr, bool enable_fabric){
140+ if (enable_fabric){
141+ CUmemGenericAllocationHandle handle;
142+ CU_CHECK (cuMemRetainAllocationHandle (&handle, ptr));
143+
144+ size_t size = 0 ;
145+ CU_CHECK (cuMemGetAddressRange (NULL , &size, (CUdeviceptr)ptr));
146+
147+ CU_CHECK (cuMemUnmap ((CUdeviceptr)ptr, size));
148+ CU_CHECK (cuMemAddressFree ((CUdeviceptr)ptr, size));
149+ CU_CHECK (cuMemRelease (handle));
150+ }else {
151+ CUDA_CHECK (cudaFree (ptr));
152+ }
153+ }
154+
155+ class HybridEpBuffer {
156+ public:
157+ HybridEpBuffer (HybridEpConfigInstance config, int local_rank, int node_rank,
158+ int num_of_ranks_per_node);
159+ ~HybridEpBuffer ();
160+
161+ // Exchange IPC addresses using C++ distributed communication
162+ void exchange_ipc_address (pybind11::object process_group);
163+
164+ void update_num_of_tokens_per_rank (int num_of_tokens_per_rank);
165+
166+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
167+ torch::Tensor>
168+ metadata_preprocessing (torch::Tensor routing_map, int64_t node_rank,
169+ int64_t local_rank);
170+
171+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
172+ dispatch (torch::Tensor hidden, c10::optional<torch::Tensor> probs,
173+ c10::optional<torch::Tensor> scaling_factor,
174+ torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map,
175+ torch::Tensor attn_to_rdma_map, int64_t num_of_tokens_for_experts,
176+ bool with_probs);
177+
178+ std::tuple<torch::Tensor, torch::Tensor>
179+ combine (torch::Tensor hidden, c10::optional<torch::Tensor> probs,
180+ torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map,
181+ torch::Tensor attn_to_rdma_map, bool with_probs);
182+
183+ private:
184+ void allocate_buffer ();
185+ void allocate_buffer_for_preprocessing ();
186+ void allocate_buffer_for_dispatch ();
187+ void allocate_buffer_for_combine ();
188+ void open_handles_from_other_ranks (std::vector<torch::Tensor> dispatch_handles,
189+ std::vector<torch::Tensor> combine_handles);
190+
191+ HybridEpConfigInstance config;
192+ int rank;
193+ int group_size;
194+ int local_rank;
195+ int node_rank;
196+ int num_of_ranks_per_node;
197+
198+ int64_t max_num_of_tokens_for_experts;
199+
200+ hybrid_ep::tmp_state_t *preprocessing_tmp;
201+
202+ struct DispatchBuffers {
203+ TOKEN_DATA_TYPE data_type;
204+
205+ void *expert_output_token;
206+
207+ void **expert_output_token_all_ranks;
208+
209+ float *expert_output_prob;
210+
211+ float **expert_output_prob_all_ranks;
212+
213+ float *expert_output_scaling_factor;
214+
215+ float **expert_output_scaling_factor_all_ranks;
216+
217+ void *rdma_inter_node_group_token;
218+
219+ float *rdma_inter_node_group_prob;
220+
221+ float *rdma_inter_node_group_scaling_factor;
222+
223+ uint64_t *rdma_inter_node_group_flags;
224+
225+ uint32_t *intra_node_write_completion_flags;
226+
227+ uint64_t *expected_rdma_flag_value;
228+
229+ uint32_t *expected_intra_node_flag_value;
230+
231+ } dispatch_buffers;
232+
233+ torch::Tensor
234+ dispatch_memory_handles;
235+
236+ struct CombineBuffers {
237+
238+ uint16_t *expert_input_token;
239+
240+ uint16_t **expert_input_token_all_ranks;
241+
242+ float *expert_input_prob;
243+
244+ float **expert_input_prob_all_ranks;
245+
246+ uint16_t *rdma_intra_node_red_token;
247+
248+ float *rdma_intra_node_red_prob;
249+
250+ uint16_t *rdma_inter_node_group_token;
251+
252+ float
253+ *rdma_inter_node_group_prob;
254+
255+ uint64_t
256+ *rdma_inter_node_group_flags;
257+
258+ uint32_t *intra_node_write_completion_flags;
259+
260+ uint64_t *expected_rdma_flag_value;
261+
262+ uint32_t *expected_intra_node_flag_value;
263+
264+
265+ } combine_buffers;
266+
267+ torch::Tensor
268+ combine_memory_handles;
269+
270+ };
0 commit comments