Skip to content

Implement automatic NGL detection #6502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
invalid_param = true;
return true;
}
params.n_gpu_layers = std::stoi(argv[i]);
std::string argValue = argv[i];
if (argValue == "auto" || argValue == "a") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it can be a breaking change, but I would prefer to have this approach as the default. i.e. if -ngl is not passed: automatically offload the maximum possible layers to VRAM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be. If someone doesn't want that, they could simply -ngl 0 or just not compile with GPU args passed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but this is just my personal point of view. @ggerganov or @slaren would have a better global view

params.n_gpu_layers = -2;
} else {
params.n_gpu_layers = std::stoi(argValue);
}
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
Expand Down Expand Up @@ -1407,6 +1412,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
if (llama_supports_gpu_offload()) {
printf(" -ngl N, --n-gpu-layers N\n");
printf(" number of layers to store in VRAM\n");
printf(" set to 'auto' or 'a' to estimate max automatically based on VRAM size\n");
printf(" -ngld N, --n-gpu-layers-draft N\n");
printf(" number of layers to store in VRAM for the draft model\n");
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
Expand Down Expand Up @@ -2480,7 +2486,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
fprintf(stream, "n_gpu_layers: %d # default: -1, auto: -2\n", params.n_gpu_layers);
fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);
fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs);
fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false");
Expand Down
2 changes: 1 addition & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct gpt_params {
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default, -2 - determine automatically)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
Expand Down
5 changes: 5 additions & 0 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2612,6 +2612,11 @@ GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, si
CUDA_CHECK(cudaMemGetInfo(free, total));
}

GGML_CALL void ggml_backend_cuda_get_free_device_memory(int device, size_t * free) {
size_t total;
ggml_backend_cuda_get_device_memory(device, free, &total);
}

GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
return false;
Expand Down
1 change: 1 addition & 0 deletions ggml-cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type
GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void);
GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
GGML_API GGML_CALL void ggml_backend_cuda_get_free_device_memory(int device, size_t * free);

GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
Expand Down
11 changes: 11 additions & 0 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16022,6 +16022,17 @@ catch (sycl::exception const &exc) {
std::exit(1);
}

GGML_CALL void ggml_backend_sycl_get_free_device_memory(int device, size_t *free) try {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_free_device_memory\n");
size_t total;
ggml_backend_sycl_get_device_memory(device, free, &total);
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}

////////////////////////////////////////////////////////////////////////////////

// backend interface
Expand Down
1 change: 1 addition & 0 deletions ggml-sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ GGML_API GGML_CALL void ggml_sycl_get_gpu_list(int *id_list, int max_len);
GGML_API GGML_CALL void ggml_sycl_get_device_description(int device, char *description, size_t description_size);
GGML_API GGML_CALL int ggml_backend_sycl_get_device_count();
GGML_API GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
GGML_API GGML_CALL void ggml_backend_sycl_get_free_device_memory(int device, size_t *free);
GGML_API GGML_CALL int ggml_backend_sycl_get_device_index(int device_id);

// TODO: these are temporary
Expand Down
15 changes: 15 additions & 0 deletions ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5781,6 +5781,21 @@ GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size
}
}

GGML_CALL void ggml_backend_vk_get_free_device_memory(int device, size_t * free) {
GGML_ASSERT(device < (int) vk_instance.device_indices.size());

vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];

vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();

for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
*free = heap.size;
break;
}
}
}

// backend registry
GGML_CALL static ggml_backend_t ggml_backend_reg_vk_init(const char * params, void * user_data) {
ggml_backend_t vk_backend = ggml_backend_vk_init((int) (intptr_t) user_data);
Expand Down
1 change: 1 addition & 0 deletions ggml-vulkan.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ GGML_API GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend);
GGML_API GGML_CALL int ggml_backend_vk_get_device_count(void);
GGML_API GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
GGML_API GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
GGML_API GGML_CALL void ggml_backend_vk_get_free_device_memory(int device, size_t * free);

GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
Expand Down
122 changes: 122 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1650,6 +1650,28 @@ static size_t llama_get_device_memory(int device) {
#endif
}

// TODO: implement for other backends to return free memory
static size_t llama_get_available_device_memory(int device) {
#if defined(GGML_USE_CUDA)
size_t free;
ggml_backend_cuda_get_free_device_memory(device, &free);
return free;
#elif defined(GGML_USE_SYCL)
size_t total;
size_t free;
ggml_backend_sycl_get_free_device_memory(device, &total, &free);
return free;
#elif defined(GGML_USE_VULKAN)
size_t total;
size_t free;
ggml_backend_vk_get_free_device_memory(device, &total, &free);
return free;
#else
return 1;
GGML_UNUSED(device);
#endif
}

//
// globals
//
Expand Down Expand Up @@ -3254,6 +3276,17 @@ struct llama_model_loader {
return cur;
}

size_t estimate_tensor_bytes(const std::string & name, const std::vector<int64_t> & ne, bool required = true) const {
const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);

if (cur == NULL) {
return 0;
}

int64_t nbytes = ggml_nbytes(cur);
return nbytes;
}

struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, bool required = true) {
const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);

Expand Down Expand Up @@ -4329,6 +4362,90 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
}

static int llm_determine_max_ngl(const llama_model_loader & ml, const llama_model & model, const int main_gpu, enum llama_split_mode split_mode) {
const auto & hparams = model.hparams;

// could become negative - use signed size_t
ssize_t available_gpu_memory = 0;
int n_layer = hparams.n_layer;

if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
// veeery sketchy, there has to be a better way to do this
int deice_count = llama_get_device_count();
for (int i = 0; i < deice_count; ++i) {
available_gpu_memory += llama_get_available_device_memory(i);
}
} else {
available_gpu_memory = llama_get_available_device_memory(main_gpu);
}

// "avoid a scenario where an application ooms because llama.cpp only left 5 MB of VRAM" - https://github.com/ggerganov/llama.cpp/pull/6502#discussion_r1555060962
available_gpu_memory -= 50 * MiB;

// determine tensor sizes
size_t total_size = 0;
size_t extra_size = 0;
std::vector<size_t> layer_sizes(n_layer);
for (int i = 0; i < ml.n_tensors; ++i) {
const struct ggml_tensor * tensor = (const struct ggml_tensor *) ml.get_tensor_meta(i);
std::vector<int64_t> tensor_ne = { tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3] };
size_t tensor_size = ml.estimate_tensor_bytes(tensor->name, tensor_ne);
//LLAMA_LOG_INFO("%s: %-30s: %12ld\n", __func__, tensor->name, tensor_size);
// all layer specific tensor names have the prefix "blk."
if (strncmp(tensor->name, "blk.", 4) == 0) {
int layer_no;
sscanf(tensor->name+4, "%d", &layer_no);
layer_sizes[layer_no] += tensor_size;
} else {
extra_size += tensor_size;
}
total_size += tensor_size;
}

// TODO: get buffer size dynamically
size_t buf_size = 400 * MiB;
size_t buf_size_k = 200 * MiB;
size_t buf_size_v = 200 * MiB;

ssize_t buffer_size = buf_size + buf_size_k + buf_size_v;

// can/will be pretty big for large models
size_t n_ctx = hparams.n_ctx_train;

size_t ctx_size =
ggml_tensor_overhead()*(ml.n_tensors + 1) +
ggml_tensor_overhead()*hparams.n_expert*n_layer;

size_t context_size = n_ctx*ctx_size;

// Calculate the maximum number of layers that can fit into the available GPU memory
int max_ngl = 0;
ssize_t used_memory = extra_size+buffer_size+context_size;
for (int i = 0; i < n_layer; ++i) {
LLAMA_LOG_INFO("%s: layer %2d size: %12ld\n", __func__, i, layer_sizes[i]);
used_memory += layer_sizes[i];
if (used_memory > available_gpu_memory) {
break;
}
max_ngl++;
}

LLAMA_LOG_INFO("%s: extra size: %12ld\n", __func__, extra_size);
LLAMA_LOG_INFO("%s: ----------------------------------\n", __func__);
LLAMA_LOG_INFO("%s: total size: %12ld\n", __func__, total_size);
LLAMA_LOG_INFO("%s: available_gpu_memory: %12ld\n", __func__, available_gpu_memory);
LLAMA_LOG_INFO("%s: buffer size: %12ld\n", __func__, buffer_size);
LLAMA_LOG_INFO("%s: total context size: %12ld\n", __func__, context_size);
LLAMA_LOG_INFO("%s: used memory: %12ld\n", __func__, used_memory);
LLAMA_LOG_INFO("%s: ----------------------------------\n", __func__);
LLAMA_LOG_INFO("%s: tokens in context: %ld\n", __func__, n_ctx);
LLAMA_LOG_INFO("%s: per token context size: %ld\n", __func__, ctx_size);
LLAMA_LOG_INFO("%s: layer_count: %d\n", __func__, n_layer);
LLAMA_LOG_INFO("%s: max_ngl: %d\n", __func__, max_ngl);

return max_ngl;
}

// Returns false if cancelled by progress_callback
static bool llm_load_tensors(
llama_model_loader & ml,
Expand All @@ -4344,6 +4461,11 @@ static bool llm_load_tensors(

auto & hparams = model.hparams;

if (n_gpu_layers == -2) {
n_gpu_layers = llm_determine_max_ngl(ml, model, main_gpu, split_mode);
LLAMA_LOG_INFO("%s: automatically set n_gpu_layers to %d\n", __func__, n_gpu_layers);
}

model.split_mode = split_mode;
model.main_gpu = main_gpu;
model.n_gpu_layers = n_gpu_layers;
Expand Down