11
11
#include < tvm/runtime/device_api.h>
12
12
#include < tvm/runtime/disco/builtin.h>
13
13
#include < tvm/runtime/disco/disco_worker.h>
14
- #include < tvm/runtime/vm/ndarray_cache_support .h>
14
+ #include < tvm/runtime/vm/tensor_cache_support .h>
15
15
16
16
#include < chrono>
17
17
#include < filesystem>
@@ -31,7 +31,7 @@ namespace llm {
31
31
namespace multi_gpu {
32
32
33
33
using tvm::Device;
34
- using tvm::runtime::vm::NDArrayCacheMetadata ;
34
+ using tvm::runtime::vm::TensorCacheMetadata ;
35
35
using namespace tvm ::runtime;
36
36
using tvm::ffi::Array;
37
37
using tvm::ffi::Function;
@@ -76,11 +76,11 @@ class PreprocessorPool {
76
76
}
77
77
}
78
78
79
- NDArray Apply (NDArray param, const ModelMetadata::Param& param_info) const {
79
+ Tensor Apply (Tensor param, const ModelMetadata::Param& param_info) const {
80
80
for (const ModelMetadata::Param::Preproc& preproc : param_info.preprocs ) {
81
81
const std::string& func_name = preproc.func_name ;
82
- NDArray param_in = param;
83
- param = NDArray ::Empty (preproc.out_shape , preproc.out_dtype , param->device );
82
+ Tensor param_in = param;
83
+ param = Tensor ::Empty (preproc.out_shape , preproc.out_dtype , param->device );
84
84
ICHECK (preproc_funcs.count (func_name));
85
85
DLTensor dl_param_in = *param_in.operator ->();
86
86
DLTensor dl_param = *param.operator ->();
@@ -94,19 +94,19 @@ class PreprocessorPool {
94
94
};
95
95
96
96
struct ParamInfo {
97
- const NDArrayCacheMetadata ::FileRecord* file;
98
- const NDArrayCacheMetadata ::FileRecord::ParamRecord* param;
97
+ const TensorCacheMetadata ::FileRecord* file;
98
+ const TensorCacheMetadata ::FileRecord::ParamRecord* param;
99
99
};
100
100
101
- NDArray RecvFromGlobalWorker0 (Device device, const ModelMetadata::Param& param_info) {
101
+ Tensor RecvFromGlobalWorker0 (Device device, const ModelMetadata::Param& param_info) {
102
102
Shape shape = param_info.preprocs .empty () ? param_info.shape : param_info.preprocs [0 ].in_shape ;
103
- NDArray result = NDArray ::Empty (shape, param_info.dtype , device);
103
+ Tensor result = Tensor ::Empty (shape, param_info.dtype , device);
104
104
RecvFromWorker0 (result);
105
105
return result;
106
106
}
107
107
108
- NDArray BroadcastOrShardAndScatter (NDArray param, const ModelMetadata::Param& param_info,
109
- int num_shards, const PreprocessorPool& preprocs) {
108
+ Tensor BroadcastOrShardAndScatter (Tensor param, const ModelMetadata::Param& param_info,
109
+ int num_shards, const PreprocessorPool& preprocs) {
110
110
bool needs_sharding = !param_info.preprocs .empty ();
111
111
if (!needs_sharding) {
112
112
BroadcastFromWorker0 (param, /* in_group=*/ true , param);
@@ -119,22 +119,22 @@ NDArray BroadcastOrShardAndScatter(NDArray param, const ModelMetadata::Param& pa
119
119
<< " ValueError: The first dimension of the output shape must be equal to the "
120
120
<< " number of shards, but got: " << shape << " and num_shards = " << num_shards;
121
121
param = preprocs.Apply (param, param_info);
122
- NDArray result = NDArray ::Empty (Shape (shape.begin () + 1 , shape.end ()), dtype, device);
122
+ Tensor result = Tensor ::Empty (Shape (shape.begin () + 1 , shape.end ()), dtype, device);
123
123
ScatterFromWorker0 (param, /* in_group=*/ true , result);
124
124
return result;
125
125
}
126
126
127
- NDArray ReceiveBroadcastedOrSharded (Device device, const ModelMetadata::Param& param_info,
128
- int num_shards) {
127
+ Tensor ReceiveBroadcastedOrSharded (Device device, const ModelMetadata::Param& param_info,
128
+ int num_shards) {
129
129
bool needs_sharding = !param_info.preprocs .empty ();
130
- NDArray result;
130
+ Tensor result;
131
131
if (needs_sharding) {
132
132
Shape shape = param_info.preprocs .back ().out_shape ;
133
133
DataType dtype = param_info.preprocs .back ().out_dtype ;
134
- result = NDArray ::Empty (Shape (shape.begin () + 1 , shape.end ()), dtype, device);
134
+ result = Tensor ::Empty (Shape (shape.begin () + 1 , shape.end ()), dtype, device);
135
135
ScatterFromWorker0 (std::nullopt, /* in_group=*/ true , result);
136
136
} else {
137
- result = NDArray ::Empty (param_info.shape , param_info.dtype , device);
137
+ result = Tensor ::Empty (param_info.shape , param_info.dtype , device);
138
138
BroadcastFromWorker0 (result, /* in_group=*/ true , result);
139
139
}
140
140
return result;
@@ -147,8 +147,8 @@ std::string FormatDuration(DurationType duration) {
147
147
return os.str ();
148
148
}
149
149
150
- Array<Optional<NDArray >> LoadMultiGPU (const std::string& model_path, Module vm_module,
151
- const std::string& model_config_str) {
150
+ Array<Optional<Tensor >> LoadMultiGPU (const std::string& model_path, Module vm_module,
151
+ const std::string& model_config_str) {
152
152
DiscoWorker* worker = DiscoWorker::ThreadLocal ();
153
153
Device device = worker->default_device ;
154
154
int worker_id = worker->worker_id ;
@@ -157,7 +157,7 @@ Array<Optional<NDArray>> LoadMultiGPU(const std::string& model_path, Module vm_m
157
157
int group_id = worker_id / group_size;
158
158
LOG (INFO) << " [Worker #" << worker_id << " ] Loading model to device: " << device;
159
159
// Step 0. Initialize metadata and paths
160
- NDArrayCacheMetadata ndarray_cache_metadata = NDArrayCacheMetadata ::Load (model_path);
160
+ TensorCacheMetadata tensor_cache_metadata = TensorCacheMetadata ::Load (model_path);
161
161
picojson::value model_config;
162
162
picojson::parse (model_config, model_config_str);
163
163
ModelMetadata model_metadata =
@@ -175,14 +175,14 @@ Array<Optional<NDArray>> LoadMultiGPU(const std::string& model_path, Module vm_m
175
175
param_name2info[param.name ] = param;
176
176
}
177
177
// Step 2. Load, preprocess and shard all the parameters
178
- std::unordered_map<std::string, NDArray > sharded_params;
178
+ std::unordered_map<std::string, Tensor > sharded_params;
179
179
if (worker_id == 0 ) {
180
180
DurationType time_loading (0 );
181
181
DurationType time_preproc (0 );
182
182
ProgressBar progress_bar (model_metadata.params .size ());
183
183
LOG (INFO) << " Loading parameters..." ;
184
- for (const NDArrayCacheMetadata ::FileRecord& record : ndarray_cache_metadata .records ) {
185
- Array<NDArray > loaded_params;
184
+ for (const TensorCacheMetadata ::FileRecord& record : tensor_cache_metadata .records ) {
185
+ Array<Tensor > loaded_params;
186
186
{
187
187
RangeTimer _ (&time_loading);
188
188
std::string raw_data_buffer;
@@ -212,7 +212,7 @@ Array<Optional<NDArray>> LoadMultiGPU(const std::string& model_path, Module vm_m
212
212
<< " Loading " << FormatDuration (time_loading) << " Preprocessing "
213
213
<< FormatDuration (time_preproc) << " ." ;
214
214
} else {
215
- for (const NDArrayCacheMetadata ::FileRecord& record : ndarray_cache_metadata .records ) {
215
+ for (const TensorCacheMetadata ::FileRecord& record : tensor_cache_metadata .records ) {
216
216
for (size_t i = 0 ; i < record.records .size (); ++i) {
217
217
const std::string& param_name = record.records [i].name ;
218
218
const ModelMetadata::Param& param_info = param_name2info.at (param_name);
@@ -225,7 +225,7 @@ Array<Optional<NDArray>> LoadMultiGPU(const std::string& model_path, Module vm_m
225
225
if (worker_id % group_size == 0 ) {
226
226
// The worker is the first worker of its worker group (while not the first worker group).
227
227
// Receive the full parameter from the global worker 0.
228
- NDArray full_param = RecvFromGlobalWorker0 (device, param_info);
228
+ Tensor full_param = RecvFromGlobalWorker0 (device, param_info);
229
229
// Broadcast or shard-scatter this parameter to all workers in its worker group.
230
230
sharded_params[param_name] =
231
231
BroadcastOrShardAndScatter (full_param, param_info, num_shards, preprocs);
@@ -239,17 +239,17 @@ Array<Optional<NDArray>> LoadMultiGPU(const std::string& model_path, Module vm_m
239
239
}
240
240
241
241
// Step 3. Reorder the sharded parameters according to the order in model_metadata
242
- Array<Optional<NDArray >> shards;
242
+ Array<Optional<Tensor >> shards;
243
243
shards.reserve (model_metadata.params .size ());
244
244
for (const ModelMetadata::Param& param : model_metadata.params ) {
245
245
const auto & it = sharded_params.find (param.name );
246
- shards.push_back (it == sharded_params.end () ? Optional<NDArray >() : it->second );
246
+ shards.push_back (it == sharded_params.end () ? Optional<Tensor >() : it->second );
247
247
}
248
248
return shards;
249
249
}
250
250
251
- Array<Optional<NDArray >> LoadMultiGPUPresharded (const std::string& model_path, Module vm_module,
252
- const std::string& model_config_str) {
251
+ Array<Optional<Tensor >> LoadMultiGPUPresharded (const std::string& model_path, Module vm_module,
252
+ const std::string& model_config_str) {
253
253
DiscoWorker* worker = DiscoWorker::ThreadLocal ();
254
254
Device device = worker->default_device ;
255
255
int worker_id = worker->worker_id ;
@@ -259,22 +259,22 @@ Array<Optional<NDArray>> LoadMultiGPUPresharded(const std::string& model_path, M
259
259
int local_worker_id = worker_id % group_size;
260
260
LOG (INFO) << " [Worker #" << worker_id << " ] Loading model to device: " << device;
261
261
// Step 0. Initialize metadata and paths
262
- NDArrayCacheMetadata ndarray_cache_metadata = NDArrayCacheMetadata ::Load (model_path);
262
+ TensorCacheMetadata tensor_cache_metadata = TensorCacheMetadata ::Load (model_path);
263
263
picojson::value model_config;
264
264
picojson::parse (model_config, model_config_str);
265
265
ModelMetadata model_metadata =
266
266
ModelMetadata::FromModule (vm_module, model_config.get <picojson::object>());
267
267
268
268
std::unordered_map<std::string, ParamInfo> param_info_map;
269
- for (const NDArrayCacheMetadata ::FileRecord& file_record : ndarray_cache_metadata .records ) {
270
- for (const NDArrayCacheMetadata ::FileRecord::ParamRecord& param_record : file_record.records ) {
269
+ for (const TensorCacheMetadata ::FileRecord& file_record : tensor_cache_metadata .records ) {
270
+ for (const TensorCacheMetadata ::FileRecord::ParamRecord& param_record : file_record.records ) {
271
271
const std::string& param_name = param_record.name ;
272
272
param_info_map[param_name] = ParamInfo{&file_record, ¶m_record};
273
273
}
274
274
}
275
275
276
- Array<Optional<NDArray >> params;
277
- const NDArrayCacheMetadata ::FileRecord* current_file_;
276
+ Array<Optional<Tensor >> params;
277
+ const TensorCacheMetadata ::FileRecord* current_file_;
278
278
std::string current_file_stream_;
279
279
params.reserve (model_metadata.params .size ());
280
280
DurationType time_loading (0 );
@@ -283,7 +283,7 @@ Array<Optional<NDArray>> LoadMultiGPUPresharded(const std::string& model_path, M
283
283
if (std::find (param.pipeline_stages .begin (), param.pipeline_stages .end (), group_id) ==
284
284
param.pipeline_stages .end ()) {
285
285
// This worker group doesn't need to hold a copy of this parameter.
286
- params.push_back (Optional<NDArray >());
286
+ params.push_back (Optional<Tensor >());
287
287
continue ;
288
288
}
289
289
bool needs_sharding = !param.preprocs .empty ();
@@ -295,8 +295,8 @@ Array<Optional<NDArray>> LoadMultiGPUPresharded(const std::string& model_path, M
295
295
auto it = param_info_map.find (param_name);
296
296
CHECK (it != param_info_map.end ()) << " ValueError: Cannot find parameter: " << param_name;
297
297
const ParamInfo& param_info = (*it).second ;
298
- const NDArrayCacheMetadata ::FileRecord::ParamRecord* param_record = param_info.param ;
299
- const NDArrayCacheMetadata ::FileRecord* file_record = param_info.file ;
298
+ const TensorCacheMetadata ::FileRecord::ParamRecord* param_record = param_info.param ;
299
+ const TensorCacheMetadata ::FileRecord* file_record = param_info.file ;
300
300
301
301
if (file_record != current_file_) {
302
302
current_file_ = file_record;
0 commit comments