diff --git a/README.md b/README.md index c15da48..5414082 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,18 @@ Assembling the best SotA AI techniques into a unified model https://twitter.com/winglian/status/1778675583817326842 +## Easy Start + +Use the official Nividia/Pytorch docker container @ `nvcr.io/nvidia/pytorch:24.03-py3` + +```bash +git clone https://github.com/OpenAccess-AI-Collective/voltronformers.git +cd voltronformers +pip install -e . +accelerate launch train.py +``` + + # References ## BitNet diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d3be3fd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "voltronformers" +dynamic = ["version"] +requires-python = ">= 3.10" +dependencies = [ + "accelerate", + "addict", + "bitnet", + "schedulefree", + "bitsandbytes", + "datasets", + "einops", + "flash-attn", + "mosaicml-streaming", + "numba", + "numpy", + "safetensors", + "wandb", + "tqdm", + "transformers==4.39.3", + "zstandard", + "denseformer @ git+https://github.com/epfml/DenseFormer.git@main", +] +maintainers = [ + {name="Wing Lian", email="wing.lian@gmail.com"}, +] +description = "voltronformers: Assembling the best SotA AI techniques into a unified model" + +[project.optional-dependencies] +dev = [ + "tox", + "pre-commit", + "black", + "mypy", + "pytest", +] diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/voltronformer/__init__.py b/src/voltronformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/voltronformer/bitlinear/__init__.py b/src/voltronformer/bitlinear/__init__.py new file mode 100644 index 0000000..81e90cf --- /dev/null +++ b/src/voltronformer/bitlinear/__init__.py @@ -0,0 +1,3 @@ +# from .cg123 import BitLinear +from .official import BitLinear +from .attention import scaled_dot_product_gqa \ No newline at end of file diff --git a/src/voltronformer/bitlinear/attention.py b/src/voltronformer/bitlinear/attention.py new file mode 100644 index 0000000..e8c5883 --- /dev/null +++ b/src/voltronformer/bitlinear/attention.py @@ -0,0 +1,143 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import einsum, rearrange +from torch import Tensor + + +def scaled_dot_product_gqa( + query: Tensor, + key: Tensor, + value: Tensor, + dropout: float = 0.0, + scale: Optional[float] = None, + mask: Optional[Tensor] = None, + is_causal: Optional[bool] = None, + need_weights: bool = False, + average_attn_weights: bool = False, + force_grouped: bool = False, +): + """Scaled dot product attention with support for grouped queries. + + Einstein notation: + - b: batch size + - n / s: sequence length + - h: number of heads + - g: number of groups + - d: dimension of query/key/value + + Args: + query: Query tensor of shape (b, n, h, d) + key: Key tensor of shape (b, s, h, d) + value: Value tensor of shape (b, s, h, d) + dropout: Dropout probability (default: 0.0) + scale: Scale factor for query (default: d_query ** 0.5) + mask: Mask tensor of shape (b, n, s) or (b, s). If 'ndim == 2', the mask is + applied to all 'n' rows of the attention matrix. (default: None) + force_grouped: If True, apply grouped-query attention even if the number of + heads is equal for query, key, and value. (default: False) + + Returns: + 2-tuple of: + - Attention output with shape (b, n, h, d) + - (Optional) Attention weights with shape (b, h, n, s). Only returned if + 'need_weights' is True. + """ + if (mask is not None) and (is_causal is not None): + raise ValueError( + "Only one of 'mask' and 'is_causal' should be provided, but got both." + ) + elif not query.ndim == key.ndim == value.ndim == 4: + raise ValueError( + f"Expected query, key, and value to be 4-dimensional, but got shapes " + f"{query.shape}, {key.shape}, and {value.shape}." + ) + + # Move sequence length dimension to axis 2. + # This makes the attention operations below *much* faster. + query = rearrange(query, "b n h d -> b h n d") + key = rearrange(key, "b s h d -> b h s d") + value = rearrange(value, "b s h d -> b h s d") + + bq, hq, nq, dq = query.shape + bk, hk, nk, dk = key.shape + bv, hv, nv, dv = value.shape + if not (bq == bk == bv and dq == dk == dv): + raise ValueError( + "Expected query, key, and value to have the same batch size (dim=0) and " + f"embedding dimension (dim=3), but got query: {query.shape}, " + f"key: {key.shape}, and value: {value.shape}." + ) + elif (hk != hv) or (nk != nv): + raise ValueError( + "Expected key and value to have the same size in dimensions 1 and 2, but " + f"got key: {key.shape} and value: {value.shape}." + ) + elif hq % hk != 0: + raise ValueError( + "Expected query heads to be a multiple of key/value heads, but got " + f"query: {query.shape} and key/value: {key.shape}." + ) + + if scale is None: + scale = query.size(-1) ** 0.5 + query = query / scale + + num_head_groups = hq // hk + if num_head_groups > 1 or force_grouped: + # Separate the query heads into 'num_head_groups' chunks, and fold the group + # dimension into the batch dimension. This allows us to compute the attention + # for each head in parallel, then sum over all of the groups at the end. + query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) + similarity = einsum(query, key, "b g h n d, b h s d -> b h n s") + else: + # If the number of query/key heads is equal, we can skip grouping the queries, + # and just use the standard sdot product attention. + similarity = einsum(query, key, "b h n d, b h s d -> b h n s") + + if is_causal: + # Mask out the upper triangular portion of the attention matrix. This prevents + # the model from attending to tokens in the future. + mask = torch.ones( + (bq, nq, nk), + device=query.device, + dtype=torch.bool, + ).tril_() + + if mask is not None: + # Expand mask to match the shape of the attention matrix. + # If mask is 2D, assume that it is applied to the key/value sequence dimension. + # Else if mask is 3D, assume that it is applied to the query/key/value sequence + # dimension for all attention heads. + # + # Users could also provide a 4D mask, which is applied to the query/key/value + # sequence dimension for each attention head (though I don't have a particular + # use case in mind for that). + if mask.ndim == 2: + mask = rearrange(mask, "b s -> b () () s") + elif mask.ndim == 3: + mask = rearrange(mask, "b n s -> b () n s") + # Mask similarity values by setting them to negative infinity. This guarantees + # that they will not contribute to the softmax computation below. + similarity.masked_fill_(~mask, torch.finfo(similarity.dtype).min) + + attention = F.softmax(similarity / scale, dim=-1, dtype=torch.float32).to(dtype=query.dtype) + if dropout > 0.0: + attention = F.dropout(attention, p=dropout) + + # Apply attention matrix to the value Tensor. + out = einsum(attention, value, "b h n s, b h s d -> b h n d") + # Move head dimension back to axis 2 + out = rearrange(out, "b h n d -> b n h d") + + attn_weights: Optional[Tensor] = None + if need_weights: + # Move the sequence dimensions back to positions 1, 2. Move the head dimension + # to position 3. This more closely matches the return shape of the attention + # output: (b, n, h, d). + attn_weights = rearrange(attention, "b h n s -> b n s h") + if average_attn_weights: + attn_weights = attn_weights.mean(dim=1) + + return out, attn_weights \ No newline at end of file diff --git a/src/voltronformer/bitlinear/cg123.py b/src/voltronformer/bitlinear/cg123.py new file mode 100644 index 0000000..85e9009 --- /dev/null +++ b/src/voltronformer/bitlinear/cg123.py @@ -0,0 +1,172 @@ +""" +Implementation of the BitLinear layer described in the papers: + +1. "BitNet: Scaling 1-bit Transformers for Large Language Models" +2. "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits" + +References: +- https://arxiv.org/abs/2310.11453 +- https://arxiv.org/abs/2402.17764 +""" + +#!/usr/bin/env python3 +# Copyright (C) 2024 Charles O. Goddard + +import math +from typing import NamedTuple, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _ste(x: torch.Tensor, x0: torch.Tensor) -> torch.Tensor: + """Straight-through estimator.""" + return x0 + (x - x0).detach() + + +@torch.compile() +def _quantize( + x: Optional[torch.Tensor], is_input: bool, num_groups: int, eps: float +) -> Tuple[torch.Tensor, torch.Tensor]: + if x is None: + return None, None + + x0 = x + if is_input: + # split last dimension into num_groups + x = x.view(list(x.shape[:-1]) + [num_groups, -1]) + scale_factor = x.abs().max(dim=-1, keepdim=True).values + else: + # first dimension is output features, so split that + x = x.view([num_groups, -1] + list(x.shape[1:])) + scale_factor = x.abs().mean(dim=list(range(1, len(x.shape))), keepdim=True) + + x_scaled = x / (scale_factor + eps) + if is_input: + x_q = (x_scaled * 127).clamp(-127, 127).to(torch.int8) + else: + x_q = x_scaled.round().clamp(-1, 1).to(torch.int8) + + # adjust scale_factor to match shape returned for input + scale_factor = scale_factor.view(1, 1, num_groups, 1) + + return _ste(x_q, x_scaled).view_as(x0), scale_factor + + +class QuantizedWeights(NamedTuple): + """Quantized weight and optional bias tensor for BitLinear.""" + + w_q: torch.Tensor + bias_q: Optional[torch.Tensor] + beta: torch.Tensor + + +@torch.compile() +def _quantize_weights( + weight: torch.Tensor, + bias: Optional[torch.Tensor], + num_groups: int, + eps: float, +) -> QuantizedWeights: + w_q, beta = _quantize(weight, is_input=False, num_groups=num_groups, eps=eps) + bias_q, _ = _quantize(bias, is_input=True, num_groups=num_groups, eps=eps) + # bias assumes the scale factor of weights + return QuantizedWeights(w_q=w_q, bias_q=bias_q, beta=beta) + + +def _pack_ternary(x: torch.Tensor) -> torch.Tensor: + """Pack ternary float tensor into int8 tensor. Uses ~1.6 bits per element.""" + + x_packed = torch.empty( + x.shape[:-1] + (math.ceil(x.shape[-1] / 5)), dtype=torch.int8 + ) + for i in range(0, x.shape[-1], 5): + chunk = x[..., i : i + 5].to(torch.int8).view(x.shape[:-1] + (1, 5)) + # -1 -> 0, 0 -> 1, 1 -> 2 + chunk = chunk + 1 + # store as base-3 number + chunk = ( + chunk + * torch.tensor([1, 3, 9, 27, 81], device=chunk.device, dtype=chunk.dtype) + ).sum(dim=-1) + x_packed[..., i // 5] = chunk + return x_packed + + +class BitLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + *args, + preserve_scale: bool = False, + num_groups: int = 1, + eps: float = 1e-7, + bias: bool = False, + **kwargs, + ): + if num_groups < 1: + raise ValueError("num_groups must be >= 1") + if num_groups > 1 and out_features % num_groups != 0: + raise ValueError("out_features must be divisible by num_groups") + + super().__init__(in_features, out_features, *args, bias=bias, **kwargs) + self.input_norm = nn.LayerNorm(self.in_features, elementwise_affine=False) + self.preserve_scale = preserve_scale + self.num_groups = num_groups + self.eps = eps + + @torch.compile() + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.input_norm(x) + x_q, gamma = _quantize( + x, is_input=True, num_groups=self.num_groups, eps=self.eps + ) + w_q, bias_q, beta = _quantize_weights( + self.weight, self.bias, num_groups=self.num_groups, eps=self.eps + ) + + y = F.linear(x_q, w_q, bias_q) + y = y.to(x.dtype) / 127 + if self.preserve_scale: + y_grouped = y.view(list(y.shape[:-1]) + [self.num_groups, -1]) + y = (y_grouped * gamma * beta).reshape_as(y) + + return y + + +class BitConv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + *args, + preserve_scale: bool = False, + eps: float = 1e-7, + bias: bool = False, + **kwargs, + ): + super().__init__( + in_channels, out_channels, kernel_size, *args, bias=bias, **kwargs + ) + self.input_norm = nn.GroupNorm(1, self.in_channels, affine=False) + self.preserve_scale = preserve_scale + self.eps = eps + + @torch.compile() + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.input_norm(x) + x_q, gamma = _quantize(x, is_input=True, num_groups=1, eps=self.eps) + w_q, bias_q, beta = _quantize_weights( + self.weight, self.bias, num_groups=1, eps=self.eps + ) + + y = F.conv2d(x_q, w_q, bias_q, self.stride, self.padding, self.dilation) + y = y.to(x.dtype) / 127 + if self.preserve_scale: + y_grouped = y.view(list(y.shape[:-1]) + [1, -1]) + y = (y_grouped * gamma * beta).reshape_as(y) + + return y \ No newline at end of file diff --git a/src/voltronformer/bitlinear/official.py b/src/voltronformer/bitlinear/official.py new file mode 100644 index 0000000..03323ed --- /dev/null +++ b/src/voltronformer/bitlinear/official.py @@ -0,0 +1,52 @@ +"""Official implementation of the quantized bit-linear""" +import torch +from torch import nn + + +def weight_quant(weight, dtype=torch.float16): + weight = weight.bfloat16() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) / s + return result.to(dtype=dtype) + + +def activation_quant(x, num_bits=8): + dtype = x.dtype + x = x.float() + Qn = -2 ** (num_bits - 1) + Qp = 2 ** (num_bits - 1) - 1 + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) / s + return result.type(dtype) + + +class BitLinear(nn.Linear): + + def __init__(self, + *args, + eps=1e-5, + weight_bits=1, + input_bits=8, + **kwargs + ): + super(BitLinear, self).__init__(*args, **kwargs) + """ + RMSNorm is placed outside BitLinear + """ + self.eps = eps + self.weight_bits = weight_bits + self.input_bits = input_bits + + def forward(self, input): + quant_input = input + (activation_quant(input, self.input_bits) - input).detach() + # Convert the uint8 weights to the input data type + fp_weight = self.weight.to(input.dtype) + + # seems silly, but this is done for the cuda graph's sake + quant_weight = fp_weight + (weight_quant(self.weight, dtype=input.dtype) - fp_weight).detach() + + out = nn.functional.linear(quant_input, quant_weight) + if not self.bias is None: + out += self.bias.view(1, -1).expand_as(out) + + return out \ No newline at end of file diff --git a/src/voltronformer/config.py b/src/voltronformer/config.py new file mode 100644 index 0000000..c951e5d --- /dev/null +++ b/src/voltronformer/config.py @@ -0,0 +1,70 @@ +from src.voltronformer.utils import DictDefault + + +def teeny(): + """50M parameters""" + return DictDefault({ + "hidden_size": 512, + "intermediate_size": 1408, + "rope_theta": 10_000, + "max_position_embeddings": 2048, + "num_attention_heads": 16, + "num_key_value_heads": 4, + "num_hidden_layers": 12, + "vocab_size": 32000, + "dwa_dilation": 4, + "dwa_period": 5, + "pad_token_id": 2, + "mod_every": 2, + "mod_capacity_factor": 0.125, + "rms_norm_eps": 0.000001, + "dwa": True, + "infini_attention": True, + "ia_segment_len": 512, # needs to be evenly divide max_position_embeddings * mod_capacity_factor + "ia_dim_key": 64, + "ia_dim_value": 64, + }) + + +def tiny(): + """300M parameters""" + return DictDefault({ + "hidden_size": 1024, + "intermediate_size": 2816, + "rope_theta": 10_000, + "max_position_embeddings": 4096, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 24, + "vocab_size": 32000, + "dwa_dilation": 4, + "dwa_period": 5, + "pad_token_id": 2, + "mod_every": 2, + "mod_capacity_factor": 0.125, + "rms_norm_eps": 0.000001, + "dwa": True, + "infini_attention": False, + }) + + +def small(): + """1.1B parameters""" + return DictDefault({ + "hidden_size": 2048, + "intermediate_size": 5632, + "rope_theta": 10_000, + "max_position_embeddings": 8192, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 24, + "vocab_size": 32000, + "dwa_dilation": 4, + "dwa_period": 5, + "pad_token_id": 2, + "mod_every": 2, + "mod_capacity_factor": 0.125, + "rms_norm_eps": 0.000001, + "dwa": True, + "infini_attention": False, + }) diff --git a/src/voltronformer/core.py b/src/voltronformer/core.py new file mode 100644 index 0000000..da162da --- /dev/null +++ b/src/voltronformer/core.py @@ -0,0 +1,4 @@ +try: + from .bitlinear import BitLinear as Linear +except ImportError: + from torch.nn import Linear diff --git a/src/voltronformer/infini_attention.py b/src/voltronformer/infini_attention.py new file mode 100644 index 0000000..abae3ba --- /dev/null +++ b/src/voltronformer/infini_attention.py @@ -0,0 +1,109 @@ +from typing import Optional + +import torch +from torch import nn + +from .core import Linear + + +# https://github.com/dingo-actual/infini-transformer/blob/main/infini_transformer/compressive_memory.py + +class CompressiveMemory(nn.Module): + """Implements the Compressive Transformer memory module.""" + def __init__(self, dim_input: int, dim_key: int, dim_value: int, num_heads: int, segment_len: int, update: str = "linear"): + """Initialize module. + + Args: + dim_input (int): Input dimension. + dim_key (int): Key dimension. + dim_value (int): Value dimension. + num_heads (int): Number of attention heads. + segment_len (int): Segment length (must be a factor of the input sequence length). + update (str, optional): Type of memory update rule to use ("linear" or "delta"). Defaults to "delta". + """ + super(CompressiveMemory, self).__init__() + + # Record input parameters + self.num_heads = num_heads + self.segment_len = segment_len + + self.dim_input = dim_input + self.dim_key = dim_key + self.dim_value = dim_value + + self.update = update + + # Projections for stacked SDP attention + self.proj_k = Linear(dim_input, num_heads * dim_key, bias=False) + self.proj_v = Linear(dim_input, num_heads * dim_value, bias=False) + self.proj_q = Linear(dim_input, num_heads * dim_key, bias=False) + + # Initialize betas for weighted average of dot-product and memory-based attention + self.betas = nn.Parameter(torch.randn(1, num_heads, 1, dim_value)) + + # Projection for output + self.proj_out = Linear(num_heads * dim_value, dim_input, bias=False) + + def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Applies Scaled Dot-Product Attention to the input tensor. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim_input). + Returns: + torch.Tensor: Output tensor of shape (batch_size, seq_len, dim_input). + """ + batch_size, seq_len, _ = x.shape + + n_seq, rem = divmod(seq_len, self.segment_len) + + if rem != 0: + raise ValueError(f"Sequence length must be divisible by segment length. seq_len: {seq_len} segment_len: {self.segment_len}") + + out = [] + + # Initialize mem and normalization + # !!! Initialization was never specified in the paper, so this is an educated guess + mem = torch.zeros(1, self.num_heads, self.dim_key, self.dim_value).to(device=x.device) + z = torch.zeros(1, self.num_heads, self.dim_value, 1).repeat(batch_size, 1, 1, 1).to(device=x.device) + + for ix in range(n_seq): + ix_lo = ix * self.segment_len + ix_hi = ix_lo + self.segment_len + + # Extract segment from input + x_seg = x[:, ix_lo:ix_hi, :] + + # Project the input tensor to get the key, value, and query tensors + k = self.proj_k(x_seg).unsqueeze(1).view((batch_size, self.num_heads, self.segment_len, self.dim_key)) + v = self.proj_v(x_seg).unsqueeze(1).view((batch_size, self.num_heads, self.segment_len, self.dim_value)) + q = self.proj_q(x_seg).unsqueeze(1).view((batch_size, self.num_heads, self.segment_len, self.dim_key)) + + # Pre-calculate sigma(q) for updating memory and calculating attention + sigma_q = (nn.functional.elu(q) + 1.0) # shape: (batch_size, num_heads, segment_len, dim_key) + + # Apply mem update + if self.update == "linear": + mem = mem + sigma_q.transpose(-2, -1) @ v + elif self.update == "delta": + sigma_k = nn.functional.elu(k) + 1.0 + mem = mem + sigma_q.transpose(-2, -1) @ (v - (sigma_k @ mem) / (sigma_k @ z)) + + # Apply normalization term update + z = z + (nn.functional.elu(k) + 1.0).sum(dim=-2, keepdim=True) + + # Apply SDP attention + att_dot = nn.functional.softmax(q @ k.transpose(-2, -1) / torch.sqrt(torch.tensor(self.dim_key)), dim=-1) @ v + + # Calculate normalized linear attention + att_mem = (sigma_q @ mem) / (sigma_q @ z) # shape: (batch_size, num_heads, segment_len, dim_value) + + # Calculate weighted average of dot-product and memory-based attention + att = nn.functional.sigmoid(self.betas) * att_mem + (1 - nn.functional.sigmoid(self.betas)) * att_dot + att = att.view((batch_size, self.segment_len, self.num_heads * self.dim_value)) + + # Append output to buffer + out.append(self.proj_out(att)) + + # Return concatenated full sequence from buffer + return torch.concat(out, dim=1) diff --git a/src/voltronformer/kernels/activations.py b/src/voltronformer/kernels/activations.py new file mode 100644 index 0000000..65b1213 --- /dev/null +++ b/src/voltronformer/kernels/activations.py @@ -0,0 +1,16 @@ +import triton +import triton.language as tl + + +@triton.jit +def silu(x): + """ + SiLU activation function, also known as Swish-1. + """ + return x * tl.sigmoid(x) + + +@triton.jit +def silu_grad(x): + sigmoid_x = tl.sigmoid(x) + return sigmoid_x * (1.0 + x * (1.0 - sigmoid_x)) diff --git a/src/voltronformer/kernels/rms_norm.py b/src/voltronformer/kernels/rms_norm.py new file mode 100644 index 0000000..e7c2951 --- /dev/null +++ b/src/voltronformer/kernels/rms_norm.py @@ -0,0 +1,91 @@ +import torch +import triton +import triton.language as tl +from torch import nn + + +# from https://ai.lefebvre-sarrut.eu/2023/07/20/deep-dive-into-kernel-fusion-accelerating-inference-in-llama-v2/#openai-triton-rewriting +@triton.jit +def rmsnorm_triton(x_ptr, rms_w_ptr, output_ptr, + stride_x_batch, stride_x_m, stride_x_k, + stride_rms_w, + stride_out_batch, stride_out_m, stride_out_k, + N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr): + pid_batch = tl.program_id(0) + pid_m = tl.program_id(1) + + offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m + block_N = tl.arange(0, BLOCK_N_SIZE) + var = tl.zeros((BLOCK_N_SIZE,), tl.float32) + + # first loop over input tensor to compute the root mean of the square + for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE): + offs_n = block_n_start_idx + block_N + x_ptr_mask = offs_n < N_SIZE + # recompute address at each iteration + x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0) + var += tl.math.pow(x.to(tl.float32), 2) + + # we keep this reduction operation outside the loop for perf reasons + var = tl.sum(var, axis=0) / N_SIZE + rstd = tl.math.rsqrt(var + eps) + + # apply the normalization and multiply by RMS weights + for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE): + offs_n = block_n_start_idx + block_N + x_ptr_mask = offs_n < N_SIZE + rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask) + + x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0).to(tl.float32) + x_hat = x * rstd + out = x_hat * rms_w + out_off = pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k + tl.store(output_ptr + out_off, out, mask=x_ptr_mask) + + +class RMSNorm(nn.Module): + """copied from torchtune""" + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fp32 = x.float() + x_normed = ( + x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) + ).type_as(x) + return x_normed * self.scale + + +"""not ready for use yet. 2X Faster, but not accurate""" +class RMSNormTriton(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Define the grid and block dimensions + N_SIZE = x.shape[-1] + BLOCK_N_SIZE = 512 # Adjust this value based on your requirements + + # Allocate output tensor + output = torch.empty_like(x) + + # Define the strides for input, scale, and output tensors + stride_x_batch, stride_x_m, stride_x_k = x.stride() + stride_rms_w = self.scale.stride(0) + stride_out_batch, stride_out_m, stride_out_k = output.stride() + + # Launch the Triton kernel + grid = lambda meta: (x.shape[0], x.shape[1]) + rmsnorm_triton[grid]( + x, self.scale, output, + stride_x_batch, stride_x_m, stride_x_k, + stride_rms_w, + stride_out_batch, stride_out_m, stride_out_k, + N_SIZE, self.eps, BLOCK_N_SIZE + ) + + return output \ No newline at end of file diff --git a/src/voltronformer/mod.py b/src/voltronformer/mod.py new file mode 100644 index 0000000..2c88d45 --- /dev/null +++ b/src/voltronformer/mod.py @@ -0,0 +1,50 @@ +""" +from https://github.com/epfml/llm-baselines/compare/main...mixture_of_depth +""" +import torch +from torch import nn + + +class MoDBlock(nn.Module): + def __init__(self, config, block_class): + super().__init__() + self.config = config + self.block = block_class(config, is_mod_wrapped=True) + self.router = nn.Linear(config.hidden_size, 1, bias=False) + self.capacity_factor = config.mod_capacity_factor + self.top_k =int(self.capacity_factor * config.max_position_embeddings) + + def forward(self, x, position_ids, **kwargs): + # [batch_size, sequence_length, n_embd] + B, T, C = x.shape + # inference time optimization: sequence length can + # be smaller than seq len during training + top_k = min(self.top_k, int(self.capacity_factor * T)) + + """STEP 1: get logits and top_k tokens""" + # [batch_size, sequence_length, 1] + router_logits = self.router(x) + # weights and selected tokens: [batch_size, top_k, 1] + weights, selected_tokens = torch.topk(router_logits, top_k, dim=1, sorted=False) + # IMPORTANT: need to sort indices to keep causal order for those tokens that + # are processed in a block + selected_tokens, index = torch.sort(selected_tokens, dim=1) + weights = torch.gather(weights, dim=1, index=index) + + """STEP 2: expand indices to process batches with _reduced_ seqlen""" + # We need to expand indices' dimensions from + # [batch_size, top_k, 1] to [batch_size, top_k, n_embd] for gathering + indices_expanded = selected_tokens.expand(-1, -1, C) + # [batch_size, top_k, n_embd] + top_k_tokens = torch.gather(x, 1, indices_expanded) + top_k_tokens_processed = self.block(top_k_tokens, position_ids, **kwargs) + + """STEP 3: combine results""" + x = torch.scatter_add( + x, + dim=1, + index=indices_expanded, + src=top_k_tokens_processed * weights, + ) + + return x \ No newline at end of file diff --git a/src/voltronformer/model.py b/src/voltronformer/model.py new file mode 100644 index 0000000..f10c75b --- /dev/null +++ b/src/voltronformer/model.py @@ -0,0 +1,390 @@ +import functools +from typing import Optional, Callable, Tuple + +import torch +from .bitlinear import BitLinear, scaled_dot_product_gqa + +from functorch.einops import rearrange +from torch import nn, Tensor +from denseformer import DWAModules +from torch.utils.checkpoint import checkpoint +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + +from .mod import MoDBlock +from .infini_attention import CompressiveMemory as InfiniAttention + +try: + from apex.normalization import FusedRMSNorm as RMSNorm +except ImportError: + from .kernels.rms_norm import RMSNorm + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class FeedForward(nn.Module): + def __init__(self, gate_proj: BitLinear, down_proj: BitLinear, up_proj: BitLinear): + super().__init__() + self.gate_proj = gate_proj + self.down_proj = down_proj + self.up_proj = up_proj + self.act_fn = nn.SiLU() + + def forward(self, x): + x = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + # FIXME layernorm??? + x = self.down_proj(x) + return x + + +def mlp(dim: int, hidden_dim: int) -> FeedForward: + """ + Build the MLP layer associated with the Llama model. + """ + gate_proj = BitLinear(dim, hidden_dim, bias=False) + down_proj = BitLinear(hidden_dim, dim, bias=False) + up_proj = BitLinear(dim, hidden_dim, bias=False) + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) + + +# copied from https://github.com/kyegomez/BitNet/blob/main/bitnet/bit_attention.py +class LlamaBitMGQA(nn.Module): + """Multi-head grouped query attention (GQA) layer. + + Reference: + "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" + https://arxiv.org/pdf/2305.13245v1.pdf + + GQA is a variant of multihead attention (MHA) that uses fewer write heads + (key / value) than query heads. GQA can be viewed as a generalization of + multi-query attention (MQA), which uses a single write head. GQA and MQA give + significant speedups over standard MHA in decoder layers, with minimal loss in + accuracy. In the paper, GQA is shown to be more accurate than MQA, while still + having a significant speedup over MHA. + + NOTE: The original authors only benchmark GQA by adapting the T5 (XL or XXL) model + from MHA to GQA. As a result, they do not mention parameter initialization or + layer normalization strategies. I follow the best practices laid out in the + MAGNETO paper, which improves Transformer performance through better parameter + initialization and layer norm placement. See: + https://arxiv.org/pdf/2210.06423.pdf, Fig. 2 + """ + + def __init__( + self, + embed_dim: int, + query_heads: int = 8, + kv_heads: int = 4, + dropout: float = 0.1, + bias: bool = True, + layer_norm: bool = True, + layer_norm_eps: float = 1e-5, + gamma_init: float = 1.0, + linear_groups: int = 1, + *args, + max_position_embeddings=2048, + rope_theta=10_000, + **kwargs, + ): + super().__init__() + self.query_heads = query_heads + self.kv_heads = kv_heads + self.dropout = dropout + self.layer_norm = layer_norm + self.gamma_init = gamma_init + + if self.query_heads % self.kv_heads != 0: + raise ValueError( + f"query_heads ({query_heads}) must be divisible by " + f"kv_heads ({kv_heads})" + ) + elif (embed_dim % self.query_heads != 0) or (embed_dim % self.kv_heads != 0): + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"query_heads ({query_heads}) and kv_heads ({kv_heads})" + ) + + head_dim = embed_dim // query_heads + if not head_dim % 8 == 0: + raise ValueError( + f"head_dim (embed_dim / num_heads = {head_dim}) must be divisible by 8" + ) + if not head_dim <= 128: + raise ValueError( + f"head_dim (embed_dim / num_heads = {head_dim}) must be <= 128" + ) + + # Query projection layer is the same as in vanilla MHA. + self.q_proj = BitLinear( + embed_dim, + embed_dim, + bias=bias, + *args, + **kwargs, # device=device, dtype=dtype + ) + # Key/value projection layers have a smaller output dimension, so that + # the we have fewer key/value attention heads after reshaping. + kv_embed_dim = embed_dim // query_heads * kv_heads + self.k_proj = BitLinear( + embed_dim, + kv_embed_dim, + bias=bias, + *args, + **kwargs, # device=device, dtype=dtype + ) + self.v_proj = BitLinear( + embed_dim, + kv_embed_dim, + bias=bias, + *args, + **kwargs, # device=device, dtype=dtype + ) + self.norm: Optional[nn.LayerNorm] = None + if layer_norm: + self.norm = nn.LayerNorm( + kv_embed_dim, + eps=layer_norm_eps, # device=device, dtype=dtype + ) + # Grouped attention output will have the same embedding dimension as the + # key/value Tensors. So the output projection layer needs to accept the + # same dimension (kv_embed_dim). + self.out_proj = BitLinear( + embed_dim, + embed_dim, + bias=bias, # device=device, dtype=dtype + ) + self.rotary_emb = LlamaRotaryEmbedding(head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta) + + self._reset_parameters() + + def _reset_parameters(self): + nn.init.xavier_normal_(self.q_proj.weight) + if self.q_proj.bias is not None: + nn.init.constant_(self.q_proj.bias, 0) + nn.init.xavier_normal_(self.k_proj.weight) + if self.k_proj.bias is not None: + nn.init.constant_(self.k_proj.bias, 0) + + # NOTE: We follow the initialization strategy from MAGNETO. See: + # https://arxiv.org/pdf/2210.06423.pdf, Fig. 2 + # Gain (self.gamma_init) should be provided as a keyword argument when + # initializing the larger Transformer model, since it requires knowledge + # of the number of encoder/decoder layers in the model. + + nn.init.xavier_normal_(self.v_proj.weight, gain=self.gamma_init) + if self.v_proj.bias is not None: + nn.init.constant_(self.v_proj.bias, 0) + nn.init.xavier_normal_(self.out_proj.weight, gain=self.gamma_init) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0) + + def forward( + self, + x: Tensor, + position_ids: Optional[Tensor] = None, + need_weights: bool = False, + # attn_mask: Optional[Tensor] = None, + is_causal: bool = True, + average_attn_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + + # Input shape: (b, n, d) + q: Tensor = self.q_proj(x) + k: Tensor = self.k_proj(x) + v: Tensor = self.v_proj(x) + + # Unfold 'd' dimension into 'h' separate attention heads. + q = rearrange(q, "b n (h d) -> b h n d", h=self.query_heads) + k = rearrange(k, "b n (h d) -> b h n d", h=self.kv_heads) + v = rearrange(v, "b n (h d) -> b h n d", h=self.kv_heads) + + # Generate rotary embeddings + cos, sin = self.rotary_emb(x, position_ids) + + # Reshape cos and sin to match the shape of q and k + seq_len = q.shape[2] # Get the sequence length from q + cos = cos[:, :seq_len, :].unsqueeze(1) + sin = sin[:, :seq_len, :].unsqueeze(1) + + # Apply rotary position embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1) + + # Adjust the dimensions of q, k, and v + q = q.view(-1, *q.shape[-3:]) + k = k.view(-1, *k.shape[-3:]) + v = v.view(-1, *v.shape[-3:]) + + # Apply attention, then fold 'h' attention heads back into 'd'. + output, attn_weights = scaled_dot_product_gqa( + query=q, + key=k, + value=v, + # TODO + # mask=attn_mask, + is_causal=is_causal, + need_weights=need_weights, + average_attn_weights=average_attn_weights, + force_grouped=False, + ) + + # Re-assemble all head outputs side-by-side. + # output = output.transpose(1, 2).contiguous().view(b, n, d) + output = rearrange(output, "b n h d -> b h (n d)") + + # Linear projection on attention outputs. + output = self.out_proj(output) + + return output, attn_weights + + +class TransformerDecoderBlock(nn.Module): + + def __init__(self, config, is_mod_wrapped=False): + super().__init__() + if config.infini_attention: + if is_mod_wrapped: + seq_len = min(config.ia_segment_len, int(config.max_position_embeddings * config.mod_capacity_factor)) + else: + seq_len = config.ia_segment_len + self.attn = InfiniAttention( + config.hidden_size, + config.ia_dim_key, + config.ia_dim_value, + config.num_attention_heads, + seq_len, + update="linear", + ) + else: + self.attn = LlamaBitMGQA( + config.hidden_size, + config.num_attention_heads, + config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + bias=False, + layer_norm=False, + ) + + self.mlp = mlp(config.hidden_size, config.intermediate_size) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, x, position_ids): + residual = x + h = self.input_layernorm(x) + output = residual + self.attn(h, position_ids=position_ids)[0] + return residual + self.mlp(self.post_attention_layernorm(output)) + + +class CheckpointingMixin(nn.Module): + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": False} + + gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + + def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint): + if hasattr(self, "gradient_checkpointing"): + self._gradient_checkpointing_func = gradient_checkpointing_func + self.gradient_checkpointing = enable + + +class Transformer(CheckpointingMixin): + supports_gradient_checkpointing = True + + def __init__(self, config): + super().__init__() + self.config = config + if config.dwa: + self.dwa_modules = DWAModules(config.num_hidden_layers, config.dwa_dilation, config.dwa_period) + self.wte = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + + self.h = nn.ModuleList([ + ( + MoDBlock(config, TransformerDecoderBlock) + if self.config.mod_every and i % self.config.mod_every == 0 + else TransformerDecoderBlock(config) + ) + for i in range(config.num_hidden_layers) + ]) + self.ln_f = RMSNorm(config.hidden_size, eps=1e-6) + self.gradient_checkpointing = False + + def forward(self, x): + inputs_embeds = self.wte(x) + past_seen_tokens = 0 + position_ids = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ).unsqueeze(0) + + hidden_states = inputs_embeds + if self.config.dwa: + self.dwa_modules.init_accumulators(hidden_states) + for i, decoder_layer in enumerate(self.h): + # gradient checkpointing + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + decoder_layer, + hidden_states, + position_ids, + ) + else: + hidden_states = decoder_layer(hidden_states, position_ids) + if self.config.dwa: + hidden_states = self.dwa_modules(hidden_states, block_idx=i) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class CausalLM(nn.Module): + def __init__(self, config): + super().__init__() + self.transformer = Transformer(config) + self.vocab_size = config.vocab_size + # should this use a BitLinear layer? + self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # tie weights + self.transformer.wte.weight = self.embed_out.weight + + def forward(self, x): + x = self.transformer(x) + logits = self.embed_out(x) + + return logits.float() + + def train(self, mode: bool = True): + """ + Override the default train() to enable gradient checkpointing. + """ + if mode: + self.transformer.gradient_checkpointing_enable() + return super().train(mode) diff --git a/src/voltronformer/train/__init__.py b/src/voltronformer/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/voltronformer/train/collators.py b/src/voltronformer/train/collators.py new file mode 100644 index 0000000..24a5462 --- /dev/null +++ b/src/voltronformer/train/collators.py @@ -0,0 +1,155 @@ +""" +DataCollator to pad labels and position_ids for packed sequences +""" +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np +from transformers import PreTrainedTokenizerBase +from transformers.utils import PaddingStrategy + +IGNORE_INDEX = -100 + + +@dataclass +class DataCollatorForSeq2Seq: + """ + Data collator that will dynamically pad the inputs received, as well as the labels and position_ids + + Args: + tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): + The tokenizer used for encoding the data. + model ([`PreTrainedModel`]): + The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to + prepare the *decoder_input_ids* + + This is useful when using *label_smoothing* to avoid calculating loss twice. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + + - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single + sequence is provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + label_pad_token_id (`int`, *optional*, defaults to -100): + The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). + return_tensors (`str`): + The type of Tensor to return. Allowable values are "np", "pt" and "tf". + """ + + tokenizer: PreTrainedTokenizerBase + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + position_pad_token_id: int = 0 + return_tensors: str = "pt" + + def __call__(self, features, return_tensors=None): + labels = None + if return_tensors is None: + return_tensors = self.return_tensors + + for feature_name, pad_token_id in [ + ("labels", self.label_pad_token_id), + ("position_ids", self.position_pad_token_id), + ]: + feat = ( + [feature[feature_name] for feature in features] + if feature_name in features[0].keys() + else None + ) + labels = feat if feat and feature_name == "labels" else labels + # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the + # same length to return tensors. + if feat is not None: + max_feature_length = max(len(l) for l in feat) # noqa: E741 + if self.pad_to_multiple_of is not None: + max_feature_length = ( + (max_feature_length + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + * self.pad_to_multiple_of + ) + + padding_side = self.tokenizer.padding_side + for feature in features: + remainder = [pad_token_id] * ( + max_feature_length - len(feature[feature_name]) + ) + if isinstance(feature[feature_name], list): + feature[feature_name] = ( + feature[feature_name] + remainder + if padding_side == "right" + else remainder + feature[feature_name] + ) + elif padding_side == "right": + feature[feature_name] = np.concatenate( + [feature[feature_name], remainder] + ).astype(np.int64) + else: + feature[feature_name] = np.concatenate( + [remainder, feature[feature_name]] + ).astype(np.int64) + + features = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=return_tensors, + ) + + # prepare decoder_input_ids + if ( + labels is not None + and self.model is not None + and hasattr(self.model, "prepare_decoder_input_ids_from_labels") + ): + decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels( + labels=features["labels"] + ) + features["decoder_input_ids"] = decoder_input_ids + + return features + + +@dataclass +class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): + """ + Collator for multipack specific to the using the BatchSampler + """ + + def __init__(self, *args, multipack_attn=True, **kwargs): + super().__init__(*args, **kwargs) + self.multipack_attn = multipack_attn + + def __call__(self, features, return_tensors=None): + chunked_data = {} + for feature in features.keys(): + if feature == "length": + continue + if feature == "attention_mask": + if self.multipack_attn: + arrays = [ + (i + 1) * np.array(item[feature]) + for i, item in enumerate(features[feature]) + if feature in item + ] + else: + arrays = [(1) * np.array(item) for item in features[feature]] + chunked_data[feature] = np.concatenate(arrays) + else: + arrays = [np.array(item) for item in features[feature]] + chunked_data[feature] = np.concatenate(arrays) + features = [chunked_data] + return super().__call__(features, return_tensors=return_tensors) diff --git a/src/voltronformer/train/data.py b/src/voltronformer/train/data.py new file mode 100644 index 0000000..e80ff0f --- /dev/null +++ b/src/voltronformer/train/data.py @@ -0,0 +1,146 @@ +import functools +from collections import defaultdict +from queue import Queue +from threading import Thread +from typing import Callable, Dict, List + +import numpy as np +from datasets import Dataset +from torch.utils.data import RandomSampler, DataLoader + +from src.voltronformer.train.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from src.voltronformer.train.samplers import MultipackBatchSampler + + +def get_dataset_lengths(dataset: Dataset): + input_ids = dataset.data.column("input_ids") + lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) + return lengths + + +def wrap_pretraining_dataset( + dataset, + tokenizer, + ds_wrapper_fn, + max_tokens=2048, + batch_size=1, + buffer_size=10_000, +): + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding=True, + pad_to_multiple_of=max_tokens, + multipack_attn=False, + ) + encode = functools.partial( + encode_packed_pretraining, + collate_fn, + ds_wrapper_fn, + max_seq_length=max_tokens, + batch_size=batch_size, + ) + + # remove all the existing columns after mapping since they end up having + # a different length than the encoded/tokenized column + # this is empty during streaming/pretraining + remove_columns = [] + if dataset.features is None: + for first_row in dataset: + remove_columns = first_row.keys() + break + else: + remove_columns = dataset.features.keys() + + dataset = dataset.map( + encode, + batched=True, + batch_size=buffer_size, + remove_columns=remove_columns, + ) + return dataset + + +def drop_long_seq(sample, sequence_len=2048): + return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0 + + +def encode_packed_pretraining( + collate_fn, + ds_wrapper: Callable, + examples: Dict[str, List], + max_seq_length: int = 2048, + batch_size: int = 4, +) -> Dict[str, List]: + # pylint: disable=duplicate-code + # tokenize all the examples + # rows get split with stride (overlap) + train_dataset = Dataset.from_dict(examples) + train_dataset = train_dataset.map( + ds_wrapper, + batched=True, + remove_columns = list(train_dataset.features.keys()) + ) + + drop_long = functools.partial(drop_long_seq, sequence_len=max_seq_length) + train_dataset = train_dataset.filter( + drop_long, + num_proc=8, + ) + + sampler = MultipackBatchSampler( + RandomSampler(train_dataset), + batch_size=batch_size, + drop_last=True, + batch_max_len=max_seq_length, + lengths=get_dataset_lengths(train_dataset), + ) + + chunked_data = defaultdict(list) + + for batch in sampler: + for data in batch: + features = train_dataset[data] + if "num_truncated_tokens" in features: + del features["num_truncated_tokens"] + if "num_truncated_tokens" in features: + del features["num_truncated_tokens"] + if "overflow_to_sample_mapping" in features: + del features["overflow_to_sample_mapping"] + if "labels" not in features: + features["labels"] = features["input_ids"].copy() + collated_features = collate_fn(features) + + for feature in features.keys(): + if feature == "length": + continue + chunked_data[feature].append(collated_features[feature].squeeze(0)) + + return chunked_data + + +class QueuedDataLoader(DataLoader): + def __init__(self, *args, queue_len=1_000, **kwargs): + kwargs["persistent_workers"] = True + super().__init__(*args, **kwargs) + self.data_queue = Queue(maxsize=queue_len) + self.prefetch_thread = Thread(target=self.prefetch_data) + self.prefetch_thread.daemon = True + self.prefetch_thread.start() + + def prefetch_data(self): + for data in super().__iter__(): + self.data_queue.put(data) + self.data_queue.put(None) + + def __iter__(self): + return super().__iter__() + + def __next__(self): + if hasattr(self, 'data_queue'): + data = self.data_queue.get() + if data is None: + raise StopIteration + return data + else: + return self._iterator.__next__() diff --git a/src/voltronformer/train/samplers.py b/src/voltronformer/train/samplers.py new file mode 100644 index 0000000..be92f7f --- /dev/null +++ b/src/voltronformer/train/samplers.py @@ -0,0 +1,202 @@ +# pylint: skip-file +""" +Multipack Batch Sampler +""" +import logging +import math +import os +from typing import Any, Iterable, List, Union + +import numba +import numpy as np +from torch.utils.data import BatchSampler, Sampler + +LOG = logging.getLogger("multipack") + + +@numba.njit +def ffd_check(a: np.ndarray, c: int, n: int): + # First-fit-decreasing bin packing + # Check if a[] could fit in n bins with capacity c + # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing + + a = np.sort(a)[::-1] + bins = np.full((n,), c, dtype=a.dtype) + for size in a: + not_found = True + for idx in range(n): + if bins[idx] >= size: + bins[idx] -= size + not_found = False + break + + if not_found: + return False + + return True + + +@numba.njit +def ffd_with_result(a: np.ndarray, c: int, start_index: int): + # First-fit-decreasing bin packing (with result return) + + indices = np.argsort(a)[::-1] + a = a[indices] + + bins: List[Any] = [] + bins_result: List[Any] = [] + for a_id, size in enumerate(a): + add_new = True + for idx in range(len(bins)): + if bins[idx] >= size: + bins[idx] -= size + bins_result[idx].append(indices[a_id] + start_index) + add_new = False + break + + if add_new: + bins.append(c - size) + bins_result.append([indices[a_id] + start_index]) + + return bins_result + + +@numba.njit +def allocate( + lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int +): + # Dynamic batch allocator, similar to Multifit + # https://en.wikipedia.org/wiki/Multifit_algorithm + # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) + + s = 0 + start_index = 0 + result = [] + + while True: + # binary search [l, r) + left = 1 + right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") + + while right - left > 1: + mid = (left + right) // 2 + if ffd_check(lengths[start_index : start_index + mid], c, n): + left = mid + else: + right = mid + + # use length l + batch = ffd_with_result( + lengths[start_index : start_index + left], c, start_index + ) + assert len(batch) <= n + if len(batch) < n: + break + + start_index += left + s = lengths_cumsum[start_index - 1] + + # add local rank + result.append(batch[rank]) + + return result, s, len(result) * c * n + + +class MultipackBatchSampler(BatchSampler): + """ + Batch Sampler class for multipack + """ + + def __init__( + self, + sampler: Union[Sampler[int], Iterable[int]], + batch_size: int, + drop_last: bool, + batch_max_len: int, + lengths: np.ndarray, + packing_efficiency_estimate: float = 1.0, + ): + super().__init__(sampler, batch_size, drop_last) + self.batch_size = batch_size + self.batch_max_len = batch_max_len + self.lengths: np.ndarray = lengths + self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 + + assert isinstance(self.lengths, np.ndarray) + + self.epoch = 0 + + # statistics + self.eff_total_used = 0 + self.eff_total_slots = 0 + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def generate_batches(self, set_stats=False): + indices = [idx for idx in self.sampler] + + lengths = self.lengths[indices] + lengths_cumsum = np.cumsum(lengths) + + batches, total_used, total_slots = allocate( + lengths=lengths, + lengths_cumsum=lengths_cumsum, + rank=0, + c=self.batch_max_len, + n=1, + ) + + batches = [ + [ + [indices[b_idx] for b_idx in batch] + for batch in batches[i : i + self.batch_size] + ] + for i in range(0, len(batches), self.batch_size) + ] + + # statistics + if set_stats: + self.eff_total_used += total_used + self.eff_total_slots += total_slots + + return batches + + def __iter__(self): + batches = self.generate_batches(set_stats=True) + return iter(batches) + + def num_batches(self): + batches = self.generate_batches(set_stats=True) + return len(batches) + + def efficiency(self): + return self.eff_total_used / self.eff_total_slots + + def __len__(self): + self.num_batches() + return self._len_est() + + def _len_est(self): + world_size = int(os.getenv("WORLD_SIZE", "1")) + lengths_sum = np.sum(self.lengths) + lengths_sum_per_device = lengths_sum // world_size + LOG.info( + f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + f"total_num_tokens per device: {lengths_sum_per_device}" + ) + + # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler + return max( + 0, + ( + world_size + * math.floor( + 0.99 + * lengths_sum_per_device + / self.packing_efficiency_estimate + // (self.batch_max_len * self.batch_size) + ) + - 1 + ), + ) diff --git a/src/voltronformer/utils.py b/src/voltronformer/utils.py new file mode 100644 index 0000000..c1bc178 --- /dev/null +++ b/src/voltronformer/utils.py @@ -0,0 +1,72 @@ +import math +import os +from typing import Optional, Set, Type + +import torch + +from addict import Dict +from torch import nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, +) +from torch.distributed.fsdp.wrap import ModuleWrapPolicy + + +def set_activation_checkpointing( + model: nn.Module, auto_wrap_policy: Optional[Set[Type[nn.Module]]] = None, **kwargs +) -> None: + """Utility to setup activation checkpointing and wrap the model for checkpointing. + + Args: + model (nn.Module): Model to setup activation checkpointing. + auto_wrap_policy (Optional[Set[nn.Module]]): Policy to wrap module. + **kwargs: additional arguments to pass to torch.distributed activation checkpointing. + """ + wrap_policy = ModuleWrapPolicy(auto_wrap_policy or set()) + apply_activation_checkpointing(model, auto_wrap_policy=wrap_policy, **kwargs) + + +def device_get_local_rank(): + """ + Returns the local rank of the current device. + """ + local_rank = int(os.getenv("LOCAL_RANK", 0)) + return local_rank + + +def device_get_cuda(): + rank = device_get_local_rank() + device = torch.device(type="cuda", index=rank) + torch.cuda.set_device(device) + return device + + +def get_cosine_schedule_with_min_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + min_lr_ratio: float, +): + # Warm up + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + # Cosine learning rate decay + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + scaling = 0.5 * (1.0 + math.cos(math.pi * progress)) + return (1 - min_lr_ratio) * scaling + min_lr_ratio + + +class DictDefault(Dict): + """ + A Dict that returns None instead of returning empty Dict for missing keys. + """ + + def __missing__(self, key): + return None + + def __or__(self, other): + return DictDefault(super().__ror__(other)) diff --git a/train.py b/train.py new file mode 100644 index 0000000..fd75690 --- /dev/null +++ b/train.py @@ -0,0 +1,315 @@ +import functools +import math +import os +import tempfile +from dataclasses import dataclass +from typing import Optional + +import torch +import wandb +from accelerate import Accelerator, PartialState, DistributedDataParallelKwargs +from datasets import load_dataset +from safetensors.torch import save_model +from schedulefree import AdamWScheduleFree +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoTokenizer, DataCollatorForSeq2Seq +from transformers.trainer_pt_utils import distributed_concat + +from src.voltronformer.config import teeny, tiny, small +from src.voltronformer.model import CausalLM, TransformerDecoderBlock +from src.voltronformer.train.data import wrap_pretraining_dataset, QueuedDataLoader +from src.voltronformer.utils import device_get_cuda, device_get_local_rank, set_activation_checkpointing + + +state = PartialState() + +@dataclass +class TrainingArguments: + gradient_accumulation_steps: int = 1 + max_steps_per_epoch: Optional[int] = None + log_steps: int = 1 + adam_betas: tuple = (0.9, 0.95) + adam_epsilon: Optional[float] = 1e-8 + output_dir: Optional[str] = None + weight_decay: float = 0.0 + warmup_steps: Optional[int] = 1000 + per_gpu_train_batch_size: Optional[int] = 1 + save_steps: Optional[int] = 5_000 + max_sequence_length: Optional[int] = 8192 + learning_rate: float = 5e-5 + vocab_size: Optional[int] = None + max_grad_norm: Optional[float] = 1.0 + n_gpu: Optional[int] = None + bf16: Optional[bool] = False + + +class Trainer: + def __init__(self, model, args, dataloader, accelerator, activation_checkpointing=True): + self.args = args + self._model = model + if activation_checkpointing: + set_activation_checkpointing( + model, auto_wrap_policy={TransformerDecoderBlock} + ) + self.build_optimizer_and_scheduler() + + self._model, self.dataloader, self.optimizer = accelerator.prepare(self._model, dataloader, self.optimizer) + + self.device = device_get_cuda() + self.global_step = 0 + self.rank = device_get_local_rank() + + if accelerator.is_main_process: + report_config = self.args.__dict__ + report_config["model_num_parameters"] = self.model_num_parameters + + wandb.init( + project="voltronformer", + config=report_config, + ) + self.accelerator = accelerator + + @property + def model_num_parameters(self): + all_param = 0 + for _, param in self._model.named_parameters(): + num_params = param.numel() + all_param += num_params + + return all_param + + def build_optimizer_and_scheduler(self): + self.optimizer = AdamWScheduleFree(self._model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay, warmup_steps=self.args.weight_decay, eps=self.args.adam_epsilon, betas=self.args.adam_betas) + self.lr_scheduler = None + + def _loss_fn(self, logits, labels): + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape(-1, logits.shape[-1]), labels.reshape(-1)) + return loss + + def save_checkpoint(self): + output_dir = self.args.output_dir if self.args.output_dir is not None else "." + save_model(self._model, os.path.join(output_dir, f"model_{self.global_step}.safetensors")) + torch.save( + self._model.state_dict(), + os.path.join(output_dir, f"model_{self.global_step}.pt"), + ) + + def train(self): + self._model.train() + try: + self.optimizer.train() + except: + pass + self.train_loop() + + def train_loop(self): + tr_loss = torch.tensor(0.0).to(self.device) + total_batched_samples = 0 + for idx, batch in enumerate(pbar := tqdm(self.dataloader, disable=not (self.rank == 0))): + total_batched_samples += 1 + is_grad_accum_step = total_batched_samples % self.args.gradient_accumulation_steps == 0 + with self.accelerator.accumulate(self._model): + input_ids = batch["input_ids"].to(self.device) + if "labels" in batch.keys(): + labels = batch["labels"].to(self.device) + else: + labels = input_ids.clone() + + logits = self._model(input_ids) + + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, self.args.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + + # Compute loss + loss = self._loss_fn(shift_logits, shift_labels) + if self.args.n_gpu > 1: + loss = loss.mean() + self.accelerator.backward(loss) + mini_step_loss = loss.detach() / self.args.gradient_accumulation_steps + tr_loss += mini_step_loss + + if is_grad_accum_step: + grad_norm = self.accelerator.clip_grad_norm_(self._model.parameters(), self.args.max_grad_norm) + + self.optimizer.step() + if self.lr_scheduler: + self.lr_scheduler.step() + self._model.zero_grad() + + if self.accelerator.num_processes > 1: + tr_loss_scalar = distributed_concat(tr_loss).mean().item() + else: + tr_loss_scalar = tr_loss.mean().item() + tr_loss -= tr_loss + + perplexity = math.exp(tr_loss_scalar) + + self.global_step += 1 + + if self.global_step % self.args.log_steps == 0: + grad_norm = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm + if self.rank == 0: + pbar.set_description(f"Loss: {tr_loss_scalar} Global Step: {self.global_step} gradient_norm: {grad_norm}") + print(f"Loss: {tr_loss_scalar} Global Step: {self.global_step} gradient_norm: {grad_norm}") + try: + wandb.log({ + "training_loss": tr_loss_scalar, + "gradient_norm": grad_norm, + "global_step": self.global_step, + "perplexity": perplexity, + }, step=self.global_step) + except: + pass + self.accelerator.log({"training_loss": tr_loss_scalar, "gradient_norm": grad_norm}, step=self.global_step) + if self.global_step % self.args.save_steps == 0: + self.save_checkpoint() + # TODO Freeze DWA after ~5K-10K steps + + self.accelerator.end_training() + + +def get_redpajama_v1(): + return load_dataset("togethercomputer/RedPajama-Data-1T", "common_crawl", split="train", streaming=True), "text" + +def get_redpajama_v2(): + return load_dataset("togethercomputer/RedPajama-Data-V2", + name="default", + partition="head_middle", + snapshots=["2023-14"], + languages=["en"], + split="train", + trust_remote_code=True, + streaming=True, + ), "raw_content" + + +def get_ds(dispatch_batches): + """ + this is a janky workaround so it doesn't connect to the dataset server unnecessarily + when using dispatch_batches + """ + if state.is_main_process or not dispatch_batches: + return get_redpajama_v2() + else: + with tempfile.NamedTemporaryFile(mode="w+", delete=True) as f: + f.write("text\n") + f.write("lorem ipsum dolor sit amet\n") + # f.writelines(["text", "lorem ipsum dolor sit amet"]) + f.seek(0) + return load_dataset("csv", data_files={"train": f.name}, split="train"), "text" + + # load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True) + + +def main(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + config = teeny() + dispatch_batches = True + + ds, text_field = get_ds(dispatch_batches) + + args = TrainingArguments( + gradient_accumulation_steps=8, + max_steps_per_epoch=None, + log_steps=1, + adam_epsilon=0.00001, + output_dir="./out", + weight_decay=0.1, + warmup_steps=1000, + per_gpu_train_batch_size=10, + save_steps=1000, + max_sequence_length=config.max_position_embeddings, + learning_rate=1e-4, + vocab_size=config.vocab_size, + n_gpu=state.num_processes, + bf16=True, + ) + os.makedirs(args.output_dir, exist_ok=True) + + model = CausalLM(config) + # model = model.to(device_get_cuda()) + # tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base") + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + if not tokenizer.pad_token_id: + tokenizer.pad_token_id = tokenizer.eos_token_id + + def tokenize_function(examples, field="text", tokenizer=None): + outputs = tokenizer(examples[field], truncation=True, max_length=config.max_position_embeddings) + return outputs + + with state.main_process_first(): + ds_wrapper_partial = functools.partial( + tokenize_function, + tokenizer=tokenizer, + field=text_field, + ) + + train_dataset = wrap_pretraining_dataset( + ds, + tokenizer, + ds_wrapper_partial, + max_tokens=args.max_sequence_length, + batch_size=args.per_gpu_train_batch_size, + buffer_size=40_000, + ) + # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 + train_dataset = train_dataset.with_format("torch") + + kwargs_handlers =[] + # ddp kwargs with find_unused_parameters needed for RMSNormTriton + # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + # kwargs_handlers.append(ddp_kwargs) + + accelerator_kwargs = {} + if args.bf16: + accelerator_kwargs["mixed_precision"] = "bf16" + accelerator = Accelerator( + log_with=["wandb", "tensorboard"], + project_dir="./runs", + gradient_accumulation_steps=args.gradient_accumulation_steps, + dispatch_batches=dispatch_batches, + kwargs_handlers=kwargs_handlers, + **accelerator_kwargs, + ) + + dataloader_params = dict( + batch_size=args.per_gpu_train_batch_size, + num_workers=1, + pin_memory=True, + prefetch_factor=2_000, + drop_last=True, + collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, max_length=True), + ) + dataloader = DataLoader(train_dataset, **dataloader_params) + + ### float32 casting for improved accuracy + if args.bf16: + model = model.to(dtype=torch.bfloat16) + for name, module in model.named_modules(): + if "layernorm" in name or name == "ln_f": + module.to(torch.float32) + elif any(m in name for m in ["wte", "embed_out"]): + if hasattr(module, "weight"): + module.to(torch.float32) + elif "_proj" in name: + # module.to(torch.uint8) + # module.weight.to(torch.float8_e4m3fn) + pass + + trainer = Trainer(model, args, dataloader, accelerator, activation_checkpointing=True) + if state.is_main_process: + print(f"Total number of parameters: {trainer.model_num_parameters:_}") + trainer.train() + + +if __name__ == "__main__": + main() \ No newline at end of file