diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py index cdee37935..5e02d6d2e 100644 --- a/fast_llm/engine/config_utils/initialization.py +++ b/fast_llm/engine/config_utils/initialization.py @@ -13,7 +13,7 @@ @config_class(registry=True) class InitializationConfig(Config): _abstract = True - has_initialization: typing.ClassVar[bool] = True + is_default: typing.ClassVar[bool] = False @classmethod def _from_dict( @@ -35,7 +35,7 @@ def get_initializer(self) -> "Initializer": class DefaultInitializationConfig(InitializationConfig): # A placeholder indicating that the class default should be used instead. _abstract = False - has_initialization = False + is_default = True @config_class(dynamic_type={InitializationConfig: "fill"}) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 8f4dffedf..439d1da2e 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -16,7 +16,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.rotary.config import NoRotaryConfig +from fast_llm.layers.attention.rotary.config import NoRotaryConfig logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/__init__.py b/fast_llm/layers/attention/__init__.py similarity index 100% rename from fast_llm/layers/transformer/__init__.py rename to fast_llm/layers/attention/__init__.py diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/attention/attention.py similarity index 73% rename from fast_llm/layers/transformer/attention.py rename to fast_llm/layers/attention/attention.py index 0bea58d9a..f41b48971 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -4,13 +4,13 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.peft import TransformerSubLayerName -from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import AttentionConfig, AttentionDimNames, AttentionKwargs -from fast_llm.utils import get_lr_scale +from fast_llm.layers.block.config import BlockDimNames +from fast_llm.utils import div try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -50,80 +50,95 @@ class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): A self-attention layer. """ - _QUERY_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.composite_heads, - AttentionDimNames.kv_channels, - ) - _KV_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.head_groups, - AttentionDimNames.kv_channels, - ) - _CONTEXT_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.composite_dense, - ) - - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config, tensor_space, block_index, name) - self._config = config - self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) - - self._kv_channels = self._tensor_space[AttentionDimNames.kv_channels].size - self._head_groups = self._tensor_space[AttentionDimNames.head_groups].global_size - self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size - self._local_heads_per_group = self._tensor_space[AttentionDimNames.group_heads].size + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): + super().__init__(config, distributed_config, hidden_dim, block_index, name) + self._use_flash_attention = self._config.do_use_flash_attention(distributed_config) + + head_group_dim = TensorDim( + "head_groups", self._config.head_groups, self._parallel_dim if self.head_groups > 1 else None + ) + group_heads_dim = TensorDim( + "group_heads", + div(self._config.num_attention_heads, self._config.head_groups), + None if self.head_groups > 1 else self._parallel_dim, + ) + self._local_head_groups = head_group_dim.size + self._local_heads_per_group = group_heads_dim.size self._local_heads = self._local_head_groups * self._local_heads_per_group - self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space[AttentionDimNames.hidden] + kv_channels_dim = TensorDim("kv_channels", self._config.kv_channels) + query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, kv_channels_dim)) + key_value_dim = ConcatenatedTensorDim( + "key_value", + ( + CompositeTensorDim("key", (group_heads_dim, kv_channels_dim)), + CompositeTensorDim("value", (group_heads_dim, kv_channels_dim)), + ), + ) + dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, kv_channels_dim)) - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) + self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power) + + lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None # TODO: Merge the query and key-value computations? (harder with sequence parallel.) - self.query = OutputParallelLinear( + self.query = self._config.query_layer.get_layer( hidden_dim, - self._tensor_space[AttentionDimNames.composite_query], + query_dim, bias=self._config.add_qkv_bias, weight_init_method=self._config.qkv_weight_initialization_method, bias_init_method=self._config.qkv_bias_initialization_method, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, + peft=self._config.block.peft, ) - self.key_value = OutputParallelLinear( + # TODO: Separate. + self.key_value = self._config.key_layer.get_layer( hidden_dim, - self._tensor_space[AttentionDimNames.composite_key_value], - bias=self._config.add_qkv_bias, - weight_init_method=self._config.qkv_weight_initialization_method, - bias_init_method=self._config.qkv_bias_initialization_method, + key_value_dim, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, + peft=self._config.block.peft, ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) # Rotary embeddings. - self._rotary = self._config.rotary.build() + self._rotary = self._config.rotary.get_layer() # Output. - self.dense = InputParallelLinear( - self._tensor_space[AttentionDimNames.composite_dense], + self.dense = self._config.dense_layer.get_layer( + dense_dim, hidden_dim, - bias=self._config.add_dense_bias, - weight_init_method=self._config.dense_weight_initialization_method, - bias_init_method=self._config.dense_bias_initialization_method, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + peft=self._config.block.peft, + lr_scale=lr_scale, ) - # PEFT. - self.query = self._config.peft.apply_linear(self.query, TransformerSubLayerName.query) - self.key_value = self._config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) - self.dense = self._config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) + if self._debug.enabled: + self._query_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + CompositeTensorDim("heads", (head_group_dim, group_heads_dim)), + kv_channels_dim, + ) + self._kv_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + head_group_dim, + kv_channels_dim, + ) + self._context_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + dense_dim, + ) def _attn_fused( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor @@ -133,16 +148,18 @@ def _attn_fused( sk = key.size(1) if self._local_head_groups == 1: - query = query.view(b, sq * self._local_heads, self._kv_channels) + query = query.view(b, sq * self._local_heads, self.__config.kv_channels) key = key.transpose(-1, -2) else: query = ( - query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self._kv_channels)) + query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self.__config.kv_channels)) .transpose(1, 2) - .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._kv_channels) + .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self.__config.kv_channels) + ) + key = key.unflatten(-1, (self._local_head_groups, self.__config.kv_channels)).movedim(1, 3).flatten(0, 1) + value = ( + value.unflatten(-1, (self._local_head_groups, self.__config.kv_channels)).transpose(1, 2).flatten(0, 1) ) - key = key.unflatten(-1, (self._local_head_groups, self._kv_channels)).movedim(1, 3).flatten(0, 1) - value = value.unflatten(-1, (self._local_head_groups, self._kv_channels)).transpose(1, 2).flatten(0, 1) attn_weights = torch.empty( (b * self._local_head_groups, sq * self._local_heads_per_group, sk), device=query.device, dtype=query.dtype @@ -169,7 +186,7 @@ def _attn_fused( return attn_output.view(b, sq, -1) else: return ( - attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._kv_channels) + attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._config.kv_channels) .transpose(1, 2) .flatten(2) ) @@ -181,7 +198,7 @@ def _query_key_value_forward( handle = None - if self._head_groups == 1 and self._sequence_parallel: + if self._config.head_groups == 1 and self._sequence_parallel: key_value, handle = gather_op( key_value, group=self._tensor_space.distributed.tensor_group, dim=0, async_op=True ) @@ -226,7 +243,7 @@ def _query_key_value_backward( if handle: handle.wait() - if self._head_groups == 1 and (group := self._tensor_space.distributed.tensor_group): + if self._config.config.head_groups == 1 and (group := self._tensor_space.distributed.tensor_group): if self._sequence_parallel: key_value_grad = reduce_scatter_op(key_value_grad, group=group, dim=0) else: @@ -281,15 +298,15 @@ def forward( query = query.transpose(0, 1).contiguous() key_value = key_value.transpose(0, 1).contiguous() - key, value = key_value.split(self._local_head_groups * self._kv_channels, dim=-1) + key, value = key_value.split(self._local_head_groups * self.__config.kv_channels, dim=-1) - query = query.view(*query.shape[:2], self._local_heads, self._kv_channels) - key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) - value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) + query = query.view(*query.shape[:2], self._local_heads, self._config.kv_channels) + key = key.view(*key.shape[:2], self._local_head_groups, self._config.kv_channels) + value = value.view(*value.shape[:2], self._local_head_groups, self._config.kv_channels) if self._debug.enabled: - self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs) - self._debug(key, "key_rotary_input", self._KV_DIMS, kwargs) + self._debug(query, "query_rotary_input", self._query_dims, kwargs) + self._debug(key, "key_rotary_input", self._kv_dims, kwargs) query, key = self._rotary(query, key, kwargs) window_size = self._decide_window_size() @@ -337,10 +354,10 @@ def forward( ) if self._debug.enabled: - self._debug(query, "query", self._QUERY_DIMS, kwargs) - self._debug(key, "key", self._KV_DIMS, kwargs) - self._debug(value, "value", self._KV_DIMS, kwargs) - self._debug(input_, "context", self._CONTEXT_DIMS, kwargs) + self._debug(query, "query", self._query_dims, kwargs) + self._debug(key, "key", self._kv_dims, kwargs) + self._debug(value, "value", self._kv_dims, kwargs) + self._debug(input_, "context", self._context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py new file mode 100644 index 000000000..0d64ccbf8 --- /dev/null +++ b/fast_llm/layers/attention/config.py @@ -0,0 +1,166 @@ +import functools +import logging +import typing +import warnings + +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.functional.config import TritonConfig +from fast_llm.layers.attention.rotary.config import RotaryConfig +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs, MixerConfig +from fast_llm.layers.common.linear.config import LinearConfig +from fast_llm.utils import Assert, div + +logger = logging.getLogger(__name__) + + +class AttentionDimNames(BlockDimNames): + # A set of common tensor dim names packed into a namespace. + # Self-attention dimensions + # head_groups = "head_groups" + # group_heads = "group_heads" + # key_and_value = "key_value" + # kv_channels = "kv_channels" + # composite_heads = "composite_heads" + # composite_query = "composite_query" + # composite_key_value = "composite_key_value" + # composite_dense = "composite_dense" + pass + + +class AttentionKwargs(BlockKwargs): + rotary_freq_q = "rotary_freq_q" + rotary_freq_k = "rotary_freq_k" + attention_mask = "attention_mask" + attention_mask_value = "attention_mask_value" + cu_seqlens_q = "cu_seqlens_q" + cu_seqlens_k = "cu_seqlens_k" + max_seqlen_q = "max_seqlen_q" + max_seqlen_k = "max_seqlen_k" + # TODO: Review these + presents = "presents" + past_key_values = "past_key_values" + + +@config_class(dynamic_type={MixerConfig: "attention"}) +class AttentionConfig(MixerConfig): + _abstract = False + + # Needed for backward compatibility. TODO: remove + module_name: typing.ClassVar[str] = "attn" + + # TODO: Review names + query_layer: LinearConfig = Field( + desc="Configuration for the query layer.", + hint=FieldHint.architecture, + ) + key_layer: LinearConfig = Field( + desc="Configuration for the key layer.", + hint=FieldHint.architecture, + ) + value_layer: LinearConfig = Field( + desc="Configuration for the value layer.", + hint=FieldHint.architecture, + ) + dense_layer: LinearConfig = Field( + desc="Initialization configuration for the dense layer.", + hint=FieldHint.feature, + ) + rotary: RotaryConfig = Field( + desc="Configuration for the rotary positional embeddings.", + hint=FieldHint.architecture, + ) + num_attention_heads: int = Field(default=8, desc="Number of attention heads.", hint=FieldHint.architecture) + head_groups: int = Field( + default=1, + desc="Number of head group for grouped query attention.", + doc="Set to 1 for multi-query attention, `num_attention_heads` for multi-head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + kv_channels: int = Field( + default=None, + desc="Number of key and value channels, i.e., hidden dimension of each attention head. Default: hidden_size // num_attention_heads", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + attention_dropout: float = Field( + default=0.0, + desc="Dropout applied to the attention intermediate states.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + # Use flash attention if possible (fp16 or bf16) + use_flash_attention: bool = Field( + default=True, desc="Enable Flash Attention if possible.", hint=FieldHint.optional + ) + window_size: int | None = Field( + default=None, + desc="Size of the attention sliding window. Warning: this parameter is not part of the architecture and must be redefined when loading a pretrained model.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + max_window_layers: int | None = Field( + default=None, + desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.", + hint=FieldHint.optional, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + attention_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the Attention projection weights.", + doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + attention_softmax_scale_power: float = Field( + default=0.5, + desc="The scaling power to apply to kv_channel in the attention calculation. " + " Under Standard Parameterization (SP): default to 0.5. " + " Under muP (if scaling kv_channels size): use 1. " + " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + # TODO: hidden_size not yet validated. + if self.kv_channels is None: + self.kv_channels = div(self.block.hidden_size, self.num_attention_heads) + # TODO: Block variables as defaults? + for layer, scale, enable_peft in ( + zip( + (self.query_layer, self.key_layer, self.value_layer, self.dense_layer), + (1, 1, 1, 2 * max(self.block.num_blocks, 1)), + (True, False, True, False), + ), + ): + layer.default = LinearConfig( + bias=True, + weight_initialization=init_normal_(0, (self.block.hidden_size * scale) ** -0.5), + bias_initialization=init_zeros_, + lr_scale=None, + enable_peft=True, + ) + super()._validate() + + if not TritonConfig.TRITON_ENABLED: + warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") + + Assert.multiple(self.num_attention_heads, self.head_groups) + + @functools.cached_property + def projection_size(self): + assert self._validated + return self.num_attention_heads * self.kv_channels + + def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: + return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + + +@config_class() +# TODO: Remove +class TransformerConfig(BlockConfig): + pass diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/attention/preprocessing.py similarity index 98% rename from fast_llm/layers/transformer/preprocessing.py rename to fast_llm/layers/attention/preprocessing.py index 16e5811e6..78f123438 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/attention/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/rotary/__init__.py b/fast_llm/layers/attention/rotary/__init__.py similarity index 100% rename from fast_llm/layers/transformer/rotary/__init__.py rename to fast_llm/layers/attention/rotary/__init__.py diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/attention/rotary/config.py similarity index 90% rename from fast_llm/layers/transformer/rotary/config.py rename to fast_llm/layers/attention/rotary/config.py index 748f2af28..8c4007ef1 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -10,7 +10,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.transformer.rotary.rotary import DefaultRotary, Llama3Rotary, NoRotary, Rotary, YarnRotary + from fast_llm.layers.attention.rotary.rotary import DefaultRotary, Llama3Rotary, NoRotary, Rotary, YarnRotary @config_class(registry=True) @@ -29,7 +29,7 @@ def _from_dict( return NoRotaryConfig._from_dict(default, strict, flat) return super()._from_dict(default, strict=strict, flat=flat) - def build(self, tensor_space: TensorSpace | None = None) -> "Rotary": + def get_layer(self, tensor_space: TensorSpace | None = None) -> "Rotary": return self._get_configurable_class()(self, tensor_space) @classmethod @@ -44,7 +44,7 @@ class NoRotaryConfig(RotaryConfig): @classmethod def _get_configurable_class(self) -> "type[NoRotary]": - from fast_llm.layers.transformer.rotary.rotary import NoRotary + from fast_llm.layers.attention.rotary.rotary import NoRotary return NoRotary @@ -75,7 +75,7 @@ def _validate(self) -> None: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") def _get_configurable_class(self) -> "type[DefaultRotary]": - from fast_llm.layers.transformer.rotary.rotary import DefaultRotary + from fast_llm.layers.attention.rotary.rotary import DefaultRotary return DefaultRotary @@ -97,7 +97,7 @@ def _validate(self) -> None: Assert.gt(self.high_frequency_factor, self.low_frequency_factor) def _get_configurable_class(self) -> "type[Llama3Rotary]": - from fast_llm.layers.transformer.rotary.rotary import Llama3Rotary + from fast_llm.layers.attention.rotary.rotary import Llama3Rotary return Llama3Rotary @@ -137,6 +137,6 @@ def _validate(self) -> None: super()._validate() def _get_configurable_class(self) -> "type[YarnRotary]": - from fast_llm.layers.transformer.rotary.rotary import YarnRotary + from fast_llm.layers.attention.rotary.rotary import YarnRotary return YarnRotary diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py similarity index 98% rename from fast_llm/layers/transformer/rotary/rotary.py rename to fast_llm/layers/attention/rotary/rotary.py index ebb629aa1..953f6cf8f 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -8,8 +8,8 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs -from fast_llm.layers.transformer.rotary.config import ( +from fast_llm.layers.attention.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.attention.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, @@ -83,7 +83,7 @@ def __init__( self._tensor_space = tensor_space if self._tensor_space is not None: self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[AttentionDimNames.value_head_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 070f5dc67..eb1b502a2 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -10,8 +10,10 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs, BlockLayerConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -20,8 +22,9 @@ class DebugLayer: # TODO: Move elsewhere? - def __init__(self, tensor_space: TensorSpace, name: str, debug_level: int = 0, debug_memory: bool = False): - self._tensor_space = tensor_space + _distributed: Distributed + + def __init__(self, name: str, debug_level: int = 0, debug_memory: bool = False): self._name = name self._debug_level = debug_level self._debug_memory = debug_memory @@ -37,14 +40,17 @@ def _get_meta( ( dim if isinstance(dim, TensorDim) - else hidden_dims[dim] if dim in hidden_dims else self._tensor_space[dim] + else hidden_dims[dim] if dim in hidden_dims else TensorDim(dim, tensor.size(i)) ) - for dim in dims + for i, dim in enumerate(dims) ), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, ) + def setup(self, distributed: Distributed): + self._distributed = distributed + @functools.cached_property def enabled(self) -> bool: return self._debug_level > 0 or self._debug_memory @@ -70,7 +76,7 @@ def __call__[ tensor, level=self._debug_level, meta=self._get_meta(tensor, name, dims, kwargs), - distributed=self._tensor_space.distributed, + distributed=self._distributed, global_=global_, log_fn=log_fn, scale=scale, @@ -81,7 +87,7 @@ def __call__[ tensor, level=self._debug_level, meta=self._get_meta(tensor, name + " grad", dims, kwargs), - distributed=self._tensor_space.distributed, + distributed=self._distributed, global_=global_, log_fn=log_fn, scale=scale, @@ -94,23 +100,26 @@ class BlockLayerBase[ConfigType: BaseModelConfig](Configurable[ConfigType], torc """ def __init__( - self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str, block_config: BlockConfig + self, + config: ConfigType, + distributed_config: DistributedConfig, + block_config: BlockConfig, + block_index: int, + name: str, ): super().__init__(config) - self._tensor_space = tensor_space self._block_index = block_index + self._distributed_config = distributed_config + self._sequence_parallel: bool = self._distributed_config.sequence_tensor_parallel self._name = name - self._sequence_parallel: bool = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug = DebugLayer( - tensor_space, self._name, block_config.debug_transformer, block_config.debug_transformer_memory, ) - # @property - # def name(self) -> str: - # return self._name + def setup(self, distributed: Distributed): + self._debug.setup(distributed) class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType], torch.nn.Module): @@ -118,8 +127,16 @@ class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType], torch Base class for mixer and MLP modules. """ - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config, tensor_space, block_index, name, config.block) + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): + super().__init__(config, distributed_config, config.block, block_index, name) + self._hidden_dim = hidden_dim @abc.abstractmethod def forward( @@ -138,14 +155,18 @@ class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer): """ def __init__( - self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str, return_input: bool = False + self, + config: ConfigType, + hidden_dim: TensorDim, + distributed_config: DistributedConfig, + block_index: int, + name: str, + return_input: bool = False, ): - super().__init__(config, tensor_space, block_index, name, config) + super().__init__(config, distributed_config, config, block_index, name) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - hidden_dim = self._tensor_space[BlockDimNames.hidden] - # Note, layer_lr_scale does not impact the norms - # TODO: add a separate norm_lr_scale + # Note, layer_lr_scale does not impact the norms (TODO: Address?) self.norm_1 = self._config.peft.apply_other(self._config.normalization.get_layer(hidden_dim)) self.norm_2 = self._config.peft.apply_other(self._config.normalization.get_layer(hidden_dim)) @@ -153,9 +174,18 @@ def __init__( setattr( self, self._config.mixer.module_name, - self._config.mixer.get_layer(self._tensor_space, self._block_index, f"{self._name} mixer"), + self._config.mixer.get_layer( + hidden_dim, self._distributed_config, self._block_index, f"{self._name} mixer" + ), ) - self.mlp = self._config.mlp.get_layer(self._tensor_space, self._block_index, f"{self._name} mlp") + self.mlp = self._config.mlp.get_layer( + hidden_dim, self._distributed_config, self._block_index, f"{self._name} mlp" + ) + + def setup(self, distributed: Distributed): + super().setup(distributed) + getattr(self, self._config.mixer.module_name).setup(distributed) + self.mlp.setup(distributed) @torch.compile def _bias_dropout_add( @@ -177,11 +207,7 @@ def forward( if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self._name} output", dtype=input_.dtype) - generator = ( - self._tensor_space.distributed.tp_generator - if self._tensor_space.distributed_config.sequence_tensor_parallel - else self._tensor_space.distributed.pp_generator - ) + generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator if self._debug.enabled: self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) fw_input = input_ diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 680a122eb..a4ea6e2f9 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,11 +1,11 @@ -import enum import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.peft import TransformerPeftConfig -from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -38,13 +38,6 @@ class BlockKwargs: grad_output = "grad_output" -class AddLinearBiasChoices(str, enum.Enum): - # TODO: Review - nowhere = "nowhere" - everywhere = "everywhere" - only_attn_qkv = "only_attn_qkv" - - @config_class() class BlockLayerConfig(BaseModelConfig): """ @@ -58,8 +51,14 @@ class BlockLayerConfig(BaseModelConfig): def layer_class(self) -> "type[BlockLayer]": raise NotImplementedError() - def get_layer(self, tensor_space: TensorSpace, block_index: int, name: str) -> "BlockLayer": - return self.layer_class(self, tensor_space, block_index, name) + def get_layer( + self, + hidden_dim: TensorDim, + distributed_config: DistributedConfig, + block_index: int, + name: str, + ) -> "BlockLayer": + return self.layer_class(self, hidden_dim, distributed_config, block_index, name) @config_class(registry=True) @@ -82,7 +81,7 @@ def _from_dict( flat: bool = False, ) -> typing.Self: if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: - from fast_llm.layers.transformer.config import AttentionConfig + from fast_llm.layers.attention.config import AttentionConfig # Default subclass. return AttentionConfig._from_dict(default, strict, flat) @@ -124,15 +123,16 @@ class BlockConfig(BaseModelConfig): desc="Configuration for the MLP.", hint=FieldHint.architecture, ) - # TODO: Review names + # TODO: Allow separate initializations? normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) - peft: TransformerPeftConfig = Field( + peft: PeftConfig = Field( desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) + # TODO: Review names hidden_dropout: float = Field( default=0.0, desc="Dropout applied to the residual connections.", @@ -150,9 +150,9 @@ class BlockConfig(BaseModelConfig): desc="Log the memory usage after each operation in a transformer layer..", hint=FieldHint.logging, ) - add_linear_biases: bool | AddLinearBiasChoices = Field( + add_linear_biases: bool = Field( default=True, - desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", + desc="Whether to add biases to linear layers. May be overridden in individual layer configs.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index bde775a27..ce002dada 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,13 +1,13 @@ import enum -import functools import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_, init_zeros_ +from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.layers.block.config import BlockLayerConfig +from fast_llm.layers.common.linear.config import LinearConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -15,18 +15,19 @@ class MLPDimNames: + pass # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" + # mlp = "mlp" + # gate_and_up = "gate_and_up" + # composite_gated_mlp = "composite_gated_mlp" + # experts = "experts" + # top_experts = "top_experts" + # shared_experts = "shared_experts" + # unshared_experts = "unshared_experts" + # composite_expert_mlp = "composite_expert_mlp" + # composite_gated_expert_mlp = "composite_gated_expert_mlp" + # composite_shared_expert_mlp = "composite_shared_expert_mlp" + # composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" class MLPLossNames: @@ -130,26 +131,17 @@ class MLPConfig(BlockLayerConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) - layer_1_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the first mlp layer weights. Default: normal(0, hidden_size**-0.5).", - hint=FieldHint.feature, - ) - layer_1_bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the first mlp layer biases. Default: fill with zeros.", - hint=FieldHint.feature, - ) - layer_2_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the second mlp layer weights." - " Default: normal((2 * num_blocks * hidden_size)**-0.5)", - hint=FieldHint.feature, + layer_1: LinearConfig = Field( + desc="Configuration for the first MLP layer.", + hint=FieldHint.architecture, ) - layer_2_bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the second mlp layer biases. Default: fill with zeros.", - hint=FieldHint.feature, + layer_2: LinearConfig = Field( + desc="Configuration for the second MLP layer.", + hint=FieldHint.architecture, ) - router_weight_initialization: InitializationConfig = Field( + router: LinearConfig = Field( # TODO: Improve default? - desc="Initialization configuration for the MoE router weight. Default: normal(0, hidden_size**-0.5).", + desc="Configuration for the MoE router.", hint=FieldHint.feature, ) @@ -164,18 +156,20 @@ def layer_class(self) -> "type[MLPBase]": return MLP - @property - def add_bias(self) -> bool: - from fast_llm.layers.block.config import AddLinearBiasChoices - - if isinstance(self.block.add_linear_biases, bool): - return self.block.add_linear_biases - if self.block.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - def _validate(self) -> None: assert hasattr(self, "block") + for layer, bias, scale in zip( + (self.layer_1, self.layer_2, self.router), + (self.block.add_linear_biases, self.block.add_linear_biases, False), + (1, max(self.block.num_blocks, 1), 1), + ): + layer.default = LinearConfig( + bias=bias, + weight_initialization=init_normal_(0, (self.block.hidden_size * scale) ** -0.5), + bias_initialization=init_zeros_, + apply_peft=False, + ) + with self._set_implicit_default(): if self.activation_type is None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu @@ -198,45 +192,6 @@ def _validate(self) -> None: elif self.mlp_lr_scale is not None: Assert.geq(self.mlp_lr_scale, 0) - if self.layer_1_bias_initialization.has_initialization or self.layer_2_bias_initialization.has_initialization: - assert self.add_bias - - @functools.cached_property - def layer_1_weight_initialization_method(self) -> Initializer: - if self.layer_1_weight_initialization.has_initialization: - return self.layer_1_weight_initialization.get_initializer() - else: - return init_normal_(0, self.block.hidden_size**-0.5) - - @functools.cached_property - def layer_1_bias_initialization_method(self) -> Initializer: - if self.layer_1_bias_initialization.has_initialization: - return self.layer_1_bias_initialization.get_initializer() - else: - return init_zeros_ - - @functools.cached_property - def layer_2_weight_initialization_method(self) -> Initializer: - if self.layer_2_weight_initialization.has_initialization: - return self.layer_2_weight_initialization.get_initializer() - else: - return init_normal_(0, self.block.hidden_size**-0.5 / max(2 * self.block.num_blocks, 1)) - - @functools.cached_property - def layer_2_bias_initialization_method(self) -> Initializer: - if self.layer_2_bias_initialization.has_initialization: - return self.layer_2_bias_initialization.get_initializer() - else: - return init_zeros_ - - @functools.cached_property - def router_weight_initialization_method(self) -> Initializer: - if self.router_weight_initialization.has_initialization: - assert self.add_bias - return self.router_weight_initialization.get_initializer() - else: - return init_zeros_ - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -250,16 +205,5 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: CompositeTensorDim(MLPDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) ) tensor_space.add_tensor_dim(TensorDim(MLPDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(MLPDimNames.unshared_experts, self.num_unshared_experts)) - - # shared_experts - if self.num_shared_experts: - tensor_space.add_tensor_dim( - shared_experts := TensorDim(MLPDimNames.shared_experts, self.num_shared_experts) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(MLPDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(MLPDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp)) - ) + # composite_gated_expert_mlp + # composite_expert_mlp diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index f401371a4..bf4148fc0 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -4,15 +4,15 @@ import torch from fast_llm.core.distributed import ProcessGroup, set_generator -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames, MLPLossNames, RoutingType +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.block.mlp.config import MLPConfig, MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss -from fast_llm.layers.common.linear import Linear -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -31,41 +31,54 @@ class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): _group: ProcessGroup - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, block_index, name) + super().__init__(config, distributed_config, hidden_dim, block_index, name) layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None - router_lr_scale = get_lr_scale(self._config.router_lr_scale, layer_lr_scale) - - self.router = Linear( - tensor_space[BlockDimNames.hidden], - tensor_space[MLPDimNames.unshared_experts], - bias=False, - weight_init_method=init_normal_( - std=self._config.init_method_std, - min_val=self._config.init_method_min, - max_val=self._config.init_method_max, - ), - lr_scale=router_lr_scale, + + self.router = self._config.router.get_layer( + hidden_dim, + TensorDim("router_experts", self._config.num_unshared_experts), + lr_scale=layer_lr_scale, ) dropless_moe = self._config.dropless_moe - if dropless_moe and tensor_space.distributed_config.sequence_tensor_parallel: + if dropless_moe and self._sequence_parallel: warnings.warn( "Dropless MoE not supported for sequence-tensor-parallel, falling back to looped implementation." ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped + if self._debug.enabled: + self._top_expert_dim = TensorDim("top_experts", self._config.num_experts_per_token) + + def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: + intermediate_1_dim, intermediate_2_dim = super()._get_intermediate_dims() + experts_dim = TensorDim("experts", self._config.num_experts) + return ( + CompositeTensorDim("moe_intermediate_1", (experts_dim, intermediate_1_dim)), + CompositeTensorDim("moe_intermediate_2", (experts_dim, intermediate_2_dim)), + ) + def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug.enabled: - self._debug(logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.experts,), kwargs) + self._debug( + logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs + ) # Apply z_loss if applicable if self._config.expert_z_loss_coefficient > 0.0: @@ -96,12 +109,12 @@ def forward( if self._debug.enabled: # To log all ranks set `global_=False` self._debug( - scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), kwargs + scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs ) self._debug( top_experts, "Router top experts", - kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), + kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs, ) diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 0716bf777..2fdf2d92f 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,59 +2,61 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_space import ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockConfig, BlockDimNames -from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames -from fast_llm.layers.block.peft import TransformerSubLayerName -from fast_llm.layers.common.linear import LinearBase -from fast_llm.utils import get_lr_scale +from fast_llm.layers.block.mlp.config import MLPConfig class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config, tensor_space, block_index, name) + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): + super().__init__(config, distributed_config, hidden_dim, block_index, name) + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() - hidden_dim = self._tensor_space[BlockDimNames.hidden] - self._intermediate_dim = self._tensor_space[MLPDimNames.composite_expert_mlp] self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None - lr_scale = ( - tuple(self._config.mlp_lr_scale) - if isinstance(self._config.mlp_lr_scale, list) - else self._config.mlp_lr_scale - ) - lr_scale = get_lr_scale(lr_scale, layer_lr_scale) + lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) - self.layer_1 = LinearBase( + self.layer_1 = self._config.layer_1.get_layer( hidden_dim, - self._tensor_space[MLPDimNames.composite_gated_expert_mlp], - bias=self._config.add_bias, - weight_init_method=self._config.layer_1_weight_initialization_method, - bias_init_method=self._config.layer_1_bias_initialization_method, + intermediate_1_dim, + sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, + peft=self._config.block.peft, ) - self.layer_2 = LinearBase( - self._intermediate_dim, + self.layer_2 = self._config.layer_2.get_layer( + intermediate_2_dim, hidden_dim, - bias=self._config.add_bias, - weight_init_method=self._config.layer_2_weight_initialization_method, - bias_init_method=self._config.layer_2_bias_initialization_method, - auto_bias_grad_accumulation=self._tensor_space.distributed_config.tensor_parallel > 1, + auto_bias_grad_accumulation=self._parallel_dim.size > 1, transposed_weight=True, lr_scale=lr_scale, + peft=self._config.block.peft, ) - # PEFT. - self.layer_1 = self._config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) - self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + def _get_intermediate_dims(self): + intermediate_2_dim = TensorDim("intermediate", self._config.ffn_hidden_size, self._parallel_dim) + if self.gated: + TensorDim("gate_and_up", 2) + intermediate_1_dim = ConcatenatedTensorDim("gate_and_up", (intermediate_2_dim, intermediate_2_dim)) + else: + intermediate_1_dim = intermediate_2_dim + return intermediate_1_dim, intermediate_2_dim class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): + def forward( self, input_: torch.Tensor, @@ -62,7 +64,6 @@ def forward( losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - parallel_group = self._intermediate_dim.parallel_group return ( mlp_autograd( input_, @@ -70,14 +71,14 @@ def forward( self.layer_1.weight, self.layer_1.bias, self.layer_2.weight, - None if parallel_group else self.layer_2.bias, + None if self._parallel_dim.group else self.layer_2.bias, gated=self._config.gated, activation_type=self._config.activation_type, - group=parallel_group, + group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, training=self.training, recompute_level=self._config.mlp_recompute_level, transposed_layer_2_weight=self.layer_2.transposed_weight, ), - self.layer_2.bias if parallel_group else None, + self.layer_2.bias if self._parallel_dim.group else None, ) diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py deleted file mode 100644 index 66bc675ed..000000000 --- a/fast_llm/layers/block/peft.py +++ /dev/null @@ -1,128 +0,0 @@ -""" -TODO: Generalize beyond transformers. -""" - -import abc -import enum -import typing - -from fast_llm.config import Field, FieldHint, config_class -from fast_llm.layers.common.config import LoRAConfig, NoPeftConfig, PeftConfig -from fast_llm.utils import div - -if typing.TYPE_CHECKING: - import torch - - from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.tensor import ParameterMeta - - -class TransformerSubLayerName(str, enum.Enum): - # TODO: Use this to replace AddLinearBiasChoices. - query = "query" - key = "key" - value_ = "value" - key_value = "key_value" - dense = "dense" - mlp_1 = "mlp_1" - mlp_2 = "mlp_2" - - -@config_class(registry=True) -class TransformerPeftConfig(PeftConfig): - @abc.abstractmethod - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - pass - - @abc.abstractmethod - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - pass - - @abc.abstractmethod - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - pass - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return TransformerNoPeftConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - -@config_class(dynamic_type={TransformerPeftConfig: "none"}) -class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): - _abstract = False - - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - return super().apply_linear(linear) - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - return parameter - - -@config_class(dynamic_type={TransformerPeftConfig: "lora"}) -class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): - layers: list[TransformerSubLayerName] = Field( - default=(TransformerSubLayerName.query, TransformerSubLayerName.value_), - desc="The layers on which to apply LoRA.", - hint=FieldHint.feature, - ) - freeze_others: bool = Field( - default=True, - desc="Whether to freeze other layers during training.", - ) - - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - if layer_type is None or self.layers is None or layer_type in self.layers: - if layer_type == TransformerSubLayerName.key: - return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) - elif layer_type == TransformerSubLayerName.value_: - return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) - else: - return super().apply_linear(linear) - elif self.freeze_others: - linear.weight.requires_grad = False - return linear - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - if self.freeze_others: - for parameter in module.parameters(): - parameter.requires_grad = False - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - if self.freeze_others: - parameter.requires_grad = False - return parameter - - def _validate(self) -> None: - super()._validate() - if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: - # TODO: Add MLP support. - raise NotImplementedError("LoRA not supported for MLP.") - if TransformerSubLayerName.dense in self.layers: - # TODO: Support InputParallelLinear (different output format). - raise NotImplementedError("LoRA not supported for attention dense layer.") - if ( - sum( - name in self.layers - for name in ( - TransformerSubLayerName.key_value, - TransformerSubLayerName.key, - TransformerSubLayerName.value_, - ) - ) - > 1 - ): - raise ValueError( - f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." - ) diff --git a/fast_llm/layers/common/linear/__init__.py b/fast_llm/layers/common/linear/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py new file mode 100644 index 000000000..e7bd26df2 --- /dev/null +++ b/fast_llm/layers/common/linear/config.py @@ -0,0 +1,88 @@ +import typing + +from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.engine.config_utils.initialization import InitializationConfig +from fast_llm.layers.common.peft.config import PeftConfig + +if typing.TYPE_CHECKING: + from fast_llm.engine.config_utils.tensor_space import TensorDim + from fast_llm.layers.common.linear.linear import InputParallelLinear, Linear, LinearLike, OutputParallelLinear + + +@config_class() +class LinearConfig(Config): + bias: bool = Field( + default=None, + desc="Use bias.", + hint=FieldHint.architecture, + ) + weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the weight.", + hint=FieldHint.feature, + ) + bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the bias.", + hint=FieldHint.feature, + ) + lr_scale: float | None | tuple[float | None, ...] = (None,) + apply_peft: bool = Field( + default=None, + desc="Apply peft on this layer if defined. Otherwise, treat the layer as a non-peft layer (may be frozen).", + hint=FieldHint.feature, + ) + # Fixed defaults don't make sense because each parent layer uses its own. + # Instead, we use this variable to set defaults dynamically. + # This can either be a constant, + # or may point to another config, ex. to set a default for all layers in a model. + default: "LinearConfig" = Field(init=False) + + def _validate(self) -> None: + if hasattr(self, "default"): + self.default.validate() + with self._set_implicit_default(): + if self.bias is None: + self.bias = self.default.bias + if self.weight_initialization.is_default: + self.weight_initialization = self.default.weight_initialization + if self.bias_initialization.is_default: + self.bias_initialization = self.default.bias_initialization + if self.lr_scale is None: + self.lr_scale = self.default.lr_scale + if self.apply_peft is None: + self.apply_peft = self.default.apply_peft + if None in (self.bias, self.weight_initialization, self.bias_initialization, self.lr_scale, self.apply_peft): + raise ValueError("Missing default values for linear layer.") + + super()._validate() + + def get_layer( + self, + in_dim: TensorDim, + out_dim: TensorDim, + *, + sequence_parallel: bool = False, + transposed_weight: bool = False, + auto_bias_grad_accumulation: bool = False, + lr_scale: float | None | tuple[float | None, ...], + peft: PeftConfig | None = None, + ) -> "LinearLike": + if in_dim.parallel_dim is not None: + assert out_dim.parallel_dim is None + cls = InputParallelLinear + elif out_dim.parallel_dim is not None: + cls = OutputParallelLinear + else: + assert not sequence_parallel + cls = Linear + out = cls( + self, + in_dim, + out_dim, + transposed_weight=transposed_weight, + sequence_parallel=sequence_parallel, + auto_bias_grad_accumulation=auto_bias_grad_accumulation, + lr_scale=lr_scale, + ) + if peft is not None: + out = peft.apply_linear(out, self.apply_peft) + return out diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear/linear.py similarity index 81% rename from fast_llm/layers/common/linear.py rename to fast_llm/layers/common/linear/linear.py index 740b4847c..faa01b7fa 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear/linear.py @@ -3,7 +3,7 @@ import torch -from fast_llm.engine.config_utils.initialization import init_zeros_ +from fast_llm.config import Configurable from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( @@ -15,7 +15,9 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) +from fast_llm.layers.common.linear.config import LinearConfig from fast_llm.tensor import ParameterMeta +from fast_llm.utils import combine_lr_scales logger = logging.getLogger(__name__) @@ -35,38 +37,38 @@ def backward(self, grad_output: torch.Tensor, context: typing.Any) -> torch.Tens raise NotImplementedError() -class LinearBase(LinearLike): +class LinearBase(Configurable[LinearConfig], LinearLike): """ A base module for linear layers holding weights and biases. """ def __init__( self, + config: LinearConfig, in_dim: TensorDim, out_dim: TensorDim, *, - bias=True, - weight_init_method, - bias_init_method=init_zeros_, transposed_weight: bool = False, + sequence_parallel: bool = False, auto_bias_grad_accumulation: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - super().__init__() + super().__init__(config) self._transposed_weight = transposed_weight + self._sequence_parallel = sequence_parallel self._in_dim = in_dim self._out_dim = out_dim - self._weight_init_method = weight_init_method + lr_scale = combine_lr_scales(self._config.lr_scale, lr_scale) self.weight = ParameterMeta.from_dims( (self._in_dim, self._out_dim) if self._transposed_weight else (self._out_dim, self._in_dim), - init_method=weight_init_method, + init_method=self._config.weight_initialization, auto_grad_accumulation=False, lr_scale=lr_scale, ) - if bias: + if self._config.bias: self.bias = ParameterMeta.from_dims( (self._out_dim,), - init_method=bias_init_method, + init_method=self._config.bias_initialization, weight_decay=False, auto_grad_accumulation=auto_bias_grad_accumulation, lr_scale=lr_scale, @@ -86,24 +88,25 @@ class Linear(LinearBase): def __init__( self, + config: LinearConfig, in_dim: TensorDim, out_dim: TensorDim, *, - bias=True, - weight_init_method, - bias_init_method=init_zeros_, transposed_weight: bool = False, + sequence_parallel: bool = False, + auto_bias_grad_accumulation: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): assert not in_dim.is_parallel assert not out_dim.is_parallel + assert not sequence_parallel super().__init__( + config, in_dim, out_dim, - bias=bias, - weight_init_method=weight_init_method, - bias_init_method=bias_init_method, transposed_weight=transposed_weight, + sequence_parallel=sequence_parallel, + auto_bias_grad_accumulation=auto_bias_grad_accumulation, lr_scale=lr_scale, ) @@ -123,26 +126,24 @@ class OutputParallelLinear(LinearBase): def __init__( self, + config: LinearConfig, in_dim: TensorDim, out_dim: TensorDim, *, - bias=True, - weight_init_method, - bias_init_method=init_zeros_, transposed_weight: bool = False, sequence_parallel: bool = False, + auto_bias_grad_accumulation: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): assert not in_dim.is_parallel self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size - self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( + config, in_dim, out_dim, - bias=bias, - weight_init_method=weight_init_method, - bias_init_method=bias_init_method, transposed_weight=transposed_weight, + sequence_parallel=sequence_parallel and self._group_size > 1, + auto_bias_grad_accumulation=auto_bias_grad_accumulation, lr_scale=lr_scale, ) @@ -167,28 +168,24 @@ class InputParallelLinear(LinearBase): def __init__( self, + config: LinearConfig, in_dim: TensorDim, out_dim: TensorDim, *, - bias=True, - weight_init_method, - bias_init_method=init_zeros_, - sequence_parallel: bool = False, transposed_weight: bool = False, + sequence_parallel: bool = False, + auto_bias_grad_accumulation: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): assert not out_dim.is_parallel self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size - self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( + config, in_dim, out_dim, - bias=bias, - weight_init_method=weight_init_method, - bias_init_method=bias_init_method, transposed_weight=transposed_weight, - # Tensor-parallel bias is computed in _bias_dropout_grad. - auto_bias_grad_accumulation=self._group_size > 1, + sequence_parallel=sequence_parallel and self._group_size > 1, + auto_bias_grad_accumulation=auto_bias_grad_accumulation, lr_scale=lr_scale, ) diff --git a/fast_llm/layers/common/normalization/__init__.py b/fast_llm/layers/common/normalization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/normalization/config.py similarity index 62% rename from fast_llm/layers/common/config.py rename to fast_llm/layers/common/normalization/config.py index 8483dc573..d2d494dba 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -1,16 +1,15 @@ import abc import enum -import functools import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_ones_, init_zeros_ +from fast_llm.engine.config_utils.initialization import InitializationConfig from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.common.linear import LinearBase, LinearLike from fast_llm.layers.common.normalization import Normalization @@ -28,15 +27,23 @@ class NormalizationImplementation(str, enum.Enum): @config_class(registry=True) class NormalizationConfig(BaseModelConfig): - pass + lr_scale: float | None = (None,) @property @abc.abstractmethod def module_class(self) -> type["Normalization"]: pass - def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "Normalization": - return self.module_class(self, hidden_dim, lr_scale) + def get_layer( + self, + hidden_dim: "TensorDim", + lr_scale: float | None = None, + peft: PeftConfig | None = None, + ) -> "Normalization": + out = self.module_class(self, hidden_dim, lr_scale) + if peft: + out = peft.apply_other(out) + return out @classmethod def _from_dict( @@ -51,13 +58,22 @@ def _from_dict( return super()._from_dict(default, strict=strict, flat=flat) +@config_class(dynamic_type={NormalizationConfig: "none"}) +class DefaultNormalizationConfig(NormalizationConfig): + _abstract = False + + @property + def module_class(self) -> type["Normalization"]: + raise NotImplementedError() + + @config_class(dynamic_type={NormalizationConfig: "none"}) class NoNormalizationConfig(NormalizationConfig): _abstract = False @property def module_class(self) -> type["Normalization"]: - from fast_llm.layers.common.normalization import NoNormalization + from fast_llm.layers.common.normalization.normalization import NoNormalization return NoNormalization @@ -104,12 +120,12 @@ def _from_dict( cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") return super()._from_dict(default, strict, flat) - @functools.cached_property - def weight_initialization_method(self) -> Initializer: - if self.weight_initialization.has_initialization: - return self.weight_initialization.get_initializer() - else: - return init_ones_ + # @functools.cached_property + # def weight_initialization_method(self) -> Initializer: + # if self.weight_initialization.is_default: + # return self.weight_initialization.get_initializer() + # else: + # return init_ones_ @config_class(dynamic_type={NormalizationConfig: "layer_norm"}) @@ -120,16 +136,16 @@ class LayerNormalizationConfig(LayerNormalizationBaseConfig): hint=FieldHint.feature, ) - @functools.cached_property - def bias_initialization_method(self) -> Initializer: - if self.bias_initialization.has_initialization: - return self.bias_initialization.get_initializer() - else: - return init_zeros_ + # @functools.cached_property + # def bias_initialization_method(self) -> Initializer: + # if self.bias_initialization.is_default: + # return self.bias_initialization.get_initializer() + # else: + # return init_zeros_ @property def module_class(self): - from fast_llm.layers.common.normalization import LayerNormalization + from fast_llm.layers.common.normalization.normalization import LayerNormalization return LayerNormalization @@ -140,56 +156,6 @@ class RMSNormalizationConfig(LayerNormalizationBaseConfig): @property def module_class(self): - from fast_llm.layers.common.normalization import RMSNormalization + from fast_llm.layers.common.normalization.normalization import RMSNormalization return RMSNormalization - - -@config_class() -class PeftConfig(BaseModelConfig): - @abc.abstractmethod - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - pass - - -@config_class() -class NoPeftConfig(PeftConfig): - _abstract = False - - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - return linear - - -@config_class() -class LoRAConfig(PeftConfig): - _abstract = False - - rank: int = Field( - default=8, - desc="The LoRA rank, i.e. the size of the intermediate dimension.", - hint=FieldHint.stability, - ) - alpha: float = Field( - default=8.0, - desc="The LoRA scaling parameter.", - hint=FieldHint.stability, - ) - dropout: float = Field( - default=0.0, - desc="Dropout rate for LoRA.", - hint=FieldHint.stability, - ) - - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - from fast_llm.layers.common.peft import lora_linear - - # TODO: Init method? - return lora_linear( - linear, - linear.weight.param_init_method, - linear.weight.param_init_method, - self.rank, - self.alpha, - self.dropout, - **kwargs, - ) diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization/normalization.py similarity index 98% rename from fast_llm/layers/common/normalization.py rename to fast_llm/layers/common/normalization/normalization.py index cedfd2294..8b5a61c8b 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -7,7 +7,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd -from fast_llm.layers.common.config import ( +from fast_llm.layers.common.normalization.config import ( LayerNormalizationConfig, NoNormalizationConfig, NormalizationConfig, @@ -15,7 +15,7 @@ RMSNormalizationConfig, ) from fast_llm.tensor import ParameterMeta, accumulate_gradient -from fast_llm.utils import Assert +from fast_llm.utils import Assert, combine_lr_scales try: import fused_layer_norm_cuda # noqa @@ -156,7 +156,7 @@ def __init__( ): super().__init__(config) self._hidden_dim = hidden_dim - self._lr_scale = lr_scale + self._lr_scale = combine_lr_scales(self._config.lr_scale, lr_scale) assert not self._hidden_dim.is_parallel @abc.abstractmethod diff --git a/fast_llm/layers/common/peft/__init__.py b/fast_llm/layers/common/peft/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py new file mode 100644 index 000000000..69f6c7577 --- /dev/null +++ b/fast_llm/layers/common/peft/config.py @@ -0,0 +1,82 @@ +import typing + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.base_model.config import BaseModelConfig + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.layers.common.linear.linear import LinearBase, LinearLike + from fast_llm.layers.common.normalization.normalization import Normalization + from fast_llm.tensor import ParameterMeta + + +@config_class(registry=True) +class PeftConfig(BaseModelConfig): + def apply_linear(self, module: "LinearBase", enabled: bool) -> "LinearLike": + return self.apply_other(module) + + def apply_normalization(self, module: "Normalization") -> "Normalization": + return self.apply_other(module) + + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + for parameter in module.parameters(): + self.apply_weight(parameter) + return module + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + return parameter + + +@config_class(dynamic_type={PeftConfig: "none"}) +class NoPeftConfig(PeftConfig): + _abstract = False + + +@config_class(dynamic_type={PeftConfig: "lora"}) +class LoRAConfig(PeftConfig): + _abstract = False + + rank: int = Field( + default=8, + desc="The LoRA rank, i.e. the size of the intermediate dimension.", + hint=FieldHint.stability, + ) + alpha: float = Field( + default=8.0, + desc="The LoRA scaling parameter.", + hint=FieldHint.stability, + ) + dropout: float = Field( + default=0.0, + desc="Dropout rate for LoRA.", + hint=FieldHint.stability, + ) + freeze_others: bool = Field( + default=True, + desc="Whether to freeze other layers during training.", + ) + + def apply_linear(self, module: "LinearBase", enabled: bool) -> "LinearLike": + if not enabled: + return self.apply_other(module) + + from fast_llm.layers.common.linear.linear import InputParallelLinear + from fast_llm.layers.common.peft.peft import lora_linear + + if isinstance(module, InputParallelLinear): + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for InputParallelLinear.") + + # TODO: Init method? + return lora_linear( + module, + self.rank, + self.alpha, + self.dropout, + ) + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + if self.freeze_others: + parameter.requires_grad = False + return parameter diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft/peft.py similarity index 68% rename from fast_llm/layers/common/peft.py rename to fast_llm/layers/common/peft/peft.py index 08f3e535b..c4cd2a26e 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft/peft.py @@ -4,25 +4,24 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.common.linear import Linear, LinearBase +from fast_llm.layers.common.linear.config import LinearConfig +from fast_llm.layers.common.linear.linear import Linear, LinearBase def lora_linear( - layer: LinearBase, - init_method_0, - init_method_1, + module: LinearBase, rank: int, alpha: float, dropout: float = 0.0, out_channel_begin: int | None = None, out_channel_end: int | None = None, ): - layer.weight.requires_grad = False - in_dim = layer._in_dim + module.weight.requires_grad = False + in_dim = module._in_dim assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: in_dim = TensorDim(in_dim.name, in_dim.global_size) - out_dim = layer._out_dim + out_dim = module._out_dim assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: out_dim = TensorDim(out_dim.name, out_dim.global_size) @@ -35,28 +34,25 @@ def lora_linear( out_dim = TensorDim(out_dim.name, out_channel_end - out_channel_begin) middle_dim = TensorDim("lora_middle", rank) + config = LinearConfig.from_dict(module.config, {"bias": False, "lr_scale": module.weight.lr_scale}) - layer.lora_0 = Linear( + module.lora_0 = Linear( + config, in_dim, middle_dim, - bias=False, - weight_init_method=init_method_0, - transposed_weight=layer.transposed_weight, - lr_scale=layer.weight.lr_scale, + transposed_weight=module.transposed_weight, ) - layer.lora_1 = Linear( + module.lora_1 = Linear( + config, middle_dim, out_dim, - bias=False, - weight_init_method=init_method_1, - transposed_weight=layer.transposed_weight, - lr_scale=layer.weight.lr_scale, + transposed_weight=module.transposed_weight, ) # TODO: Implement proper backward pass. - layer.lora_0.weight.auto_grad_accumulation = True - layer.lora_1.weight.auto_grad_accumulation = True + module.lora_0.weight.auto_grad_accumulation = True + module.lora_1.weight.auto_grad_accumulation = True - old_forward = layer._forward + old_forward = module._forward def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: # TODO: torch compile? @@ -66,8 +62,8 @@ def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor if isinstance(output, tuple): layer_out, tp_bias = output[0] assert tp_bias is None - lora_out = (alpha / rank) * layer.lora_1( - layer.lora_0(torch.dropout(input_, dropout, layer.training) if dropout > 0.0 else input_) + lora_out = (alpha / rank) * module.lora_1( + module.lora_0(torch.dropout(input_, dropout, module.training) if dropout > 0.0 else input_) ) if out_channel_begin is None: output = output + lora_out @@ -83,8 +79,8 @@ def backward( output.backward(grad_output) return input_.grad - layer._forward = wrap_forward_backward(forward_only, backward) - layer.forward_only = forward_only - layer.backward = backward + module._forward = wrap_forward_backward(forward_only, backward) + module.forward_only = forward_only + module.backward = backward - return layer + return module diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 943c64d01..c5e1563e6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,8 +1,6 @@ -import functools - from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_ +from fast_llm.engine.config_utils.initialization import InitializationConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl @@ -215,9 +213,9 @@ def _validate(self) -> None: Assert.eq( len(self.transformer.per_layer_lr_scale), self.transformer.num_blocks + self.prediction_heads - 1 + 1 ) - if self.output_weight_initialization.has_initialization: + if not self.output_weight_initialization.is_default: assert self.use_absolute_position_embeddings - if self.output_weight_initialization.has_initialization: + if not self.output_weight_initialization.is_default: assert not self.tie_word_embeddings def setup_tensor_space(self, tensor_space: TensorSpace) -> None: @@ -237,23 +235,23 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: def use_absolute_position_embeddings(self) -> int: return self.absolute_position_embeddings is not None - @functools.cached_property - def word_embedding_weight_initialization_method(self) -> Initializer: - if self.word_embedding_weight_initialization.has_initialization: - return self.word_embedding_weight_initialization.get_initializer() - else: - return init_normal_(self.transformer.hidden_size**-0.5) + # @functools.cached_property + # def word_embedding_weight_initialization_method(self) -> Initializer: + # if self.word_embedding_weight_initialization.is_default: + # return self.word_embedding_weight_initialization.get_initializer() + # else: + # return init_normal_(self.transformer.hidden_size**-0.5) - @functools.cached_property - def position_embedding_weight_initialization_method(self) -> Initializer: - if self.position_embedding_weight_initialization.has_initialization: - return self.position_embedding_weight_initialization.get_initializer() - else: - return init_normal_(self.transformer.hidden_size**-0.5) + # @functools.cached_property + # def position_embedding_weight_initialization_method(self) -> Initializer: + # if self.position_embedding_weight_initialization.is_default: + # return self.position_embedding_weight_initialization.get_initializer() + # else: + # return init_normal_(self.transformer.hidden_size**-0.5) - @functools.cached_property - def output_weight_initialization_method(self) -> Initializer: - if self.output_weight_initialization.has_initialization: - return self.output_weight_initialization.get_initializer() - else: - return init_normal_(self.transformer.hidden_size**-0.5) + # @functools.cached_property + # def output_weight_initialization_method(self) -> Initializer: + # if self.output_weight_initialization.is_default: + # return self.output_weight_initialization.get_initializer() + # else: + # return init_normal_(self.transformer.hidden_size**-0.5) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 364dc745e..6e1ef4bb5 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -64,7 +64,9 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, prediction_dis else 1.0 ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim) + self.final_norm = self._config.transformer.normalization.get_layer( + hidden_dim, peft=self._config.transformer.peft + ) self._logits_scale_factor = self._config.logits_scale_factor self._language_model_loss_factor = self._config.language_model_loss_factor self._distillation_loss_factor = self._config.distillation_loss_factor @@ -102,7 +104,6 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, prediction_dis self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) # PEFT. - self.final_norm = self._config.transformer.peft.apply_other(self.final_norm) if hasattr(self, "output_weights"): self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index dec0675b9..dd7e9fbe5 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -5,33 +5,33 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.attention.config import MixerDimNames, setup_mqa_dims from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.initialization import Initializer -class SSMDimNames(BlockDimNames): +class SSMDimNames(MixerDimNames): # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. - state = "ssm_state" # State dimension (N), aka head size / num channels - head_dim = "ssm_head_dim" - head_groups = "ssm_head_groups" - group_heads = "ssm_group_heads" + # state = "ssm_state" # State dimension (N), aka head size / num channels + # head_dim = "ssm_head_dim" + # head_groups = "ssm_head_groups" + # group_heads = "ssm_group_heads" convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers dt_rank = "ssm_dt_rank" # Composite dimensions - composite_heads = "ssm_composite_heads" - composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" - composite_head_groups_and_state = "ssm_composite_head_groups_and_state" + # composite_heads = "ssm_composite_heads" + # composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" + # composite_head_groups_and_state = "ssm_composite_head_groups_and_state" # Concatenated dimensions concatenated_convolution = "ssm_concatenated_convolution" concatenated_x_projection = "ssm_x_concatenated_x_projection" - concatenated_inner_projection = "ssm_concatenated_inner_projection" + # concatenated_inner_projection = "ssm_concatenated_inner_projection" class SSMBlockType(enum.StrEnum): @@ -109,7 +109,7 @@ class SSMConfig(Config): desc="Number of QK heads for Mamba2 blocks.", hint=FieldHint.architecture, ) - # heads [DiscreteMamba2]# TODO: Remove? (redundant) + # heads [DiscreteMamba2] n_v_heads: int = Field( default=32, desc="Number of V heads for Mamba2 blocks.", @@ -208,6 +208,34 @@ def _validate(self) -> None: def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + if block_type == SSMBlockType.mamba: + setup_mqa_dims( + tensor_space, + div(self.d_inner, self.state_size), + div(self.d_inner, self.state_size), + self.state_size, + self.state_size, + ) + + elif block_type == SSMBlockType.mamba2: + setup_mqa_dims( + tensor_space, + div(self.d_inner, self.state_size), + div(self.d_xb, self.state_size), + self.state_size, + self.state_size, + ) + + elif block_type == SSMBlockType.mamba2_discrete: + # TODO: Use different variables? + num_heads = self.n_v_heads + num_head_groups = self.n_qk_heads + setup_mqa_dims( + tensor_space, self.n_v_heads, self.n_qk_heads, div(self.d_inner, self.n_v_heads), self.state_size + ) + else: + raise NotImplementedError(block_type) + # Head groups are configured differently depending on the block type. if block_type == SSMBlockType.mamba: num_heads = div(self.d_inner, self.state_size) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 61291f845..7a8fffef2 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -9,11 +9,11 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear +from fast_llm.layers.common.linear.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_kaiming_ from fast_llm.tensor import ParameterMeta -from fast_llm.utils import get_lr_scale +from fast_llm.utils import combine_lr_scales logger = logging.getLogger(__name__) @@ -55,7 +55,7 @@ def __init__( ) self._config: SSMConfig = config layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._config.mamba_lr_scale, layer_lr_scale) inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] hidden_dim = tensor_space[SSMDimNames.hidden] diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b6626e893..f090f2216 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -8,11 +8,11 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear +from fast_llm.layers.common.linear.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias, init_kaiming_ from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, div, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales, div try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -69,7 +69,9 @@ def __init__( layer_lr_scale: float | None = ( block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None ) - lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + lr_scale: float | tuple[float | None, ...] | None = combine_lr_scales( + self._config.mamba_lr_scale, layer_lr_scale + ) inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 0dcc29f0b..a249b25ab 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -9,10 +9,10 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.common.linear import Linear +from fast_llm.layers.common.linear.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -79,7 +79,7 @@ def __init__( inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] hidden_dim = tensor_space[SSMDimNames.hidden] layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._config.mamba_lr_scale, layer_lr_scale) # TODO: Backward compatibility? # TODO: lr_scale? diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/transformer/block.py deleted file mode 100644 index 89d7a2e3b..000000000 --- a/fast_llm/layers/transformer/block.py +++ /dev/null @@ -1,22 +0,0 @@ -import logging -import typing - -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.block.block import Block, BlockLayer -from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig - -logger = logging.getLogger(__name__) - - -class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): - _name = "Transformer layer" - # TODO: Standardize to `mixer` - _mixer_module_name: typing.ClassVar[str] = "self_attn" - _config: TransformerConfig - - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): - super().__init__(config, tensor_space, block_index, return_input) - - def _create_mixer(self) -> BlockLayer: - return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py deleted file mode 100644 index e8c319b0f..000000000 --- a/fast_llm/layers/transformer/config.py +++ /dev/null @@ -1,227 +0,0 @@ -import functools -import logging -import typing -import warnings - -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_, init_zeros_ -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.functional.config import TritonConfig -from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs, MixerConfig -from fast_llm.layers.transformer.rotary.config import RotaryConfig -from fast_llm.utils import Assert, div - -logger = logging.getLogger(__name__) - - -class AttentionDimNames(BlockDimNames): - # A set of common tensor dim names packed into a namespace. - # Self-attention dimensions - head_groups = "head_groups" - group_heads = "group_heads" - key_and_value = "key_value" - kv_channels = "kv_channels" - composite_heads = "composite_heads" - composite_query = "composite_query" - composite_key_value = "composite_key_value" - composite_dense = "composite_dense" - - -class AttentionKwargs(BlockKwargs): - rotary_freq_q = "rotary_freq_q" - rotary_freq_k = "rotary_freq_k" - attention_mask = "attention_mask" - attention_mask_value = "attention_mask_value" - cu_seqlens_q = "cu_seqlens_q" - cu_seqlens_k = "cu_seqlens_k" - max_seqlen_q = "max_seqlen_q" - max_seqlen_k = "max_seqlen_k" - # TODO: Review these - presents = "presents" - past_key_values = "past_key_values" - - -@config_class(dynamic_type={MixerConfig: "attention"}) -class AttentionConfig(MixerConfig): - _abstract = False - - # Needed for backward compatibility. TODO: remove - module_name: typing.ClassVar[str] = "attn" - - # TODO: Review names - rotary: RotaryConfig = Field( - desc="Configuration for the rotary positional embeddings.", - hint=FieldHint.architecture, - ) - num_attention_heads: int = Field(default=8, desc="Number of attention heads.", hint=FieldHint.architecture) - head_groups: int = Field( - default=1, - desc="Number of head group for grouped query attention.", - doc="Set to 1 for multi-query attention, `num_attention_heads` for multi-head.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - kv_channels: int = Field( - default=None, - desc="Number of key and value channels, i.e., hidden dimension of each attention head. Default: hidden_size // num_attention_heads", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - attention_dropout: float = Field( - default=0.0, - desc="Dropout applied to the attention intermediate states.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - # Use flash attention if possible (fp16 or bf16) - use_flash_attention: bool = Field( - default=True, desc="Enable Flash Attention if possible.", hint=FieldHint.optional - ) - window_size: int | None = Field( - default=None, - desc="Size of the attention sliding window. Warning: this parameter is not part of the architecture and must be redefined when loading a pretrained model.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - max_window_layers: int | None = Field( - default=None, - desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.", - hint=FieldHint.optional, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate scale for the Attention projection weights.", - doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_softmax_scale_power: float = Field( - default=0.5, - desc="The scaling power to apply to kv_channel in the attention calculation. " - " Under Standard Parameterization (SP): default to 0.5. " - " Under muP (if scaling kv_channels size): use 1. " - " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - qkv_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the query, key and value layer weights." - " Default: normal(std=hidden_size**-0.5)", - hint=FieldHint.feature, - ) - qkv_bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the query, key and value layer biases. Default: fill with zeros.", - hint=FieldHint.feature, - ) - dense_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the dense layer weight." - " Default: normal(std=(2 * num_blocks * hidden_size)**-0.5)", - hint=FieldHint.feature, - ) - dense_bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the dense layer biases. Default: fill with zeros.", - hint=FieldHint.feature, - ) - - def _validate(self) -> None: - with self._set_implicit_default(): - # TODO: hidden_size not yet validated. - if self.kv_channels is None: - self.kv_channels = div(self.block.hidden_size, self.num_attention_heads) - - super()._validate() - - if not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") - - Assert.multiple(self.num_attention_heads, self.head_groups) - if self.qkv_bias_initialization.has_initialization: - assert self.add_qkv_bias - if self.dense_bias_initialization.has_initialization: - assert self.add_dense_bias - - @functools.cached_property - def projection_size(self): - assert self._validated - return self.num_attention_heads * self.kv_channels - - def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - # Needed for multiple inheritance. - super().setup_tensor_space(tensor_space) # Noqa - - tensor_space.add_tensor_dim( - head_groups := TensorDim( - AttentionDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None - ) - ) - tensor_space.add_tensor_dim( - group_heads := TensorDim( - AttentionDimNames.group_heads, - div(self.num_attention_heads, self.head_groups), - None if self.head_groups > 1 else tensor, - ) - ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(AttentionDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(AttentionDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim(CompositeTensorDim(AttentionDimNames.composite_heads, (head_groups, group_heads))) - tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_query, (head_groups, group_heads, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_dense, (head_groups, group_heads, kv_channels)) - ) - - @functools.cached_property - def add_qkv_bias(self) -> bool: - if isinstance(self.block.add_linear_biases, bool): - return self.block.add_linear_biases - return self.block.add_linear_biases != AddLinearBiasChoices.nowhere - - @functools.cached_property - def add_dense_bias(self) -> bool: - if isinstance(self.block.add_linear_biases, bool): - return self.block.add_linear_biases - return self.block.add_linear_biases == AddLinearBiasChoices.everywhere - - @functools.cached_property - def qkv_weight_initialization_method(self) -> Initializer: - if self.qkv_weight_initialization.has_initialization: - return self.qkv_weight_initialization.get_initializer() - else: - return init_normal_(0, self.block.hidden_size**-0.5) - - @functools.cached_property - def qkv_bias_initialization_method(self) -> Initializer: - if self.qkv_bias_initialization.has_initialization: - return self.qkv_bias_initialization.get_initializer() - else: - return init_zeros_ - - @functools.cached_property - def dense_weight_initialization_method(self) -> Initializer: - if self.dense_weight_initialization.has_initialization: - return self.dense_weight_initialization.get_initializer() - else: - return init_normal_(0, self.block.hidden_size**-0.5 / max(2 * self.block.num_blocks, 1)) - - @functools.cached_property - def dense_bias_initialization_method(self) -> Initializer: - if self.dense_bias_initialization.has_initialization: - return self.dense_bias_initialization.get_initializer() - else: - return init_zeros_ - - -@config_class() -# TODO: Remove -class TransformerConfig(BlockConfig): - pass diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py deleted file mode 100644 index 9f8732f85..000000000 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ /dev/null @@ -1,68 +0,0 @@ -import typing - -import torch - -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig -from fast_llm.tensor import TensorMeta - - -class RotaryEmbeddingPreprocessor(Preprocessor): - _scalar_dim: TensorDim - _kv_channels_dim: TensorDim - _rotary_embedding_frequencies: torch.Tensor - _mask: torch.Tensor - _mask_value: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - - def __init__( - self, - config: DefaultRotaryConfig, - tensor_space: TensorSpace, - ): - self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k - ] - kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( - ( - self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, - self._kv_channels_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_q, - ) - kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( - ( - self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, - self._kv_channels_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_k, - ) - - def _create_tensors(self, sequence_length: int) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - self._rotary_embedding_frequencies = self._config.get_frequencies( - sequence_length, - self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, - ) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index ea56b7b5a..19fdf4c52 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -6,8 +6,8 @@ from fast_llm.engine.base_model.base_model import Layer, LossDef from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.layers.block.block import Block from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.model import GPTBaseModel, GPTModel @@ -28,7 +28,7 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerBlock( + Block( self._config.transformer, self._tensor_space, block_index=i + 1, diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 0ef970db2..eda8e47a0 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -24,11 +24,11 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex +from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.mlp.config import RoutingType -from fast_llm.layers.common.config import LayerNormalizationConfig -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig -from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex +from fast_llm.layers.common.normalization.config import LayerNormalizationConfig from fast_llm.models.gpt.config import ( DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, @@ -191,7 +191,7 @@ def _create_weight_converters( def _create_transformer_layer_converters( self, fast_llm_layer_name: str, hf_layer_name: str, ignore_export: bool = False ) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + transformer_config: BlockConfig = self._model.config.base_model.transformer norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] names_bias_cls = [ @@ -341,7 +341,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + transformer_config: BlockConfig = self._model.config.base_model.transformer return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", transformer_config.add_bias @@ -458,7 +458,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + transformer_config: BlockConfig = self._model.config.base_model.transformer return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", @@ -526,7 +526,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + transformer_config: BlockConfig = self._model.config.base_model.transformer return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", @@ -636,7 +636,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + transformer_config: BlockConfig = self._model.config.base_model.transformer return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 4e3f258fc..2f99ae4c3 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.config import AttentionKwargs +from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index 20ed8e828..f9ef0e85d 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -1,7 +1,7 @@ import typing -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig +from fast_llm.layers.block.config import BlockConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -13,7 +13,7 @@ def get_init_megatron( - meta: "ParameterMeta", config: TransformerConfig + meta: "ParameterMeta", config: BlockConfig ) -> typing.Callable[["torch.Tensor", "Distributed"], None]: def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) @@ -50,7 +50,7 @@ def set_megatron_distributed_seeds(config: "DistributedConfig") -> None: def _init_attention_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: BlockConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": # Megatron combines q and kv and inverts the initialization order of qkv and dense layers. # It also always treats the tensors as tensor-parallel and uses a different rotary embedding format. @@ -94,7 +94,7 @@ def _init_attention_megatron( raise NotImplementedError(meta.tensor_name) if isinstance(config.rotary, DefaultRotaryConfig) and config.rotary.complex_format: - from fast_llm.layers.transformer.rotary.config import convert_rotary_real_to_complex + from fast_llm.layers.attention.rotary.config import convert_rotary_real_to_complex # Megatron uses (2, kv_channels/2) for the complex split; we use (kv_channels/2, 2). # TODO: Avoid unnecessarily changing the value and dense tensors. diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 796c34756..6d9bf1676 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,14 +10,14 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.attention.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor +from fast_llm.layers.block.block import Block from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor -from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs -from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -48,7 +48,7 @@ def __init__( self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. - self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space)) + self._preprocessors.append(self._config.transformer.rotary.get_layer(self._tensor_space)) if self._use_flash_attention: self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) else: @@ -62,7 +62,7 @@ def get_output_layers(self) -> list[Layer]: for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerBlock( + Block( self._config.transformer, self._tensor_space, # TODO MTP: which index? @@ -85,7 +85,7 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerBlock( + Block( self._config.transformer, self._tensor_space, block_index=i + 1, @@ -330,7 +330,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[0] @property - def transformer_layers(self) -> list[TransformerBlock]: + def transformer_layers(self) -> list[Block]: return self.layers[1:-1] @property diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 11d888eaf..2fc2aeaec 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -21,7 +21,7 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import RMSNormalizationConfig +from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 32fbdad9b..a096d0cd0 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -3,10 +3,10 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import Block from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.block import SSMBlock -from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -41,7 +41,7 @@ def get_output_layers(self) -> list[Layer]: for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerBlock( + Block( self._config.transformer, self._tensor_space, block_index=len(self._config.hybrid_block_layout), diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 58285d408..f7f5e9663 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -348,22 +348,29 @@ def check_equal_nested(config_a, config_b): raise ValueError("\n".join(errors)) -def get_lr_scale( - lr_scale: float | None | tuple[float | None, ...], layer_lr_scale: float | None -) -> float | None | tuple[float | None, ...]: - """ - Combine module and layer lr_scale. - If one is None, return the other. - """ - if lr_scale is None: - return layer_lr_scale - if layer_lr_scale is None: - return lr_scale - if isinstance(lr_scale, float): - return lr_scale * layer_lr_scale - if isinstance(lr_scale, tuple): - return tuple(lrs * layer_lr_scale if lrs is not None else layer_lr_scale for lrs in lr_scale) - raise ValueError(f"Invalid lr_scale: {lr_scale} (type {type(lr_scale)})") +def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): + # Remove `None` entries. + lr_scales = [lr_scale for lr_scale in lr_scales if lr_scale is not None] + if not lr_scales: + # Everything is None + return None + tuple_length = None + # Check if we have tuples, and determine the length. + for lr_scale in lr_scales: + if isinstance(lr_scale, tuple): + if tuple_length is None: + tuple_length = len(lr_scale) + else: + assert len(lr_scale) == tuple_length + if tuple_length is None: + # No tuple: simple product. + return math.prod(lr_scales) + else: + # Tuple(s): use recursion. + return [ + combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales]) + for i in range(tuple_length) + ] class Interrupter: diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index e61f72244..a2f09b40d 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -23,8 +23,8 @@ from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill from fast_llm.functional.triton.rotary import triton_rotary_ from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig -from fast_llm.layers.transformer.rotary.rotary import ( +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig +from fast_llm.layers.attention.rotary.rotary import ( apply_rotary_embeddings, convert_rotary_complex_to_real, convert_rotary_real_to_complex, @@ -92,7 +92,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y1 = apply_rotary_embeddings( x, DefaultRotaryConfig(triton=False) - .build() + .get_layer() ._get_frequencies( sequence_length, kv_channels, @@ -103,7 +103,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y2 = convert_rotary_real_to_complex( triton_rotary_( convert_rotary_complex_to_real(x, kv_channels, 3), - DefaultRotaryConfig(triton=True).build()._get_frequencies(sequence_length, kv_channels, device="cuda"), + DefaultRotaryConfig(triton=True).get_layer()._get_frequencies(sequence_length, kv_channels, device="cuda"), ), kv_channels, 3, diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 8c33aed4d..380ab0550 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -6,10 +6,10 @@ from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda diff --git a/tests/test_attention.py b/tests/test_attention.py index 534e3800e..6ae34e730 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -5,9 +5,10 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig -from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor +from fast_llm.layers.attention.attention import Attention +from fast_llm.layers.attention.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor +from fast_llm.layers.block.config import BlockConfig from fast_llm.utils import Assert @@ -16,22 +17,22 @@ def test_decide_window_size(): attention._decide_window_size = Attention._decide_window_size.__get__(attention) # Attach real method # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) - attention._config = TransformerConfig(window_size=512, max_window_layers=2) + attention._config = BlockConfig(window_size=512, max_window_layers=2) attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) - attention._config = TransformerConfig(window_size=512, max_window_layers=2) + attention._config = BlockConfig(window_size=512, max_window_layers=2) attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) - attention._config = TransformerConfig(window_size=512, max_window_layers=None) + attention._config = BlockConfig(window_size=512, max_window_layers=None) assert attention._decide_window_size() == 512 def test_attention_constructor(): - transformer_conf = TransformerConfig( + transformer_conf = BlockConfig( num_layers=2, num_attention_heads=2, hidden_size=16, @@ -63,7 +64,7 @@ def test_varlen_preprocessor(): ] micro_sequence_length = 12 sequence_length = 36 - transformer_cfg = TransformerConfig( + transformer_cfg = BlockConfig( num_layers=2, num_attention_heads=2, hidden_size=16, diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 802833eb2..75f2564c7 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -1,12 +1,12 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.block.mlp.mlp import MLP -from fast_llm.layers.transformer.config import TransformerConfig def test_mlp_constructor(): - transformer_conf = TransformerConfig( + transformer_conf = BlockConfig( num_layers=2, num_attention_heads=2, hidden_size=16, @@ -19,7 +19,7 @@ def test_mlp_constructor(): def test_moe_mlp_constructor(): - transformer_conf = TransformerConfig( + transformer_conf = BlockConfig( num_layers=2, num_attention_heads=2, hidden_size=16, num_experts=2, add_linear_biases=False ) distributed_config = DistributedConfig() diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 0639ec7ed..65530f561 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,8 +3,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.block import SSMBlock -from fast_llm.layers.transformer.block import TransformerBlock +from fast_llm.layers.block.block import Block from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup @@ -41,7 +40,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, Block) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip(