diff --git a/exllamav2/module.py b/exllamav2/module.py index d5daef22..49da109b 100644 --- a/exllamav2/module.py +++ b/exllamav2/module.py @@ -89,11 +89,13 @@ def load_multi(self, for v, ks in submap_i.items(): stfile = STFile.open(v, keymap = self.model.config.arch.keymap) - for k in ks: - if measure: + if measure: + for k in ks: size += stfile.measure(key + "." + k) - else: - tensors[k] = stfile.get_tensor(key + "." + k, device = self.device() if not cpu else "cpu") + else: + loaded = stfile.get_tensors([key + "." + k for k in ks], device = self.device() if not cpu else "cpu") + for k, tensor in zip(ks, loaded.values()): + tensors[k] = tensor return size if measure else tensors diff --git a/exllamav2/stloader.py b/exllamav2/stloader.py index 43a60823..8293aca8 100644 --- a/exllamav2/stloader.py +++ b/exllamav2/stloader.py @@ -165,4 +165,48 @@ def get_tensor( ) if out_dtype: tensor = tensor.to(out_dtype) - return tensor \ No newline at end of file + return tensor + + def get_tensors( + self, + keys: list, + device, + out_dtypes = None + ) -> dict: + """ + Batch load multiple tensors from file. + + :param keys: + List of tensor names + + :param device: + Target device + + :param out_dtypes: + Optional list of output dtypes (or None for each) + + :return: + dict of {key: tensor} + """ + tensors = {} + if out_dtypes is None: + out_dtypes = [None] * len(keys) + for key, out_dtype in zip(keys, out_dtypes): + h = self.header[key] + dtype, esize = convert_dtype(h["dtype"]) + beg, end = h["data_offsets"] + size = end - beg + shape = h["shape"] + tensor = torch.empty(shape, dtype = dtype, device = device) + torch.cuda.synchronize() + assert tensor.is_contiguous, "Non-contiguous tensor" + ext_c.stloader_read( + self.filename, + beg + self.header_size, + size, + tensor + ) + if out_dtype: + tensor = tensor.to(out_dtype) + tensors[key] = tensor + return tensors