From f43824fca07b5b1fb2daa6b1fc1bcf7e31a8d0bc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 12 Apr 2024 22:53:20 -0400 Subject: [PATCH 01/42] wip for v0 for testing --- pyproject.toml | 33 ++++++ requirements.txt | 6 + src/__init__.py | 0 src/train.py | 172 ++++++++++++++++++++++++++++ src/voltronformer/__init__.py | 0 src/voltronformer/config.py | 26 +++++ src/voltronformer/core.py | 4 + src/voltronformer/mod.py | 50 ++++++++ src/voltronformer/model.py | 138 ++++++++++++++++++++++ src/voltronformer/train/__init__.py | 0 src/voltronformer/train/data.py | 96 ++++++++++++++++ src/voltronformer/utils.py | 38 ++++++ 12 files changed, 563 insertions(+) create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 src/__init__.py create mode 100644 src/train.py create mode 100644 src/voltronformer/__init__.py create mode 100644 src/voltronformer/config.py create mode 100644 src/voltronformer/core.py create mode 100644 src/voltronformer/mod.py create mode 100644 src/voltronformer/model.py create mode 100644 src/voltronformer/train/__init__.py create mode 100644 src/voltronformer/train/data.py create mode 100644 src/voltronformer/utils.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6db06c3 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,33 @@ +[project] +name = "voltronformers" +dynamic = ["version"] +requires-python = ">= 3.10" +dependencies = [ + "schedulefree", + "bitsandbytes", + "datasets", + "einops", + "flash-attn", + "wandb", + "tqdm", + "transformers==4.49.3", + "torch==2.2.1", + "axolotl @ git+https://github.com/openaccess-ai-collective/axolotl.git@main", +] +maintainers = [ + {name="Wing Lian", email="wing.lian@gmail.com"}, +] +description = "voltronformers: Assembling the best SotA AI techniques into a unified model" + +[project.optional-dependencies] +dev = [ + "tox", + "pre-commit", + "black", + "mypy", + "pytest", +] + +[build-system] +requires = ["flit_core<4"] +build-backend = "flit_core.buildapi" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4263a26 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +schedulefree +wandb +torch==2.2.1 +transformers==4.39.3 +datasets +axolotl @ git+https://github.com/openaccess-ai-collective/axolotl.git@main \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..e6c04d5 --- /dev/null +++ b/src/train.py @@ -0,0 +1,172 @@ +import functools +import os +from dataclasses import dataclass +from functools import partial +from typing import Optional + +import torch +import wandb +from accelerate import Accelerator, PartialState +from datasets import load_dataset +from schedulefree import AdamWScheduleFree +from torch.utils.data import DataLoader, RandomSampler +from tqdm import tqdm +from transformers import AutoTokenizer + +from src.voltronformer.config import tiny +from src.voltronformer.model import CausalLM +from src.voltronformer.train.data import wrap_pretraining_dataset +from src.voltronformer.utils import device_get_cuda, device_get_local_rank, get_cosine_schedule_with_min_lr_lambda + + +@dataclass +class TrainingArguments: + gradient_accumulation_steps: int = 1 + max_steps_per_epoch: Optional[int] = None + log_steps: int = 1 + output_dir: Optional[str] = None + weight_decay: float = 0.0 + warmup_steps: Optional[int] = 1000 + per_gpu_train_batch_size: Optional[int] = 1 + save_steps: Optional[int] = 5_000 + max_sequence_length: Optional[int] = 8192 + +class Trainer: + def __init__(self, model, args, dataloader, accelerator): + self.args = args + self._model = model + self.build_optimizer_and_scheduler() + self._model, self.dataloader, self.optimizer = accelerator.prepare(self._model, dataloader, self.optimizer) + + self.device = device_get_cuda() + self.global_step = 0 + self.rank = device_get_local_rank() + wandb.init() + self.accelerator = accelerator + + def build_optimizer_and_scheduler(self): + self.optimizer = AdamWScheduleFree(self._model.parameters(), lr=self.args.learning, weight_decay=self.args.weight_decay, warmup_steps=self.args.weight_decay) + self.lr_scheduler = None + + def _loss_fn(self, logits, labels): + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) + return loss + + def save_checkpoint(self): + output_dir = self.args.output_dir if self.args.output_dir is not None else "." + torch.save( + self._model.state_dict(), + os.path.join(output_dir, f"model_{self.global_step}.pt"), + ) + + def train(self): + self._model.train() + try: + self.optimizer.train() + except: + pass + + def train_loop(self, dataloader, rank): + for idx, batch in enumerate(pbar := tqdm(dataloader, disable=not (rank == 0))): + if ( + self.args.max_steps_per_epoch is not None + and (idx // self.args.gradient_accumulation_steps) + == self.args.max_steps_per_epoch + ): + break + + input_ids = batch["input_ids"].to(self.device) + labels = batch["labels"].to(self.device) + + logits = self._model(input_ids) + + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + + # Compute loss + loss = self._loss_fn(logits, labels) + + if ( + self.global_step * self.args.log_steps == 0 + and self.rank == 0 + ): + pbar.set_description(f"Loss: {loss.item()}") + wandb.log({"loss": loss.item(), "global_step": self.global_step}) + + loss = loss / self.args.gradient_accumulation_steps + loss.backward() + + if (idx + 1) % self.args.gradient_accumulation_steps == 0: + self.optimizer.step() + if self.lr_scheduler: + self.lr_scheduler.step() + self.optimizer.zero_grad(set_to_none=True) + self.global_step += 1 + + if self.global_step % self.args.save_steps == 0: + self.save_checkpoint() + + +def main(): + state = PartialState() + + ds = load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True) + args = TrainingArguments( + gradient_accumulation_steps=1, + max_steps_per_epoch=None, + log_steps=1, + output_dir="./out", + weight_decay=0.0, + warmup_steps=1000, + per_gpu_train_batch_size=1, + save_steps=10000, + ) + os.makedirs(args.output_dir, exist_ok=True) + + config = tiny() + model = CausalLM(config) + dataloader = DataLoader(ds) + tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base") + + def tokenize_function(examples, tokenizer=None): + outputs = tokenizer(examples["text"], truncation=True, max_length=None) + return outputs + + with state.main_process_first(): + ds_wrapper_partial = functools.partial( + tokenize_function, + tokenizer=tokenizer, + remove_columns=["text", "meta"], + ) + + train_dataset = wrap_pretraining_dataset( + ds, + tokenizer, + ds_wrapper_partial, + max_tokens=args.max_sequence_length, + batch_size=args.per_gpu_train_batch_size, + buffer_size=10_000, + ) + # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 + train_dataset = train_dataset.with_format("torch") + + accelerator = Accelerator() + + dataloader_params = dict( + sampler=RandomSampler(train_dataset), + batch_size=args.per_gpu_train_batch_size, + num_workers=8, + pin_memory=True, + drop_last=True, + collate_fn=None, + ) + dataloader = DataLoader(train_dataset, **dataloader_params) + + trainer = Trainer(model, args, dataloader, accelerator) + trainer.train_loop(dataloader, rank=0) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/voltronformer/__init__.py b/src/voltronformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py new file mode 100644 index 0000000..d2e0e1b --- /dev/null +++ b/src/voltronformer/config.py @@ -0,0 +1,26 @@ +def tiny(): + return { + "hidden_size": 1024, + "intermediate_size": 2816, + "max_position_embeddings": 4096, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 24, + "vocab_size": 100352, + "dwa_dilation": 4, + "dwa_period": 5, + "pad_token_id": 100277, # dbrx <{|pad|}> + "mod_every": 2, + "mod_capacity_factor": 0.125, + } + + +def medium(): + return { + "hidden_size": 1024, + "intermediate_size": 2816, + + "max_position_embeddings": 32768, + + "vocab_size": 100352 + } \ No newline at end of file diff --git a/src/voltronformer/core.py b/src/voltronformer/core.py new file mode 100644 index 0000000..2cd0d5e --- /dev/null +++ b/src/voltronformer/core.py @@ -0,0 +1,4 @@ +try: + from bitnet.bitlinear import BitLinear as Linear +except ImportError: + from torch.nn import Linear diff --git a/src/voltronformer/mod.py b/src/voltronformer/mod.py new file mode 100644 index 0000000..ac3a397 --- /dev/null +++ b/src/voltronformer/mod.py @@ -0,0 +1,50 @@ +""" +from https://github.com/epfml/llm-baselines/compare/main...mixture_of_depth +""" +import torch +from torch import nn + + +class MoDBlock(nn.Module): + def __init__(self, config, block_class): + super().__init__() + self.config = config + self.block = block_class(config) + self.router = nn.Linear(config.hidden_size, 1, bias=False) + self.capacity_factor = config.mod_capacity_factor + self.top_k =int(self.capacity_factor * config.max_position_embeddings) + + def forward(self, x, **kwargs): + # [batch_size, sequence_length, n_embd] + B, T, C = x.shape + # inference time optimization: sequence length can + # be smaller than seq len during training + top_k = min(self.top_k, int(self.capacity_factor * T)) + + """STEP 1: get logits and top_k tokens""" + # [batch_size, sequence_length, 1] + router_logits = self.router(x) + # weights and selected tokens: [batch_size, top_k, 1] + weights, selected_tokens = torch.topk(router_logits, top_k, dim=1, sorted=False) + # IMPORTANT: need to sort indices to keep causal order for those tokens that + # are processed in a block + selected_tokens, index = torch.sort(selected_tokens, dim=1) + weights = torch.gather(weights, dim=1, index=index) + + """STEP 2: expand indices to process batches with _reduced_ seqlen""" + # We need to expand indices' dimensions from + # [batch_size, top_k, 1] to [batch_size, top_k, n_embd] for gathering + indices_expanded = selected_tokens.expand(-1, -1, C) + # [batch_size, top_k, n_embd] + top_k_tokens = torch.gather(x, 1, indices_expanded) + top_k_tokens_processed = self.block(top_k_tokens, **kwargs) + + """STEP 3: combine results""" + x = torch.scatter_add( + x, + dim=1, + index=indices_expanded, + src=top_k_tokens_processed * weights, + ) + + return x \ No newline at end of file diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py new file mode 100644 index 0000000..122b701 --- /dev/null +++ b/src/voltronformer/model.py @@ -0,0 +1,138 @@ +import functools +from typing import List, Optional, Callable + +import torch +from torch import nn +from denseformer import DWAModules +from torch.utils.checkpoint import checkpoint + +from .core import Linear +from .mod import MoDBlock + + +class FeedForward(nn.Module): + def __init__(self, gate_proj: Linear, down_proj: Linear, up_proj: Linear): + super().__init__() + self.gate_proj = gate_proj + self.down_proj = down_proj + self.up_proj = up_proj + self.act_fn = nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class Attention(nn.Module): + def __init__(self, hidden_size: int, num_heads: int): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + + +class RMSNorm(nn.Module): + """copied from torchtune""" + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fp32 = x.float() + x_normed = ( + x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) + ).type_as(x) + return x_normed * self.scale + +def mlp(dim: int, hidden_dim: int) -> FeedForward: + """ + Build the MLP layer associated with the Llama model. + """ + gate_proj = Linear(dim, hidden_dim, bias=False) + down_proj = Linear(hidden_dim, dim, bias=False) + up_proj = Linear(dim, hidden_dim, bias=False) + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) + + +class TransformerDecoderBlock(nn.Module): + + def __init__(self, config): + super().__init__() + self.attn = Attention(config.hidden_size, config.num_attention_heads) + self.mlp = mlp(config.hidden_size, config.intermediate_size) + + + +class CheckpointingMixin(nn.Module): + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": False} + + gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + + def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint): + if hasattr(self, "gradient_checkpointing"): + self._gradient_checkpointing_func = gradient_checkpointing_func + self.gradient_checkpointing = enable + + +class Transformer(CheckpointingMixin): + supports_gradient_checkpointing = True + + def __init__(self, config): + super().__init__() + self.config = config + self.dwa_modules = DWAModules(config.num_hidden_layers, config.dwa_dilation, config.dwa_period) + self.wte = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + + self.h = nn.ModuleList([ + ( + MoDBlock(config, TransformerDecoderBlock) + if i % self.config.mod_every == 0 + else TransformerDecoderBlock(config) + ) + for i in range(config.num_hidden_layers) + ]) + self.ln_f = RMSNorm(config.hidden_size, eps=1e-6) + self.gradient_checkpointing = False + + def forward(self, x): + x = self.wte(x) + self.dwa_modules.init_accumulators(x) + for i, decoder_layer in enumerate(self.h): + # gradient checkpointing + if self.gradient_checkpointing and self.training: + x = self._gradient_checkpointing_func( + decoder_layer, + x, + ) + else: + x = decoder_layer(x) + x = self.dwa_modules(x, block_idx=i) + x = self.ln_f(x) + return x + + +class CausalLM(nn.Module): + def __init__(self, config): + super().__init__() + self.transformer = Transformer(config) + self.vocab_size = config.vocab_size + # should this use a BitLinear layer? + self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # tie weights + self.transformer.wte.weight = self.embed_out.weight + + def forward(self, x): + x = self.transformer(x) + logits = self.embed_out(x) + + return logits.float() + + def train(self, mode: bool = True): + """ + Override the default train() to enable gradient checkpointing. + """ + if mode: + self.transformer.gradient_checkpointing_enable() + return super().train(mode) diff --git a/src/voltronformer/train/__init__.py b/src/voltronformer/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/voltronformer/train/data.py b/src/voltronformer/train/data.py new file mode 100644 index 0000000..649d1a1 --- /dev/null +++ b/src/voltronformer/train/data.py @@ -0,0 +1,96 @@ +import functools +from collections import defaultdict +from typing import Callable, Dict, List, Optional + +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from datasets import Dataset +from torch.utils.data import RandomSampler + + +def wrap_pretraining_dataset( + dataset, + tokenizer, + ds_wrapper_fn, + max_tokens=2048, + batch_size=1, + buffer_size=10_000, +): + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding=True, + pad_to_multiple_of=max_tokens, + multipack_attn=False, + ) + encode = functools.partial( + encode_packed_pretraining, + collate_fn, + ds_wrapper_fn, + max_seq_length=max_tokens, + batch_size=batch_size, + multipack_attn=False, + ) + + # remove all the existing columns after mapping since they end up having + # a different length than the encoded/tokenized column + # this is empty during streaming/pretraining + remove_columns = [] + if dataset.features is None: + for first_row in dataset: + remove_columns = first_row.keys() + break + else: + remove_columns = dataset.features.keys() + + dataset = dataset.map( + encode, + batched=True, + batch_size=buffer_size, + remove_columns=remove_columns, + ) + return dataset + + +def encode_packed_pretraining( + collate_fn, + ds_wrapper: Callable, + examples: Dict[str, List], + max_seq_length: int = 2048, + batch_size: int = 4, + multipack_attn: Optional[bool] = False, +) -> Dict[str, List]: + # pylint: disable=duplicate-code + # tokenize all the examples + # rows get split with stride (overlap) + train_dataset = ds_wrapper(Dataset.from_dict(examples))[0] + + sampler = MultipackBatchSampler( + RandomSampler(train_dataset), + batch_size=1, + drop_last=True, + batch_max_len=batch_size * max_seq_length, + lengths=get_dataset_lengths(train_dataset), + ) + + chunked_data = defaultdict(list) + + for batch in sampler: + for data in batch: + features = train_dataset[data] + if "num_truncated_tokens" in features: + del features["num_truncated_tokens"] + if "num_truncated_tokens" in features: + del features["num_truncated_tokens"] + if "overflow_to_sample_mapping" in features: + del features["overflow_to_sample_mapping"] + if "labels" not in features: + features["labels"] = features["input_ids"].copy() + collated_features = collate_fn(features) + + for feature in features.keys(): + if feature == "length": + continue + chunked_data[feature].append(collated_features[feature].squeeze(0)) + + return chunked_data diff --git a/src/voltronformer/utils.py b/src/voltronformer/utils.py new file mode 100644 index 0000000..ad1c860 --- /dev/null +++ b/src/voltronformer/utils.py @@ -0,0 +1,38 @@ +import math +import os + +import torch + + +def device_get_local_rank(): + """ + Returns the local rank of the current device. + """ + local_rank = int(os.getenv("LOCAL_RANK", 0)) + return local_rank + + +def device_get_cuda(): + rank = device_get_local_rank() + device = torch.device(type="cuda", index=rank) + torch.cuda.set_device(device) + return device + + +def get_cosine_schedule_with_min_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + min_lr_ratio: float, +): + # Warm up + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + # Cosine learning rate decay + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + scaling = 0.5 * (1.0 + math.cos(math.pi * progress)) + return (1 - min_lr_ratio) * scaling + min_lr_ratio From b7cd6eab4d776ff74b110417e5e006652fb263b3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 12 Apr 2024 22:59:40 -0400 Subject: [PATCH 02/42] mini fixes --- src/train.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/train.py b/src/train.py index e6c04d5..39971b8 100644 --- a/src/train.py +++ b/src/train.py @@ -44,6 +44,15 @@ def __init__(self, model, args, dataloader, accelerator): wandb.init() self.accelerator = accelerator + @property + def model_num_parameters(self): + all_param = 0 + for _, param in self._model.named_parameters(): + num_params = param.numel() + all_param += num_params + + return all_param + def build_optimizer_and_scheduler(self): self.optimizer = AdamWScheduleFree(self._model.parameters(), lr=self.args.learning, weight_decay=self.args.weight_decay, warmup_steps=self.args.weight_decay) self.lr_scheduler = None @@ -60,12 +69,13 @@ def save_checkpoint(self): os.path.join(output_dir, f"model_{self.global_step}.pt"), ) - def train(self): + def train(self, dataloader, rank): self._model.train() try: self.optimizer.train() except: pass + self.train_loop(dataloader, rank) def train_loop(self, dataloader, rank): for idx, batch in enumerate(pbar := tqdm(dataloader, disable=not (rank == 0))): @@ -128,7 +138,6 @@ def main(): config = tiny() model = CausalLM(config) - dataloader = DataLoader(ds) tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base") def tokenize_function(examples, tokenizer=None): @@ -166,6 +175,7 @@ def tokenize_function(examples, tokenizer=None): dataloader = DataLoader(train_dataset, **dataloader_params) trainer = Trainer(model, args, dataloader, accelerator) + print("Total number of parameters: ", trainer.model_num_parameters) trainer.train_loop(dataloader, rank=0) if __name__ == "__main__": From 9b565253e0ef0998db5072d4ed1a502403d632bd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 12 Apr 2024 23:24:52 -0400 Subject: [PATCH 03/42] fix install and train --- pyproject.toml | 8 ++------ src/train.py => train.py | 8 ++++++-- 2 files changed, 8 insertions(+), 8 deletions(-) rename src/train.py => train.py (96%) diff --git a/pyproject.toml b/pyproject.toml index 6db06c3..012ff9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,9 +10,9 @@ dependencies = [ "flash-attn", "wandb", "tqdm", - "transformers==4.49.3", - "torch==2.2.1", + "transformers==4.39.3", "axolotl @ git+https://github.com/openaccess-ai-collective/axolotl.git@main", + "denseformer @ git+https://github.com/epfml/DenseFormer.git@main", ] maintainers = [ {name="Wing Lian", email="wing.lian@gmail.com"}, @@ -27,7 +27,3 @@ dev = [ "mypy", "pytest", ] - -[build-system] -requires = ["flit_core<4"] -build-backend = "flit_core.buildapi" diff --git a/src/train.py b/train.py similarity index 96% rename from src/train.py rename to train.py index 39971b8..0b4edc6 100644 --- a/src/train.py +++ b/train.py @@ -41,7 +41,7 @@ def __init__(self, model, args, dataloader, accelerator): self.device = device_get_cuda() self.global_step = 0 self.rank = device_get_local_rank() - wandb.init() + wandb.init(project="voltronformer") self.accelerator = accelerator @property @@ -87,7 +87,10 @@ def train_loop(self, dataloader, rank): break input_ids = batch["input_ids"].to(self.device) - labels = batch["labels"].to(self.device) + if "labels" in batch.keys(): + labels = batch["labels"].to(self.device) + else: + labels = input_ids.clone() logits = self._model(input_ids) @@ -178,5 +181,6 @@ def tokenize_function(examples, tokenizer=None): print("Total number of parameters: ", trainer.model_num_parameters) trainer.train_loop(dataloader, rank=0) + if __name__ == "__main__": main() \ No newline at end of file From a357c7fcc25e34761b9ff1668652a28f938c951b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 12 Apr 2024 23:38:31 -0400 Subject: [PATCH 04/42] fix config and dataset --- src/voltronformer/config.py | 10 ++++++---- train.py | 13 ++++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index d2e0e1b..80d8261 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -1,5 +1,7 @@ +from axolotl.utils.dict import DictDefault + def tiny(): - return { + return DictDefault({ "hidden_size": 1024, "intermediate_size": 2816, "max_position_embeddings": 4096, @@ -12,15 +14,15 @@ def tiny(): "pad_token_id": 100277, # dbrx <{|pad|}> "mod_every": 2, "mod_capacity_factor": 0.125, - } + }) def medium(): - return { + return DictDefault({ "hidden_size": 1024, "intermediate_size": 2816, "max_position_embeddings": 32768, "vocab_size": 100352 - } \ No newline at end of file + }) \ No newline at end of file diff --git a/train.py b/train.py index 0b4edc6..1b8f834 100644 --- a/train.py +++ b/train.py @@ -123,10 +123,21 @@ def train_loop(self, dataloader, rank): self.save_checkpoint() +def get_ds(): + return load_dataset("togethercomputer/RedPajama-Data-V2", + name="default", + partition="head_middle", + snapshots=["2023-14"], + languages=["en"], + split="train", + streaming=True, + ) + # load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True) + def main(): state = PartialState() - ds = load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True) + ds = get_ds() args = TrainingArguments( gradient_accumulation_steps=1, max_steps_per_epoch=None, From 1bfbdb5651e79ef5a8c5f7182947e65c73359784 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 12 Apr 2024 23:42:45 -0400 Subject: [PATCH 05/42] fix lr --- train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 1b8f834..9a1762b 100644 --- a/train.py +++ b/train.py @@ -30,6 +30,7 @@ class TrainingArguments: per_gpu_train_batch_size: Optional[int] = 1 save_steps: Optional[int] = 5_000 max_sequence_length: Optional[int] = 8192 + learning_rate: float = 5e-5 class Trainer: def __init__(self, model, args, dataloader, accelerator): @@ -54,7 +55,7 @@ def model_num_parameters(self): return all_param def build_optimizer_and_scheduler(self): - self.optimizer = AdamWScheduleFree(self._model.parameters(), lr=self.args.learning, weight_decay=self.args.weight_decay, warmup_steps=self.args.weight_decay) + self.optimizer = AdamWScheduleFree(self._model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay, warmup_steps=self.args.weight_decay) self.lr_scheduler = None def _loss_fn(self, logits, labels): @@ -179,7 +180,7 @@ def tokenize_function(examples, tokenizer=None): accelerator = Accelerator() dataloader_params = dict( - sampler=RandomSampler(train_dataset), + sampler=None, batch_size=args.per_gpu_train_batch_size, num_workers=8, pin_memory=True, From 072b74a3b203ca6b7d07a7cfefc0c5d92678f665 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 12 Apr 2024 23:48:10 -0400 Subject: [PATCH 06/42] fix text field of dataset --- train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 9a1762b..e55ea39 100644 --- a/train.py +++ b/train.py @@ -132,13 +132,13 @@ def get_ds(): languages=["en"], split="train", streaming=True, - ) + ), "raw_content" # load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True) def main(): state = PartialState() - ds = get_ds() + ds, text_field = get_ds() args = TrainingArguments( gradient_accumulation_steps=1, max_steps_per_epoch=None, @@ -155,15 +155,15 @@ def main(): model = CausalLM(config) tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base") - def tokenize_function(examples, tokenizer=None): - outputs = tokenizer(examples["text"], truncation=True, max_length=None) + def tokenize_function(examples, field="text", tokenizer=None): + outputs = tokenizer(examples[field], truncation=True, max_length=None) return outputs with state.main_process_first(): ds_wrapper_partial = functools.partial( tokenize_function, tokenizer=tokenizer, - remove_columns=["text", "meta"], + field=text_field, ) train_dataset = wrap_pretraining_dataset( @@ -190,7 +190,7 @@ def tokenize_function(examples, tokenizer=None): dataloader = DataLoader(train_dataset, **dataloader_params) trainer = Trainer(model, args, dataloader, accelerator) - print("Total number of parameters: ", trainer.model_num_parameters) + print(f"Total number of parameters: {trainer.model_num_parameters:_}") trainer.train_loop(dataloader, rank=0) From 38a872d688a67593b75fdab61f4c41bdfc042bf5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 12 Apr 2024 23:57:02 -0400 Subject: [PATCH 07/42] improve data handling --- src/voltronformer/train/data.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/voltronformer/train/data.py b/src/voltronformer/train/data.py index 649d1a1..d45fdca 100644 --- a/src/voltronformer/train/data.py +++ b/src/voltronformer/train/data.py @@ -2,12 +2,19 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional +import numpy as np from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq -from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.samplers import MultipackBatchSampler from datasets import Dataset from torch.utils.data import RandomSampler +def get_dataset_lengths(dataset): + input_ids = dataset.data.column("input_ids") + lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) + return lengths + + def wrap_pretraining_dataset( dataset, tokenizer, @@ -29,7 +36,6 @@ def wrap_pretraining_dataset( ds_wrapper_fn, max_seq_length=max_tokens, batch_size=batch_size, - multipack_attn=False, ) # remove all the existing columns after mapping since they end up having @@ -58,7 +64,6 @@ def encode_packed_pretraining( examples: Dict[str, List], max_seq_length: int = 2048, batch_size: int = 4, - multipack_attn: Optional[bool] = False, ) -> Dict[str, List]: # pylint: disable=duplicate-code # tokenize all the examples From 6c0e92b15385190b813e6863ce30e2dc0ea030e8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 00:40:31 -0400 Subject: [PATCH 08/42] flesh out the model w/ attn --- pyproject.toml | 1 + src/voltronformer/config.py | 19 +++++++ src/voltronformer/model.py | 88 ++++++++++++++++++++++++++++----- src/voltronformer/train/data.py | 2 +- train.py | 4 +- 5 files changed, 101 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 012ff9e..a64aafa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ name = "voltronformers" dynamic = ["version"] requires-python = ">= 3.10" dependencies = [ + "bitnet", "schedulefree", "bitsandbytes", "datasets", diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index 80d8261..a6af841 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -1,6 +1,25 @@ from axolotl.utils.dict import DictDefault def tiny(): + return DictDefault({ + "hidden_size": 768, + "intermediate_size": 2112, + "rope_theta": 10_000, + "max_position_embeddings": 2048, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 24, + "vocab_size": 100352, + "dwa_dilation": 4, + "dwa_period": 5, + "pad_token_id": 100277, # dbrx <{|pad|}> + "mod_every": 2, + "mod_capacity_factor": 0.125, + "rms_norm_eps": 0.000001, + }) + + +def small(): return DictDefault({ "hidden_size": 1024, "intermediate_size": 2816, diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 122b701..329c29d 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -1,10 +1,13 @@ import functools -from typing import List, Optional, Callable +from typing import List, Optional, Callable, Tuple import torch -from torch import nn +from bitnet.bit_attention import scaled_dot_product_gqa, BitMGQA +from functorch.einops import rearrange +from torch import nn, Tensor from denseformer import DWAModules from torch.utils.checkpoint import checkpoint +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from .core import Linear from .mod import MoDBlock @@ -53,13 +56,69 @@ def mlp(dim: int, hidden_dim: int) -> FeedForward: return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) +class LlamaBitMGQA(BitMGQA): + def __init__(self, embed_dim, query_heads, max_position_embeddings=2048, rope_theta=10_000, *args, **kwargs): + self.head_dim = embed_dim // query_heads + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta) + super().__init__(embed_dim, query_heads, *args, **kwargs) + + def forward( + self, + x: Tensor, + need_weights: bool = False, + # attn_mask: Optional[Tensor] = None, + is_causal: bool = False, + average_attn_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + # Input shape: (b, n, d) + q: Tensor = self.q_proj(x) + k: Tensor = self.k_proj(x) + v: Tensor = self.v_proj(x) + + # Unfold 'd' dimension into 'h' separate attention heads. + q = rearrange(q, "b n (h d) -> b n h d", h=self.query_heads) + k = rearrange(k, "b n (h d) -> b n h d", h=self.kv_heads) + v = rearrange(v, "b n (h d) -> b n h d", h=self.kv_heads) + # Apply attention, then fold 'h' attention heads back into 'd'. + output, attn_weights = scaled_dot_product_gqa( + query=q, + key=k, + value=v, + # TODO + # mask=attn_mask, + is_causal=is_causal, + need_weights=need_weights, + average_attn_weights=average_attn_weights, + force_grouped=False, + ) + output = rearrange(output, "b n h d -> b n (h d)") + + # NOTE: This is different from 'nn.MultiheadAttention'! We follow the MAGNETO + # architecture (https://arxiv.org/pdf/2210.06423.pdf), which applies an extra + # layer norm before the linear output projection. The cross-attention layer in + # the MAGNETO decoder does not include this layer norm, so users have the + # option to disable it (layer_norm=False). + if self.layer_norm: + assert self.norm is not None + output = self.norm(output) + # Linear projection on attention outputs. + output = self.out_proj(output) + + return output, attn_weights + + class TransformerDecoderBlock(nn.Module): def __init__(self, config): super().__init__() - self.attn = Attention(config.hidden_size, config.num_attention_heads) + self.attn = BitMGQA(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta) self.mlp = mlp(config.hidden_size, config.intermediate_size) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def forward(self, x, position_ids): + output, _ = x + self.attn(self.input_layernorm(x), position_ids=position_ids) + return x + self.mlp(self.post_attention_layernorm(x)) class CheckpointingMixin(nn.Module): @@ -97,20 +156,27 @@ def __init__(self, config): self.gradient_checkpointing = False def forward(self, x): - x = self.wte(x) - self.dwa_modules.init_accumulators(x) + inputs_embeds = self.wte(x) + past_seen_tokens = 0 + position_ids = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ).unsqueeze(0) + + hidden_states = inputs_embeds + self.dwa_modules.init_accumulators(hidden_states) for i, decoder_layer in enumerate(self.h): # gradient checkpointing if self.gradient_checkpointing and self.training: - x = self._gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( decoder_layer, - x, + hidden_states, + position_ids, ) else: - x = decoder_layer(x) - x = self.dwa_modules(x, block_idx=i) - x = self.ln_f(x) - return x + hidden_states = decoder_layer(hidden_states, position_ids) + hidden_states = self.dwa_modules(hidden_states, block_idx=i) + hidden_states = self.ln_f(hidden_states) + return hidden_states class CausalLM(nn.Module): diff --git a/src/voltronformer/train/data.py b/src/voltronformer/train/data.py index d45fdca..e690a99 100644 --- a/src/voltronformer/train/data.py +++ b/src/voltronformer/train/data.py @@ -10,7 +10,7 @@ def get_dataset_lengths(dataset): - input_ids = dataset.data.column("input_ids") + input_ids = dataset.column("input_ids") lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) return lengths diff --git a/train.py b/train.py index e55ea39..874b032 100644 --- a/train.py +++ b/train.py @@ -137,6 +137,7 @@ def get_ds(): def main(): state = PartialState() + config = tiny() ds, text_field = get_ds() args = TrainingArguments( @@ -148,10 +149,11 @@ def main(): warmup_steps=1000, per_gpu_train_batch_size=1, save_steps=10000, + max_sequence_length=config.max_position_embeddings, + learning_rate=5e-5, ) os.makedirs(args.output_dir, exist_ok=True) - config = tiny() model = CausalLM(config) tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base") From 91c4fa2a68b8ccd55d3059c79fbb0f443079cb14 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 00:45:10 -0400 Subject: [PATCH 09/42] fix args/kwargs ordering --- src/voltronformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 329c29d..8295990 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -57,7 +57,7 @@ def mlp(dim: int, hidden_dim: int) -> FeedForward: class LlamaBitMGQA(BitMGQA): - def __init__(self, embed_dim, query_heads, max_position_embeddings=2048, rope_theta=10_000, *args, **kwargs): + def __init__(self, embed_dim, query_heads, *args, max_position_embeddings=2048, rope_theta=10_000, **kwargs): self.head_dim = embed_dim // query_heads self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta) super().__init__(embed_dim, query_heads, *args, **kwargs) From c298db14a3d3115db7b542148acc750fb6843e58 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 00:45:50 -0400 Subject: [PATCH 10/42] use LlamaBitMGQA --- src/voltronformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 8295990..1569340 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -111,7 +111,7 @@ class TransformerDecoderBlock(nn.Module): def __init__(self, config): super().__init__() - self.attn = BitMGQA(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta) + self.attn = LlamaBitMGQA(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta) self.mlp = mlp(config.hidden_size, config.intermediate_size) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) From b5487fdd71efbe2783765657a6389445c060a1c0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 00:46:50 -0400 Subject: [PATCH 11/42] fix order of init for module --- src/voltronformer/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 1569340..2d30400 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -58,9 +58,9 @@ def mlp(dim: int, hidden_dim: int) -> FeedForward: class LlamaBitMGQA(BitMGQA): def __init__(self, embed_dim, query_heads, *args, max_position_embeddings=2048, rope_theta=10_000, **kwargs): + super().__init__(embed_dim, query_heads, *args, **kwargs) self.head_dim = embed_dim // query_heads self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta) - super().__init__(embed_dim, query_heads, *args, **kwargs) def forward( self, From a7854a21df44be2768d5f2f0abfdf9b277aae64d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 00:50:22 -0400 Subject: [PATCH 12/42] make tinier and fix dataset map --- src/voltronformer/config.py | 2 +- src/voltronformer/train/data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index a6af841..913e4fa 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -8,7 +8,7 @@ def tiny(): "max_position_embeddings": 2048, "num_attention_heads": 32, "num_key_value_heads": 8, - "num_hidden_layers": 24, + "num_hidden_layers": 16, "vocab_size": 100352, "dwa_dilation": 4, "dwa_period": 5, diff --git a/src/voltronformer/train/data.py b/src/voltronformer/train/data.py index e690a99..76f7cf6 100644 --- a/src/voltronformer/train/data.py +++ b/src/voltronformer/train/data.py @@ -68,7 +68,7 @@ def encode_packed_pretraining( # pylint: disable=duplicate-code # tokenize all the examples # rows get split with stride (overlap) - train_dataset = ds_wrapper(Dataset.from_dict(examples))[0] + train_dataset = Dataset.from_dict(examples).map(ds_wrapper, batched=True) sampler = MultipackBatchSampler( RandomSampler(train_dataset), From bc4ce8afada7384e7afa7ced4ed531107ed454c1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 00:52:11 -0400 Subject: [PATCH 13/42] fix back to use dataset.data.columns --- src/voltronformer/train/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/voltronformer/train/data.py b/src/voltronformer/train/data.py index 76f7cf6..5e0f674 100644 --- a/src/voltronformer/train/data.py +++ b/src/voltronformer/train/data.py @@ -9,8 +9,8 @@ from torch.utils.data import RandomSampler -def get_dataset_lengths(dataset): - input_ids = dataset.column("input_ids") +def get_dataset_lengths(dataset: Dataset): + input_ids = dataset.data.column("input_ids") lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) return lengths From 122316e18835370d2a4823c06ba958e61ae10391 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 00:56:04 -0400 Subject: [PATCH 14/42] make sure to remove extra columns --- src/voltronformer/train/data.py | 7 ++++++- train.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/voltronformer/train/data.py b/src/voltronformer/train/data.py index 5e0f674..2bd7f09 100644 --- a/src/voltronformer/train/data.py +++ b/src/voltronformer/train/data.py @@ -68,7 +68,12 @@ def encode_packed_pretraining( # pylint: disable=duplicate-code # tokenize all the examples # rows get split with stride (overlap) - train_dataset = Dataset.from_dict(examples).map(ds_wrapper, batched=True) + train_dataset = Dataset.from_dict(examples) + train_dataset = train_dataset.map( + ds_wrapper, + batched=True, + remove_columns = list(train_dataset.features.keys()) + ) sampler = MultipackBatchSampler( RandomSampler(train_dataset), diff --git a/train.py b/train.py index 874b032..d270d72 100644 --- a/train.py +++ b/train.py @@ -141,13 +141,13 @@ def main(): ds, text_field = get_ds() args = TrainingArguments( - gradient_accumulation_steps=1, + gradient_accumulation_steps=16, max_steps_per_epoch=None, log_steps=1, output_dir="./out", weight_decay=0.0, warmup_steps=1000, - per_gpu_train_batch_size=1, + per_gpu_train_batch_size=8, save_steps=10000, max_sequence_length=config.max_position_embeddings, learning_rate=5e-5, From a1c7b2a7b687a2d416283fbe2c5d1cce9140852d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 01:02:20 -0400 Subject: [PATCH 15/42] fix data loop and make tinier --- src/voltronformer/config.py | 8 ++++---- train.py | 11 +++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index 913e4fa..a3c292c 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -2,12 +2,12 @@ def tiny(): return DictDefault({ - "hidden_size": 768, - "intermediate_size": 2112, + "hidden_size": 512, + "intermediate_size": 1408, "rope_theta": 10_000, "max_position_embeddings": 2048, - "num_attention_heads": 32, - "num_key_value_heads": 8, + "num_attention_heads": 16, + "num_key_value_heads": 4, "num_hidden_layers": 16, "vocab_size": 100352, "dwa_dilation": 4, diff --git a/train.py b/train.py index d270d72..9642af5 100644 --- a/train.py +++ b/train.py @@ -70,16 +70,16 @@ def save_checkpoint(self): os.path.join(output_dir, f"model_{self.global_step}.pt"), ) - def train(self, dataloader, rank): + def train(self): self._model.train() try: self.optimizer.train() except: pass - self.train_loop(dataloader, rank) + self.train_loop() - def train_loop(self, dataloader, rank): - for idx, batch in enumerate(pbar := tqdm(dataloader, disable=not (rank == 0))): + def train_loop(self): + for idx, batch in enumerate(pbar := tqdm(self.dataloader, disable=not (self.rank == 0))): if ( self.args.max_steps_per_epoch is not None and (idx // self.args.gradient_accumulation_steps) @@ -182,7 +182,6 @@ def tokenize_function(examples, field="text", tokenizer=None): accelerator = Accelerator() dataloader_params = dict( - sampler=None, batch_size=args.per_gpu_train_batch_size, num_workers=8, pin_memory=True, @@ -193,7 +192,7 @@ def tokenize_function(examples, field="text", tokenizer=None): trainer = Trainer(model, args, dataloader, accelerator) print(f"Total number of parameters: {trainer.model_num_parameters:_}") - trainer.train_loop(dataloader, rank=0) + trainer.train_loop() if __name__ == "__main__": From 0d76c4e642f4d8a7ac9289c7f7d3a2e0bcc558f4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 01:08:30 -0400 Subject: [PATCH 16/42] use generic collator to pad equally --- pyproject.toml | 1 + requirements.txt | 6 ------ train.py | 7 +++---- 3 files changed, 4 insertions(+), 10 deletions(-) delete mode 100644 requirements.txt diff --git a/pyproject.toml b/pyproject.toml index a64aafa..aba13b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ name = "voltronformers" dynamic = ["version"] requires-python = ">= 3.10" dependencies = [ + "accelerate", "bitnet", "schedulefree", "bitsandbytes", diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 4263a26..0000000 --- a/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -schedulefree -wandb -torch==2.2.1 -transformers==4.39.3 -datasets -axolotl @ git+https://github.com/openaccess-ai-collective/axolotl.git@main \ No newline at end of file diff --git a/train.py b/train.py index 9642af5..eb419aa 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,6 @@ import functools import os from dataclasses import dataclass -from functools import partial from typing import Optional import torch @@ -9,9 +8,9 @@ from accelerate import Accelerator, PartialState from datasets import load_dataset from schedulefree import AdamWScheduleFree -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import AutoTokenizer +from transformers import AutoTokenizer, DataCollatorForSeq2Seq from src.voltronformer.config import tiny from src.voltronformer.model import CausalLM @@ -186,7 +185,7 @@ def tokenize_function(examples, field="text", tokenizer=None): num_workers=8, pin_memory=True, drop_last=True, - collate_fn=None, + collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, max_length=True), ) dataloader = DataLoader(train_dataset, **dataloader_params) From 91d4a119e48cdd0abc6956cae956dc3fa7c4ccf4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 01:12:01 -0400 Subject: [PATCH 17/42] accont for position_ids in mod block --- src/voltronformer/mod.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/voltronformer/mod.py b/src/voltronformer/mod.py index ac3a397..53e1a97 100644 --- a/src/voltronformer/mod.py +++ b/src/voltronformer/mod.py @@ -14,7 +14,7 @@ def __init__(self, config, block_class): self.capacity_factor = config.mod_capacity_factor self.top_k =int(self.capacity_factor * config.max_position_embeddings) - def forward(self, x, **kwargs): + def forward(self, x, position_ids, **kwargs): # [batch_size, sequence_length, n_embd] B, T, C = x.shape # inference time optimization: sequence length can @@ -37,7 +37,7 @@ def forward(self, x, **kwargs): indices_expanded = selected_tokens.expand(-1, -1, C) # [batch_size, top_k, n_embd] top_k_tokens = torch.gather(x, 1, indices_expanded) - top_k_tokens_processed = self.block(top_k_tokens, **kwargs) + top_k_tokens_processed = self.block(top_k_tokens, position_ids, **kwargs) """STEP 3: combine results""" x = torch.scatter_add( From 81e18b9532654949bda9facbd05af4b9638ba713 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 01:19:55 -0400 Subject: [PATCH 18/42] flesh out rotary embeddigs --- src/voltronformer/model.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 2d30400..c7f31cf 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -13,6 +13,38 @@ from .mod import MoDBlock +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + class FeedForward(nn.Module): def __init__(self, gate_proj: Linear, down_proj: Linear, up_proj: Linear): super().__init__() @@ -65,6 +97,7 @@ def __init__(self, embed_dim, query_heads, *args, max_position_embeddings=2048, def forward( self, x: Tensor, + position_ids: Optional[Tensor] = None, need_weights: bool = False, # attn_mask: Optional[Tensor] = None, is_causal: bool = False, @@ -79,6 +112,10 @@ def forward( q = rearrange(q, "b n (h d) -> b n h d", h=self.query_heads) k = rearrange(k, "b n (h d) -> b n h d", h=self.kv_heads) v = rearrange(v, "b n (h d) -> b n h d", h=self.kv_heads) + + # Apply rotary embedding. + q, k = apply_rotary_pos_emb(q, k, *self.rotary_emb(v, position_ids)) + # Apply attention, then fold 'h' attention heads back into 'd'. output, attn_weights = scaled_dot_product_gqa( query=q, From bc5ac078c88e51413348095258f35c34a91b80cb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 13:44:16 -0400 Subject: [PATCH 19/42] misc fixes --- src/voltronformer/bitlinear/__init__.py | 2 + src/voltronformer/bitlinear/cg123.py | 172 ++++++++++++++++++++++++ src/voltronformer/bitlinear/official.py | 48 +++++++ src/voltronformer/config.py | 31 +++-- src/voltronformer/core.py | 42 +++++- src/voltronformer/model.py | 76 +++++++---- src/voltronformer/train/data.py | 14 +- src/voltronformer/utils.py | 21 +++ train.py | 37 +++-- 9 files changed, 391 insertions(+), 52 deletions(-) create mode 100644 src/voltronformer/bitlinear/__init__.py create mode 100644 src/voltronformer/bitlinear/cg123.py create mode 100644 src/voltronformer/bitlinear/official.py diff --git a/src/voltronformer/bitlinear/__init__.py b/src/voltronformer/bitlinear/__init__.py new file mode 100644 index 0000000..e2f9675 --- /dev/null +++ b/src/voltronformer/bitlinear/__init__.py @@ -0,0 +1,2 @@ +# from .cg123 import BitLinear +from .official import BitLinear diff --git a/src/voltronformer/bitlinear/cg123.py b/src/voltronformer/bitlinear/cg123.py new file mode 100644 index 0000000..85e9009 --- /dev/null +++ b/src/voltronformer/bitlinear/cg123.py @@ -0,0 +1,172 @@ +""" +Implementation of the BitLinear layer described in the papers: + +1. "BitNet: Scaling 1-bit Transformers for Large Language Models" +2. "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits" + +References: +- https://arxiv.org/abs/2310.11453 +- https://arxiv.org/abs/2402.17764 +""" + +#!/usr/bin/env python3 +# Copyright (C) 2024 Charles O. Goddard + +import math +from typing import NamedTuple, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _ste(x: torch.Tensor, x0: torch.Tensor) -> torch.Tensor: + """Straight-through estimator.""" + return x0 + (x - x0).detach() + + +@torch.compile() +def _quantize( + x: Optional[torch.Tensor], is_input: bool, num_groups: int, eps: float +) -> Tuple[torch.Tensor, torch.Tensor]: + if x is None: + return None, None + + x0 = x + if is_input: + # split last dimension into num_groups + x = x.view(list(x.shape[:-1]) + [num_groups, -1]) + scale_factor = x.abs().max(dim=-1, keepdim=True).values + else: + # first dimension is output features, so split that + x = x.view([num_groups, -1] + list(x.shape[1:])) + scale_factor = x.abs().mean(dim=list(range(1, len(x.shape))), keepdim=True) + + x_scaled = x / (scale_factor + eps) + if is_input: + x_q = (x_scaled * 127).clamp(-127, 127).to(torch.int8) + else: + x_q = x_scaled.round().clamp(-1, 1).to(torch.int8) + + # adjust scale_factor to match shape returned for input + scale_factor = scale_factor.view(1, 1, num_groups, 1) + + return _ste(x_q, x_scaled).view_as(x0), scale_factor + + +class QuantizedWeights(NamedTuple): + """Quantized weight and optional bias tensor for BitLinear.""" + + w_q: torch.Tensor + bias_q: Optional[torch.Tensor] + beta: torch.Tensor + + +@torch.compile() +def _quantize_weights( + weight: torch.Tensor, + bias: Optional[torch.Tensor], + num_groups: int, + eps: float, +) -> QuantizedWeights: + w_q, beta = _quantize(weight, is_input=False, num_groups=num_groups, eps=eps) + bias_q, _ = _quantize(bias, is_input=True, num_groups=num_groups, eps=eps) + # bias assumes the scale factor of weights + return QuantizedWeights(w_q=w_q, bias_q=bias_q, beta=beta) + + +def _pack_ternary(x: torch.Tensor) -> torch.Tensor: + """Pack ternary float tensor into int8 tensor. Uses ~1.6 bits per element.""" + + x_packed = torch.empty( + x.shape[:-1] + (math.ceil(x.shape[-1] / 5)), dtype=torch.int8 + ) + for i in range(0, x.shape[-1], 5): + chunk = x[..., i : i + 5].to(torch.int8).view(x.shape[:-1] + (1, 5)) + # -1 -> 0, 0 -> 1, 1 -> 2 + chunk = chunk + 1 + # store as base-3 number + chunk = ( + chunk + * torch.tensor([1, 3, 9, 27, 81], device=chunk.device, dtype=chunk.dtype) + ).sum(dim=-1) + x_packed[..., i // 5] = chunk + return x_packed + + +class BitLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + *args, + preserve_scale: bool = False, + num_groups: int = 1, + eps: float = 1e-7, + bias: bool = False, + **kwargs, + ): + if num_groups < 1: + raise ValueError("num_groups must be >= 1") + if num_groups > 1 and out_features % num_groups != 0: + raise ValueError("out_features must be divisible by num_groups") + + super().__init__(in_features, out_features, *args, bias=bias, **kwargs) + self.input_norm = nn.LayerNorm(self.in_features, elementwise_affine=False) + self.preserve_scale = preserve_scale + self.num_groups = num_groups + self.eps = eps + + @torch.compile() + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.input_norm(x) + x_q, gamma = _quantize( + x, is_input=True, num_groups=self.num_groups, eps=self.eps + ) + w_q, bias_q, beta = _quantize_weights( + self.weight, self.bias, num_groups=self.num_groups, eps=self.eps + ) + + y = F.linear(x_q, w_q, bias_q) + y = y.to(x.dtype) / 127 + if self.preserve_scale: + y_grouped = y.view(list(y.shape[:-1]) + [self.num_groups, -1]) + y = (y_grouped * gamma * beta).reshape_as(y) + + return y + + +class BitConv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + *args, + preserve_scale: bool = False, + eps: float = 1e-7, + bias: bool = False, + **kwargs, + ): + super().__init__( + in_channels, out_channels, kernel_size, *args, bias=bias, **kwargs + ) + self.input_norm = nn.GroupNorm(1, self.in_channels, affine=False) + self.preserve_scale = preserve_scale + self.eps = eps + + @torch.compile() + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.input_norm(x) + x_q, gamma = _quantize(x, is_input=True, num_groups=1, eps=self.eps) + w_q, bias_q, beta = _quantize_weights( + self.weight, self.bias, num_groups=1, eps=self.eps + ) + + y = F.conv2d(x_q, w_q, bias_q, self.stride, self.padding, self.dilation) + y = y.to(x.dtype) / 127 + if self.preserve_scale: + y_grouped = y.view(list(y.shape[:-1]) + [1, -1]) + y = (y_grouped * gamma * beta).reshape_as(y) + + return y \ No newline at end of file diff --git a/src/voltronformer/bitlinear/official.py b/src/voltronformer/bitlinear/official.py new file mode 100644 index 0000000..059b3cc --- /dev/null +++ b/src/voltronformer/bitlinear/official.py @@ -0,0 +1,48 @@ +import math +import torch +from torch import nn + + +def weight_quant(weight, num_bits=1): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) / s + return result.type(dtype) + + +def activation_quant(x, num_bits=8): + dtype = x.dtype + x = x.float() + Qn = -2 ** (num_bits - 1) + Qp = 2 ** (num_bits - 1) - 1 + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) / s + return result.type(dtype) + + +class BitLinear(nn.Linear): + + def __init__(self, + *kargs, + weight_bits=1, + input_bits=8, + **kwargs + ): + super(BitLinear, self).__init__(*kargs, **kwargs) + """ + RMSNorm is placed outside BitLinear + """ + self.weight_bits = weight_bits + self.input_bits = input_bits + + def forward(self, input): + + quant_input = input + (activation_quant(input, self.input_bits) - input).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() + + out = nn.functional.linear(quant_input, quant_weight) + if not self.bias is None: + out += self.bias.view(1, -1).expand_as(out) + + return out \ No newline at end of file diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index a3c292c..e947d3b 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -1,5 +1,6 @@ from axolotl.utils.dict import DictDefault + def tiny(): return DictDefault({ "hidden_size": 512, @@ -8,11 +9,11 @@ def tiny(): "max_position_embeddings": 2048, "num_attention_heads": 16, "num_key_value_heads": 4, - "num_hidden_layers": 16, - "vocab_size": 100352, + "num_hidden_layers": 12, + "vocab_size": 32000, "dwa_dilation": 4, "dwa_period": 5, - "pad_token_id": 100277, # dbrx <{|pad|}> + "pad_token_id": 2, "mod_every": 2, "mod_capacity_factor": 0.125, "rms_norm_eps": 0.000001, @@ -27,21 +28,29 @@ def small(): "num_attention_heads": 32, "num_key_value_heads": 8, "num_hidden_layers": 24, - "vocab_size": 100352, + "vocab_size": 32000, "dwa_dilation": 4, "dwa_period": 5, - "pad_token_id": 100277, # dbrx <{|pad|}> + "pad_token_id": 2, "mod_every": 2, "mod_capacity_factor": 0.125, + "rms_norm_eps": 0.000001, }) def medium(): return DictDefault({ - "hidden_size": 1024, - "intermediate_size": 2816, - - "max_position_embeddings": 32768, - - "vocab_size": 100352 + "hidden_size": 4096, + "intermediate_size": 11264, + "max_position_embeddings": 8192, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 24, + "vocab_size": 32000, + "dwa_dilation": 4, + "dwa_period": 5, + "pad_token_id": 2, + "mod_every": 2, + "mod_capacity_factor": 0.125, + "rms_norm_eps": 0.000001, }) \ No newline at end of file diff --git a/src/voltronformer/core.py b/src/voltronformer/core.py index 2cd0d5e..ff02821 100644 --- a/src/voltronformer/core.py +++ b/src/voltronformer/core.py @@ -1,4 +1,38 @@ -try: - from bitnet.bitlinear import BitLinear as Linear -except ImportError: - from torch.nn import Linear +from bitnet.bitlinear import activation_quant, weight_quant +from torch import Tensor, nn +import torch.nn.functional as F + +class Linear(nn.Linear): + """ + Custom linear layer with bit quantization. + + Args: + dim (int): The input dimension of the layer. + training (bool, optional): Whether the layer is in training mode or not. Defaults to False. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Attributes: + dim (int): The input dimension of the layer. + + """ + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the BitLinear layer. + + Args: + x (Tensor): The input tensor. + + Returns: + Tensor: The output tensor. + + """ + w = self.weight + x_norm = x + + # STE using detach + x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() + w_quant = w + (weight_quant(w) - w).detach() + y = F.linear(x_quant, w_quant) + return y diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index c7f31cf..9fc6825 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -2,6 +2,10 @@ from typing import List, Optional, Callable, Tuple import torch +import bitnet.bit_attention +from .bitlinear import BitLinear +bitnet.bit_attention.BitLinear = BitLinear + from bitnet.bit_attention import scaled_dot_product_gqa, BitMGQA from functorch.einops import rearrange from torch import nn, Tensor @@ -46,7 +50,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): class FeedForward(nn.Module): - def __init__(self, gate_proj: Linear, down_proj: Linear, up_proj: Linear): + def __init__(self, gate_proj: BitLinear, down_proj: BitLinear, up_proj: BitLinear): super().__init__() self.gate_proj = gate_proj self.down_proj = down_proj @@ -54,7 +58,10 @@ def __init__(self, gate_proj: Linear, down_proj: Linear, up_proj: Linear): self.act_fn = nn.SiLU() def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + x = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + # FIXME layernorm??? + x = self.down_proj(x) + return x class Attention(nn.Module): @@ -78,31 +85,42 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ).type_as(x) return x_normed * self.scale + def mlp(dim: int, hidden_dim: int) -> FeedForward: """ Build the MLP layer associated with the Llama model. """ - gate_proj = Linear(dim, hidden_dim, bias=False) - down_proj = Linear(hidden_dim, dim, bias=False) - up_proj = Linear(dim, hidden_dim, bias=False) + gate_proj = BitLinear(dim, hidden_dim, bias=False) + down_proj = BitLinear(hidden_dim, dim, bias=False) + up_proj = BitLinear(dim, hidden_dim, bias=False) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) class LlamaBitMGQA(BitMGQA): - def __init__(self, embed_dim, query_heads, *args, max_position_embeddings=2048, rope_theta=10_000, **kwargs): - super().__init__(embed_dim, query_heads, *args, **kwargs) + def __init__(self, embed_dim, query_heads=8, kv_heads=4, dropout=0.1, bias=True, *args, max_position_embeddings=2048, rope_theta=10_000, **kwargs): + super().__init__(embed_dim, query_heads, kv_heads, dropout, bias, *args, **kwargs) self.head_dim = embed_dim // query_heads self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta) - def forward( + # rebuild the out_proj + self.out_proj = BitLinear( + embed_dim, # this is incorrect upstream + embed_dim, + bias=bias, # device=device, dtype=dtype + ) + self._reset_parameters() + + +def forward( self, x: Tensor, position_ids: Optional[Tensor] = None, need_weights: bool = False, # attn_mask: Optional[Tensor] = None, - is_causal: bool = False, + is_causal: bool = True, average_attn_weights: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: + # Input shape: (b, n, d) q: Tensor = self.q_proj(x) k: Tensor = self.k_proj(x) @@ -113,8 +131,21 @@ def forward( k = rearrange(k, "b n (h d) -> b n h d", h=self.kv_heads) v = rearrange(v, "b n (h d) -> b n h d", h=self.kv_heads) - # Apply rotary embedding. - q, k = apply_rotary_pos_emb(q, k, *self.rotary_emb(v, position_ids)) + # Generate rotary embeddings + cos, sin = self.rotary_emb(x, position_ids) + + # Reshape cos and sin to match the shape of q and k + seq_len = q.shape[2] # Get the sequence length from q + cos = cos[:, :seq_len, :].unsqueeze(1) + sin = sin[:, :seq_len, :].unsqueeze(1) + + # Apply rotary position embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1) + + # Adjust the dimensions of q, k, and v + q = q.view(-1, *q.shape[-3:]) + k = k.view(-1, *k.shape[-3:]) + v = v.view(-1, *v.shape[-3:]) # Apply attention, then fold 'h' attention heads back into 'd'. output, attn_weights = scaled_dot_product_gqa( @@ -128,16 +159,11 @@ def forward( average_attn_weights=average_attn_weights, force_grouped=False, ) - output = rearrange(output, "b n h d -> b n (h d)") - - # NOTE: This is different from 'nn.MultiheadAttention'! We follow the MAGNETO - # architecture (https://arxiv.org/pdf/2210.06423.pdf), which applies an extra - # layer norm before the linear output projection. The cross-attention layer in - # the MAGNETO decoder does not include this layer norm, so users have the - # option to disable it (layer_norm=False). - if self.layer_norm: - assert self.norm is not None - output = self.norm(output) + + # Re-assemble all head outputs side-by-side. + # output = output.transpose(1, 2).contiguous().view(b, n, d) + output = rearrange(output, "b n h d -> b h (n d)") + # Linear projection on attention outputs. output = self.out_proj(output) @@ -148,14 +174,16 @@ class TransformerDecoderBlock(nn.Module): def __init__(self, config): super().__init__() - self.attn = LlamaBitMGQA(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta) + self.attn = LlamaBitMGQA(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, bias=False, layer_norm=False) self.mlp = mlp(config.hidden_size, config.intermediate_size) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, x, position_ids): - output, _ = x + self.attn(self.input_layernorm(x), position_ids=position_ids) - return x + self.mlp(self.post_attention_layernorm(x)) + residual = x + h = self.input_layernorm(x) + output = residual + self.attn(h, position_ids=position_ids)[0] + return residual + self.mlp(self.post_attention_layernorm(output)) class CheckpointingMixin(nn.Module): diff --git a/src/voltronformer/train/data.py b/src/voltronformer/train/data.py index 2bd7f09..7da3fd1 100644 --- a/src/voltronformer/train/data.py +++ b/src/voltronformer/train/data.py @@ -58,6 +58,10 @@ def wrap_pretraining_dataset( return dataset +def drop_long_seq(sample, sequence_len=2048): + return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0 + + def encode_packed_pretraining( collate_fn, ds_wrapper: Callable, @@ -75,11 +79,17 @@ def encode_packed_pretraining( remove_columns = list(train_dataset.features.keys()) ) + drop_long = functools.partial(drop_long_seq, sequence_len=max_seq_length) + train_dataset = train_dataset.filter( + drop_long, + num_proc=8, + ) + sampler = MultipackBatchSampler( RandomSampler(train_dataset), - batch_size=1, + batch_size=batch_size, drop_last=True, - batch_max_len=batch_size * max_seq_length, + batch_max_len=max_seq_length, lengths=get_dataset_lengths(train_dataset), ) diff --git a/src/voltronformer/utils.py b/src/voltronformer/utils.py index ad1c860..851fab3 100644 --- a/src/voltronformer/utils.py +++ b/src/voltronformer/utils.py @@ -1,8 +1,29 @@ import math import os +from typing import Optional, Set, Type import torch +from torch import nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, +) +from torch.distributed.fsdp.wrap import ModuleWrapPolicy + + +def set_activation_checkpointing( + model: nn.Module, auto_wrap_policy: Optional[Set[Type[nn.Module]]] = None, **kwargs +) -> None: + """Utility to setup activation checkpointing and wrap the model for checkpointing. + + Args: + model (nn.Module): Model to setup activation checkpointing. + auto_wrap_policy (Optional[Set[nn.Module]]): Policy to wrap module. + **kwargs: additional arguments to pass to torch.distributed activation checkpointing. + """ + wrap_policy = ModuleWrapPolicy(auto_wrap_policy or set()) + apply_activation_checkpointing(model, auto_wrap_policy=wrap_policy, **kwargs) + def device_get_local_rank(): """ diff --git a/train.py b/train.py index eb419aa..67770e2 100644 --- a/train.py +++ b/train.py @@ -4,18 +4,20 @@ from typing import Optional import torch +import torch.nn.functional as F import wandb from accelerate import Accelerator, PartialState from datasets import load_dataset +from einops import rearrange from schedulefree import AdamWScheduleFree from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoTokenizer, DataCollatorForSeq2Seq from src.voltronformer.config import tiny -from src.voltronformer.model import CausalLM +from src.voltronformer.model import CausalLM, TransformerDecoderBlock from src.voltronformer.train.data import wrap_pretraining_dataset -from src.voltronformer.utils import device_get_cuda, device_get_local_rank, get_cosine_schedule_with_min_lr_lambda +from src.voltronformer.utils import device_get_cuda, device_get_local_rank, set_activation_checkpointing @dataclass @@ -30,11 +32,17 @@ class TrainingArguments: save_steps: Optional[int] = 5_000 max_sequence_length: Optional[int] = 8192 learning_rate: float = 5e-5 + vocab_size: Optional[int] = None + class Trainer: - def __init__(self, model, args, dataloader, accelerator): + def __init__(self, model, args, dataloader, accelerator, activation_checkpointing=True): self.args = args self._model = model + if activation_checkpointing: + set_activation_checkpointing( + model, auto_wrap_policy={TransformerDecoderBlock} + ) self.build_optimizer_and_scheduler() self._model, self.dataloader, self.optimizer = accelerator.prepare(self._model, dataloader, self.optimizer) @@ -59,7 +67,7 @@ def build_optimizer_and_scheduler(self): def _loss_fn(self, logits, labels): loss_fct = torch.nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) + loss = loss_fct(logits.reshape(-1, logits.shape[-1]), labels.reshape(-1)) return loss def save_checkpoint(self): @@ -95,12 +103,17 @@ def train_loop(self): logits = self._model(input_ids) # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, self.args.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + + # logits = logits.transpose(1, 2) # Compute loss - loss = self._loss_fn(logits, labels) + loss = self._loss_fn(shift_logits, shift_labels) if ( self.global_step * self.args.log_steps == 0 @@ -149,11 +162,13 @@ def main(): per_gpu_train_batch_size=8, save_steps=10000, max_sequence_length=config.max_position_embeddings, - learning_rate=5e-5, + learning_rate=5e-4, + vocab_size=config.vocab_size, ) os.makedirs(args.output_dir, exist_ok=True) model = CausalLM(config) + model = model.to(device_get_cuda()) tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base") def tokenize_function(examples, field="text", tokenizer=None): @@ -189,9 +204,9 @@ def tokenize_function(examples, field="text", tokenizer=None): ) dataloader = DataLoader(train_dataset, **dataloader_params) - trainer = Trainer(model, args, dataloader, accelerator) + trainer = Trainer(model, args, dataloader, accelerator, activation_checkpointing=False) print(f"Total number of parameters: {trainer.model_num_parameters:_}") - trainer.train_loop() + trainer.train() if __name__ == "__main__": From c5a4c7118cab13f9ad80217cb137e77292a15678 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 13:49:32 -0400 Subject: [PATCH 20/42] fix tokenizer and activation checkpointing --- src/voltronformer/config.py | 1 + train.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index e947d3b..c0f6254 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -2,6 +2,7 @@ def tiny(): + """50M parameters""" return DictDefault({ "hidden_size": 512, "intermediate_size": 1408, diff --git a/train.py b/train.py index 67770e2..a6592c3 100644 --- a/train.py +++ b/train.py @@ -169,7 +169,10 @@ def main(): model = CausalLM(config) model = model.to(device_get_cuda()) - tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base") + # tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base") + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + if not tokenizer.pad_token_id: + tokenizer.pad_token_id = tokenizer.eos_token_id def tokenize_function(examples, field="text", tokenizer=None): outputs = tokenizer(examples[field], truncation=True, max_length=None) @@ -204,7 +207,7 @@ def tokenize_function(examples, field="text", tokenizer=None): ) dataloader = DataLoader(train_dataset, **dataloader_params) - trainer = Trainer(model, args, dataloader, accelerator, activation_checkpointing=False) + trainer = Trainer(model, args, dataloader, accelerator, activation_checkpointing=True) print(f"Total number of parameters: {trainer.model_num_parameters:_}") trainer.train() From c493b65f99cbe64a41414e5f3094300991f2dd95 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 14:14:41 -0400 Subject: [PATCH 21/42] more fixes --- src/voltronformer/config.py | 2 +- src/voltronformer/core.py | 42 ++--------- src/voltronformer/model.py | 140 ++++++++++++++++++++++++++++++++---- 3 files changed, 133 insertions(+), 51 deletions(-) diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index c0f6254..6eadbd0 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -7,7 +7,7 @@ def tiny(): "hidden_size": 512, "intermediate_size": 1408, "rope_theta": 10_000, - "max_position_embeddings": 2048, + "max_position_embeddings": 1024, "num_attention_heads": 16, "num_key_value_heads": 4, "num_hidden_layers": 12, diff --git a/src/voltronformer/core.py b/src/voltronformer/core.py index ff02821..da162da 100644 --- a/src/voltronformer/core.py +++ b/src/voltronformer/core.py @@ -1,38 +1,4 @@ -from bitnet.bitlinear import activation_quant, weight_quant -from torch import Tensor, nn -import torch.nn.functional as F - -class Linear(nn.Linear): - """ - Custom linear layer with bit quantization. - - Args: - dim (int): The input dimension of the layer. - training (bool, optional): Whether the layer is in training mode or not. Defaults to False. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Attributes: - dim (int): The input dimension of the layer. - - """ - - def forward(self, x: Tensor) -> Tensor: - """ - Forward pass of the BitLinear layer. - - Args: - x (Tensor): The input tensor. - - Returns: - Tensor: The output tensor. - - """ - w = self.weight - x_norm = x - - # STE using detach - x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() - w_quant = w + (weight_quant(w) - w).detach() - y = F.linear(x_quant, w_quant) - return y +try: + from .bitlinear import BitLinear as Linear +except ImportError: + from torch.nn import Linear diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 9fc6825..8313a6a 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -96,22 +96,138 @@ def mlp(dim: int, hidden_dim: int) -> FeedForward: return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) -class LlamaBitMGQA(BitMGQA): - def __init__(self, embed_dim, query_heads=8, kv_heads=4, dropout=0.1, bias=True, *args, max_position_embeddings=2048, rope_theta=10_000, **kwargs): - super().__init__(embed_dim, query_heads, kv_heads, dropout, bias, *args, **kwargs) - self.head_dim = embed_dim // query_heads - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta) +# copied from https://github.com/kyegomez/BitNet/blob/main/bitnet/bit_attention.py +class LlamaBitMGQA(nn.Module): + """Multi-head grouped query attention (GQA) layer. + + Reference: + "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" + https://arxiv.org/pdf/2305.13245v1.pdf + + GQA is a variant of multihead attention (MHA) that uses fewer write heads + (key / value) than query heads. GQA can be viewed as a generalization of + multi-query attention (MQA), which uses a single write head. GQA and MQA give + significant speedups over standard MHA in decoder layers, with minimal loss in + accuracy. In the paper, GQA is shown to be more accurate than MQA, while still + having a significant speedup over MHA. + + NOTE: The original authors only benchmark GQA by adapting the T5 (XL or XXL) model + from MHA to GQA. As a result, they do not mention parameter initialization or + layer normalization strategies. I follow the best practices laid out in the + MAGNETO paper, which improves Transformer performance through better parameter + initialization and layer norm placement. See: + https://arxiv.org/pdf/2210.06423.pdf, Fig. 2 + """ + + def __init__( + self, + embed_dim: int, + query_heads: int = 8, + kv_heads: int = 4, + dropout: float = 0.1, + bias: bool = True, + layer_norm: bool = True, + layer_norm_eps: float = 1e-5, + gamma_init: float = 1.0, + linear_groups: int = 1, + *args, + max_position_embeddings=2048, + rope_theta=10_000, + **kwargs, + ): + super().__init__() + self.query_heads = query_heads + self.kv_heads = kv_heads + self.dropout = dropout + self.layer_norm = layer_norm + self.gamma_init = gamma_init + + if self.query_heads % self.kv_heads != 0: + raise ValueError( + f"query_heads ({query_heads}) must be divisible by " + f"kv_heads ({kv_heads})" + ) + elif (embed_dim % self.query_heads != 0) or (embed_dim % self.kv_heads != 0): + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"query_heads ({query_heads}) and kv_heads ({kv_heads})" + ) - # rebuild the out_proj + head_dim = embed_dim // query_heads + if not head_dim % 8 == 0: + raise ValueError( + f"head_dim (embed_dim / num_heads = {head_dim}) must be divisible by 8" + ) + if not head_dim <= 128: + raise ValueError( + f"head_dim (embed_dim / num_heads = {head_dim}) must be <= 128" + ) + + # Query projection layer is the same as in vanilla MHA. + self.q_proj = BitLinear( + embed_dim, + embed_dim, + bias=bias, + *args, + **kwargs, # device=device, dtype=dtype + ) + # Key/value projection layers have a smaller output dimension, so that + # the we have fewer key/value attention heads after reshaping. + kv_embed_dim = embed_dim // query_heads * kv_heads + self.k_proj = BitLinear( + embed_dim, + kv_embed_dim, + bias=bias, + *args, + **kwargs, # device=device, dtype=dtype + ) + self.v_proj = BitLinear( + embed_dim, + kv_embed_dim, + bias=bias, + *args, + **kwargs, # device=device, dtype=dtype + ) + self.norm: Optional[nn.LayerNorm] = None + if layer_norm: + self.norm = nn.LayerNorm( + kv_embed_dim, + eps=layer_norm_eps, # device=device, dtype=dtype + ) + # Grouped attention output will have the same embedding dimension as the + # key/value Tensors. So the output projection layer needs to accept the + # same dimension (kv_embed_dim). self.out_proj = BitLinear( - embed_dim, # this is incorrect upstream + embed_dim, embed_dim, bias=bias, # device=device, dtype=dtype ) - self._reset_parameters() + self.rotary_emb = LlamaRotaryEmbedding(head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta) + self._reset_parameters() -def forward( + def _reset_parameters(self): + nn.init.xavier_normal_(self.q_proj.weight) + if self.q_proj.bias is not None: + nn.init.constant_(self.q_proj.bias, 0) + nn.init.xavier_normal_(self.k_proj.weight) + if self.k_proj.bias is not None: + nn.init.constant_(self.k_proj.bias, 0) + + # NOTE: We follow the initialization strategy from MAGNETO. See: + # https://arxiv.org/pdf/2210.06423.pdf, Fig. 2 + # Gain (self.gamma_init) should be provided as a keyword argument when + # initializing the larger Transformer model, since it requires knowledge + # of the number of encoder/decoder layers in the model. + + nn.init.xavier_normal_(self.v_proj.weight, gain=self.gamma_init) + if self.v_proj.bias is not None: + nn.init.constant_(self.v_proj.bias, 0) + nn.init.xavier_normal_(self.out_proj.weight, gain=self.gamma_init) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0) + + def forward( self, x: Tensor, position_ids: Optional[Tensor] = None, @@ -127,9 +243,9 @@ def forward( v: Tensor = self.v_proj(x) # Unfold 'd' dimension into 'h' separate attention heads. - q = rearrange(q, "b n (h d) -> b n h d", h=self.query_heads) - k = rearrange(k, "b n (h d) -> b n h d", h=self.kv_heads) - v = rearrange(v, "b n (h d) -> b n h d", h=self.kv_heads) + q = rearrange(q, "b n (h d) -> b h n d", h=self.query_heads) + k = rearrange(k, "b n (h d) -> b h n d", h=self.kv_heads) + v = rearrange(v, "b n (h d) -> b h n d", h=self.kv_heads) # Generate rotary embeddings cos, sin = self.rotary_emb(x, position_ids) From 8aa8d81080c44f2e9b898a2fe0e29c831912c069 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 14:22:12 -0400 Subject: [PATCH 22/42] remove hard dependencies from axolotl --- pyproject.toml | 4 +- src/voltronformer/config.py | 2 +- src/voltronformer/train/collators.py | 155 ++++++++++++++++++++ src/voltronformer/train/data.py | 7 +- src/voltronformer/train/samplers.py | 202 +++++++++++++++++++++++++++ src/voltronformer/utils.py | 13 ++ 6 files changed, 378 insertions(+), 5 deletions(-) create mode 100644 src/voltronformer/train/collators.py create mode 100644 src/voltronformer/train/samplers.py diff --git a/pyproject.toml b/pyproject.toml index aba13b7..42f405b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,16 +4,18 @@ dynamic = ["version"] requires-python = ">= 3.10" dependencies = [ "accelerate", + "addict", "bitnet", "schedulefree", "bitsandbytes", "datasets", "einops", "flash-attn", + "numba", + "numpy", "wandb", "tqdm", "transformers==4.39.3", - "axolotl @ git+https://github.com/openaccess-ai-collective/axolotl.git@main", "denseformer @ git+https://github.com/epfml/DenseFormer.git@main", ] maintainers = [ diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index 6eadbd0..2e96f2e 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -1,4 +1,4 @@ -from axolotl.utils.dict import DictDefault +from src.voltronformer.utils import DictDefault def tiny(): diff --git a/src/voltronformer/train/collators.py b/src/voltronformer/train/collators.py new file mode 100644 index 0000000..24a5462 --- /dev/null +++ b/src/voltronformer/train/collators.py @@ -0,0 +1,155 @@ +""" +DataCollator to pad labels and position_ids for packed sequences +""" +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np +from transformers import PreTrainedTokenizerBase +from transformers.utils import PaddingStrategy + +IGNORE_INDEX = -100 + + +@dataclass +class DataCollatorForSeq2Seq: + """ + Data collator that will dynamically pad the inputs received, as well as the labels and position_ids + + Args: + tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): + The tokenizer used for encoding the data. + model ([`PreTrainedModel`]): + The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to + prepare the *decoder_input_ids* + + This is useful when using *label_smoothing* to avoid calculating loss twice. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + + - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single + sequence is provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + label_pad_token_id (`int`, *optional*, defaults to -100): + The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). + return_tensors (`str`): + The type of Tensor to return. Allowable values are "np", "pt" and "tf". + """ + + tokenizer: PreTrainedTokenizerBase + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + position_pad_token_id: int = 0 + return_tensors: str = "pt" + + def __call__(self, features, return_tensors=None): + labels = None + if return_tensors is None: + return_tensors = self.return_tensors + + for feature_name, pad_token_id in [ + ("labels", self.label_pad_token_id), + ("position_ids", self.position_pad_token_id), + ]: + feat = ( + [feature[feature_name] for feature in features] + if feature_name in features[0].keys() + else None + ) + labels = feat if feat and feature_name == "labels" else labels + # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the + # same length to return tensors. + if feat is not None: + max_feature_length = max(len(l) for l in feat) # noqa: E741 + if self.pad_to_multiple_of is not None: + max_feature_length = ( + (max_feature_length + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + * self.pad_to_multiple_of + ) + + padding_side = self.tokenizer.padding_side + for feature in features: + remainder = [pad_token_id] * ( + max_feature_length - len(feature[feature_name]) + ) + if isinstance(feature[feature_name], list): + feature[feature_name] = ( + feature[feature_name] + remainder + if padding_side == "right" + else remainder + feature[feature_name] + ) + elif padding_side == "right": + feature[feature_name] = np.concatenate( + [feature[feature_name], remainder] + ).astype(np.int64) + else: + feature[feature_name] = np.concatenate( + [remainder, feature[feature_name]] + ).astype(np.int64) + + features = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=return_tensors, + ) + + # prepare decoder_input_ids + if ( + labels is not None + and self.model is not None + and hasattr(self.model, "prepare_decoder_input_ids_from_labels") + ): + decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels( + labels=features["labels"] + ) + features["decoder_input_ids"] = decoder_input_ids + + return features + + +@dataclass +class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): + """ + Collator for multipack specific to the using the BatchSampler + """ + + def __init__(self, *args, multipack_attn=True, **kwargs): + super().__init__(*args, **kwargs) + self.multipack_attn = multipack_attn + + def __call__(self, features, return_tensors=None): + chunked_data = {} + for feature in features.keys(): + if feature == "length": + continue + if feature == "attention_mask": + if self.multipack_attn: + arrays = [ + (i + 1) * np.array(item[feature]) + for i, item in enumerate(features[feature]) + if feature in item + ] + else: + arrays = [(1) * np.array(item) for item in features[feature]] + chunked_data[feature] = np.concatenate(arrays) + else: + arrays = [np.array(item) for item in features[feature]] + chunked_data[feature] = np.concatenate(arrays) + features = [chunked_data] + return super().__call__(features, return_tensors=return_tensors) diff --git a/src/voltronformer/train/data.py b/src/voltronformer/train/data.py index 7da3fd1..3ec0fab 100644 --- a/src/voltronformer/train/data.py +++ b/src/voltronformer/train/data.py @@ -1,13 +1,14 @@ import functools from collections import defaultdict -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List import numpy as np -from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq -from axolotl.utils.samplers import MultipackBatchSampler from datasets import Dataset from torch.utils.data import RandomSampler +from src.voltronformer.train.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from src.voltronformer.train.samplers import MultipackBatchSampler + def get_dataset_lengths(dataset: Dataset): input_ids = dataset.data.column("input_ids") diff --git a/src/voltronformer/train/samplers.py b/src/voltronformer/train/samplers.py new file mode 100644 index 0000000..be92f7f --- /dev/null +++ b/src/voltronformer/train/samplers.py @@ -0,0 +1,202 @@ +# pylint: skip-file +""" +Multipack Batch Sampler +""" +import logging +import math +import os +from typing import Any, Iterable, List, Union + +import numba +import numpy as np +from torch.utils.data import BatchSampler, Sampler + +LOG = logging.getLogger("multipack") + + +@numba.njit +def ffd_check(a: np.ndarray, c: int, n: int): + # First-fit-decreasing bin packing + # Check if a[] could fit in n bins with capacity c + # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing + + a = np.sort(a)[::-1] + bins = np.full((n,), c, dtype=a.dtype) + for size in a: + not_found = True + for idx in range(n): + if bins[idx] >= size: + bins[idx] -= size + not_found = False + break + + if not_found: + return False + + return True + + +@numba.njit +def ffd_with_result(a: np.ndarray, c: int, start_index: int): + # First-fit-decreasing bin packing (with result return) + + indices = np.argsort(a)[::-1] + a = a[indices] + + bins: List[Any] = [] + bins_result: List[Any] = [] + for a_id, size in enumerate(a): + add_new = True + for idx in range(len(bins)): + if bins[idx] >= size: + bins[idx] -= size + bins_result[idx].append(indices[a_id] + start_index) + add_new = False + break + + if add_new: + bins.append(c - size) + bins_result.append([indices[a_id] + start_index]) + + return bins_result + + +@numba.njit +def allocate( + lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int +): + # Dynamic batch allocator, similar to Multifit + # https://en.wikipedia.org/wiki/Multifit_algorithm + # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) + + s = 0 + start_index = 0 + result = [] + + while True: + # binary search [l, r) + left = 1 + right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") + + while right - left > 1: + mid = (left + right) // 2 + if ffd_check(lengths[start_index : start_index + mid], c, n): + left = mid + else: + right = mid + + # use length l + batch = ffd_with_result( + lengths[start_index : start_index + left], c, start_index + ) + assert len(batch) <= n + if len(batch) < n: + break + + start_index += left + s = lengths_cumsum[start_index - 1] + + # add local rank + result.append(batch[rank]) + + return result, s, len(result) * c * n + + +class MultipackBatchSampler(BatchSampler): + """ + Batch Sampler class for multipack + """ + + def __init__( + self, + sampler: Union[Sampler[int], Iterable[int]], + batch_size: int, + drop_last: bool, + batch_max_len: int, + lengths: np.ndarray, + packing_efficiency_estimate: float = 1.0, + ): + super().__init__(sampler, batch_size, drop_last) + self.batch_size = batch_size + self.batch_max_len = batch_max_len + self.lengths: np.ndarray = lengths + self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 + + assert isinstance(self.lengths, np.ndarray) + + self.epoch = 0 + + # statistics + self.eff_total_used = 0 + self.eff_total_slots = 0 + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def generate_batches(self, set_stats=False): + indices = [idx for idx in self.sampler] + + lengths = self.lengths[indices] + lengths_cumsum = np.cumsum(lengths) + + batches, total_used, total_slots = allocate( + lengths=lengths, + lengths_cumsum=lengths_cumsum, + rank=0, + c=self.batch_max_len, + n=1, + ) + + batches = [ + [ + [indices[b_idx] for b_idx in batch] + for batch in batches[i : i + self.batch_size] + ] + for i in range(0, len(batches), self.batch_size) + ] + + # statistics + if set_stats: + self.eff_total_used += total_used + self.eff_total_slots += total_slots + + return batches + + def __iter__(self): + batches = self.generate_batches(set_stats=True) + return iter(batches) + + def num_batches(self): + batches = self.generate_batches(set_stats=True) + return len(batches) + + def efficiency(self): + return self.eff_total_used / self.eff_total_slots + + def __len__(self): + self.num_batches() + return self._len_est() + + def _len_est(self): + world_size = int(os.getenv("WORLD_SIZE", "1")) + lengths_sum = np.sum(self.lengths) + lengths_sum_per_device = lengths_sum // world_size + LOG.info( + f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + f"total_num_tokens per device: {lengths_sum_per_device}" + ) + + # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler + return max( + 0, + ( + world_size + * math.floor( + 0.99 + * lengths_sum_per_device + / self.packing_efficiency_estimate + // (self.batch_max_len * self.batch_size) + ) + - 1 + ), + ) diff --git a/src/voltronformer/utils.py b/src/voltronformer/utils.py index 851fab3..c1bc178 100644 --- a/src/voltronformer/utils.py +++ b/src/voltronformer/utils.py @@ -4,6 +4,7 @@ import torch +from addict import Dict from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, @@ -57,3 +58,15 @@ def get_cosine_schedule_with_min_lr_lambda( ) scaling = 0.5 * (1.0 + math.cos(math.pi * progress)) return (1 - min_lr_ratio) * scaling + min_lr_ratio + + +class DictDefault(Dict): + """ + A Dict that returns None instead of returning empty Dict for missing keys. + """ + + def __missing__(self, key): + return None + + def __or__(self, other): + return DictDefault(super().__ror__(other)) From e6edf93791a4e43851a96f311844b7547531cb67 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 14:26:13 -0400 Subject: [PATCH 23/42] remove more hard deps --- src/voltronformer/bitlinear/__init__.py | 1 + src/voltronformer/bitlinear/attention.py | 143 +++++++++++++++++++++++ src/voltronformer/model.py | 8 +- train.py | 2 - 4 files changed, 146 insertions(+), 8 deletions(-) create mode 100644 src/voltronformer/bitlinear/attention.py diff --git a/src/voltronformer/bitlinear/__init__.py b/src/voltronformer/bitlinear/__init__.py index e2f9675..81e90cf 100644 --- a/src/voltronformer/bitlinear/__init__.py +++ b/src/voltronformer/bitlinear/__init__.py @@ -1,2 +1,3 @@ # from .cg123 import BitLinear from .official import BitLinear +from .attention import scaled_dot_product_gqa \ No newline at end of file diff --git a/src/voltronformer/bitlinear/attention.py b/src/voltronformer/bitlinear/attention.py new file mode 100644 index 0000000..dfd6646 --- /dev/null +++ b/src/voltronformer/bitlinear/attention.py @@ -0,0 +1,143 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import einsum, rearrange +from torch import Tensor + + +def scaled_dot_product_gqa( + query: Tensor, + key: Tensor, + value: Tensor, + dropout: float = 0.0, + scale: Optional[float] = None, + mask: Optional[Tensor] = None, + is_causal: Optional[bool] = None, + need_weights: bool = False, + average_attn_weights: bool = False, + force_grouped: bool = False, +): + """Scaled dot product attention with support for grouped queries. + + Einstein notation: + - b: batch size + - n / s: sequence length + - h: number of heads + - g: number of groups + - d: dimension of query/key/value + + Args: + query: Query tensor of shape (b, n, h, d) + key: Key tensor of shape (b, s, h, d) + value: Value tensor of shape (b, s, h, d) + dropout: Dropout probability (default: 0.0) + scale: Scale factor for query (default: d_query ** 0.5) + mask: Mask tensor of shape (b, n, s) or (b, s). If 'ndim == 2', the mask is + applied to all 'n' rows of the attention matrix. (default: None) + force_grouped: If True, apply grouped-query attention even if the number of + heads is equal for query, key, and value. (default: False) + + Returns: + 2-tuple of: + - Attention output with shape (b, n, h, d) + - (Optional) Attention weights with shape (b, h, n, s). Only returned if + 'need_weights' is True. + """ + if (mask is not None) and (is_causal is not None): + raise ValueError( + "Only one of 'mask' and 'is_causal' should be provided, but got both." + ) + elif not query.ndim == key.ndim == value.ndim == 4: + raise ValueError( + f"Expected query, key, and value to be 4-dimensional, but got shapes " + f"{query.shape}, {key.shape}, and {value.shape}." + ) + + # Move sequence length dimension to axis 2. + # This makes the attention operations below *much* faster. + query = rearrange(query, "b n h d -> b h n d") + key = rearrange(key, "b s h d -> b h s d") + value = rearrange(value, "b s h d -> b h s d") + + bq, hq, nq, dq = query.shape + bk, hk, nk, dk = key.shape + bv, hv, nv, dv = value.shape + if not (bq == bk == bv and dq == dk == dv): + raise ValueError( + "Expected query, key, and value to have the same batch size (dim=0) and " + f"embedding dimension (dim=3), but got query: {query.shape}, " + f"key: {key.shape}, and value: {value.shape}." + ) + elif (hk != hv) or (nk != nv): + raise ValueError( + "Expected key and value to have the same size in dimensions 1 and 2, but " + f"got key: {key.shape} and value: {value.shape}." + ) + elif hq % hk != 0: + raise ValueError( + "Expected query heads to be a multiple of key/value heads, but got " + f"query: {query.shape} and key/value: {key.shape}." + ) + + if scale is None: + scale = query.size(-1) ** 0.5 + query = query / scale + + num_head_groups = hq // hk + if num_head_groups > 1 or force_grouped: + # Separate the query heads into 'num_head_groups' chunks, and fold the group + # dimension into the batch dimension. This allows us to compute the attention + # for each head in parallel, then sum over all of the groups at the end. + query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) + similarity = einsum(query, key, "b g h n d, b h s d -> b h n s") + else: + # If the number of query/key heads is equal, we can skip grouping the queries, + # and just use the standard sdot product attention. + similarity = einsum(query, key, "b h n d, b h s d -> b h n s") + + if is_causal: + # Mask out the upper triangular portion of the attention matrix. This prevents + # the model from attending to tokens in the future. + mask = torch.ones( + (bq, nq, nk), + device=query.device, + dtype=torch.bool, + ).tril_() + + if mask is not None: + # Expand mask to match the shape of the attention matrix. + # If mask is 2D, assume that it is applied to the key/value sequence dimension. + # Else if mask is 3D, assume that it is applied to the query/key/value sequence + # dimension for all attention heads. + # + # Users could also provide a 4D mask, which is applied to the query/key/value + # sequence dimension for each attention head (though I don't have a particular + # use case in mind for that). + if mask.ndim == 2: + mask = rearrange(mask, "b s -> b () () s") + elif mask.ndim == 3: + mask = rearrange(mask, "b n s -> b () n s") + # Mask similarity values by setting them to negative infinity. This guarantees + # that they will not contribute to the softmax computation below. + similarity.masked_fill_(~mask, torch.finfo(similarity.dtype).min) + + attention = F.softmax(similarity / scale, dim=-1) + if dropout > 0.0: + attention = F.dropout(attention, p=dropout) + + # Apply attention matrix to the value Tensor. + out = einsum(attention, value, "b h n s, b h s d -> b h n d") + # Move head dimension back to axis 2 + out = rearrange(out, "b h n d -> b n h d") + + attn_weights: Optional[Tensor] = None + if need_weights: + # Move the sequence dimensions back to positions 1, 2. Move the head dimension + # to position 3. This more closely matches the return shape of the attention + # output: (b, n, h, d). + attn_weights = rearrange(attention, "b h n s -> b n s h") + if average_attn_weights: + attn_weights = attn_weights.mean(dim=1) + + return out, attn_weights \ No newline at end of file diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 8313a6a..9684c23 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -1,19 +1,15 @@ import functools -from typing import List, Optional, Callable, Tuple +from typing import Optional, Callable, Tuple import torch -import bitnet.bit_attention -from .bitlinear import BitLinear -bitnet.bit_attention.BitLinear = BitLinear +from .bitlinear import BitLinear, scaled_dot_product_gqa -from bitnet.bit_attention import scaled_dot_product_gqa, BitMGQA from functorch.einops import rearrange from torch import nn, Tensor from denseformer import DWAModules from torch.utils.checkpoint import checkpoint from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding -from .core import Linear from .mod import MoDBlock diff --git a/train.py b/train.py index a6592c3..2ce75e2 100644 --- a/train.py +++ b/train.py @@ -4,11 +4,9 @@ from typing import Optional import torch -import torch.nn.functional as F import wandb from accelerate import Accelerator, PartialState from datasets import load_dataset -from einops import rearrange from schedulefree import AdamWScheduleFree from torch.utils.data import DataLoader from tqdm import tqdm From 270150f93652a0f622a1bf7757ca416edcfdad79 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 14:28:48 -0400 Subject: [PATCH 24/42] re-enable DWA again --- src/voltronformer/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index 2e96f2e..c514300 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -18,6 +18,7 @@ def tiny(): "mod_every": 2, "mod_capacity_factor": 0.125, "rms_norm_eps": 0.000001, + "dwa": True, }) From 49cc04aa6f9bd643bfeb5489f3e34dcdd4bf4f9b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 14:38:59 -0400 Subject: [PATCH 25/42] actually check for dwa --- src/voltronformer/model.py | 9 ++++++--- train.py | 16 ++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 9684c23..7774e96 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -318,7 +318,8 @@ class Transformer(CheckpointingMixin): def __init__(self, config): super().__init__() self.config = config - self.dwa_modules = DWAModules(config.num_hidden_layers, config.dwa_dilation, config.dwa_period) + if config.dwa: + self.dwa_modules = DWAModules(config.num_hidden_layers, config.dwa_dilation, config.dwa_period) self.wte = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.h = nn.ModuleList([ @@ -340,7 +341,8 @@ def forward(self, x): ).unsqueeze(0) hidden_states = inputs_embeds - self.dwa_modules.init_accumulators(hidden_states) + if self.config.dwa: + self.dwa_modules.init_accumulators(hidden_states) for i, decoder_layer in enumerate(self.h): # gradient checkpointing if self.gradient_checkpointing and self.training: @@ -351,7 +353,8 @@ def forward(self, x): ) else: hidden_states = decoder_layer(hidden_states, position_ids) - hidden_states = self.dwa_modules(hidden_states, block_idx=i) + if self.config.dwa: + hidden_states = self.dwa_modules(hidden_states, block_idx=i) hidden_states = self.ln_f(hidden_states) return hidden_states diff --git a/train.py b/train.py index 2ce75e2..db7ea0d 100644 --- a/train.py +++ b/train.py @@ -136,13 +136,13 @@ def train_loop(self): def get_ds(): return load_dataset("togethercomputer/RedPajama-Data-V2", - name="default", - partition="head_middle", - snapshots=["2023-14"], - languages=["en"], - split="train", - streaming=True, - ), "raw_content" + name="default", + partition="head_middle", + snapshots=["2023-14"], + languages=["en"], + split="train", + streaming=True, + ), "raw_content" # load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True) def main(): @@ -173,7 +173,7 @@ def main(): tokenizer.pad_token_id = tokenizer.eos_token_id def tokenize_function(examples, field="text", tokenizer=None): - outputs = tokenizer(examples[field], truncation=True, max_length=None) + outputs = tokenizer(examples[field], truncation=True, max_length=config.max_position_embeddings) return outputs with state.main_process_first(): From 4437672fb99c7b2c8c19250cfbbfc483f893f85a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 14:55:47 -0400 Subject: [PATCH 26/42] wandb on main rank only --- train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index db7ea0d..ca05b3d 100644 --- a/train.py +++ b/train.py @@ -47,7 +47,8 @@ def __init__(self, model, args, dataloader, accelerator, activation_checkpointin self.device = device_get_cuda() self.global_step = 0 self.rank = device_get_local_rank() - wandb.init(project="voltronformer") + if accelerator.is_main_process: + wandb.init(project="voltronformer") self.accelerator = accelerator @property From c2f804c2525d20b9eb4cd8d315281dac1b08433a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 15:58:26 -0400 Subject: [PATCH 27/42] fix modulo for log steps --- train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index ca05b3d..9fb70d9 100644 --- a/train.py +++ b/train.py @@ -42,6 +42,7 @@ def __init__(self, model, args, dataloader, accelerator, activation_checkpointin model, auto_wrap_policy={TransformerDecoderBlock} ) self.build_optimizer_and_scheduler() + self._model, self.dataloader, self.optimizer = accelerator.prepare(self._model, dataloader, self.optimizer) self.device = device_get_cuda() @@ -115,7 +116,7 @@ def train_loop(self): loss = self._loss_fn(shift_logits, shift_labels) if ( - self.global_step * self.args.log_steps == 0 + self.global_step % self.args.log_steps == 0 and self.rank == 0 ): pbar.set_description(f"Loss: {loss.item()}") From 88f25a97c9b66e64bb1de3fba336ece493070100 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 18:04:25 -0400 Subject: [PATCH 28/42] attempt to use accelerator loop --- train.py | 81 +++++++++++++++++++++++++++----------------------------- 1 file changed, 39 insertions(+), 42 deletions(-) diff --git a/train.py b/train.py index 9fb70d9..a9c9b86 100644 --- a/train.py +++ b/train.py @@ -87,53 +87,50 @@ def train(self): def train_loop(self): for idx, batch in enumerate(pbar := tqdm(self.dataloader, disable=not (self.rank == 0))): - if ( - self.args.max_steps_per_epoch is not None - and (idx // self.args.gradient_accumulation_steps) - == self.args.max_steps_per_epoch - ): - break - - input_ids = batch["input_ids"].to(self.device) - if "labels" in batch.keys(): - labels = batch["labels"].to(self.device) - else: - labels = input_ids.clone() - - logits = self._model(input_ids) - - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_logits = shift_logits.view(-1, self.args.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - - # logits = logits.transpose(1, 2) - - # Compute loss - loss = self._loss_fn(shift_logits, shift_labels) - - if ( - self.global_step % self.args.log_steps == 0 - and self.rank == 0 - ): - pbar.set_description(f"Loss: {loss.item()}") - wandb.log({"loss": loss.item(), "global_step": self.global_step}) - - loss = loss / self.args.gradient_accumulation_steps - loss.backward() - - if (idx + 1) % self.args.gradient_accumulation_steps == 0: + # if ( + # self.args.max_steps_per_epoch is not None + # and (idx // self.args.gradient_accumulation_steps) + # == self.args.max_steps_per_epoch + # ): + # break + + with self.accelerator.accumulate(self._model): + input_ids = batch["input_ids"].to(self.device) + if "labels" in batch.keys(): + labels = batch["labels"].to(self.device) + else: + labels = input_ids.clone() + + logits = self._model(input_ids) + + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, self.args.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + + # Compute loss + loss = self._loss_fn(shift_logits, shift_labels) + + if ( + self.global_step % self.args.log_steps == 0 + and self.rank == 0 + ): + pbar.set_description(f"Loss: {loss.item()}") + wandb.log({"loss": loss.item(), "global_step": self.global_step}) + + self.accelerator.backward(loss) self.optimizer.step() + self._model.zero_grad() if self.lr_scheduler: self.lr_scheduler.step() self.optimizer.zero_grad(set_to_none=True) self.global_step += 1 - if self.global_step % self.args.save_steps == 0: - self.save_checkpoint() + if self.global_step % self.args.save_steps == 0: + self.save_checkpoint() def get_ds(): @@ -196,7 +193,7 @@ def tokenize_function(examples, field="text", tokenizer=None): # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") - accelerator = Accelerator() + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) dataloader_params = dict( batch_size=args.per_gpu_train_batch_size, From a10c31a545e4cd8c37a9ed77f5bf8e3cf7fbb9fc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Apr 2024 18:19:22 -0400 Subject: [PATCH 29/42] update configuration --- src/voltronformer/config.py | 2 +- train.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index c514300..0ee6e81 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -7,7 +7,7 @@ def tiny(): "hidden_size": 512, "intermediate_size": 1408, "rope_theta": 10_000, - "max_position_embeddings": 1024, + "max_position_embeddings": 2048, "num_attention_heads": 16, "num_key_value_heads": 4, "num_hidden_layers": 12, diff --git a/train.py b/train.py index a9c9b86..7dd77a9 100644 --- a/train.py +++ b/train.py @@ -152,14 +152,14 @@ def main(): args = TrainingArguments( gradient_accumulation_steps=16, max_steps_per_epoch=None, - log_steps=1, + log_steps=10, output_dir="./out", - weight_decay=0.0, + weight_decay=0.1, warmup_steps=1000, - per_gpu_train_batch_size=8, - save_steps=10000, + per_gpu_train_batch_size=24, + save_steps=1000, max_sequence_length=config.max_position_embeddings, - learning_rate=5e-4, + learning_rate=1e-3, vocab_size=config.vocab_size, ) os.makedirs(args.output_dir, exist_ok=True) From 8239c6eb86ceed98b3a09607bdc04c741c848abf Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 Apr 2024 07:17:47 -0400 Subject: [PATCH 30/42] wip rms norm --- pyproject.toml | 1 + src/voltronformer/config.py | 8 ++- src/voltronformer/kernels/rms_norm.py | 91 +++++++++++++++++++++++++++ src/voltronformer/model.py | 17 +---- train.py | 14 ++++- 5 files changed, 111 insertions(+), 20 deletions(-) create mode 100644 src/voltronformer/kernels/rms_norm.py diff --git a/pyproject.toml b/pyproject.toml index 42f405b..239228d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "flash-attn", "numba", "numpy", + "safetensors, "wandb", "tqdm", "transformers==4.39.3", diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index 0ee6e81..d1cbcda 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -26,6 +26,7 @@ def small(): return DictDefault({ "hidden_size": 1024, "intermediate_size": 2816, + "rope_theta": 10_000, "max_position_embeddings": 4096, "num_attention_heads": 32, "num_key_value_heads": 8, @@ -37,6 +38,7 @@ def small(): "mod_every": 2, "mod_capacity_factor": 0.125, "rms_norm_eps": 0.000001, + "dwa": True, }) @@ -44,10 +46,11 @@ def medium(): return DictDefault({ "hidden_size": 4096, "intermediate_size": 11264, + "rope_theta": 10_000, "max_position_embeddings": 8192, "num_attention_heads": 32, "num_key_value_heads": 8, - "num_hidden_layers": 24, + "num_hidden_layers": 32, "vocab_size": 32000, "dwa_dilation": 4, "dwa_period": 5, @@ -55,4 +58,5 @@ def medium(): "mod_every": 2, "mod_capacity_factor": 0.125, "rms_norm_eps": 0.000001, - }) \ No newline at end of file + "dwa": True, + }) diff --git a/src/voltronformer/kernels/rms_norm.py b/src/voltronformer/kernels/rms_norm.py new file mode 100644 index 0000000..e7c2951 --- /dev/null +++ b/src/voltronformer/kernels/rms_norm.py @@ -0,0 +1,91 @@ +import torch +import triton +import triton.language as tl +from torch import nn + + +# from https://ai.lefebvre-sarrut.eu/2023/07/20/deep-dive-into-kernel-fusion-accelerating-inference-in-llama-v2/#openai-triton-rewriting +@triton.jit +def rmsnorm_triton(x_ptr, rms_w_ptr, output_ptr, + stride_x_batch, stride_x_m, stride_x_k, + stride_rms_w, + stride_out_batch, stride_out_m, stride_out_k, + N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr): + pid_batch = tl.program_id(0) + pid_m = tl.program_id(1) + + offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m + block_N = tl.arange(0, BLOCK_N_SIZE) + var = tl.zeros((BLOCK_N_SIZE,), tl.float32) + + # first loop over input tensor to compute the root mean of the square + for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE): + offs_n = block_n_start_idx + block_N + x_ptr_mask = offs_n < N_SIZE + # recompute address at each iteration + x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0) + var += tl.math.pow(x.to(tl.float32), 2) + + # we keep this reduction operation outside the loop for perf reasons + var = tl.sum(var, axis=0) / N_SIZE + rstd = tl.math.rsqrt(var + eps) + + # apply the normalization and multiply by RMS weights + for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE): + offs_n = block_n_start_idx + block_N + x_ptr_mask = offs_n < N_SIZE + rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask) + + x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0).to(tl.float32) + x_hat = x * rstd + out = x_hat * rms_w + out_off = pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k + tl.store(output_ptr + out_off, out, mask=x_ptr_mask) + + +class RMSNorm(nn.Module): + """copied from torchtune""" + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fp32 = x.float() + x_normed = ( + x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) + ).type_as(x) + return x_normed * self.scale + + +"""not ready for use yet. 2X Faster, but not accurate""" +class RMSNormTriton(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Define the grid and block dimensions + N_SIZE = x.shape[-1] + BLOCK_N_SIZE = 512 # Adjust this value based on your requirements + + # Allocate output tensor + output = torch.empty_like(x) + + # Define the strides for input, scale, and output tensors + stride_x_batch, stride_x_m, stride_x_k = x.stride() + stride_rms_w = self.scale.stride(0) + stride_out_batch, stride_out_m, stride_out_k = output.stride() + + # Launch the Triton kernel + grid = lambda meta: (x.shape[0], x.shape[1]) + rmsnorm_triton[grid]( + x, self.scale, output, + stride_x_batch, stride_x_m, stride_x_k, + stride_rms_w, + stride_out_batch, stride_out_m, stride_out_k, + N_SIZE, self.eps, BLOCK_N_SIZE + ) + + return output \ No newline at end of file diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 7774e96..17a4a6b 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from .mod import MoDBlock - +from .kernels.rms_norm import RMSNorm def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -67,21 +67,6 @@ def __init__(self, hidden_size: int, num_heads: int): self.num_heads = num_heads -class RMSNorm(nn.Module): - """copied from torchtune""" - def __init__(self, dim: int, eps: float = 1e-6) -> None: - super().__init__() - self.eps = eps - self.scale = nn.Parameter(torch.ones(dim)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_fp32 = x.float() - x_normed = ( - x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) - ).type_as(x) - return x_normed * self.scale - - def mlp(dim: int, hidden_dim: int) -> FeedForward: """ Build the MLP layer associated with the Llama model. diff --git a/train.py b/train.py index 7dd77a9..fa5ea00 100644 --- a/train.py +++ b/train.py @@ -5,8 +5,9 @@ import torch import wandb -from accelerate import Accelerator, PartialState +from accelerate import Accelerator, PartialState, DistributedDataParallelKwargs from datasets import load_dataset +from safetensors.torch import save_model from schedulefree import AdamWScheduleFree from torch.utils.data import DataLoader from tqdm import tqdm @@ -72,6 +73,7 @@ def _loss_fn(self, logits, labels): def save_checkpoint(self): output_dir = self.args.output_dir if self.args.output_dir is not None else "." + save_model(self._model, os.path.join(output_dir, f"model_{self.global_step}.safetensors")) torch.save( self._model.state_dict(), os.path.join(output_dir, f"model_{self.global_step}.pt"), @@ -193,12 +195,20 @@ def tokenize_function(examples, field="text", tokenizer=None): # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + # ddp kwargs with find_unused_parameters needed for triton rmsnorm + # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + mixed_precision="bf16", + log_with="wandb", + gradient_accumulation_steps=args.gradient_accumulation_steps, + # kwargs_handlers=[ddp_kwargs], + ) dataloader_params = dict( batch_size=args.per_gpu_train_batch_size, num_workers=8, pin_memory=True, + prefetch_factor=4, drop_last=True, collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, max_length=True), ) From 2b2f332fb48d050fe96ca5bc76d454eb5698709f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 Apr 2024 08:14:04 -0400 Subject: [PATCH 31/42] use apex rms norm optim --- pyproject.toml | 2 +- src/voltronformer/config.py | 1 + src/voltronformer/model.py | 5 ++++- train.py | 3 ++- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 239228d..39c4aeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "flash-attn", "numba", "numpy", - "safetensors, + "safetensors", "wandb", "tqdm", "transformers==4.39.3", diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index d1cbcda..4efe5d7 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -23,6 +23,7 @@ def tiny(): def small(): + """300M parameters""" return DictDefault({ "hidden_size": 1024, "intermediate_size": 2816, diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 17a4a6b..89906f9 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -11,7 +11,10 @@ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from .mod import MoDBlock -from .kernels.rms_norm import RMSNorm +try: + from apex.normalization import FusedRMSNorm as RMSNorm +except ImportError: + from .kernels.rms_norm import RMSNorm def rotate_half(x): """Rotates half the hidden dims of the input.""" diff --git a/train.py b/train.py index fa5ea00..ca02df7 100644 --- a/train.py +++ b/train.py @@ -215,7 +215,8 @@ def tokenize_function(examples, field="text", tokenizer=None): dataloader = DataLoader(train_dataset, **dataloader_params) trainer = Trainer(model, args, dataloader, accelerator, activation_checkpointing=True) - print(f"Total number of parameters: {trainer.model_num_parameters:_}") + if state.is_main_process: + print(f"Total number of parameters: {trainer.model_num_parameters:_}") trainer.train() From cba6e66957cd7304ba615c3d0758b6f069d53752 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 Apr 2024 16:07:00 -0400 Subject: [PATCH 32/42] queued dataloader and gradient norm --- pyproject.toml | 2 ++ src/voltronformer/train/data.py | 31 +++++++++++++++++++- train.py | 52 +++++++++++++++++++-------------- 3 files changed, 62 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 39c4aeb..d3be3fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,12 +11,14 @@ dependencies = [ "datasets", "einops", "flash-attn", + "mosaicml-streaming", "numba", "numpy", "safetensors", "wandb", "tqdm", "transformers==4.39.3", + "zstandard", "denseformer @ git+https://github.com/epfml/DenseFormer.git@main", ] maintainers = [ diff --git a/src/voltronformer/train/data.py b/src/voltronformer/train/data.py index 3ec0fab..e80ff0f 100644 --- a/src/voltronformer/train/data.py +++ b/src/voltronformer/train/data.py @@ -1,10 +1,12 @@ import functools from collections import defaultdict +from queue import Queue +from threading import Thread from typing import Callable, Dict, List import numpy as np from datasets import Dataset -from torch.utils.data import RandomSampler +from torch.utils.data import RandomSampler, DataLoader from src.voltronformer.train.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from src.voltronformer.train.samplers import MultipackBatchSampler @@ -115,3 +117,30 @@ def encode_packed_pretraining( chunked_data[feature].append(collated_features[feature].squeeze(0)) return chunked_data + + +class QueuedDataLoader(DataLoader): + def __init__(self, *args, queue_len=1_000, **kwargs): + kwargs["persistent_workers"] = True + super().__init__(*args, **kwargs) + self.data_queue = Queue(maxsize=queue_len) + self.prefetch_thread = Thread(target=self.prefetch_data) + self.prefetch_thread.daemon = True + self.prefetch_thread.start() + + def prefetch_data(self): + for data in super().__iter__(): + self.data_queue.put(data) + self.data_queue.put(None) + + def __iter__(self): + return super().__iter__() + + def __next__(self): + if hasattr(self, 'data_queue'): + data = self.data_queue.get() + if data is None: + raise StopIteration + return data + else: + return self._iterator.__next__() diff --git a/train.py b/train.py index ca02df7..fb6c31d 100644 --- a/train.py +++ b/train.py @@ -15,7 +15,7 @@ from src.voltronformer.config import tiny from src.voltronformer.model import CausalLM, TransformerDecoderBlock -from src.voltronformer.train.data import wrap_pretraining_dataset +from src.voltronformer.train.data import wrap_pretraining_dataset, QueuedDataLoader from src.voltronformer.utils import device_get_cuda, device_get_local_rank, set_activation_checkpointing @@ -32,6 +32,7 @@ class TrainingArguments: max_sequence_length: Optional[int] = 8192 learning_rate: float = 5e-5 vocab_size: Optional[int] = None + max_grad_norm: Optional[float] = 1.0 class Trainer: @@ -89,13 +90,6 @@ def train(self): def train_loop(self): for idx, batch in enumerate(pbar := tqdm(self.dataloader, disable=not (self.rank == 0))): - # if ( - # self.args.max_steps_per_epoch is not None - # and (idx // self.args.gradient_accumulation_steps) - # == self.args.max_steps_per_epoch - # ): - # break - with self.accelerator.accumulate(self._model): input_ids = batch["input_ids"].to(self.device) if "labels" in batch.keys(): @@ -115,27 +109,33 @@ def train_loop(self): # Compute loss loss = self._loss_fn(shift_logits, shift_labels) - - if ( - self.global_step % self.args.log_steps == 0 - and self.rank == 0 - ): - pbar.set_description(f"Loss: {loss.item()}") - wandb.log({"loss": loss.item(), "global_step": self.global_step}) - self.accelerator.backward(loss) + if self.accelerator.sync_gradients: + grad_norm = self.accelerator.clip_grad_norm_(self._model.parameters(), self.args.max_grad_norm) self.optimizer.step() self._model.zero_grad() if self.lr_scheduler: self.lr_scheduler.step() + self.optimizer.zero_grad(set_to_none=True) self.global_step += 1 + if self.global_step % self.args.log_steps == 0: + if self.rank == 0: + pbar.set_description(f"Loss: {loss.item()} Global Step: {self.global_step}") + grad_norm = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm + self.accelerator.log({"training_loss": loss.item(), "gradient_norm": grad_norm}, step=self.global_step) + if self.global_step % self.args.save_steps == 0: self.save_checkpoint() + self.accelerator.end_training() -def get_ds(): + +def get_redpajama_v1(): + return load_dataset("togethercomputer/RedPajama-Data-1T", "common_crawl", split="train", streaming=True), "text" + +def get_redpajama_v2(): return load_dataset("togethercomputer/RedPajama-Data-V2", name="default", partition="head_middle", @@ -144,10 +144,18 @@ def get_ds(): split="train", streaming=True, ), "raw_content" + +def get_ds(): + return get_redpajama_v2() + # load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True) + def main(): state = PartialState() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + config = tiny() ds, text_field = get_ds() @@ -198,21 +206,21 @@ def tokenize_function(examples, field="text", tokenizer=None): # ddp kwargs with find_unused_parameters needed for triton rmsnorm # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( - mixed_precision="bf16", - log_with="wandb", + # mixed_precision="bf16", + log_with=["wandb", "tensorboard"], gradient_accumulation_steps=args.gradient_accumulation_steps, # kwargs_handlers=[ddp_kwargs], ) dataloader_params = dict( batch_size=args.per_gpu_train_batch_size, - num_workers=8, + num_workers=1, pin_memory=True, - prefetch_factor=4, + prefetch_factor=8, drop_last=True, collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, max_length=True), ) - dataloader = DataLoader(train_dataset, **dataloader_params) + dataloader = QueuedDataLoader(train_dataset, queue_len=10_000, **dataloader_params) trainer = Trainer(model, args, dataloader, accelerator, activation_checkpointing=True) if state.is_main_process: From ffafd7a2db47cd24fd68b2b7431e4f927e7b1371 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 Apr 2024 19:21:11 -0400 Subject: [PATCH 33/42] fixes for loss calc, grad accum, dataloader for dispatch_batches --- train.py | 102 ++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 71 insertions(+), 31 deletions(-) diff --git a/train.py b/train.py index fb6c31d..24a07f2 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,6 @@ import functools import os +import tempfile from dataclasses import dataclass from typing import Optional @@ -12,18 +13,22 @@ from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoTokenizer, DataCollatorForSeq2Seq +from transformers.trainer_pt_utils import distributed_concat -from src.voltronformer.config import tiny +from src.voltronformer.config import tiny, small from src.voltronformer.model import CausalLM, TransformerDecoderBlock from src.voltronformer.train.data import wrap_pretraining_dataset, QueuedDataLoader from src.voltronformer.utils import device_get_cuda, device_get_local_rank, set_activation_checkpointing +state = PartialState() + @dataclass class TrainingArguments: gradient_accumulation_steps: int = 1 max_steps_per_epoch: Optional[int] = None log_steps: int = 1 + adam_epsilon: Optional[float] = 1e-8 output_dir: Optional[str] = None weight_decay: float = 0.0 warmup_steps: Optional[int] = 1000 @@ -33,6 +38,7 @@ class TrainingArguments: learning_rate: float = 5e-5 vocab_size: Optional[int] = None max_grad_norm: Optional[float] = 1.0 + n_gpu: Optional[int] = None class Trainer: @@ -64,7 +70,7 @@ def model_num_parameters(self): return all_param def build_optimizer_and_scheduler(self): - self.optimizer = AdamWScheduleFree(self._model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay, warmup_steps=self.args.weight_decay) + self.optimizer = AdamWScheduleFree(self._model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay, warmup_steps=self.args.weight_decay, eps=self.args.adam_epsilon) self.lr_scheduler = None def _loss_fn(self, logits, labels): @@ -89,7 +95,11 @@ def train(self): self.train_loop() def train_loop(self): + tr_loss = torch.tensor(0.0).to(self.device) + total_batched_samples = 0 for idx, batch in enumerate(pbar := tqdm(self.dataloader, disable=not (self.rank == 0))): + total_batched_samples += 1 + is_grad_accum_step = total_batched_samples % self.args.gradient_accumulation_steps == 0 with self.accelerator.accumulate(self._model): input_ids = batch["input_ids"].to(self.device) if "labels" in batch.keys(): @@ -109,25 +119,37 @@ def train_loop(self): # Compute loss loss = self._loss_fn(shift_logits, shift_labels) + if self.args.n_gpu > 1: + loss = loss.mean() self.accelerator.backward(loss) - if self.accelerator.sync_gradients: - grad_norm = self.accelerator.clip_grad_norm_(self._model.parameters(), self.args.max_grad_norm) - self.optimizer.step() - self._model.zero_grad() - if self.lr_scheduler: - self.lr_scheduler.step() - - self.optimizer.zero_grad(set_to_none=True) - self.global_step += 1 + mini_step_loss = loss.detach() / self.args.gradient_accumulation_steps + tr_loss += mini_step_loss - if self.global_step % self.args.log_steps == 0: - if self.rank == 0: - pbar.set_description(f"Loss: {loss.item()} Global Step: {self.global_step}") - grad_norm = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm - self.accelerator.log({"training_loss": loss.item(), "gradient_norm": grad_norm}, step=self.global_step) + if is_grad_accum_step: + grad_norm = self.accelerator.clip_grad_norm_(self._model.parameters(), self.args.max_grad_norm) - if self.global_step % self.args.save_steps == 0: - self.save_checkpoint() + self.optimizer.step() + if self.lr_scheduler: + self.lr_scheduler.step() + self._model.zero_grad() + + if self.accelerator.num_processes > 1: + tr_loss_scalar = distributed_concat(tr_loss).mean().item() + else: + tr_loss_scalar = tr_loss.mean().item() + tr_loss -= tr_loss + + self.global_step += 1 + + if self.global_step % self.args.log_steps == 0: + grad_norm = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm + if self.rank == 0: + pbar.set_description(f"Loss: {tr_loss_scalar} Global Step: {self.global_step} gradient_norm: {grad_norm}") + print(f"Loss: {tr_loss_scalar} Global Step: {self.global_step} gradient_norm: {grad_norm}") + wandb.log({"training_loss": tr_loss_scalar, "gradient_norm": grad_norm, "global_step": self.global_step}, step=self.global_step) + self.accelerator.log({"training_loss": tr_loss_scalar, "gradient_norm": grad_norm}, step=self.global_step) + if self.global_step % self.args.save_steps == 0: + self.save_checkpoint() self.accelerator.end_training() @@ -145,37 +167,53 @@ def get_redpajama_v2(): streaming=True, ), "raw_content" -def get_ds(): - return get_redpajama_v2() + +def get_ds(dispatch_batches): + """ + this is a janky workaround so it doesn't connect to the dataset server unnecessarily + when using dispatch_batches + """ + if state.is_main_process or not dispatch_batches: + return get_redpajama_v2() + else: + with tempfile.NamedTemporaryFile(mode="w+", delete=True) as f: + f.write("text\n") + f.write("lorem ipsum dolor sit amet\n") + # f.writelines(["text", "lorem ipsum dolor sit amet"]) + f.seek(0) + return load_dataset("csv", data_files={"train": f.name}, split="train"), "text" # load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True) def main(): - state = PartialState() torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - config = tiny() + config = small() + dispatch_batches = True + + ds, text_field = get_ds(dispatch_batches) - ds, text_field = get_ds() args = TrainingArguments( - gradient_accumulation_steps=16, + gradient_accumulation_steps=8, max_steps_per_epoch=None, - log_steps=10, + log_steps=1, + adam_epsilon=0.00001, output_dir="./out", weight_decay=0.1, warmup_steps=1000, - per_gpu_train_batch_size=24, + per_gpu_train_batch_size=10, save_steps=1000, max_sequence_length=config.max_position_embeddings, - learning_rate=1e-3, + learning_rate=1e-4, vocab_size=config.vocab_size, + n_gpu=state.num_processes, ) os.makedirs(args.output_dir, exist_ok=True) model = CausalLM(config) - model = model.to(device_get_cuda()) + # model = model.to(device_get_cuda()) # tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base") tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") if not tokenizer.pad_token_id: @@ -198,7 +236,7 @@ def tokenize_function(examples, field="text", tokenizer=None): ds_wrapper_partial, max_tokens=args.max_sequence_length, batch_size=args.per_gpu_train_batch_size, - buffer_size=10_000, + buffer_size=100_000, ) # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") @@ -208,7 +246,9 @@ def tokenize_function(examples, field="text", tokenizer=None): accelerator = Accelerator( # mixed_precision="bf16", log_with=["wandb", "tensorboard"], + project_dir="./runs", gradient_accumulation_steps=args.gradient_accumulation_steps, + dispatch_batches=dispatch_batches, # kwargs_handlers=[ddp_kwargs], ) @@ -216,11 +256,11 @@ def tokenize_function(examples, field="text", tokenizer=None): batch_size=args.per_gpu_train_batch_size, num_workers=1, pin_memory=True, - prefetch_factor=8, + prefetch_factor=1_000, drop_last=True, collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, max_length=True), ) - dataloader = QueuedDataLoader(train_dataset, queue_len=10_000, **dataloader_params) + dataloader = DataLoader(train_dataset, **dataloader_params) trainer = Trainer(model, args, dataloader, accelerator, activation_checkpointing=True) if state.is_main_process: From 0255ffe12e0076f704e2c0368aef6736294f31fe Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 Apr 2024 20:50:34 -0400 Subject: [PATCH 34/42] tweak size names --- src/voltronformer/config.py | 13 +++++++------ train.py | 10 +++++++--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index 4efe5d7..ff2cf90 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -1,7 +1,7 @@ from src.voltronformer.utils import DictDefault -def tiny(): +def teeny(): """50M parameters""" return DictDefault({ "hidden_size": 512, @@ -22,7 +22,7 @@ def tiny(): }) -def small(): +def tiny(): """300M parameters""" return DictDefault({ "hidden_size": 1024, @@ -43,15 +43,16 @@ def small(): }) -def medium(): +def small(): + """1.1B parameters""" return DictDefault({ - "hidden_size": 4096, - "intermediate_size": 11264, + "hidden_size": 2048, + "intermediate_size": 5632, "rope_theta": 10_000, "max_position_embeddings": 8192, "num_attention_heads": 32, "num_key_value_heads": 8, - "num_hidden_layers": 32, + "num_hidden_layers": 24, "vocab_size": 32000, "dwa_dilation": 4, "dwa_period": 5, diff --git a/train.py b/train.py index 24a07f2..9bfefbf 100644 --- a/train.py +++ b/train.py @@ -15,7 +15,7 @@ from transformers import AutoTokenizer, DataCollatorForSeq2Seq from transformers.trainer_pt_utils import distributed_concat -from src.voltronformer.config import tiny, small +from src.voltronformer.config import teeny, tiny, small from src.voltronformer.model import CausalLM, TransformerDecoderBlock from src.voltronformer.train.data import wrap_pretraining_dataset, QueuedDataLoader from src.voltronformer.utils import device_get_cuda, device_get_local_rank, set_activation_checkpointing @@ -146,7 +146,10 @@ def train_loop(self): if self.rank == 0: pbar.set_description(f"Loss: {tr_loss_scalar} Global Step: {self.global_step} gradient_norm: {grad_norm}") print(f"Loss: {tr_loss_scalar} Global Step: {self.global_step} gradient_norm: {grad_norm}") - wandb.log({"training_loss": tr_loss_scalar, "gradient_norm": grad_norm, "global_step": self.global_step}, step=self.global_step) + try: + wandb.log({"training_loss": tr_loss_scalar, "gradient_norm": grad_norm, "global_step": self.global_step}, step=self.global_step) + except: + pass self.accelerator.log({"training_loss": tr_loss_scalar, "gradient_norm": grad_norm}, step=self.global_step) if self.global_step % self.args.save_steps == 0: self.save_checkpoint() @@ -164,6 +167,7 @@ def get_redpajama_v2(): snapshots=["2023-14"], languages=["en"], split="train", + trust_remote_code=True, streaming=True, ), "raw_content" @@ -236,7 +240,7 @@ def tokenize_function(examples, field="text", tokenizer=None): ds_wrapper_partial, max_tokens=args.max_sequence_length, batch_size=args.per_gpu_train_batch_size, - buffer_size=100_000, + buffer_size=40_000, ) # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") From d289d98b1c20e217baaf186a90759dd7766cb43c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 Apr 2024 22:42:17 -0400 Subject: [PATCH 35/42] upcast/downcast --- src/voltronformer/bitlinear/attention.py | 2 +- src/voltronformer/bitlinear/official.py | 14 ++++++++------ train.py | 23 +++++++++++++++++++++-- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/voltronformer/bitlinear/attention.py b/src/voltronformer/bitlinear/attention.py index dfd6646..e8c5883 100644 --- a/src/voltronformer/bitlinear/attention.py +++ b/src/voltronformer/bitlinear/attention.py @@ -122,7 +122,7 @@ def scaled_dot_product_gqa( # that they will not contribute to the softmax computation below. similarity.masked_fill_(~mask, torch.finfo(similarity.dtype).min) - attention = F.softmax(similarity / scale, dim=-1) + attention = F.softmax(similarity / scale, dim=-1, dtype=torch.float32).to(dtype=query.dtype) if dropout > 0.0: attention = F.dropout(attention, p=dropout) diff --git a/src/voltronformer/bitlinear/official.py b/src/voltronformer/bitlinear/official.py index 059b3cc..1217900 100644 --- a/src/voltronformer/bitlinear/official.py +++ b/src/voltronformer/bitlinear/official.py @@ -3,12 +3,11 @@ from torch import nn -def weight_quant(weight, num_bits=1): - dtype = weight.dtype - weight = weight.float() +def weight_quant(weight, dtype=torch.float16): + weight = weight.bfloat16() s = 1 / weight.abs().mean().clamp(min=1e-5) result = (weight * s).round().clamp(-1, 1) / s - return result.type(dtype) + return result.to(dtype=dtype) def activation_quant(x, num_bits=8): @@ -37,9 +36,12 @@ def __init__(self, self.input_bits = input_bits def forward(self, input): - quant_input = input + (activation_quant(input, self.input_bits) - input).detach() - quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() + # Convert the uint8 weights to the input data type + fp_weight = self.weight.to(input.dtype) + + # seems silly, but this is done for the cuda graph's sake + quant_weight = fp_weight + (weight_quant(self.weight, dtype=input.dtype) - fp_weight).detach() out = nn.functional.linear(quant_input, quant_weight) if not self.bias is None: diff --git a/train.py b/train.py index 9bfefbf..8be8c19 100644 --- a/train.py +++ b/train.py @@ -39,6 +39,7 @@ class TrainingArguments: vocab_size: Optional[int] = None max_grad_norm: Optional[float] = 1.0 n_gpu: Optional[int] = None + bf16: Optional[bool] = False class Trainer: @@ -213,6 +214,7 @@ def main(): learning_rate=1e-4, vocab_size=config.vocab_size, n_gpu=state.num_processes, + bf16=True, ) os.makedirs(args.output_dir, exist_ok=True) @@ -247,25 +249,42 @@ def tokenize_function(examples, field="text", tokenizer=None): # ddp kwargs with find_unused_parameters needed for triton rmsnorm # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator_kwargs = {} + if args.bf16: + accelerator_kwargs["mixed_precision"] = "bf16" accelerator = Accelerator( - # mixed_precision="bf16", + mixed_precision="bf16", log_with=["wandb", "tensorboard"], project_dir="./runs", gradient_accumulation_steps=args.gradient_accumulation_steps, dispatch_batches=dispatch_batches, # kwargs_handlers=[ddp_kwargs], + **accelerator_kwargs, ) dataloader_params = dict( batch_size=args.per_gpu_train_batch_size, num_workers=1, pin_memory=True, - prefetch_factor=1_000, + prefetch_factor=2_000, drop_last=True, collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, max_length=True), ) dataloader = DataLoader(train_dataset, **dataloader_params) + ### float32 casting for improved accuracy + if args.bf16: + model = model.to(dtype=torch.bfloat16) + for name, module in model.named_modules(): + if "layernorm" in name or name == "ln_f": + module.to(torch.float32) + elif any(m in name for m in ["wte", "embed_out"]): + if hasattr(module, "weight"): + module.to(torch.float32) + elif "_proj" in name: + # module.to(torch.uint8) + module.weight.to(torch.float8_e4m3fn) + trainer = Trainer(model, args, dataloader, accelerator, activation_checkpointing=True) if state.is_main_process: print(f"Total number of parameters: {trainer.model_num_parameters:_}") From 84d755ab4d3b8c57d7d1661108147aa3f165479f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 15 Apr 2024 11:01:11 -0400 Subject: [PATCH 36/42] integrate infini-attention --- src/voltronformer/config.py | 3 + src/voltronformer/infini_attention.py | 107 +++++++++++++++++++++++ src/voltronformer/kernels/activations.py | 16 ++++ src/voltronformer/model.py | 31 +++++-- 4 files changed, 149 insertions(+), 8 deletions(-) create mode 100644 src/voltronformer/infini_attention.py create mode 100644 src/voltronformer/kernels/activations.py diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index ff2cf90..e21279b 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -19,6 +19,7 @@ def teeny(): "mod_capacity_factor": 0.125, "rms_norm_eps": 0.000001, "dwa": True, + "infini_attention": True, }) @@ -40,6 +41,7 @@ def tiny(): "mod_capacity_factor": 0.125, "rms_norm_eps": 0.000001, "dwa": True, + "infini_attention": False, }) @@ -61,4 +63,5 @@ def small(): "mod_capacity_factor": 0.125, "rms_norm_eps": 0.000001, "dwa": True, + "infini_attention": False, }) diff --git a/src/voltronformer/infini_attention.py b/src/voltronformer/infini_attention.py new file mode 100644 index 0000000..b121570 --- /dev/null +++ b/src/voltronformer/infini_attention.py @@ -0,0 +1,107 @@ +import torch +from torch import nn + +from .core import Linear + + +# https://github.com/dingo-actual/infini-transformer/blob/main/infini_transformer/compressive_memory.py + +class CompressiveMemory(nn.Module): + """Implements the Compressive Transformer memory module.""" + def __init__(self, dim_input: int, dim_key: int, dim_value: int, num_heads: int, segment_len: int, update: str = "delta"): + """Initialize module. + + Args: + dim_input (int): Input dimension. + dim_key (int): Key dimension. + dim_value (int): Value dimension. + num_heads (int): Number of attention heads. + segment_len (int): Segment length (must be a factor of the input sequence length). + update (str, optional): Type of memory update rule to use ("linear" or "delta"). Defaults to "delta". + """ + super(CompressiveMemory, self).__init__() + + # Record input parameters + self.num_heads = num_heads + self.segment_len = segment_len + + self.dim_input = dim_input + self.dim_key = dim_key + self.dim_value = dim_value + + self.update = update + + # Projections for stacked SDP attention + self.proj_k = Linear(dim_input, num_heads * dim_key, bias=False) + self.proj_v = Linear(dim_input, num_heads * dim_value, bias=False) + self.proj_q = Linear(dim_input, num_heads * dim_key, bias=False) + + # Initialize betas for weighted average of dot-product and memory-based attention + self.betas = nn.Parameter(torch.randn(1, num_heads, 1, dim_value)) + + # Projection for output + self.proj_out = nn.Linear(num_heads * dim_value, dim_input, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Applies Scaled Dot-Product Attention to the input tensor. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim_input). + Returns: + torch.Tensor: Output tensor of shape (batch_size, seq_len, dim_input). + """ + batch_size, seq_len, _ = x.shape + + n_seq, rem = divmod(seq_len, self.segment_len) + + if rem != 0: + raise ValueError("Sequence length must be divisible by segment length.") + + out = [] + + # Initialize mem and normalization + # !!! Initialization was never specified in the paper, so this is an educated guess + mem = torch.zeros(1, self.num_heads, self.dim_key, self.dim_value) + z = torch.zeros(1, self.num_heads, self.dim_value, 1).repeat(batch_size, 1, 1, 1) + + for ix in range(n_seq): + ix_lo = ix * self.segment_len + ix_hi = ix_lo + self.segment_len + + # Extract segment from input + x_seg = x[:, ix_lo:ix_hi, :] + + # Project the input tensor to get the key, value, and query tensors + k = self.proj_k(x_seg).unsqueeze(1).view((batch_size, self.num_heads, self.segment_len, self.dim_key)) + v = self.proj_v(x_seg).unsqueeze(1).view((batch_size, self.num_heads, self.segment_len, self.dim_value)) + q = self.proj_q(x_seg).unsqueeze(1).view((batch_size, self.num_heads, self.segment_len, self.dim_key)) + + # Pre-calculate sigma(q) for updating memory and calculating attention + sigma_q = (nn.functional.elu(q) + 1.0) # shape: (batch_size, num_heads, segment_len, dim_key) + + # Apply mem update + if self.update == "linear": + mem = mem + sigma_q.transpose(-2, -1) @ v + elif self.update == "delta": + sigma_k = nn.functional.elu(k) + 1.0 + mem = mem + sigma_q.transpose(-2, -1) @ (v - (sigma_k @ mem) / (sigma_k @ z)) + + # Apply normalization term update + z = z + (nn.functional.elu(k) + 1.0).sum(dim=-2, keepdim=True) + + # Apply SDP attention + att_dot = nn.functional.softmax(q @ k.transpose(-2, -1) / torch.sqrt(torch.tensor(self.dim_key)), dim=-1) @ v + + # Calculate normalized linear attention + att_mem = (sigma_q @ mem) / (sigma_q @ z) # shape: (batch_size, num_heads, segment_len, dim_value) + + # Calculate weighted average of dot-product and memory-based attention + att = nn.functional.sigmoid(self.betas) * att_mem + (1 - nn.functional.sigmoid(self.betas)) * att_dot + att = att.view((batch_size, self.segment_len, self.num_heads * self.dim_value)) + + # Append output to buffer + out.append(self.proj_out(att)) + + # Return concatenated full sequence from buffer + return torch.concat(out, dim=1) diff --git a/src/voltronformer/kernels/activations.py b/src/voltronformer/kernels/activations.py new file mode 100644 index 0000000..65b1213 --- /dev/null +++ b/src/voltronformer/kernels/activations.py @@ -0,0 +1,16 @@ +import triton +import triton.language as tl + + +@triton.jit +def silu(x): + """ + SiLU activation function, also known as Swish-1. + """ + return x * tl.sigmoid(x) + + +@triton.jit +def silu_grad(x): + sigmoid_x = tl.sigmoid(x) + return sigmoid_x * (1.0 + x * (1.0 - sigmoid_x)) diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 89906f9..8c32532 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -11,6 +11,8 @@ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from .mod import MoDBlock +from .infini_attention import CompressiveMemory as InfiniAttention + try: from apex.normalization import FusedRMSNorm as RMSNorm except ImportError: @@ -63,13 +65,6 @@ def forward(self, x): return x -class Attention(nn.Module): - def __init__(self, hidden_size: int, num_heads: int): - super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_heads - - def mlp(dim: int, hidden_dim: int) -> FeedForward: """ Build the MLP layer associated with the Llama model. @@ -274,7 +269,27 @@ class TransformerDecoderBlock(nn.Module): def __init__(self, config): super().__init__() - self.attn = LlamaBitMGQA(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, bias=False, layer_norm=False) + if config.infini_attention: + SEGMENT_LEN = 2048 + self.attn = InfiniAttention( + config.hidden_size, + config.num_key_value_heads, + config.num_key_value_heads, + config.num_attention_heads, + SEGMENT_LEN, + update="delta", + ) + else: + self.attn = LlamaBitMGQA( + config.hidden_size, + config.num_attention_heads, + config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + bias=False, + layer_norm=False, + ) + self.mlp = mlp(config.hidden_size, config.intermediate_size) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) From 665530b40fe825fc87d6bdf5338b54d9b3540825 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 15 Apr 2024 11:09:20 -0400 Subject: [PATCH 37/42] handle position_id if passed, throw it on the floor --- src/voltronformer/infini_attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/voltronformer/infini_attention.py b/src/voltronformer/infini_attention.py index b121570..7ee5e62 100644 --- a/src/voltronformer/infini_attention.py +++ b/src/voltronformer/infini_attention.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from torch import nn @@ -42,7 +44,7 @@ def __init__(self, dim_input: int, dim_key: int, dim_value: int, num_heads: int, # Projection for output self.proj_out = nn.Linear(num_heads * dim_value, dim_input, bias=False) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: """ Applies Scaled Dot-Product Attention to the input tensor. From 2439688aa375a81b5a4f885a9b09f12d0095c502 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 15 Apr 2024 11:20:54 -0400 Subject: [PATCH 38/42] match infini-attention segment len to mixture of depth --- src/voltronformer/config.py | 1 + src/voltronformer/model.py | 3 +-- train.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index e21279b..afc906b 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -20,6 +20,7 @@ def teeny(): "rms_norm_eps": 0.000001, "dwa": True, "infini_attention": True, + "ia_segment_len": 256, # accounts for max_position_embeddings * mod_capacity_factor }) diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 8c32532..4166a38 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -270,13 +270,12 @@ class TransformerDecoderBlock(nn.Module): def __init__(self, config): super().__init__() if config.infini_attention: - SEGMENT_LEN = 2048 self.attn = InfiniAttention( config.hidden_size, config.num_key_value_heads, config.num_key_value_heads, config.num_attention_heads, - SEGMENT_LEN, + config.ia_segment_len, update="delta", ) else: diff --git a/train.py b/train.py index 8be8c19..57a9610 100644 --- a/train.py +++ b/train.py @@ -195,7 +195,7 @@ def main(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - config = small() + config = teeny() dispatch_batches = True ds, text_field = get_ds(dispatch_batches) @@ -253,7 +253,6 @@ def tokenize_function(examples, field="text", tokenizer=None): if args.bf16: accelerator_kwargs["mixed_precision"] = "bf16" accelerator = Accelerator( - mixed_precision="bf16", log_with=["wandb", "tensorboard"], project_dir="./runs", gradient_accumulation_steps=args.gradient_accumulation_steps, From a77e6ae93d9f02e1c0aa2484b9d90f5eeb728eff Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 15 Apr 2024 11:34:01 -0400 Subject: [PATCH 39/42] misc fixes for integrations --- src/voltronformer/infini_attention.py | 8 ++++---- src/voltronformer/mod.py | 2 +- src/voltronformer/model.py | 10 +++++++--- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/voltronformer/infini_attention.py b/src/voltronformer/infini_attention.py index 7ee5e62..76b0474 100644 --- a/src/voltronformer/infini_attention.py +++ b/src/voltronformer/infini_attention.py @@ -42,7 +42,7 @@ def __init__(self, dim_input: int, dim_key: int, dim_value: int, num_heads: int, self.betas = nn.Parameter(torch.randn(1, num_heads, 1, dim_value)) # Projection for output - self.proj_out = nn.Linear(num_heads * dim_value, dim_input, bias=False) + self.proj_out = Linear(num_heads * dim_value, dim_input, bias=False) def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: """ @@ -58,14 +58,14 @@ def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) n_seq, rem = divmod(seq_len, self.segment_len) if rem != 0: - raise ValueError("Sequence length must be divisible by segment length.") + raise ValueError(f"Sequence length must be divisible by segment length. seq_len: {seq_len} segment_len: {self.segment_len}") out = [] # Initialize mem and normalization # !!! Initialization was never specified in the paper, so this is an educated guess - mem = torch.zeros(1, self.num_heads, self.dim_key, self.dim_value) - z = torch.zeros(1, self.num_heads, self.dim_value, 1).repeat(batch_size, 1, 1, 1) + mem = torch.zeros(1, self.num_heads, self.dim_key, self.dim_value).to(device=x.device) + z = torch.zeros(1, self.num_heads, self.dim_value, 1).repeat(batch_size, 1, 1, 1).to(device=x.device) for ix in range(n_seq): ix_lo = ix * self.segment_len diff --git a/src/voltronformer/mod.py b/src/voltronformer/mod.py index 53e1a97..2c88d45 100644 --- a/src/voltronformer/mod.py +++ b/src/voltronformer/mod.py @@ -9,7 +9,7 @@ class MoDBlock(nn.Module): def __init__(self, config, block_class): super().__init__() self.config = config - self.block = block_class(config) + self.block = block_class(config, is_mod_wrapped=True) self.router = nn.Linear(config.hidden_size, 1, bias=False) self.capacity_factor = config.mod_capacity_factor self.top_k =int(self.capacity_factor * config.max_position_embeddings) diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 4166a38..28e2b41 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -267,15 +267,19 @@ def forward( class TransformerDecoderBlock(nn.Module): - def __init__(self, config): + def __init__(self, config, is_mod_wrapped=False): super().__init__() if config.infini_attention: + if is_mod_wrapped: + seq_len = min(config.ia_segment_len, config.max_position_embeddings * config.mod_capacity_factor) + else: + seq_len = config.ia_segment_len self.attn = InfiniAttention( config.hidden_size, config.num_key_value_heads, config.num_key_value_heads, config.num_attention_heads, - config.ia_segment_len, + seq_len, update="delta", ) else: @@ -327,7 +331,7 @@ def __init__(self, config): self.h = nn.ModuleList([ ( MoDBlock(config, TransformerDecoderBlock) - if i % self.config.mod_every == 0 + if self.config.mod_every and i % self.config.mod_every == 0 else TransformerDecoderBlock(config) ) for i in range(config.num_hidden_layers) From 6a49def054a54716903d5c524ee93ee96dff7f1a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 15 Apr 2024 17:14:36 -0400 Subject: [PATCH 40/42] make infini-attention work --- src/voltronformer/bitlinear/official.py | 8 +++++--- src/voltronformer/config.py | 2 +- src/voltronformer/infini_attention.py | 2 +- src/voltronformer/model.py | 4 ++-- train.py | 3 ++- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/voltronformer/bitlinear/official.py b/src/voltronformer/bitlinear/official.py index 1217900..03323ed 100644 --- a/src/voltronformer/bitlinear/official.py +++ b/src/voltronformer/bitlinear/official.py @@ -1,4 +1,4 @@ -import math +"""Official implementation of the quantized bit-linear""" import torch from torch import nn @@ -23,15 +23,17 @@ def activation_quant(x, num_bits=8): class BitLinear(nn.Linear): def __init__(self, - *kargs, + *args, + eps=1e-5, weight_bits=1, input_bits=8, **kwargs ): - super(BitLinear, self).__init__(*kargs, **kwargs) + super(BitLinear, self).__init__(*args, **kwargs) """ RMSNorm is placed outside BitLinear """ + self.eps = eps self.weight_bits = weight_bits self.input_bits = input_bits diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index afc906b..5344c6d 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -20,7 +20,7 @@ def teeny(): "rms_norm_eps": 0.000001, "dwa": True, "infini_attention": True, - "ia_segment_len": 256, # accounts for max_position_embeddings * mod_capacity_factor + "ia_segment_len": 512 # needs to be evenly divide max_position_embeddings * mod_capacity_factor }) diff --git a/src/voltronformer/infini_attention.py b/src/voltronformer/infini_attention.py index 76b0474..abae3ba 100644 --- a/src/voltronformer/infini_attention.py +++ b/src/voltronformer/infini_attention.py @@ -10,7 +10,7 @@ class CompressiveMemory(nn.Module): """Implements the Compressive Transformer memory module.""" - def __init__(self, dim_input: int, dim_key: int, dim_value: int, num_heads: int, segment_len: int, update: str = "delta"): + def __init__(self, dim_input: int, dim_key: int, dim_value: int, num_heads: int, segment_len: int, update: str = "linear"): """Initialize module. Args: diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index 28e2b41..df374fb 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -271,7 +271,7 @@ def __init__(self, config, is_mod_wrapped=False): super().__init__() if config.infini_attention: if is_mod_wrapped: - seq_len = min(config.ia_segment_len, config.max_position_embeddings * config.mod_capacity_factor) + seq_len = min(config.ia_segment_len, int(config.max_position_embeddings * config.mod_capacity_factor)) else: seq_len = config.ia_segment_len self.attn = InfiniAttention( @@ -280,7 +280,7 @@ def __init__(self, config, is_mod_wrapped=False): config.num_key_value_heads, config.num_attention_heads, seq_len, - update="delta", + update="linear", ) else: self.attn = LlamaBitMGQA( diff --git a/train.py b/train.py index 57a9610..56a7bc3 100644 --- a/train.py +++ b/train.py @@ -282,7 +282,8 @@ def tokenize_function(examples, field="text", tokenizer=None): module.to(torch.float32) elif "_proj" in name: # module.to(torch.uint8) - module.weight.to(torch.float8_e4m3fn) + # module.weight.to(torch.float8_e4m3fn) + pass trainer = Trainer(model, args, dataloader, accelerator, activation_checkpointing=True) if state.is_main_process: From 555087a2a05e79377231686007b43b0284e5866c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 15 Apr 2024 19:29:33 -0400 Subject: [PATCH 41/42] fix dimensions passed to infini-attention --- src/voltronformer/config.py | 4 +++- src/voltronformer/model.py | 4 ++-- train.py | 29 ++++++++++++++++++++++++----- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py index 5344c6d..c951e5d 100644 --- a/src/voltronformer/config.py +++ b/src/voltronformer/config.py @@ -20,7 +20,9 @@ def teeny(): "rms_norm_eps": 0.000001, "dwa": True, "infini_attention": True, - "ia_segment_len": 512 # needs to be evenly divide max_position_embeddings * mod_capacity_factor + "ia_segment_len": 512, # needs to be evenly divide max_position_embeddings * mod_capacity_factor + "ia_dim_key": 64, + "ia_dim_value": 64, }) diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py index df374fb..f10c75b 100644 --- a/src/voltronformer/model.py +++ b/src/voltronformer/model.py @@ -276,8 +276,8 @@ def __init__(self, config, is_mod_wrapped=False): seq_len = config.ia_segment_len self.attn = InfiniAttention( config.hidden_size, - config.num_key_value_heads, - config.num_key_value_heads, + config.ia_dim_key, + config.ia_dim_value, config.num_attention_heads, seq_len, update="linear", diff --git a/train.py b/train.py index 56a7bc3..b1ea605 100644 --- a/train.py +++ b/train.py @@ -40,6 +40,7 @@ class TrainingArguments: max_grad_norm: Optional[float] = 1.0 n_gpu: Optional[int] = None bf16: Optional[bool] = False + adam_betas: tuple = (0.9, 0.95) class Trainer: @@ -57,8 +58,15 @@ def __init__(self, model, args, dataloader, accelerator, activation_checkpointin self.device = device_get_cuda() self.global_step = 0 self.rank = device_get_local_rank() + if accelerator.is_main_process: - wandb.init(project="voltronformer") + report_config = self.args.__dict__ + report_config["model_num_parameters"] = self.model_num_parameters + + wandb.init( + project="voltronformer", + config=report_config, + ) self.accelerator = accelerator @property @@ -71,7 +79,7 @@ def model_num_parameters(self): return all_param def build_optimizer_and_scheduler(self): - self.optimizer = AdamWScheduleFree(self._model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay, warmup_steps=self.args.weight_decay, eps=self.args.adam_epsilon) + self.optimizer = AdamWScheduleFree(self._model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay, warmup_steps=self.args.weight_decay, eps=self.args.adam_epsilon, betas=self.args.adam_betas) self.lr_scheduler = None def _loss_fn(self, logits, labels): @@ -140,6 +148,8 @@ def train_loop(self): tr_loss_scalar = tr_loss.mean().item() tr_loss -= tr_loss + perplexity = torch.exp(tr_loss_scalar) + self.global_step += 1 if self.global_step % self.args.log_steps == 0: @@ -148,12 +158,18 @@ def train_loop(self): pbar.set_description(f"Loss: {tr_loss_scalar} Global Step: {self.global_step} gradient_norm: {grad_norm}") print(f"Loss: {tr_loss_scalar} Global Step: {self.global_step} gradient_norm: {grad_norm}") try: - wandb.log({"training_loss": tr_loss_scalar, "gradient_norm": grad_norm, "global_step": self.global_step}, step=self.global_step) + wandb.log({ + "training_loss": tr_loss_scalar, + "gradient_norm": grad_norm, + "global_step": self.global_step, + "perplexity": perplexity, + }, step=self.global_step) except: pass self.accelerator.log({"training_loss": tr_loss_scalar, "gradient_norm": grad_norm}, step=self.global_step) if self.global_step % self.args.save_steps == 0: self.save_checkpoint() + # TODO Freeze DWA after ~5K-10K steps self.accelerator.end_training() @@ -247,8 +263,11 @@ def tokenize_function(examples, field="text", tokenizer=None): # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") - # ddp kwargs with find_unused_parameters needed for triton rmsnorm + kwargs_handlers =[] + # ddp kwargs with find_unused_parameters needed for RMSNormTriton # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + # kwargs_handlers.append(ddp_kwargs) + accelerator_kwargs = {} if args.bf16: accelerator_kwargs["mixed_precision"] = "bf16" @@ -257,7 +276,7 @@ def tokenize_function(examples, field="text", tokenizer=None): project_dir="./runs", gradient_accumulation_steps=args.gradient_accumulation_steps, dispatch_batches=dispatch_batches, - # kwargs_handlers=[ddp_kwargs], + kwargs_handlers=kwargs_handlers, **accelerator_kwargs, ) From b117faa8bb3b2e359e37c1d7e807760e764baaa9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 15 Apr 2024 21:14:52 -0400 Subject: [PATCH 42/42] fix perplexity calculation and add quick instructions --- README.md | 12 ++++++++++++ train.py | 5 +++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c15da48..5414082 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,18 @@ Assembling the best SotA AI techniques into a unified model https://twitter.com/winglian/status/1778675583817326842 +## Easy Start + +Use the official Nividia/Pytorch docker container @ `nvcr.io/nvidia/pytorch:24.03-py3` + +```bash +git clone https://github.com/OpenAccess-AI-Collective/voltronformers.git +cd voltronformers +pip install -e . +accelerate launch train.py +``` + + # References ## BitNet diff --git a/train.py b/train.py index b1ea605..fd75690 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ import functools +import math import os import tempfile from dataclasses import dataclass @@ -28,6 +29,7 @@ class TrainingArguments: gradient_accumulation_steps: int = 1 max_steps_per_epoch: Optional[int] = None log_steps: int = 1 + adam_betas: tuple = (0.9, 0.95) adam_epsilon: Optional[float] = 1e-8 output_dir: Optional[str] = None weight_decay: float = 0.0 @@ -40,7 +42,6 @@ class TrainingArguments: max_grad_norm: Optional[float] = 1.0 n_gpu: Optional[int] = None bf16: Optional[bool] = False - adam_betas: tuple = (0.9, 0.95) class Trainer: @@ -148,7 +149,7 @@ def train_loop(self): tr_loss_scalar = tr_loss.mean().item() tr_loss -= tr_loss - perplexity = torch.exp(tr_loss_scalar) + perplexity = math.exp(tr_loss_scalar) self.global_step += 1