diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 31a19e148..8f7009784 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -34,6 +34,8 @@ class GPTBatch: sequence_lengths: list[torch.Tensor] | None = None images: list[torch.Tensor] | None = None image_positions: list[torch.Tensor] | None = None + audio: list[torch.Tensor] | None = None + audio_positions: list[torch.Tensor] | None = None def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: @@ -54,16 +56,34 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling batch_images.append([]) batch_image_positions = [] for sample in batch: - if sample.image_positions is not None: + if sample.image_positions is not None and len(sample.image_positions) > 0: batch_image_positions.append(torch.from_numpy(sample.image_positions)) else: batch_image_positions.append([]) + + has_audio = False + batch_audio = [] + for sample in batch: + if sample.audio is not None and sample.audio_positions is not None: + batch_audio.append([torch.from_numpy(audio) for audio in sample.audio]) + has_audio = True + else: + batch_audio.append(None) + batch_audio_positions = [] + for sample in batch: + if sample.audio_positions is not None: + batch_audio_positions.append(torch.from_numpy(sample.audio_positions)) + else: + batch_audio_positions.append([]) + return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, images=batch_images if has_images else None, image_positions=batch_image_positions if has_images else None, + audio=batch_audio if has_audio else None, + audio_positions=batch_audio_positions if has_audio else None, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index bb3ff717a..357623b11 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -76,8 +76,13 @@ class GPTSamplingParameters(SamplingParameters): cross_document_attention: bool = True patch_size: int | None = None image_size: int | None = None + aud_downsampling_k: int | None = None + aud_padding_duration: int | None = None + aud_sampling_rate: int | None = None image_break_token: int | None = None image_end_token: int | None = None + audio_start_token: int | None = None + audio_end_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 @@ -204,6 +209,11 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="Expected number of pixels in the dataset.", hint=FieldHint.optional, ) + num_audio: int | None = Field( + default=None, + desc="Expected number of audio in the dataset.", + hint=FieldHint.optional, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 56c4c8927..a2bd9977a 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -44,8 +44,12 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else [] + doc_sizes, im_sizes, aud_sizes = self._dataset.get_document_sizes() + return ( + doc_sizes[self._begin : self._end], + im_sizes[self._begin : self._end] if im_sizes else [], + aud_sizes[self._begin : self._end] if aud_sizes else [], + ) def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 703809417..c47d3cf64 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -5,12 +5,13 @@ import numpy as np import PIL.Image +import torchaudio +import soundfile as sf from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert, div @@ -50,7 +51,7 @@ def _init( with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 4: self._has_images = struct.unpack("= 5: + self._has_audio = struct.unpack("= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._image_lengths = [] - self._image_positions = [] images_seen = 0 for n_images in self._n_images: self._image_lengths.append( @@ -139,15 +141,50 @@ def _init( ) ) images_seen += n_images + offset = offset + self._n_images.nbytes + 3 * self._n_images.sum() * np.dtype(np.int32).itemsize + self._audio_lengths = [] # list of arrays + self._audio_positions = [] # list of arrays + if self._has_audio and self._version >= 5: + self._n_audio = np.frombuffer( + self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset + ) + audio_seen = 0 + + offset = offset + self._n_audio.nbytes + for n_audio in self._n_audio: + self._audio_lengths.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_audio, + offset=offset + audio_seen * np.dtype(np.int32).itemsize, + ) + ) + # self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum() + self._audio_positions.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_audio, + offset=offset + + self._n_audio.sum() * np.dtype(np.int32).itemsize + + audio_seen * np.dtype(np.int32).itemsize, + ) + ) + audio_seen += n_audio self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) + # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign + # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) + + # TODO Toby: Add audio num tokens check self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) - if num_pixels is not None: - assert self._num_pixels == num_pixels - if num_tokens is not None: - assert self._num_tokens == num_tokens + # if num_pixels is not None: + # assert self._num_pixels == num_pixels + # if num_tokens is not None: + # assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) @@ -221,12 +258,37 @@ def get( count=self._image_lengths[idx].prod(initial=3), offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, ) - images = [] start = 0 for image_length in self._image_lengths[idx]: n_pixels = image_length.prod(initial=3) images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels + + audio = [] + audio_positions = None + if self._has_audio: + audio_positions = self._audio_positions[idx] + # increment offset by documents and images + aud_offset = ( + self._pointers[idx] + + offset * np.dtype(self._dtype).itemsize + + self._document_sizes[idx] * np.dtype(self._dtype).itemsize + ) + + if self._has_images and len(self._image_lengths) > 0: + aud_offset += self._image_lengths[idx].prod(initial=3) * np.dtype(np.uint8).itemsize + all_audio = np.frombuffer( + self._bin_buffer, + dtype=np.dtype(np.float32), + count=self._audio_lengths[idx].sum(), + offset=aud_offset, + ) + start = 0 + for audio_length in self._audio_lengths[idx]: + audio.append(all_audio[start : start + audio_length]) + start += audio_length + + # TODO Soham: return loss_masking_spans sample_spans = None if use_loss_masking_spans and self._spans is not None: sample_spans = self._spans[idx] @@ -235,28 +297,43 @@ def get( ] sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - if images: - image_idx = 0 - for span in sample_spans: - additional_tokens = 0 - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - while image_position >= span[0] and image_position <= span[1]: - image_tokens = get_num_image_tokens( - get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), - patch_size, - image_break=image_break, - image_end=image_end, - ) - additional_tokens += image_tokens - image_idx += 1 - image_position = ( - image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - ) - span[1] += additional_tokens + # if images: + # image_idx = 0 + # for span in sample_spans: + # additional_tokens = 0 + # image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + # while image_position >= span[0] and image_position <= span[1]: + # image_tokens = get_num_image_tokens( + # get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), + # patch_size, + # image_break=image_break, + # ) + # additional_tokens += image_tokens + # image_idx += 1 + # image_position = ( + # image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + # ) + # span[1] += additional_tokens + # if audio: + # audio_idx = 0 + # for span in sample_spans: + # additional_tokens = 0 + # audio_position = audio_positions[audio_idx] if audio_idx < len(audio_positions) else float("inf") + # while audio_position >= span[0] and audio_position <= span[1]: + # audio_tokens = ... + # additional_tokens += audio_tokens + # audio_idx += 1 + # audio_position = ( + # audio_positions[audio_idx] if audio_idx < len(audio_positions) else float("inf") + # ) + # span[1] += additional_tokens + return GPTSample( token_ids=token_ids, images=images, image_positions=image_positions, + audio=audio, + audio_positions=audio_positions, loss_masking_spans=sample_spans, ) @@ -275,16 +352,28 @@ def num_tokens(self) -> int: def has_images(self) -> bool: return self._has_images + @property + def has_audio(self) -> bool: + return self._has_audio + def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes, self._image_lengths + return self._document_sizes, self._image_lengths, self._audio_lengths def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item(), self._image_lengths[index] if self._has_images else [] + # return self._document_sizes[index].item() + ( + # sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) + # if self._has_images + # else 0 + # ) + docsize = self._document_sizes[index].item() + imagesize = self._image_lengths[index] if self._has_images else [] + audiosize = self._audio_lengths[index] if self._has_audio else [] + return docsize, imagesize, audiosize @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): @@ -296,6 +385,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP image_lengths = [] im_positions = [] total_images = 0 + n_audio = [] + audio_lengths = [] + aud_positions = [] + total_audio = 0 pointers = [] offset = 0 # number of spans for each document @@ -319,6 +412,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) total_im_size = 0 + total_aud_size = 0 if document.images: n_images.append(len(document.images)) total_images += len(document.images) @@ -334,6 +428,21 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.append(document.image_positions) + if document.audio is not None: + num_audio = 0 + for audio in document.audio: + # audio_arr, _ = torchaudio.load(io.BytesIO(audio["bytes"])) + audio_arr, _ = sf.read(io.BytesIO(audio["bytes"])) + audio_arr = audio_arr.astype(np.float32) + if len(audio_arr) > 0: + num_audio += 1 + audio_lengths.append(len(audio_arr)) + bin_stream.write(audio_arr.tobytes(order="C")) + total_aud_size += audio_arr.size + n_audio.append(num_audio) + total_audio += num_audio + if num_audio > 0: + aud_positions += document.audio_positions # Update metadata doc_length = len(document.token_ids) @@ -342,7 +451,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) - offset += doc_length * np.dtype(dtype).itemsize + total_im_size * np.dtype(np.uint8).itemsize + offset += ( + doc_length * np.dtype(dtype).itemsize + + total_im_size * np.dtype(np.uint8).itemsize + + total_aud_size * np.dtype(np.float32).itemsize + ) num_documents += 1 # Finalize metadata arrays @@ -356,10 +469,21 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if total_images: n_images = np.array(n_images, dtype=np.int32) + image_lengths = np.stack(image_lengths, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) else: n_images = np.array([]) - image_lengths = np.stack(image_lengths, dtype=np.int32) - im_positions = np.array(im_positions, dtype=np.int32) + image_lengths = np.array([]) + im_positions = np.array([]) + + if total_audio: + n_audio = np.array(n_audio, dtype=np.int32) + audio_lengths = np.array(audio_lengths, dtype=np.int32) + aud_positions = np.array(aud_positions, dtype=np.int32) + else: + n_audio = np.array([]) + audio_lengths = np.array([]) + aud_positions = np.array([]) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: @@ -367,13 +491,15 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Indicates the version # Version 2 onwards optionally add loss-masking spans # Version 4 onwards optionally add images - idx_stream.write(struct.pack(" 0 else 0)) # Placeholder flag for preference spans idx_stream.write(struct.pack(" 0 else 0)) + # Flag to indicate whether audio is present + idx_stream.write(struct.pack(" 0 else 0)) # Data type idx_stream.write(struct.pack(" None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() + document_sizes, image_sizes, audio_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) + image_token_sizes = torch.zeros_like(document_sizes).to(self._device) if image_sizes: image_token_sizes = [] for i, sizes in enumerate(image_sizes): @@ -156,8 +164,24 @@ def _sample(self) -> None: else: image_token_sizes = torch.zeros_like(document_sizes) + audio_token_sizes = torch.zeros_like(document_sizes).to(self._device) + long_audio_filter = torch.zeros_like(document_sizes, dtype=torch.bool) # longer than audio padding + for i, sizes in enumerate(audio_sizes): + audio_token_size_arr, to_filter = get_num_audio_tokens( + sizes, + self._parameters.aud_padding_duration, + self._parameters.aud_sampling_rate, + self._parameters.aud_downsampling_k, + self._parameters.audio_start_token, + self._parameters.audio_end_token, + ) + audio_token_sizes[i] = audio_token_size_arr.sum() + long_audio_filter[i] = to_filter + documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() + tokens_per_epoch = ( + document_sizes.sum().item() + image_token_sizes.sum().item() + audio_token_sizes.sum().item() + ) # Calculate basic stats. if not self._truncate_documents: @@ -165,14 +189,31 @@ def _sample(self) -> None: "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes + image_token_sizes > self._parameters.sequence_length + 1 + long_docs_filter = ( + document_sizes + image_token_sizes + +audio_token_sizes > self._parameters.sequence_length + 1 + ) ignored_documents = long_docs_filter.sum() if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = (document_sizes[~long_docs_filter] + image_token_sizes[~long_docs_filter]).sum().item() + ignored_audio_samples = sum(long_audio_filter) + if ignored_audio_samples: + log_main_rank( + f" > {ignored_audio_samples}/{documents_per_epoch} samples contain audio longer than {self._parameters.aud_padding_duration} seconds and will be ignored.", + log_fn=logger.warning, + ) + long_docs_filter = long_docs_filter | long_audio_filter + tokens_per_epoch = ( + ( + document_sizes[~long_docs_filter] + + image_token_sizes[~long_docs_filter] + + audio_token_sizes[~long_docs_filter] + ) + .sum() + .item() + ) if tokens_per_epoch == 0: raise RuntimeError( f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." @@ -202,7 +243,7 @@ def _sample(self) -> None: shuffled_documents = documents_per_epoch * shuffled_epochs unshuffled_epochs = num_epochs - shuffled_epochs - yaml_data = { + yaml_data = { # TODO Toby: add audio "dataset": { "name": self._indexed_dataset.name, "documents_per_epoch": documents_per_epoch, @@ -296,7 +337,7 @@ def _sample(self) -> None: # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes + image_token_sizes, + document_sizes + image_token_sizes + audio_token_sizes, offset=0, # TODO: Allowing for max 100% extra tokens for padding, is that enough? dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), @@ -306,7 +347,7 @@ def _sample(self) -> None: unshuffled_tokens = 0 if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens + yaml_data["unshuffled_tokens"] = unshuffled_tokens * unshuffled_epochs self._load_yaml_data(yaml_data) if self._yaml_path is not None: self._yaml_path.parent.mkdir(parents=True, exist_ok=True) @@ -321,7 +362,14 @@ def _sample(self) -> None: ) ] + image_token_sizes[ - document_shuffling.to(torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32) + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) + ] + + audio_token_sizes[ + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) ], offset=self._unshuffled_tokens, # TODO: Allowing for max 100% extra tokens for padding, is that enough? @@ -414,8 +462,10 @@ def __getitem__(self, index: int) -> typing.Any: token_ids = [] loss_masking_spans = [] images = [] + audio = [] image_positions = [] - image_tokens_added = 0 + audio_positions = [] + mm_tokens_added = 0 text_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. @@ -424,7 +474,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) + text_size, image_lengths, audio_lengths = self._indexed_dataset.get_document_size(document_index) resized_image_lengths = [ get_resize_dims( @@ -445,20 +495,36 @@ def __getitem__(self, index: int) -> typing.Any: for image_length in resized_image_lengths ] image_tokens = sum(image_sizes) - document_size = text_size + image_tokens + + audio_token_size_arr, _ = get_num_audio_tokens( + audio_lengths, + self._parameters.aud_padding_duration, + self._parameters.aud_sampling_rate, + self._parameters.aud_downsampling_k, + self._parameters.audio_start_token, + self._parameters.audio_end_token, + ) + audio_tokens = int(audio_token_size_arr.sum()) + + document_size = text_size + image_tokens + audio_tokens if not self._truncate_documents: + # Document too long, ignore if document_size > self._parameters.sequence_length + 1: - # Document too long, ignore document_sampling_index += 1 continue + + # Where are we currently in sample? tokens_in_sample = token_count % (self._parameters.sequence_length + 1) if document_size + tokens_in_sample >= self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count >= token_start: # Add padding tokens to current sample - token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + try: + token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + except: + pass Assert.eq(token_count + padding_size, token_end) break else: @@ -480,20 +546,45 @@ def __getitem__(self, index: int) -> typing.Any: # image_end=self._parameters.image_end_token is not None, ) start_pos = 0 - if sample.image_positions: - for idx, im_position in enumerate(sample.image_positions): + + # add tokens and multi modal padding placeholders + # multimodal_positions = np.concatenate( + # [ + # arr.astype(np.int32) + # for arr in (sample.image_positions, sample.audio_positions) + # if arr is not None + # ] + # ) or np.array([], dtype=np.int32) + # multimodal_positions.sort() + + multimodal_positions = [] + if sample.image_positions is not None: + multimodal_positions.extend( + [(pos, "image", idx) for idx, pos in enumerate(sample.image_positions)] + ) + if sample.audio_positions is not None: + multimodal_positions.extend( + [(pos, "audio", idx) for idx, pos in enumerate(sample.audio_positions)] + ) + + token_ids_per_sample = [] + special_mm_tok_loss_masking_spans = np.empty((0, 2), dtype=np.int32) + multimodal_positions.sort(key=lambda x: x[0]) + for global_idx, (mm_position, mm_type, source_idx) in enumerate(multimodal_positions): + # Add placeholders for image and audio tokens tokens + token_ids_per_sample.append(sample.token_ids[start_pos:mm_position]) + text_tokens_added += len(token_ids_per_sample[-1]) + if mm_type == "image": # image_positions.append(im_positions + len(token_ids) + image_tokens_added) # Add placeholders for image tokens - token_ids.append(sample.token_ids[start_pos:im_position]) - text_tokens_added += len(token_ids[-1]) - image_positions.append(text_tokens_added + image_tokens_added) + image_positions.append(text_tokens_added + mm_tokens_added) if self._parameters.image_break_token is not None: - height, width = resized_image_lengths[idx] + height, width = resized_image_lengths[source_idx] num_patches_h = div(height, self._parameters.patch_size) num_patches_w = div(width, self._parameters.patch_size) # Create image token placeholder array - image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + image_token_array = np.full((image_sizes[source_idx],), -100, dtype=np.int64) # Add break tokens after each row except the last row for row in range(num_patches_h - 1): @@ -506,28 +597,129 @@ def __getitem__(self, index: int) -> typing.Any: else: image_token_array[last_row_position] = self._parameters.image_break_token else: - image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + image_token_array = np.full((image_sizes[source_idx],), -100, dtype=np.int64) if self._parameters.image_end_token is not None: image_token_array[-1] = self._parameters.image_end_token - token_ids.append(image_token_array) - image_tokens_added += image_sizes[idx] - start_pos = im_position - token_ids.append(sample.token_ids[start_pos:]) - text_tokens_added += len(token_ids[-1]) + token_ids_per_sample.append(image_token_array) + mm_tokens_added += image_sizes[source_idx] + elif mm_type == "audio": + # audio_pos = sum(t.size for t in token_ids) # includes mm tokens added already + # compute audio position + start_token_offset = int(self._parameters.audio_start_token is not None) + audio_pos = text_tokens_added + mm_tokens_added + start_token_offset + audio_positions.append(audio_pos) + + # compute number of special tokens + num_audio_special_tokens = int(self._parameters.audio_start_token is not None) + int( + self._parameters.audio_end_token is not None + ) + + # add start tokens + if self._parameters.audio_start_token is not None: + token_ids_per_sample.append(np.array([self._parameters.audio_start_token])) + # add to loss masking spans + special_mm_tok_loss_masking_spans = np.append( + special_mm_tok_loss_masking_spans, [[audio_pos - 1, audio_pos - 1]], axis=0 + ) + # sample.loss_masking_spans = np.append(sample.loss_masking_spans, [[audio_pos-1, audio_pos-1]], axis=0) + + # add audio pad tokens + num_audio_pad_tokens = audio_token_size_arr[source_idx] + num_audio_pad_tokens -= num_audio_special_tokens # ignore start/end tokens for padding + audio_padding_tokens = np.full((num_audio_pad_tokens,), -100, dtype=np.int64) + token_ids_per_sample.append(audio_padding_tokens) + + # add end token + if self._parameters.audio_end_token is not None: + token_ids_per_sample.append(np.array([self._parameters.audio_end_token])) + # add to loss masking spans + special_mm_tok_loss_masking_spans = np.append( + special_mm_tok_loss_masking_spans, + [[audio_pos + num_audio_pad_tokens, audio_pos + num_audio_pad_tokens]], + axis=0, + ) + # sample.loss_masking_spans = np.append(sample.loss_masking_spans, [[audio_pos + num_audio_pad_tokens, audio_pos + num_audio_pad_tokens]], axis=0) + + # update mm tokens added + mm_tokens_added += num_audio_special_tokens + num_audio_pad_tokens + start_pos = mm_position + + # add remaining text tokens + token_ids_per_sample.append(sample.token_ids[start_pos:]) + text_tokens_added += len(token_ids_per_sample[-1]) + + token_ids.append(np.concatenate(token_ids_per_sample)) if sample.images: images.append(sample.images) else: images.append([]) + if sample.audio: + # audio.append(self.apply_audio_padding(sample.audio)) + audio.append(sample.audio) + else: + audio.append([]) + if self._parameters.use_loss_masking_spans: + mm_idx = 0 + mm_tokens_before_span = 0 + + # sort by start of span + sample.loss_masking_spans = sample.loss_masking_spans[sample.loss_masking_spans[:, 0].argsort()] for loss_masking_span in sample.loss_masking_spans: + mm_tokens_within_span = 0 + mm_position, mm_type, source_idx = ( + multimodal_positions[mm_idx] + if mm_idx < len(multimodal_positions) + else (float("inf"), _, _) + ) + + # increment mm_idx until span is reached, track mm tokens before span + while mm_position < loss_masking_span[0]: + if mm_type == "image": + num_mm_tokens = image_sizes[source_idx] + elif mm_type == "audio": + num_mm_tokens = audio_token_size_arr[source_idx] + mm_tokens_before_span += num_mm_tokens + mm_idx += 1 + mm_position, mm_type, source_idx = ( + multimodal_positions[mm_idx] + if mm_idx < len(multimodal_positions) + else (float("inf"), _, _) + ) + + # get all multimodal positions within span + while mm_position >= loss_masking_span[0] and mm_position <= loss_masking_span[1]: + if mm_type == "image": + num_mm_tokens = image_sizes[source_idx] + elif mm_type == "audio": + num_mm_tokens = audio_token_size_arr[source_idx] + mm_tokens_within_span += num_mm_tokens + mm_idx += 1 + mm_position, mm_type, source_idx = ( + multimodal_positions[mm_idx] + if mm_idx < len(multimodal_positions) + else (float("inf"), _, _) + ) + loss_masking_span[0] += mm_tokens_before_span # increment by all mm tokens before span + loss_masking_span[1] += mm_tokens_before_span + mm_tokens_within_span + mm_tokens_before_span += mm_tokens_within_span + span = np.clip( - loss_masking_span + token_count - token_start, + loss_masking_span + int(token_count) - int(token_start), 0, self._parameters.sequence_length + self._parameters.extra_tokens, ) - if span[1] > span[0]: + if span[1] >= span[0]: loss_masking_spans.append(span) + for span in special_mm_tok_loss_masking_spans: + # span = np.clip( + # loss_masking_span + token_count - token_start, + # 0, + # self._parameters.sequence_length + self._parameters.extra_tokens, + # ) + if span[1] >= span[0]: + loss_masking_spans.append(span) # Go to the next document. document_sampling_index += 1 token_count += document_size @@ -543,16 +735,24 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) + images = [im for img_list in images for im in img_list] if images else None image_positions = np.array(image_positions) if image_positions else None - Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) + audio = [aud for aud_list in audio for aud in aud_list] if audio else None # flatten + audio_positions = np.array(audio_positions) if audio_positions else None + # Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) + + # # TODO: Toby remove/comment after testing (for testing only first sample) + # loss_masking_spans = np.append(loss_masking_spans, [[sequence_lengths[0], sequence_lengths[:-1].sum()]], axis=0) return GPTSample( token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths, images=images, image_positions=image_positions, + audio=audio if audio is not None and len(audio) > 0 else None, + audio_positions=audio_positions, ) @property diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 86bc080fe..53df3add1 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -35,18 +35,16 @@ def __len__(self) -> int: def __getitem__(self, idx) -> typing.Any: start_time = time.perf_counter() - try: - sample = self._dataset[idx] - sample_time = (time.perf_counter() - start_time) * 1000 - if sample_time > self._data_sample_warn_time_ms: - logger.warning( - f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" - ) - return sample - - except Exception: - logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") - raise + # try: + sample = self._dataset[idx] + sample_time = (time.perf_counter() - start_time) * 1000 + if sample_time > self._data_sample_warn_time_ms: + logger.warning(f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load") + return sample + + # except Exception as e: + # logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") + # raise @property def name(self) -> str: diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2e9243807..f4e722dcd 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -68,6 +68,12 @@ class GPTHuggingfaceDatasetConfig(Config): images: None | str = Field( default=None, desc="Field containing images relevant to a document", hint=FieldHint.optional ) + audio_positions: None | str = Field( + default=None, desc="Field containing audio positions within a document", hint=FieldHint.optional + ) + audio: None | str = Field( + default=None, desc="Field containing audio relevant to a document", hint=FieldHint.optional + ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index fa46ee92e..888e1b634 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -45,7 +45,11 @@ def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[t pass def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans, image_token_positions = map( + # input_ids = [ + # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) + # for text in batch[self._config.dataset.field] + # ] + input_ids, token_spans, image_token_positions, audio_token_positions = map( list, zip( *[ @@ -53,17 +57,20 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32).reshape(-1, 2), np.array(image_token_positions, dtype=np.int32), + np.array(audio_token_positions, dtype=np.int32), ) - for input_ids, token_spans, image_token_positions in [ + for input_ids, token_spans, image_token_positions, audio_token_positions in [ self._tokenizer.tokenize( text, loss_mask_spans, im_char_positions, + aud_char_positions, ) - for text, loss_mask_spans, im_char_positions in zip( + for text, loss_mask_spans, im_char_positions, aud_char_positions in zip( batch[self._config.dataset.field], batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), batch.get(self._config.dataset.image_positions, itertools.repeat(None)), + batch.get(self._config.dataset.audio_positions, itertools.repeat(None)), ) ] ] @@ -77,52 +84,59 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ width, height = im.size num_pixels[idx] += width * height * 3 + num_audio = [0] * len(input_ids) + for idx, audio_lst in enumerate(batch.get(self._config.dataset.audio, [])): + for audio in audio_lst: + num_audio[idx] += len(audio) + return { "input_ids": input_ids, "image_positions": image_token_positions, + "audio_positions": audio_token_positions, "token_spans": token_spans, "num_tokens": num_tokens, "num_pixels": num_pixels, + "num_audio": num_audio, } - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans, images, image_token_positions = map( - list, - zip( - *[ - ( - np.array(input_ids, dtype=self._data_type.numpy), - np.array(token_spans, dtype=np.int32).reshape(-1, 2), - np.array(images, dtype=np.uint8), - np.array(image_token_positions, dtype=np.int32), - ) - for input_ids, token_spans, images, image_token_positions in [ - self._tokenizer.tokenize_with_spans(text, char_spans) - for text, char_spans in zip( - batch[self._config.dataset.field], - batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), - batch.get(self._config.dataset.images, itertools.repeat(None)), - batch.get(self._config.dataset.image_positions, itertools.repeat(None)), - ) - ] - ] - ), - ) - num_tokens = [len(x) for x in input_ids] - num_pixels = [0] * len(input_ids) - for idx, images in enumerate(images): - for bytes_im in images: - with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: - width, height = im.size - num_pixels[idx] += width * height * 3 - return { - "input_ids": input_ids, - "token_spans": token_spans, - "images": images, - "image_positions": image_token_positions, - "num_tokens": num_tokens, - "num_pixels": num_pixels, - } + # def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + # input_ids, token_spans, images, image_token_positions = map( + # list, + # zip( + # *[ + # ( + # np.array(input_ids, dtype=self._data_type.numpy), + # np.array(token_spans, dtype=np.int32).reshape(-1, 2), + # np.array(images, dtype=np.uint8), + # np.array(image_token_positions, dtype=np.int32), + # ) + # for input_ids, token_spans, images, image_token_positions in [ + # self._tokenizer.tokenize_with_spans(text, char_spans) + # for text, char_spans in zip( + # batch[self._config.dataset.field], + # batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), + # batch.get(self._config.dataset.images, itertools.repeat(None)), + # batch.get(self._config.dataset.image_positions, itertools.repeat(None)), + # ) + # ] + # ] + # ), + # ) + # num_tokens = [len(x) for x in input_ids] + # num_pixels = [0] * len(input_ids) + # for idx, images in enumerate(images): + # for bytes_im in images: + # with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + # width, height = im.size + # num_pixels[idx] += width * height * 3 + # return { + # "input_ids": input_ids, + # "token_spans": token_spans, + # "images": images, + # "image_positions": image_token_positions, + # "num_tokens": num_tokens, + # "num_pixels": num_pixels, + # } def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: shard_idx, shard_dataset = args @@ -140,6 +154,8 @@ def _document_generator(): ), item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, + item[self._config.dataset.audio] if self._config.dataset.audio else None, + item[self._config.dataset.audio_positions] if self._config.dataset.audio_positions else None, ) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -151,19 +167,24 @@ def _document_generator(): "num_documents": len(shard_dataset), # Use the length of the shard dataset directly "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), "num_pixels": sum(doc["num_pixels"] for doc in shard_dataset), + "num_audio": sum(doc["num_audio"] for doc in shard_dataset), } ) def _load_dataset(self) -> datasets.Dataset: - dataset = datasets.load_dataset( - path=self._config.dataset.path, - name=self._config.dataset.config_name, - data_dir=self._config.dataset.data_directory, - data_files=self._config.dataset.data_files, - split=self._config.dataset.split, - num_proc=self._config.loading_workers, - trust_remote_code=self._config.dataset.trust_remote_code, - ) + try: + dataset = datasets.load_dataset( + path=self._config.dataset.path, + name=self._config.dataset.config_name, + data_dir=self._config.dataset.data_directory, + data_files=self._config.dataset.data_files, + split=self._config.dataset.split, + num_proc=self._config.loading_workers, + trust_remote_code=self._config.dataset.trust_remote_code, + ) + except: + # backup if dataset is saved in arrow format (can we auto-detect this?) + dataset = datasets.load_from_disk(dataset_path=self._config.dataset.data_directory) assert isinstance(dataset, datasets.Dataset) return dataset @@ -271,6 +292,8 @@ def run(self) -> None: # decoding bytes to images is slow and should be done only when needed if self._config.dataset.images is not None: dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) + if self._config.dataset.audio is not None: + dataset = dataset.cast_column("audio", datasets.Sequence(datasets.Audio(decode=False))) # Tokenize the dataset in parallel tokenized_dataset = dataset.map( @@ -278,6 +301,7 @@ def run(self) -> None: batched=True, num_proc=self._config.tokenize_workers, desc="Tokenizing batches", + load_from_cache_file=False # TODO Toby: remove ) # Calculate total number of tokens @@ -287,7 +311,15 @@ def run(self) -> None: if self._config.dataset.images else 0 ) + total_audio = ( + sum(tqdm.tqdm(tokenized_dataset["num_audio"], desc="Counting audio", unit="audio")) + if self._config.dataset.audio + else 0 + ) total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize + total_tokens += total_audio * np.float32().itemsize // np.dtype(self._data_type.numpy).itemsize + + tokenized_dataset = tokenized_dataset.shuffle(seed=42) # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 1cbc1ec56..e37b0e6d0 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -42,76 +42,121 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: + ([self.eod_id] if end else []) ) - def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[list[int], list[tuple[int, int]]]: + def tokenize( + self, text: str, char_spans=None, image_positions=None, audio_positions=None + ) -> tuple[list[int], list[tuple[int, int]]]: """ Tokenize the input text and return the tokenized input_ids and if provided, token spans and image positions. """ - if not image_positions: - image_positions = [] - if not char_spans: - char_spans = [] + image_positions = image_positions or [] + audio_positions = audio_positions or [] + char_spans = char_spans or [] - image_idx = 0 + if len(set(image_positions).intersection(audio_positions)) > 0: + raise ValueError("Image and audio can not have the same position.") + multimodal_positions = sorted(image_positions + audio_positions) + + mm_idx = 0 char_pos = 0 token_ids = [] image_token_positions = [] + audio_token_positions = [] token_spans = [] beginning_of_text = True - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + multimodal_position = multimodal_positions[mm_idx] if mm_idx < len(multimodal_positions) else float("inf") for start, end in char_spans: - while image_position <= start: - tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) + # tokenize text, compute mm token position before span + while multimodal_position <= start: + # tokenize text before mm position + tokenized_text = self._tokenize(text[char_pos:multimodal_position], begin=beginning_of_text, end=False) beginning_of_text = False token_ids.extend(tokenized_text) - image_token_positions.append(len(token_ids)) - image_idx += 1 - char_pos = image_position - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + + # update mm token positions + multimodal_type = "image" if multimodal_position in image_positions else "audio" + if multimodal_type == "image": + image_token_positions.append(len(token_ids)) + else: + audio_token_positions.append(len(token_ids)) + + # updates + mm_idx += 1 + char_pos = multimodal_position + multimodal_position = ( + multimodal_positions[mm_idx] if mm_idx < len(multimodal_positions) else float("inf") + ) + + # tokenize remaining text before span if char_pos < start: self._tokenize(text[char_pos:start], begin=beginning_of_text, end=False) beginning_of_text = False token_ids.extend(tokenized_text) + char_pos = start - len(token_ids) span_length = 0 token_start = len(token_ids) - while image_position <= end: - tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) + + # tokenize text, compute mm token position within span + while multimodal_position <= end: + # tokenize text before mm position + tokenized_text = self._tokenize(text[char_pos:multimodal_position], begin=beginning_of_text, end=False) beginning_of_text = False token_ids.extend(tokenized_text) - image_token_positions.append(len(token_ids)) + + # update mm token positions + multimodal_type = "image" if multimodal_position in image_positions else "audio" + if multimodal_type == "image": + image_token_positions.append(len(token_ids)) + else: + audio_token_positions.append(len(token_ids)) + + # updates span_length += len(tokenized_text) - char_pos = image_position - image_idx += 1 - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + char_pos = multimodal_position + mm_idx += 1 + multimodal_position = ( + multimodal_positions[mm_idx] if mm_idx < len(multimodal_positions) else float("inf") + ) + + # tokenize remaining text until end of span if char_pos < end: if end >= len(text) - 1: tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=True) - beginning_of_text = False - token_ids.extend(tokenized_text) - span_length += len(tokenized_text) - char_pos = end + 1 else: tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=False) - beginning_of_text = False - token_ids.extend(tokenized_text) - span_length += len(tokenized_text) - char_pos = end + 1 + beginning_of_text = False + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + char_pos = end + 1 + + # update token spans token_spans.append((token_start, token_start + span_length - 1)) - while image_position <= len(text): - image_position = image_positions[image_idx] - tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) + # tokenize text, compute mm token position after all spans + while multimodal_position <= len(text): + # tokenize text before mm position + multimodal_position = multimodal_positions[mm_idx] + tokenized_text = self._tokenize(text[char_pos:multimodal_position], begin=beginning_of_text, end=False) beginning_of_text = False token_ids.extend(tokenized_text) - image_token_positions.append(len(token_ids)) - char_pos = image_position - image_idx += 1 - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + + # update mm token positions + multimodal_type = "image" if multimodal_position in image_positions else "audio" + if multimodal_type == "image": + image_token_positions.append(len(token_ids)) + else: + audio_token_positions.append(len(token_ids)) + + # updates + char_pos = multimodal_position + mm_idx += 1 + multimodal_position = multimodal_positions[mm_idx] if mm_idx < len(multimodal_positions) else float("inf") + + # tokenize text after all spans tokenized_text = self._tokenize(text[char_pos:], begin=beginning_of_text, end=True) token_ids.extend(tokenized_text) - return token_ids, token_spans, image_token_positions + return token_ids, token_spans, image_token_positions, audio_token_positions def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 204abdf1c..f2acf9b60 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,6 +55,11 @@ class BatchConfig(Config): desc="Maximum image height and width", hint=FieldHint.optional, ) + aud_padding_duration: int = Field( + default=-1, + desc="Audio padding duration in seconds.", + hint=FieldHint.feature, + ) def setup(self, distributed_config: DistributedConfig) -> None: self._distributed = distributed_config diff --git a/fast_llm/layers/audio_encoder/adapter.py b/fast_llm/layers/audio_encoder/adapter.py new file mode 100644 index 000000000..bc4f8f00f --- /dev/null +++ b/fast_llm/layers/audio_encoder/adapter.py @@ -0,0 +1,87 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.triton.mlp import torch_mlp_activation +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderDimNames +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.tensor import TensorMeta, init_normal_ + + +class AudioAdapter(Layer): + """ + Vision adapter layer for the LLM. + """ + + def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): + super().__init__() + audio_hidden_dim = tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels) + input_dim = tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_input) + self._activation_type = config.adapter_activation_type + self._use_adapter_bias = config.adapter_bias + self.lr_scale = config.adapter_lr_scale + + self.norm_1 = config.transformer.normalization.get_layer(audio_hidden_dim) + self.norm_1.lr_scale = self.lr_scale + self.norm_2 = config.transformer.normalization.get_layer( + tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size) + ) + self.norm_2.lr_scale = self.lr_scale + + # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? + self.layer_1 = Linear( + input_dim, + tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size), + bias=self._use_adapter_bias, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + lr_scale=self.lr_scale, + ) + self.layer_2 = Linear( + tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size), + tensor_space.get_tensor_dim(TransformerDimNames.hidden), + bias=self._use_adapter_bias, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + lr_scale=self.lr_scale, + ) + + self.aud_downsampling_k = config.aud_downsampling_k + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Audio adapter output", + dtype=input_.dtype, + ) + input_ = self.norm_1(input_) + batch_size, seq_len, dim = input_.size() + + # Check if sequence length is divisible by downsampling rate. + if seq_len % self.aud_downsampling_k != 0: + # If not divisible, trim the end of the sequence. + trimmed_seq_len = seq_len - (seq_len % self.aud_downsampling_k) + input_ = input_[:, :trimmed_seq_len, :] + seq_len = trimmed_seq_len + + # Reshape: group every k frames together (concatenate along feature dimension). + new_seq_len = seq_len // self.aud_downsampling_k + input_ = input_.contiguous().view(batch_size, new_seq_len, dim * self.aud_downsampling_k) + layer1_res = torch_mlp_activation( + input_=self.layer_1(input_), gated=False, activation_type=self._activation_type + ) + torch.manual_seed(0) # TODO Toby: remove after debugging + layer1_res_dropout = torch.nn.functional.dropout(layer1_res, 0.1) + layer1_res_norm = self.norm_2(layer1_res_dropout) + layer2_res = self.layer_2(layer1_res_norm) + return layer2_res diff --git a/fast_llm/layers/audio_encoder/config.py b/fast_llm/layers/audio_encoder/config.py new file mode 100644 index 000000000..95665901e --- /dev/null +++ b/fast_llm/layers/audio_encoder/config.py @@ -0,0 +1,159 @@ +import enum + +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType +from fast_llm.layers.transformer.config import AudioTransformerConfig +from fast_llm.utils import Assert + + +class AudioEncoderDimNames: + in_channels = "audio_in_channels" + out_channels = "audio_out_channels" + kernel_size = "audio_kernel_size" + adapter_input = "audio_adapter_input" + adapter_size = "audio_adapter_size" + audio_channels = "audio_kv_channels" + max_source_positions = "audio_max_source_positions" + + +class AudioEncoderKwargs: + audio = "audio" + audio_mel = "audio_mel" + audio_positions = "audio_positions" + + kv_channels = "audio_kv_channels" # TODO: check this + out_channels = "audio_out_channels" + hidden_dims = "audio_hidden_dims" + + # TODO: used for backup attention + sequence_length = "audio_sequence_length" + sequence_k_dim = "audio_sequence_k_dim" + sequence_q_dim = "audio_sequence_q_dim" + + +class AudioEncoderType(str, enum.Enum): + none = "none" + whisper = "whisper" + + +@config_class() +class AudioEncoderConfig(BaseModelConfig): + _abstract = False + + type: AudioEncoderType = Field( + default=AudioEncoderType.none, + desc="Type of the audio encoder. Choices: none, whisper.", + hint=FieldHint.architecture, + ) + transformer: AudioTransformerConfig = Field( + default_factory=AudioTransformerConfig, + desc="Configuration for the audio transformer architecture.", + hint=FieldHint.core, + ) + + # encoder configs + conv_bias: bool = Field( + default=True, + desc="Whether to use bias in the convolutional layer.", + hint=FieldHint.optional, + ) + encoder_dropout: float = Field( + default=0.0, + desc="Dropout for encoder.", + hint=FieldHint.core, + ) + kernel_size: int = Field( + default=3, + desc="Encoder convolution layer kernel size.", + hint=FieldHint.core, + ) + conv_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the convolutional layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + pos_emb_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the position embedding layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + + # adapter configs + adapter_size: int = Field( + default=5120, + desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + hint=FieldHint.core, + ) + adapter_activation_type: ActivationType = Field( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) + adapter_bias: bool = Field( + default=True, + desc="Whether to use bias in the adapter layer.", + hint=FieldHint.optional, + ) + adapter_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the adapter weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + + # audio configs + num_mel_bins: int = Field( + default=80, + desc="Number of bins for mel spectogram.", + hint=FieldHint.core, + ) + aud_downsampling_k: int = Field( + default=5, + desc="Audio downsampling k parameter.", + hint=FieldHint.feature, + ) + aud_sampling_rate: int = Field( + default=16000, + desc="Audio sampling rate to use.", + hint=FieldHint.feature, + ) + + # audio start/end tokens + audio_start_token: int | None = Field( + default=None, + desc="Token id for audio start.", + hint=FieldHint.optional, + ) + audio_end_token: int | None = Field( + default=None, + desc="Token id for audio end.", + hint=FieldHint.optional, + ) + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.in_channels, self.num_mel_bins)) + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.kernel_size, self.kernel_size)) + tensor_space.add_tensor_dim( + TensorDim(AudioEncoderDimNames.adapter_input, self.transformer.hidden_size * self.aud_downsampling_k) + ) + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.adapter_size, self.adapter_size)) + tensor_space.add_tensor_dim( + TensorDim(AudioEncoderDimNames.max_source_positions, 1500) + ) # TODO: configure later + + tensor_space.add_tensor_dim( + TensorDim( + AudioEncoderDimNames.audio_channels, + self.transformer.hidden_size // self.transformer.num_attention_heads, + ) + ) + self.transformer.setup_tensor_space(tensor_space) + + @property + def enabled(self) -> bool: + return self.type != AudioEncoderType.none diff --git a/fast_llm/layers/audio_encoder/encoder.py b/fast_llm/layers/audio_encoder/encoder.py new file mode 100644 index 000000000..b35cc1740 --- /dev/null +++ b/fast_llm/layers/audio_encoder/encoder.py @@ -0,0 +1,93 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderDimNames +from fast_llm.layers.transformer.config import AudioTransformerKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ + + +class AudioConv(Layer): + def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): + super().__init__() + self._tensor_space = tensor_space + self.dropout_p = config.encoder_dropout + self._conv_lr_scale = config.conv_lr_scale + self._pos_emb_lr_scale = config.pos_emb_lr_scale + + self.conv1_weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), + ), + init_method=init_normal_(), + lr_scale=self._conv_lr_scale, + ) + self.conv1_stride = 1 # TODO Toby: parameterize? + + self.conv2_weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), + ), + init_method=init_normal_(), + lr_scale=self._conv_lr_scale, + ) + self.conv2_stride = 2 # TODO Toby: parameterize? + + if config.conv_bias: + self.conv1_bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_scale=self._conv_lr_scale, + ) + self.conv2_bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_scale=self._conv_lr_scale, + ) + else: + self.conv1_bias = None + self.conv2_bias = None + + self.positional_embeddings = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.max_source_positions), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + ), + init_method=init_normal_(), + lr_scale=self._pos_emb_lr_scale, + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + hidden_dims = kwargs[AudioTransformerKwargs.hidden_dims] # TODO: check seq q + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims(hidden_dims, tensor_name="audio conv output", dtype=input_.dtype) + + # TODO Toby: check how to best cast dtype + input_ = input_.to(self.conv1_weight.dtype) + + input_ = torch.nn.functional.conv1d( + input_, self.conv1_weight, self.conv1_bias, stride=self.conv1_stride, padding=1 + ) + input_ = torch.nn.functional.gelu(input_) + input_ = torch.nn.functional.conv1d( + input_, self.conv2_weight, self.conv2_bias, stride=self.conv2_stride, padding=1 + ) + input_ = torch.nn.functional.gelu(input_) + + audio_embeddings = input_.permute(0, 2, 1) + audio_embeddings = audio_embeddings + self.positional_embeddings + audio_embeddings = torch.nn.functional.dropout(audio_embeddings, p=self.dropout_p, training=self.training) + + return audio_embeddings.contiguous() diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py new file mode 100644 index 000000000..21262fe92 --- /dev/null +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -0,0 +1,153 @@ +import math +import typing + +import numpy as np +import torch +from transformers import WhisperFeatureExtractor + +from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderKwargs + + +def get_num_audio_tokens( + sizes, aud_padding_duration, aud_sampling_rate, aud_downsampling_k, audio_start_token, audio_end_token +): + if len(sizes) == 0: # sample has no audio + return np.array(sizes), False + to_filter = False + # account for padding + if aud_padding_duration > 0: + raw_audio_seq_length = aud_padding_duration * aud_sampling_rate + sizes = sizes.copy() # original is read-only + to_filter = bool(np.any(sizes > raw_audio_seq_length)) # filter sample where any audio is too long + sizes.fill(raw_audio_seq_length) # set all audio sizes to padded amount + + # account for mel spectogram, convolution, downsampling k + audio_token_size_arr = sizes // 160 # default hop length TODO Toby: check divisible? + audio_token_size_arr = audio_token_size_arr // ( + 2 * aud_downsampling_k + ) # convolution (2 stride) * downsampling TODO Toby: make configurable convolution + + if audio_start_token is not None: + audio_token_size_arr += 1 + if audio_end_token is not None: + audio_token_size_arr += 1 + return audio_token_size_arr, to_filter + + +def apply_audio_padding(audio, aud_padding_duration, aud_sampling_rate): + if len(audio) == 0: + return audio + # TODO Toby: check 2d + padded_audio = [] + if aud_padding_duration > 0: + raw_audio_seq_length = aud_padding_duration * aud_sampling_rate + for aud in audio: + padded = np.pad(aud, (0, raw_audio_seq_length - len(aud)), mode="constant", constant_values=0) + padded_audio.append(padded) + return padded_audio + else: + return audio + + +class AudioPreprocessor(Preprocessor): + def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + + self.feature_extractor = WhisperFeatureExtractor(sampling_rate=self._config.aud_sampling_rate) + + # self.mel_transform = MelSpectrogram( + # sample_rate=self._config.aud_sampling_rate, + # n_fft=400, + # win_length=400, + # hop_length=160, + # n_mels=80, + # f_min=0.0, + # f_max=8000.0, + # mel_scale="slaney", + # norm="slaney", + # center=True, + # power=2.0, + # ) + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + # kwargs[AudioEncoderKwargs.audio_mel_meta] = TensorMeta.from_dims( + # ( + # TensorDim( + # VisionTransformerDimNames.batch, + # kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, + # ), + # TensorDim(VisionEncoderDimNames.in_channels, 3), + # TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + # TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + # ), + # dtype=self._distributed_config.training_dtype.torch, + # ) + pass + + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: + # check if audio is in batch + audio_mel = [] + if AudioEncoderKwargs.audio in kwargs: + print("Preprocessing Contains Audio") + audio_raw = kwargs[AudioEncoderKwargs.audio] + flattened_audio = [ + audio_arr for sequence in audio_raw for audio_arr in sequence + ] # flatten in the batch dimension + print("Preprocessing Flattened Audio: ", flattened_audio) + + for audio in flattened_audio: + audio_mel.append( + self.feature_extractor( + audio, + sampling_rate=self._config.aud_sampling_rate, + return_tensors="pt", + max_length=30 * self._config.aud_sampling_rate, + device=self._tensor_space.distributed.device, + )["input_features"] + ) + audio_mel = torch.stack(audio_mel, dim=0).squeeze(1) + curr_size = audio_mel.size(0) + else: + print("Preprocessing No Audio") + audio_mel = torch.tensor(audio_mel, dtype=torch.float32) + curr_size = 0 + + print("Preprocessing Audio Mel Raw: ", audio_mel) + + # compute max pad + max_pad = math.ceil( + kwargs["sequence_length"] / (kwargs["audio_encoder_sequence_length"] // self._config.aud_downsampling_k) + ) + max_pad = 1 + max_pad = max(max_pad, curr_size) + + # add padding + padding_size = max_pad - curr_size + if padding_size > 0: + padding = torch.zeros( + padding_size, + self.feature_extractor.feature_size, + self.feature_extractor.nb_max_frames, + dtype=audio_mel.dtype, + device=audio_mel.device, + ) + audio_mel = torch.cat((audio_mel, padding), dim=0) + + print("Preprocessing Audio Mel Final: ", audio_mel) + + # move to device + audio_mel = audio_mel.to(self._tensor_space.distributed.device) + kwargs[AudioEncoderKwargs.audio_mel] = audio_mel + + # # set attention mask # TODO Toby: fix backup attention + # sequence_k = kwargs[self._transformer_kwargs.sequence_k_dim].size + # sequence_q = kwargs[self._transformer_kwargs.sequence_q_dim].size + # kwargs[self._transformer_kwargs.attention_mask] = self._mask[ + # None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k + # ] + # kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value + # audio_mel = torch.rand(len(flattened_audio), 80, 3000) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index cdb27d9ef..8ba066cb3 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,6 +5,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.vision_encoder.config import VisionEncoderConfig from fast_llm.utils import Assert @@ -51,6 +52,11 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) + audio_encoder: AudioEncoderConfig = Field( + default_factory=AudioEncoderConfig, + desc="Configuration for the audio encoder that transforms audio into embeddings.", + hint=FieldHint.optional, + ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -176,6 +182,8 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) if self.vision_encoder.enabled: self.vision_encoder.setup_tensor_space(tensor_space) + if self.audio_encoder.enabled: + self.audio_encoder.setup_tensor_space(tensor_space) @property def num_absolute_position_embeddings(self) -> int: diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 7f09347bf..d137de5ec 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -5,6 +5,7 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather, reduce_forward, split from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.audio_encoder.config import AudioEncoderKwargs from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.config import TransformerKwargs @@ -34,6 +35,7 @@ def _forward( position_ids: torch.Tensor | None, image_positions: list[torch.Tensor] | None, image_sizes: list[list[tuple[int, int]]] | None, + audio_positions: list[torch.Tensor] | None, ) -> torch.Tensor: """ Forward pass for the multi-modal embedding layer. @@ -61,6 +63,7 @@ def _forward( embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa embeddings = embeddings.clone() input_ = gather(input_, group, dim=0) + # TODO: Toby implement audio for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): @@ -155,6 +158,13 @@ def _forward( # Move to the next image in the input tensor image_embedding_offset += num_patches + audio_position_idx = 0 + for sample_idx, positions in enumerate(audio_positions): + for position in positions: + num_audio_tokens = input_.shape[1] # TODO: Toby better way to get this? + embeddings[sample_idx, position : position + num_audio_tokens] = input_[audio_position_idx] + audio_position_idx += 1 + if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) with set_generator( @@ -178,9 +188,11 @@ def forward( tensor_name="Embedding output", dtype=self._residual_dtype, ) + # TODO: How do we support both Audio and Vision? position_ids = kwargs.get(LanguageModelKwargs.position_ids) - image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) - image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes, []) + image_positions = kwargs.get(VisionEncoderKwargs.image_positions, []) + audio_positions = kwargs.get(AudioEncoderKwargs.audio_positions, []) tokens = kwargs.get(LanguageModelKwargs.tokens) - return self._forward(input_, tokens, position_ids, image_positions, image_sizes) + return self._forward(input_, tokens, position_ids, image_positions, image_sizes, audio_positions) diff --git a/fast_llm/layers/transformer/audio_transformer.py b/fast_llm/layers/transformer/audio_transformer.py new file mode 100644 index 000000000..f0fb6d17f --- /dev/null +++ b/fast_llm/layers/transformer/audio_transformer.py @@ -0,0 +1,40 @@ +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.transformer.config import AudioTransformerDimNames, AudioTransformerKwargs, TransformerConfig +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.tensor import TensorMeta + + +class AudioTransformerLayer(TransformerLayer): + """ + A audio transformer layer to encode image patches + """ + + def __init__( + self, + config: TransformerConfig, + tensor_space: TensorSpace, + layer_index: int, + return_input: bool = False, + ): + super().__init__(config, tensor_space, layer_index, return_input) + + hidden_dim = self._tensor_space.get_tensor_dim(AudioTransformerDimNames.hidden) + + # use regular layernorm (not rms norm) + self.norm_1 = self._config.normalization.get_layer(hidden_dim) + self.norm_2 = self._config.normalization.get_layer(hidden_dim) + + self.norm_1 = self._config.peft.apply_other(self.norm_1) + self.norm_2 = self._config.peft.apply_other(self.norm_2) + + @property + def name(self) -> str: + return f"Audio transformer layer {self._layer_index}" + + def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): + dims = kwargs[AudioTransformerKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index b8d153672..45d911a67 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -71,6 +71,10 @@ class VisionTransformerDimNames(BaseTransformerDimNames, prefix="image_encoder") pass +class AudioTransformerDimNames(BaseTransformerDimNames, prefix="audio_encoder"): + pass + + class BaseTransformerKwargs: _kwargs_attributes = { "rotary_freq_q": "rotary_freq_q", @@ -110,6 +114,10 @@ class VisionTransformerKwargs(BaseTransformerKwargs, prefix="image_encoder"): patch_position_ids = "patch_position_ids" +class AudioTransformerKwargs(BaseTransformerKwargs, prefix="audio_encoder"): + pass + + class TransformerLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" @@ -127,6 +135,7 @@ class RotaryEmbeddingType(str, enum.Enum): class TransformerType(str, enum.Enum): lm_decoder = "lm_decoder" image_encoder = "image_encoder" + audio_encoder = "audio_encoder" @config_class() @@ -217,6 +226,15 @@ def _transformer_kwargs(self) -> VisionTransformerKwargs: return VisionTransformerKwargs +# @config_class() +# class AudioRotaryConfig(RotaryConfig): +# type: RotaryEmbeddingType = Field( +# default=RotaryEmbeddingType.none, +# desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", +# hint=FieldHint.feature, +# ) + + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -308,7 +326,7 @@ class TransformerConfig(BaseModelConfig): _abstract = False transformer_type: TransformerType = Field( default=TransformerType.lm_decoder, - desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + desc="Type of the transformer. Choices: lm_decoder, image_encoder, audio_encoder.", hint=FieldHint.architecture, ) normalization: NormalizationConfig = Field( @@ -710,7 +728,17 @@ def _from_dict( cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) return super()._from_dict(default, strict, flat) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) -> None: + if type == "vision": + # TODO Soham: better way to get around circular imports? Maybe add a type class variable to TransformerConfig? + pass + + elif type == "audio": + pass + + else: + pass + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension @@ -809,7 +837,7 @@ class VisionTransformerConfig(TransformerConfig): transformer_type: TransformerType = FieldUpdate( default=TransformerType.image_encoder, - desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + desc="Type of the transformer. Choices: lm_decoder, image_encoder, audio_encoder.", hint=FieldHint.architecture, ) causal: bool = FieldUpdate( @@ -830,3 +858,39 @@ def _transformer_kwargs(self) -> VisionTransformerKwargs: @property def _transformer_dim_names(self) -> VisionTransformerDimNames: return VisionTransformerDimNames + + +@config_class() +class AudioTransformerConfig(TransformerConfig): + """ + Configuration for the Audio Transformer model. + """ + + transformer_type: TransformerType = FieldUpdate( + default=TransformerType.audio_encoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder, audio_encoder.", + hint=FieldHint.architecture, + ) + causal: bool = FieldUpdate( + default=False, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Audio Transformer.", + hint=FieldHint.feature, + ) + gated: bool = FieldUpdate( + default=False, + desc="MLP gating.", + hint=FieldHint.feature, + ) + # rotary: AudioRotaryConfig = FieldUpdate( + # default_factory=AudioRotaryConfig, + # desc="Configuration for the rotary positional embeddings.", + # hint=FieldHint.feature, + # ) + + @property + def _transformer_kwargs(self) -> AudioTransformerKwargs: + return AudioTransformerKwargs + + @property + def _transformer_dim_names(self) -> AudioTransformerDimNames: + return AudioTransformerDimNames diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index af1a53f68..1b436eba3 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -280,9 +280,9 @@ def _create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + self._create_tensors(kwargs[self._transformer_kwargs.sequence_length]) + sequence_k = kwargs[self._transformer_kwargs.sequence_k_dim].size + sequence_q = kwargs[self._transformer_kwargs.sequence_q_dim].size kwargs[self._transformer_kwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index d7d32221d..ae6fc6ad8 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -68,6 +68,16 @@ class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "pixtral" +class WhisperGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "whisper" + + +class AyraAudioModelGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "ayra_audio" + audio_name: typing.ClassVar[str] = "whisper" + text_name: typing.ClassVar[str] = "llama" + + @config_class() class GPTBatchConfig(BatchConfig): sequence_length: int = Field( @@ -150,7 +160,9 @@ class GPTModelConfig(FastLLMModelConfig): MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, LlavaGPTHuggingfaceCheckpointFormat, + WhisperGPTHuggingfaceCheckpointFormat, PixtralGPTHuggingfaceCheckpointFormat, + AyraAudioModelGPTHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 95bbebde2..568c78080 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -27,10 +27,12 @@ from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex +from fast_llm.layers.audio_encoder.config import AudioEncoderType from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig from fast_llm.layers.vision_encoder.config import VisionEncoderType from fast_llm.models.gpt.config import ( + AyraAudioModelGPTHuggingfaceCheckpointFormat, GPTBaseModelConfig, GPTModelConfig, LlamaGPTHuggingfaceCheckpointFormat, @@ -41,6 +43,7 @@ PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, + WhisperGPTHuggingfaceCheckpointFormat, ) from fast_llm.models.gpt.external.mtp_llama.configuration_mtp_llama import MTPLlamaConfig from fast_llm.models.gpt.model import GPTModel @@ -564,6 +567,225 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] +class WhisperHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = WhisperGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + # set default layernorm + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.layer_norm + ), + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["WhisperForConditionalGeneration"] + ), + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=AudioEncoderType.whisper), + # make transformer noncasual + ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_layers", + ), + ), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "hidden_size", + ), + ), + export_names=(("d_model",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_attention_heads", + ), + ), + export_names=(("encoder_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "head_groups", + ), + ), + export_names=(("encoder_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "ffn_hidden_size", + ), + ), + export_names=(("encoder_ffn_dim",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("activation_function",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.none + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=False), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=True), + RenameParamConverter( + fast_llm_names=(("num_mel_bins",),), + export_names=(("num_mel_bins",),), + ), + ] + + def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + # return [ + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.weight", f"{hf_prefix}fc1.weight"), + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.bias", f"{hf_prefix}fc1.bias"), + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.weight", f"{hf_prefix}fc2.weight"), + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.bias", f"{hf_prefix}fc2.bias"), + # ] + transformer_config = self._model.config.base_model.audio_encoder.transformer + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + f"{hf_prefix}fc1", + transformer_config.add_mlp_bias, + WeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}fc2", + transformer_config.add_mlp_bias, + MLPLayer2Converter, + ), + ] + + def _create_audio_transformer_layer_converters( + self, transformer_layer_index: int, fast_llm_offset: int = 1, hf_base_prefix: str = "" + ) -> list[WeightConverter]: + # Vision transformer layer + transformer_config = self._model.config.base_model.audio_encoder.transformer + norm_bias: bool = transformer_config.normalization.type == NormalizationType.layer_norm + name_bias_cls = [ + # Self-attn + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.query", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.q_proj", + transformer_config.add_attn_qkv_bias, + QueryWeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.key_value", + ( + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.k_proj", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.v_proj", + ), + transformer_config.add_attn_qkv_bias, # TODO Toby: add permanent fix for key bias + KeyValueWeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.dense", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.out_proj", + transformer_config.add_attn_dense_bias, + WeightConverter, + ), + # Norm + ( + f"layers.{fast_llm_offset + transformer_layer_index}.norm_1", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn_layer_norm", + norm_bias, + WeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.norm_2", + f"{hf_base_prefix}layers.{transformer_layer_index}.final_layer_norm", + norm_bias, + WeightConverter, + ), + ] + converters = [] + for fast_llm_prefix, hf_prefix, use_bias, cls in name_bias_cls: + converters += self._get_weight_and_bias_converters( + fast_llm_prefix, + hf_prefix, + use_bias, + cls, + ) + # MLP + converters += self._get_transformer_mlp_converters( + f"layers.{fast_llm_offset + transformer_layer_index}", + f"{hf_base_prefix}layers.{transformer_layer_index}.", + ) + return converters + + def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: + converters = [] + + # audio encoder conv + converters += [ + WeightConverter(f"layers.{offset}.conv1_weight", f"{hf_base_prefix}conv1.weight"), + WeightConverter(f"layers.{offset}.conv2_weight", f"{hf_base_prefix}conv2.weight"), + ] + + if self._model.config.base_model.audio_encoder.conv_bias: + converters += [ + WeightConverter(f"layers.{offset}.conv1_bias", f"{hf_base_prefix}conv1.bias"), + WeightConverter(f"layers.{offset}.conv2_bias", f"{hf_base_prefix}conv2.bias"), + ] + + # position embedding + converters.append( + WeightConverter(f"layers.{offset}.positional_embeddings", f"{hf_base_prefix}embed_positions.weight") + ) + + # transformer encoder layers + num_layers = self._model.config.base_model.audio_encoder.transformer.num_layers + for i in range(num_layers): + converters += self._create_audio_transformer_layer_converters(i, offset + 1, hf_base_prefix) + + offset = offset + num_layers + 1 + + # add final layernorm + if self._model.config.base_model.audio_encoder.transformer.normalization.type == NormalizationType.layer_norm: + converters += [ + WeightConverter(f"layers.{offset}.norm_1.weight", f"{hf_base_prefix}layer_norm.weight"), + WeightConverter(f"layers.{offset}.norm_2.weight", "encoder_projector.layer_norm.weight"), + WeightConverter(f"layers.{offset}.norm_1.bias", f"{hf_base_prefix}layer_norm.bias"), + WeightConverter(f"layers.{offset}.norm_2.bias", "encoder_projector.layer_norm.bias"), + ] + + # multimodal projector + converters.extend( + [ + WeightConverter(f"layers.{offset}.layer_1.weight", "encoder_projector.linear1.weight"), + WeightConverter(f"layers.{offset}.layer_2.weight", "encoder_projector.linear2.weight"), + ] + ) + if self._model.config.base_model.audio_encoder.adapter_bias: + converters.extend( + [ + WeightConverter(f"layers.{offset}.layer_1.bias", "encoder_projector.linear1.bias"), + WeightConverter(f"layers.{offset}.layer_2.bias", "encoder_projector.linear2.bias"), + ] + ) + + return converters + + @property + def num_layers(self) -> int: + # +2 for projector and conv layers + return self._model.config.base_model.audio_encoder.transformer.num_layers + 2 + + class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @@ -769,6 +991,148 @@ def num_layers(self) -> int: return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 +class AyraAudioModelHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = AyraAudioModelGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + cfg_dict = cls._load_config(config.path) + kwargs = {} + if "text_config" in cfg_dict: + text_kwargs = cls._import_config(cfg_dict["text_config"]) + kwargs.update(text_kwargs) + if "audio_config" in cfg_dict: + audio_kwargs = cls._import_config(cfg_dict["audio_config"]) + audio_kwargs = {tuple(["audio_encoder"] + list(key)): value for key, value in audio_kwargs.items()} + kwargs.update(audio_kwargs) + kwargs.update( + cls._import_config( + {key: value for key, value in cfg_dict.items() if key not in ("text_config", "audio_config")} + ) + ) + imported_model_config = cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + return CheckpointMetadata( + fast_llm_version=__version__, + model=cls._model_class, + format=config.format, + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + shards=["weights"], + ) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["AyraAudioModel"]), + # projector + MappedConfigParamConverter( + fast_llm_names=(("audio_encoder", "adapter_activation_type"),), + export_names=(("activation_function",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("audio_encoder", "adapter_size"),), + export_names=(("adapter_size",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "audio_encoder", + "aud_downsampling_k", + ), + ), + export_names=(("encoder_projector_ds_rate",),), + ), + ] + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: + handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + kwargs = {} + for converter in handler_cls._create_config_converters(): + try: + values = () + for export_name in converter.export_names: + try: + value = get_nested_dict_value(config, export_name) + except KeyError: + value = MISSING + values = values + (value,) + values = converter.import_params(values) + for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): + if value is MISSING: + raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") + if fast_llm_name in kwargs: + raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") + kwargs[fast_llm_name] = value + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return kwargs + + @classmethod + def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: + # TODO Toby: implement for audio + exported_config = {} + audio_handler_class = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.audio_name) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + for converter in audio_handler_class._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, ("audio_encoder",) + fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("audio_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in text_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("text_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return exported_config + + def _create_weight_converters(self): + audio_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.audio_name) + audio_handler = audio_handler_cls(self._model) # TODO Toby: are we calling this twice? + converters = audio_handler._create_weight_converters(hf_base_prefix="encoder.", offset=0) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.text_name) + text_handler = text_handler_cls(self._model) + converters.extend( + text_handler._create_weight_converters(hf_base_prefix="llm.", offset=audio_handler.num_layers) + ) + return converters + + class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @@ -1035,6 +1399,7 @@ class AutoGPTHuggingfaceCheckpointHandler( MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, + WhisperGPTHuggingfaceCheckpointFormat.name: WhisperHuggingfaceCheckpointHandler, PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, - # MultiModalGPTHuggingfaceCheckpointFormat.name: MultiModalHuggingfaceCheckpointHandler + AyraAudioModelGPTHuggingfaceCheckpointFormat.name: AyraAudioModelHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 45cf4a4fe..48f5760b6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,12 +10,19 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.audio_encoder.adapter import AudioAdapter +from fast_llm.layers.audio_encoder.config import AudioEncoderKwargs +from fast_llm.layers.audio_encoder.encoder import AudioConv +from fast_llm.layers.audio_encoder.preprocessing import AudioPreprocessor from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding +from fast_llm.layers.transformer.audio_transformer import AudioTransformerLayer from fast_llm.layers.transformer.config import ( + AudioTransformerDimNames, + AudioTransformerKwargs, RoutingType, TransformerDimNames, TransformerKwargs, @@ -84,6 +91,8 @@ def __init__( self._preprocessors.append( RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) ) + if self._config.audio_encoder.enabled: + self._preprocessors.append(AudioPreprocessor(self._config.audio_encoder, self._tensor_space)) def get_output_layers(self) -> list[Layer]: layers = [] @@ -122,12 +131,33 @@ def get_vision_layers(self) -> list[Layer]: MultiModalEmbedding(self._config, self._tensor_space), ] + def get_audio_layers(self) -> list[Layer]: + audio_conv = AudioConv(self._config.audio_encoder, self._tensor_space) + audio_layers = [ + AudioTransformerLayer(self._config.audio_encoder.transformer, self._tensor_space, layer_index=idx + 1) + for idx in range(self._config.audio_encoder.transformer.num_layers) + ] + return [ + audio_conv, + *audio_layers, + AudioAdapter(self._config.audio_encoder, self._tensor_space), + MultiModalEmbedding(self._config, self._tensor_space), + ] + + def get_multimodal_layers(self) -> list[Layer]: + if self._config.vision_encoder.enabled: + return self.get_vision_layers() + elif self._config.audio_encoder.enabled: + return self.get_audio_layers() + else: + assert False + def get_layers(self) -> list[Layer]: return [ *( [LanguageModelEmbedding(self._config, self._tensor_space)] - if not self._config.vision_encoder.enabled - else self.get_vision_layers() + if not self._config.vision_encoder.enabled and not self._config.audio_encoder.enabled + else self.get_multimodal_layers() ), *[ TransformerLayer( @@ -189,6 +219,18 @@ def preprocess_meta( else: vision_kwargs = {} + if self._config.audio_encoder.enabled: + audio_kwargs = { + AudioEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( + AudioTransformerDimNames.kv_channels + ).size, + AudioEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( + AudioEncoderKwargs.out_channels + ).size, + } + else: + audio_kwargs = {} + batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) @@ -244,6 +286,22 @@ def preprocess_meta( } ) + if self._config.audio_encoder.enabled: + audio_hidden_dim = self._tensor_space.get_tensor_dim(AudioTransformerDimNames.hidden) + audio_hidden_dims = ( + (hidden_sequence_q_dim, batch_dim, audio_hidden_dim) + if sequence_first + else (batch_dim, hidden_sequence_q_dim, audio_hidden_dim) + ) + audio_kwargs.update( + { + AudioTransformerKwargs.hidden_dims: audio_hidden_dims, + AudioTransformerKwargs.sequence_length: 1500, # TODO: Toby Parameterize + AudioTransformerKwargs.sequence_k_dim: 1500, + AudioTransformerKwargs.sequence_q_dim: 1500, + } + ) + common_kwargs = { LanguageModelKwargs.phase: phase, TransformerKwargs.sequence_first: sequence_first, @@ -253,6 +311,7 @@ def preprocess_meta( TransformerKwargs.micro_batch_size: micro_batch_size, } common_kwargs.update(vision_kwargs) + common_kwargs.update(audio_kwargs) sequence_k_pasts = range( sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, @@ -385,7 +444,7 @@ def preprocess( if batch.loss_masking_spans is not None: # avoid changing input tokens labels = labels.clone() - for i, spans in enumerate(batch.loss_masking_spans): + for idx, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue valid_spans = spans[ @@ -398,9 +457,9 @@ def preprocess( loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - loss_mask[start : end + 1, i] = False + loss_mask[start : end + 1, idx] = False else: - loss_mask[i, start : end + 1] = False + loss_mask[idx, start : end + 1] = False if self._config.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) @@ -425,10 +484,22 @@ def preprocess( ) kwargs[LanguageModelKwargs.tokens] = tokens + if self._config.audio_encoder.enabled: + if batch.audio is not None: + kwargs[AudioEncoderKwargs.audio] = [ + [aud.to(device="cpu", dtype=torch.float32, non_blocking=True) for aud in audio] + for audio in batch.audio + ] + kwargs[AudioEncoderKwargs.audio_positions] = batch.audio_positions + kwargs[LanguageModelKwargs.tokens] = tokens + for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) - if image_patches is not None: + audio_mel = kwargs.get(AudioEncoderKwargs.audio_mel, None) + if audio_mel is not None: + preprocessed.append((audio_mel, kwargs)) + elif image_patches is not None: preprocessed.append((image_patches, kwargs)) else: preprocessed.append((tokens, kwargs)) @@ -447,6 +518,8 @@ def transformer_layers(self) -> list[TransformerLayer]: def embedding_layer_index(self) -> int: if self._config.vision_encoder.enabled: return self._config.vision_encoder.transformer.num_layers + 2 + elif self._config.audio_encoder.enabled: + return self._config.audio_encoder.transformer.num_layers + 2 else: return 0 diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index a4f0b0b42..b4a3036fe 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -41,6 +41,16 @@ def _get_sampling_parameters( "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, } ) + if self._config.model.base_model.audio_encoder.enabled: + parameters.update( + { + "aud_downsampling_k": self._config.model.base_model.audio_encoder.aud_downsampling_k, + "aud_padding_duration": self._config.batch.aud_padding_duration, + "aud_sampling_rate": self._config.model.base_model.audio_encoder.aud_sampling_rate, + "audio_start_token": self._config.model.base_model.audio_encoder.audio_start_token, + "audio_end_token": self._config.model.base_model.audio_encoder.audio_end_token, + } + ) return parameters if _return_dict else GPTSamplingParameters(**parameters) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: