diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp
index 12ad9d810327f..abe1ba46ac30b 100644
--- a/ggml/src/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan.cpp
@@ -148,6 +148,7 @@ struct vk_device_struct {
     vk::PhysicalDeviceProperties properties;
     std::string name;
     uint64_t max_memory_allocation_size;
+    uint32_t force_heap_index;
     bool fp16;
     vk::Device device;
     uint32_t vendor_id;
@@ -1008,9 +1009,12 @@ static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) {
     q.cmd_buffer_idx = 0;
 }
 
-static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
+static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags, uint32_t force_heap_index = UINT32_MAX) {
     for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
         vk::MemoryType memory_type = mem_props->memoryTypes[i];
+        if (force_heap_index != UINT32_MAX && memory_type.heapIndex != force_heap_index) {
+            continue;
+        }
         if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
             (flags & memory_type.propertyFlags) == flags &&
             mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
@@ -1053,11 +1057,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
 
     uint32_t memory_type_index = UINT32_MAX;
 
-    memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
+    memory_type_index = find_properties(&mem_props, &mem_req, req_flags, device->force_heap_index);
     buf->memory_property_flags = req_flags;
 
     if (memory_type_index == UINT32_MAX && fallback_flags) {
-        memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
+        memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags, device->force_heap_index);
         buf->memory_property_flags = fallback_flags;
     }
 
@@ -1851,6 +1855,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
             device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
         }
 
+        const char* GGML_VK_FORCE_HEAP_INDEX = getenv("GGML_VK_FORCE_HEAP_INDEX");
+
+        if (GGML_VK_FORCE_HEAP_INDEX != nullptr) {
+            device->force_heap_index = std::stoi(GGML_VK_FORCE_HEAP_INDEX);
+        } else {
+            device->force_heap_index = UINT32_MAX;
+        }
+
         device->vendor_id = device->properties.vendorID;
         device->subgroup_size = subgroup_props.subgroupSize;
         device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;