diff --git a/README.md b/README.md index adb59e1..a59f84c 100755 --- a/README.md +++ b/README.md @@ -133,6 +133,18 @@ parser.add_argument('--compile', action='store_true') # if true then model is co parser.add_argument('--rmsnorm_eps', default=1e-5, type=float) # used by the llama model parser.add_argument('--multiple_of', default=256, type=int) # used by the llama model make SwiGLU hidden layer size multiple of large power of 2 parser.add_argument('--n_kv_head', default=None, type=int) # for Adam-mini +parser.add_argument('--moe', action='store_true') +parser.add_argument('--moe_routing', default='standard_gating', type=str, choices=['standard_gating', 'expert_choice'],) +parser.add_argument('--moe_num_experts', default=8, type=int) +parser.add_argument('--capacity_factor', default=2.0, type=float) # only used for expert choice routing +parser.add_argument('--moe_num_shared_experts', default=0, type=int) # deepseek routing, experts that are always active +parser.add_argument('--moe_router_loss', default='load_balancing_z_loss', type=str, choices=['entropy', 'load_balancing_only', 'load_balancing_z_loss'],) +parser.add_argument('--moe_num_experts_per_tok', default=2, type=int) +parser.add_argument('--moe_entropy_loss_factor', default=0.01, type=float) +parser.add_argument('--moe_aux_loss_factor', default=0.1, type=float) +parser.add_argument('--moe_z_loss_factor', default=0.01, type=float) +parser.add_argument('--moe_softmax_order', type=str, default='topk_softmax', choices=['softmax_topk', 'topk_softmax'],) +parser.add_argument('--plot_router_logits', action='store_true') # Checkpointing parser.add_argument('--results_base_folder', default='./exps', type=str) parser.add_argument('--permanent_ckpt_interval', default=0, type=int) diff --git a/src/config/base.py b/src/config/base.py index b7e77d2..dcde7cd 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -260,5 +260,36 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--bias", default=False, type=bool) parser.add_argument("--compile", action="store_true") parser.add_argument("--mlp_dim_exp_factor", default=1.0, type=float) + parser.add_argument("--moe", action="store_true") + parser.add_argument( + "--moe_routing", + default="standard_gating", + type=str, + choices=["standard_gating", "expert_choice"], + ) + parser.add_argument("--moe_num_experts", default=8, type=int) + parser.add_argument( # only used for expert choice routing + "--capacity_factor", default=2.0, type=float + ) + parser.add_argument( # deepseek routing, experts that are always active + "--moe_num_shared_experts", default=0, type=int + ) + parser.add_argument( + "--moe_router_loss", + default="load_balancing_z_loss", + type=str, + choices=["entropy", "load_balancing_only", "load_balancing_z_loss"], + ) + parser.add_argument("--moe_num_experts_per_tok", default=2, type=int) + parser.add_argument("--moe_entropy_loss_factor", default=0.01, type=float) + parser.add_argument("--moe_aux_loss_factor", default=0.1, type=float) + parser.add_argument("--moe_z_loss_factor", default=0.01, type=float) + parser.add_argument( + "--moe_softmax_order", + type=str, + default="topk_softmax", + choices=["softmax_topk", "topk_softmax"], + ) + parser.add_argument("--plot_router_logits", action="store_true") return parser.parse_args(args, namespace) diff --git a/src/main.py b/src/main.py index 410fb8e..b38c29d 100755 --- a/src/main.py +++ b/src/main.py @@ -735,6 +735,7 @@ def get_exp_name( "device", "adema_beta3_warmup", "adema_alpha_warmup", + "plot_router_logits", ], ): # Get the default values @@ -747,6 +748,8 @@ def get_exp_name( for key in key_args: if hasattr(args, key): value = getattr(args, key) + if key == "model" and hasattr(args, "moe") and args.moe: + value = f"moe_{value}" prefix_parts.append(f"{key}-{value}") prefix = "_".join(prefix_parts) diff --git a/src/models/base.py b/src/models/base.py index 3f2dba3..ee7b7f5 100755 --- a/src/models/base.py +++ b/src/models/base.py @@ -14,6 +14,9 @@ import torch.nn as nn from torch.nn import functional as F +from models.moe import (ExpertChoiceMoE, MoE, entropy_reg, load_balancing_loss, + router_z_loss) + class LayerNorm(nn.Module): """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" @@ -113,7 +116,7 @@ def forward(self, x): x = self.activation(x) x = self.c_proj(x) x = self.dropout(x) - return x + return x, {} class Block(nn.Module): @@ -124,20 +127,32 @@ def __init__(self, config): self.parallel = config.parallel_block if not self.parallel: self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) - self.mlp = MLP(config) + if config.moe: + if config.moe_routing == "standard_gating": + self.mlp = MoE(config, MLP) + elif config.moe_routing == "expert_choice": + self.mlp = ExpertChoiceMoE(config, MLP) + elif config.moe_routing == "soft_moe": + self.mlp = SoftMoE(config, MLP) + elif config.moe_routing == "tree": + self.mlp = TreeRouter(config, MLP) + else: + raise ValueError(f"Unknown routing: {config.routing}") + else: + self.mlp = MLP(config) def forward(self, x, *args, **kwargs): if self.parallel: # from GPT-J 6B https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L299 x_ln = self.ln_1(x, *args, **kwargs) x_attn = self.attn(x_ln) - x_ffn = self.mlp(x_ln) + x_ffn, logits_and_experts = self.mlp(x_ln) x = x + x_attn + x_ffn else: x = x + self.attn(self.ln_1(x, *args, **kwargs)) - x_ = self.mlp(self.ln_2(x, *args, **kwargs)) + x_, logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs)) x = x + x_ - return x + return x, logits_and_experts class GPTBase(nn.Module): @@ -177,6 +192,37 @@ def __init__(self, config): mean=0.0, std=self.config.init_std / math.sqrt(2 * config.n_layer), ) + if pn.endswith("router.weight"): + # special scaled init to moe router? + with torch.no_grad(): + dim = 1 if config.moe_routing == "standard_gating" else 0 + std = p.std() + p.div_(p.sum(dim=dim, keepdim=True)) + p.mul_(std / p.std()) + + def get_router_losses(self, logits, selected_experts, eval=False): + # logits: (b * seq_len, n_experts) + # selected_experts: (b * seq_len, topk) + if eval: # eval mode, compute all losses + return { + "moe_entropy_loss": entropy_reg(logits), + "moe_aux_loss": load_balancing_loss(logits, selected_experts), + "moe_z_loss": router_z_loss(logits), + } + if self.config.moe_router_loss == "entropy": + return { + "moe_entropy_loss": entropy_reg(logits), + } + elif self.config.moe_router_loss == "load_balancing_only": + return { + "moe_aux_loss": load_balancing_loss(logits, selected_experts), + } + elif self.config.moe_router_loss == "load_balancing_z_loss": + return { + "moe_aux_loss": load_balancing_loss(logits, selected_experts), + "moe_z_loss": router_z_loss(logits), + } + return {} def get_num_params(self, non_embedding=True): """ @@ -198,7 +244,7 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) - def forward(self, idx, targets=None, get_logits=False): + def forward(self, idx, targets=None, get_logits=False, moe=False): device = idx.device b, t = idx.size() assert ( @@ -214,17 +260,42 @@ def forward(self, idx, targets=None, get_logits=False): ) # position embeddings of shape (1, t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) + # router logits is a list for each layer's routing, each of shape (b * seq_len, n_experts) + router_logits = [] + # experts is a list for each layer's selected experts, shape (b * seq_len, topk) + experts = [] + # forward pass through all the transformer blocks for block in self.transformer.h: - x = block(x) + x, logits_and_experts = block(x) + if len(logits_and_experts) > 0: + router_logits.append(logits_and_experts["router_logits"]) + experts.append(logits_and_experts["selected_experts"]) x = self.transformer.ln_f(x) + # aux_losses is a dict with keys for different auxiliary losses + aux_losses = {} + if targets is not None: # if we are given some desired targets also calculate the loss logits = self.lm_head(x) loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 ) + if moe and self.config.moe_routing == "standard_gating": + # calculate the router losses per layer + for logit, expert_choice in zip(router_logits, experts): + router_losses = self.get_router_losses( + logit, expert_choice, eval=not self.training + ) + for k, v in router_losses.items(): + aux_losses[k] = aux_losses.get(k, 0.0) + v + if self.training: + loss += ( + v + * getattr(self.config, k + "_factor") + / self.config.n_layer + ) else: # inference-time mini-optimization: only forward the lm_head on the very last position @@ -233,9 +304,14 @@ def forward(self, idx, targets=None, get_logits=False): ) # note: using list [-1] to preserve the time dim loss = None logits = logits if get_logits else None + router_logits = ( + torch.stack(router_logits, dim=0) if len(router_logits) > 0 else None + ) return { "logits": logits, "loss": loss, + "aux_losses": aux_losses, + "router_logits": router_logits, } def crop_sequence_length(self, sequence_length): @@ -250,9 +326,60 @@ def crop_sequence_length(self, sequence_length): for block in self.transformer.h: block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length] + def convert_dense_to_sparse(self, state_dict): + """ + Convert the dense model to sparse model. + """ + state_to_load = {} + for k, v in state_dict.items(): + vals = k.split(".") + print(vals) + if len(vals) >= 5 and vals[4] == "mlp": + # for layer i, go from '_orig_mod.transformer.h.i.mlp.c_fc.weight' to + # '_orig_mod.transformer.h.i.mlp.experts.e.c_fc.weight' + for e in range(self.config.moe_num_experts): + state_to_load[ + ".".join(vals[1:5] + ["experts", str(e)] + vals[5:]) + ] = v + # add router weight from already initialized weights above + state_to_load[".".join(vals[1:5] + ["router", "weight"])] = ( + self.transformer.h[int(vals[3])].mlp.router.weight + ) + else: + state_to_load[".".join(k.split(".")[1:])] = v + return state_to_load + + def convert_n_dense_to_sparse(self, state_dicts): + """ + Convert the dense model to sparse model. + """ + assert ( + len(state_dicts) == self.config.moe_num_experts + ), f"len(state_dict)={len(state_dicts)} != {self.config.moe_num_experts}." + state_to_load = {} + for e in range(self.config.moe_num_experts): + state_dict = state_dicts[e] + for k, v in state_dict.items(): + vals = k.split(".") + print(vals) + if len(vals) >= 5 and vals[4] == "mlp": + # for layer i, go from '_orig_mod.transformer.h.i.mlp.c_fc.weight' to + # '_orig_mod.transformer.h.i.mlp.experts.e.c_fc.weight' + state_to_load[ + ".".join(vals[1:5] + ["experts", str(e)] + vals[5:]) + ] = v + # add router weight from already initialized weights above + state_to_load[".".join(vals[1:5] + ["router", "weight"])] = ( + self.transformer.h[int(vals[3])].mlp.router.weight + ) + else: + state_to_load[".".join(k.split(".")[1:])] = v + return state_to_load + def from_pretrained( self, model_path, + from_dense: bool = True, ): paths = model_path.split(",") if len(paths) == 1: @@ -263,11 +390,30 @@ def from_pretrained( ) state_to_load = loaded_state["model"] - # load the sparse model - state_to_load = { - ".".join(k.split(".")[1:]): v # drop _orig_mod from keys - for k, v in state_to_load.items() - } + if self.config.moe and from_dense: + # load the dense model and convert to sparse + state_to_load = self.convert_dense_to_sparse(state_to_load) + else: + # load the sparse model + state_to_load = { + ".".join(k.split(".")[1:]): v # drop _orig_mod from keys + for k, v in state_to_load.items() + } + else: + loaded_states = [] + for path in paths: + loaded_state = torch.load( + str(path + "/ckpt.pt"), + map_location=torch.device(self.config.device), + ) + loaded_states.append(loaded_state["model"]) + if self.config.moe and from_dense: + # load the dense model and convert to sparse + print(f"Loading from {len(paths)} dense models.") + state_to_load = self.convert_n_dense_to_sparse(loaded_states) + else: + raise NotImplementedError("Multiple paths -> load from dense.") + super().load_state_dict(state_to_load) def get_parameter_group_specs(self): """ diff --git a/src/models/llama.py b/src/models/llama.py index e6aaec6..9ece968 100644 --- a/src/models/llama.py +++ b/src/models/llama.py @@ -11,6 +11,7 @@ from torch.nn import functional as F from models.base import CausalSelfAttention, GPTBase +from models.moe import MoE def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: @@ -89,7 +90,8 @@ def __init__(self, config): self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) def forward(self, x): - return self.c_proj(nn.functional.silu(self.w1(x)) * self.w2(x)) + # tuple form because of aux loss from MoE + return self.c_proj(nn.functional.silu(self.w1(x)) * self.w2(x)), {} class LlamaAttention(CausalSelfAttention): @@ -141,13 +143,17 @@ def __init__(self, config): self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) self.attn = LlamaAttention(config) self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) - self.mlp = LlamaMLP(config) + + if config.moe: + self.mlp = MoE(config, LlamaMLP) + else: + self.mlp = LlamaMLP(config) def forward(self, x, freqs_cis): x = x + self.attn(self.ln_1(x), freqs_cis) - x_ = self.mlp(self.ln_2(x)) + x_, logits_and_experts = self.mlp(self.ln_2(x)) x = x + x_ - return x + return x, logits_and_experts class Llama(GPTBase): @@ -199,7 +205,7 @@ def get_num_params(self, non_embedding=True): n_params = sum(p.numel() for p in self.parameters()) return n_params - def forward(self, idx, targets=None, get_logits=False): + def forward(self, idx, targets=None, get_logits=False, moe=False): device = idx.device b, t = idx.size() assert ( @@ -214,16 +220,40 @@ def forward(self, idx, targets=None, get_logits=False): x = self.transformer.drop(tok_emb) freqs_cis = self.freqs_cis.to(x.device)[pos] + # router logits is a list for each layer's routing, each of shape (b * seq_len, n_experts) + router_logits = [] + # experts is a list for each layer's selected experts, shape (b * seq_len, topk) + experts = [] + for block in self.transformer.h: - x = block(x, freqs_cis=freqs_cis) + x, logits_and_experts = block(x, freqs_cis=freqs_cis) + if len(logits_and_experts) > 0: + router_logits.append(logits_and_experts["router_logits"]) + experts.append(logits_and_experts["selected_experts"]) x = self.transformer.ln_f(x) + # aux_losses is a dict with keys for different auxiliary losses + aux_losses = {} if targets is not None: # if we are given some desired targets also calculate the loss logits = self.lm_head(x) loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 ) + if moe and self.config.moe_routing == "standard_gating": + # calculate the router losses per layer + for logit, expert_choice in zip(router_logits, experts): + router_losses = self.get_router_losses( + logit, expert_choice, eval=not self.training + ) + for k, v in router_losses.items(): + aux_losses[k] = aux_losses.get(k, 0.0) + v + if self.training: + loss += ( + v + * getattr(self.config, k + "_factor") + / self.config.n_layer + ) else: # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head( @@ -233,7 +263,13 @@ def forward(self, idx, targets=None, get_logits=False): logits = logits if get_logits else None + router_logits = ( + torch.stack(router_logits, dim=0) if len(router_logits) > 0 else None + ) + return { "logits": logits, "loss": loss, + "aux_losses": aux_losses, + "router_logits": router_logits, } diff --git a/src/models/moe.py b/src/models/moe.py new file mode 100644 index 0000000..ed8e31e --- /dev/null +++ b/src/models/moe.py @@ -0,0 +1,244 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def log_mean(x, dim): + return torch.logsumexp(x, dim=dim) - torch.log( + torch.tensor(x.shape[dim], dtype=torch.float32) + ) + + +def entropy_reg(logits: torch.Tensor, mean_over_batch: bool = True): + """Entropy regularization for the router.""" + + entropy_l = lambda l: -(l * l.exp()).sum(-1) + # softmax over experts + # logits: [batch_size * sequence_length, num_experts] + logprobs = F.log_softmax(logits, dim=-1) + if mean_over_batch: + # take mean probability over batch + logprobs = log_mean(logprobs, 0) + + return -entropy_l(logprobs).mean() + + +# two losses below are adapted from +# https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/routing.py +def load_balancing_loss_(router_probs, expert_indices) -> float: + """Computes auxiliary load balancing loss as in Switch Transformer. + + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function + implements the loss function presented in equations (4) - (6). It aims to + penalize those cases where the routing between experts is unbalanced. + + Args: + router_probs: Probability assigned to each expert per token. Shape: + [batch_size * sequence_length, num_experts]. + expert_indices: [batch_size * sequence_length, num_selected_experts] + indices identifying the top num_selected_experts for a given token. + + Returns: + The auxiliary loss. + """ + # num_token = batch_size * sequence_length + num_token, num_experts = router_probs.shape + + # Shape: [batch_size * sequence_length, num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_indices, num_experts) + # For a given token, determine if it was routed to a given expert. + # Shape: [batch_size * sequence_length, num_experts] + expert_mask, _ = torch.max(expert_mask, dim=-2) + + # shape [num_experts] + tokens_per_expert = torch.mean(expert_mask, dim=0, dtype=torch.float32) + # shape [num_experts] + router_prob_per_expert = torch.mean(router_probs, dtype=torch.float32, dim=0) + return ( + torch.mean( + tokens_per_expert * router_prob_per_expert, + dtype=torch.float32, + ) + * num_experts + ) + + +def load_balancing_loss(logits, expert_indices) -> float: + """Computes auxiliary load balancing loss as in Switch Transformer. + + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function + implements the loss function presented in equations (4) - (6). It aims to + penalize those cases where the routing between experts is unbalanced. + + Args: + logits: logits assigned to each expert per token. Shape: + [batch_size * sequence_length, num_experts]. + expert_indices: [batch_size * sequence_length, num_selected_experts] + indices identifying the top num_selected_experts for a given token. + + Returns: + The auxiliary loss. + """ + # num_token = batch_size * sequence_length + num_token, num_experts = logits.shape + + # Shape: [batch_size * sequence_length, num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_indices, num_experts) + # For a given token, determine if it was routed to a given expert. + # Shape: [batch_size * sequence_length, num_experts] + expert_mask, _ = torch.max(expert_mask, dim=-2) + + # shape [num_experts] + tokens_per_expert = torch.mean(expert_mask, dim=0, dtype=torch.float32) + + # compute router probability per expert in log space for numerical stability + logprobs = F.log_softmax(logits, dim=-1) + # take mean probability over batch + # shape [num_experts] + logprobs = log_mean(logprobs, dim=0) + router_prob_per_expert = torch.exp(logprobs) + return ( + torch.mean( # mean over experts + tokens_per_expert * router_prob_per_expert, + dtype=torch.float32, + ) + * num_experts + ) + + +def router_z_loss(router_logits) -> float: + """Compute router z-loss. + + The router z-loss was introduced in Designing Effective Sparse Expert Models + (https://arxiv.org/abs/2202.08906). It encourages router logits to remain + small in an effort to improve stability. + + Args: + router_logits: [batch_size * sequence_length, num_experts] + router logits + + Returns: + Scalar router z-loss. + """ + num_tokens, _ = router_logits.shape + log_z = torch.logsumexp(router_logits, dim=-1) + z_loss = log_z**2 + return torch.sum(z_loss, dtype=torch.float32) / (num_tokens) + + +class MoE(nn.Module): + def __init__(self, config, mlp): + super().__init__() + assert config.moe_num_experts > 0 + self.experts = nn.ModuleList( + [mlp(config=config) for _ in range(config.moe_num_experts)] + ) + self.n_shared_experts = config.moe_num_shared_experts + self.router = nn.Linear( + config.n_embd, config.moe_num_experts - self.n_shared_experts, bias=False + ) + self.top_k = config.moe_num_experts_per_tok + self.softmax_order = config.moe_softmax_order + + def forward(self, inputs: torch.Tensor): + # [batch_size * sequence_length, n_embd] + inputs_squashed = inputs.view(-1, inputs.shape[-1]) + # [batch_size * sequence_length, num_experts] + router_logits = self.router(inputs_squashed) + + # note that selected experts will be the same for all orders: + # softmax doesnt change top-k, but the weights are different + if self.softmax_order == "softmax_topk": + all_probs = F.softmax(router_logits, dim=1, dtype=torch.float32) + weights, selected_experts = torch.topk(all_probs, self.top_k) + elif self.softmax_order == "topk_softmax": + weights, selected_experts = torch.topk(router_logits, self.top_k) + weights = F.softmax(weights, dim=-1, dtype=torch.float32) + else: + raise ValueError(f"Unknown softmax_order: {self.softmax_order}") + + results = torch.zeros_like(inputs_squashed) + for i, expert in enumerate(self.experts): + if i < self.n_shared_experts: + # always activate shared experts + output, _ = expert(inputs_squashed) + results += output + else: + batch_idx, nth_expert = torch.where( + selected_experts == i - self.n_shared_experts + ) + output, _ = expert(inputs_squashed[batch_idx]) + results[batch_idx] += weights[batch_idx, nth_expert, None] * output + return results.view_as(inputs), { + "router_logits": router_logits, + "selected_experts": selected_experts, + } + + +class ExpertChoiceMoE(nn.Module): + def __init__(self, config, mlp): + super().__init__() + assert config.moe_num_experts > 0 + self.n_experts = config.moe_num_experts + self.experts = nn.ModuleList( + [mlp(config=config) for _ in range(config.moe_num_experts)] + ) + self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False) + self.capacity_factor = config.capacity_factor + self.softmax_order = config.moe_softmax_order + self.top_k = int( + self.capacity_factor + * config.batch_size + * config.sequence_length + / config.moe_num_experts + ) + + def forward(self, inputs: torch.Tensor): + # [batch_size * sequence_length, n_embd] + inputs_squashed = inputs.view(-1, inputs.shape[-1]) + num_tokens = inputs_squashed.shape[0] + top_k = min(self.top_k, int(self.capacity_factor * num_tokens / self.n_experts)) + # [batch_size * sequence_length, num_experts] + router_logits = self.router(inputs_squashed) + + # note that selected experts will be the same for all orders: + # softmax doesnt change top-k, but the weights are different + if self.softmax_order == "softmax_topk": + all_probs = F.softmax(router_logits, dim=1, dtype=torch.float32) + # selection over tokens! + # weights and selected tokens: [num_experts, top_k] + weights, selected_tokens = torch.topk(all_probs.T, top_k) + elif self.softmax_order == "topk_softmax": + # weights and selected tokens: [num_experts, top_k] + weights, selected_tokens = torch.topk(router_logits.T, top_k) + weights = F.softmax(weights, dim=0, dtype=torch.float32) + else: + raise ValueError(f"Unknown softmax_order: {self.softmax_order}") + + """ this is the full parallel version with einsum """ + # [num_experts, top_k, num_tokens] + # P = F.one_hot(selected_tokens, num_tokens).type_as(inputs_squashed) + # # [num_experts, top_k, n_embd] + # x_in = torch.matmul(P, inputs_squashed) + # # [num_experts, num_tokens, n_embd] + # experts_out = torch.stack( + # [expert(x)[0] for expert, x in zip(self.experts, x_in)], dim=0 + # ) + # results = torch.einsum("ijl,ij,ijd->ld", P, weights, experts_out) + + """ this is the loop version """ + # need to loop through experts because of memory growing too large + # when doing everything in parallel? + results = torch.zeros_like(inputs_squashed) + for i, expert in enumerate(self.experts): + # [top_k] + batch_idx = selected_tokens[i] + # [top_k, n_embd] + output, _ = expert(inputs_squashed[batch_idx]) + results[batch_idx] += weights[i, :, None] * output + return results.view_as(inputs), { + "router_logits": router_logits, + "selected_experts": selected_tokens, + } diff --git a/src/optim/base.py b/src/optim/base.py index c1143a8..e4b5ef8 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -15,7 +15,7 @@ wd_wsd_schedule) from .utils import (eval, get_batch, get_parameter_norms, load_checkpoint, load_worker_state, log_prodigy_lr, save_checkpoint, - save_worker_state) + save_worker_state, visualize_routing) def train( @@ -130,7 +130,7 @@ def train( microstep_idx=microstep_idx, gradient_accumulation_steps=cfg.acc_steps, ): - outputs = model(x, targets=y) + outputs = model(x, targets=y, moe=cfg.moe) loss = outputs["loss"] / cfg.acc_steps loss.backward() @@ -227,6 +227,9 @@ def train( and distributed_backend.is_master_process() # Only log on master rank ): train_loss = loss.detach().cpu().item() * cfg.acc_steps + train_aux_losses = { + f"train/{k}": v for k, v in outputs["aux_losses"].items() + } current_lrs = [param_group["lr"] for param_group in opt.param_groups] @@ -253,6 +256,7 @@ def train( "mean_grad_norm": ( torch.tensor(grad_norms).mean().item() if grad_norms else 0 ), + **train_aux_losses, } if cfg.weight_decay_scheduler: @@ -261,8 +265,6 @@ def train( if cfg.opt == "prodigy": wandb_logs["effective_lr"] = prodigy_efective_lrs[0] - # log the L2 norm of the parameters - # works in a single gpu setting if cfg.log_parameter_norms: raw_model = distributed_backend.get_raw_model(model) model_norm = get_parameter_norms(raw_model, order=cfg.norm_order) @@ -303,12 +305,14 @@ def eval_and_log( # to make sure we start from the beginning of the validation set, # i.e. repeat the same batches val_reader.set_step(0) - val_acc, val_loss, val_perplexity = eval( + val_acc, val_loss, val_perplexity, val_aux_losses, router_logits = eval( model, val_reader, cfg.device, max_num_batches=max_num_batches, ctx=type_ctx, + moe=cfg.moe, + get_router_logits=cfg.moe and cfg.plot_router_logits, cfg=cfg, ) @@ -327,6 +331,7 @@ def eval_and_log( "final-val/loss": val_loss, "final-val/perplexity": val_perplexity, "final-val/acc": val_acc, + **val_aux_losses, } else: logs = { @@ -335,7 +340,11 @@ def eval_and_log( "val/loss": val_loss, "val/perplexity": val_perplexity, "val/acc": val_acc, + **val_aux_losses, } + if cfg.moe and cfg.plot_router_logits: + routing_logs = visualize_routing(router_logits, cfg) + logs = {**logs, **routing_logs} wandb.log(logs) if cfg.eval_seq_prefix != "none" and ( diff --git a/src/optim/utils.py b/src/optim/utils.py index a47c644..200b8b4 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -6,6 +6,7 @@ import numpy as np import torch import torch.distributed as dist +import wandb def get_batch(datareader, device="cpu"): @@ -27,26 +28,61 @@ def eval( device="cpu", max_num_batches=24, ctx=nullcontext(), + moe=False, + get_router_logits=False, cfg=None, ): assert model.training == False - loss_list_val, acc_list = [], [] + loss_list_val, acc_list, loss_list_aux_val = [], [], {} + router_logits = [] for idx in range(max_num_batches): x, y = get_batch(reader, device=device) with ctx: - outputs = model(x, targets=y, get_logits=True) + outputs = model(x, targets=y, get_logits=True, moe=moe) val_loss = outputs["loss"] loss_list_val.append(val_loss) acc_list.append((outputs["logits"].argmax(-1) == y).float().mean()) + # auxiliary losses are optional + for k, v in outputs["aux_losses"].items(): + loss_list_aux_val[k] = loss_list_aux_val.get(k, []) + loss_list_aux_val[k].append(v) + + # router logits for MoE visualization + if get_router_logits: + # shape [layers, batch_size * sequence_length, num_experts] + logits = outputs["router_logits"] + # shape [max_batches, layers, batch_size * sequence_length, num_experts] + router_logits.append(logits) + val_acc = torch.stack(acc_list).mean().item() val_loss = torch.stack(loss_list_val).mean().item() val_perplexity = 2.71828**val_loss + val_aux_losses = { + f"val/{k}": torch.stack(v).mean().item() for k, v in loss_list_aux_val.items() + } + + if get_router_logits: + # filter out the router logits that are not of the expected shape (happens for the last batch in + # dataloader has a different batch size than the others) + if cfg: + intended_size = cfg.batch_size * cfg.sequence_length + else: + intended_size = x.shape[0] * x.shape[1] + # shape [batches - 1, layers, batch_size * sequence_length, num_experts] + router_logits = ( + torch.stack( + [rl for rl in router_logits if rl.shape[1] == intended_size], + dim=0, + ) + .detach() + .cpu() + ) - return val_acc, val_loss, val_perplexity + return val_acc, val_loss, val_perplexity, val_aux_losses, router_logits @torch.no_grad() @@ -217,3 +253,38 @@ def log_prodigy_lr(opt): effective_lrs.append(effective_lr) return effective_lrs + + +def visualize_routing(router_logits, extra_args): + # router_logits: [batches, layers, batch_size * sequence_length, num_experts] + logs = {} + + n_layers = extra_args.n_layer + num_experts = extra_args.moe_num_experts + num_experts_per_tok = extra_args.moe_num_experts_per_tok + + # histogram over all logits to see distribution + logs["router/logits"] = wandb.Histogram( + router_logits.type(torch.float32).flatten().cpu().numpy() + ) + + # distribution over experts for layer 0, layer n/2, n-1 + for layer in [0, n_layers // 2, n_layers - 1]: + router_logits_layer = router_logits[:, layer] + # shape [batches, batch_size * sequence_length, num_experts_per_tok] + weights, selected_experts = torch.topk( + router_logits_layer, num_experts_per_tok, dim=-1 + ) + # shape [batches, batch_size * sequence_length, num_experts_per_tok, num_experts] + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + # For a given token, determine if it was routed to a given expert. + # Shape: [batches, batch_size * sequence_length, num_experts] + expert_mask, _ = torch.max(expert_mask, dim=-2) + # shape [num_experts] + tokens_per_expert = torch.mean(expert_mask, dim=(0, 1), dtype=torch.float32) + layer_token_routing = { + f"router/layer_{layer}_expert_{i}_selection": tokens_per_expert[i].item() + for i in range(num_experts) + } + logs.update(layer_token_routing) + return logs