diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 99c1bcf70..66176ee0f 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -5,11 +6,25 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: + import torch + from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class TensorDim: + """ + Describes a simple, atomic dimension of a tensor and its size. + The dimension may be parallelized along a distributed dimension `parallel_dim`, + in which case its actual (local) `size` will differ from its `global_size`. + + TensorDim's are used to represent the metadata of tensors through `TensorMeta`. + + This class also serves as a base for more complex tensor dimensions. + """ + def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): # TODO: Handle None for unknown sizes? self._name = name @@ -19,11 +34,11 @@ def __init__(self, name: str, global_size: int | None, parallel_dim: Distributed def __repr__(self) -> str: return ( - f"TensorDim(" + f"{type(self).__name__}(" f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," - f" parallel_dim={None if self.parallel_dim is None else self._parallel_dim}" + f" parallel_dim={self._parallel_dim}" f")" ) @@ -38,83 +53,254 @@ def name(self) -> str: def size(self) -> int: return self._size - @property - def expanded_shape(self) -> tuple[int, ...]: - return (self._size,) - - @property - def ndim(self) -> int: - return 1 - @property def global_size(self) -> int: return self._global_size @property - def global_expanded_shape(self) -> tuple[int, ...]: - return (self._size if self._parallel_dim is None else self._size * self._parallel_dim.size,) + def is_parallel(self) -> bool: + return self._parallel_dim is not None and self._parallel_dim.size > 1 @property def parallel_dim(self) -> DistributedDim | None: + # TODO: Make more flexible for derived classes? return self._parallel_dim - @property - def parallel_dim_index(self) -> int | None: - return None if self._parallel_dim is None else 0 - @property def parallel_group(self) -> "ProcessGroup|None": + # TODO: Make more flexible for derived classes? return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim is not None + """ + Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, + but the local size remains the same. + + Used in`TensorMeta.replace_tensor_parallel_dim`. + """ + assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. + If the dimension is parallelized, this amounts to gathering along dimension `dim` + and parallel dimension `parallel_dim`, otherwise return the input tensor. + The method needs to be called my all members of the parallel group using their appropriate local slice. + + Used in`TensorMeta.local_to_global`, + which iterates over the tensor dimensions to fully reconstruct the global tensor. + """ + if self.is_parallel: + from fast_llm.core.ops import gather_op + + return gather_op(tensor, self.parallel_group, dim) + else: + return tensor + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`. + Unlike `local_to_global`, this method does not need to be called from a distributed setting. + Instead, entries from other ranks are populated with `fill_value`. + + Used in`TensorMeta.local_to_global_partial`, + which iterates over the tensor dimensions to fully reconstruct the global tensor. + """ + if self.is_parallel: + output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) + output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) + return output.flatten(dim, dim + 1) + else: + return tensor + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + """ + Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. + If the dimension is parallel, this amounts to taking the `rank`th chunk of size `size` along dimension `dim` + and parallel dimension `self.parallel_dim`, otherwise return the input tensor. + + Used in`TensorMeta.local_to_global`, + which iterates over the tensor dimensions to fully reconstruct the local tensor. + """ + return ( + tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] + if self.parallel_dim is not None and self.parallel_dim.size > 1 + else tensor + ) + class CompositeTensorDim(TensorDim): - def __init__(self, name: str, dims: tuple[TensorDim, ...]): - # TODO: Recursive composition?? - parallel_dims = [(i, dim.parallel_dim) for i, dim in enumerate(dims) if dim.parallel_dim] - Assert.leq(len(parallel_dims), 1) + """ + A composite tensor dimension that represent multiple dimensions flattened into ones. + Typically happens for flattened view or higher-dimensional tensors, or tensors that can be expanded as such. + If one of the composed dimensions -- other than the first one -- is parallelized, + this is **not** equivalent to an atomic `TensorDim` of the same size, + as the relation between local and global tensors is different. + + At most one of the sub-dimensions may be parallelized. TODO: Allow for more than one? + """ + + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = None + for dim, tensor_dim in enumerate(tensor_dims): + if tensor_dim.parallel_dim is not None: + assert parallel_dim is None + parallel_dim = tensor_dim.parallel_dim + self._parallel_dim_index = dim + + super().__init__( + name=name, + global_size=math.prod(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, + ) + self._tensor_dims = tensor_dims + + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + """ + Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, + but the local size remains the same. + """ + assert self._parallel_dim_index is not None + dims = list(self._tensor_dims) + dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. + """ + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global(tensor, dim + i) + + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`, + populating other ranks with `fill_value`. + """ + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global_partial(tensor, dim + i) + + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + """ + Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. + """ + tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): + tensor = tensor_dim.global_to_local(tensor, dim + i) + return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + + +class ConcatenatedTensorDim(TensorDim): + """ + A complex tensor dimension that results from concatenating tensors. + + All sub-dimensions should have the same `parallel_dim` (may be None). TODO: Allow for more complex scenarios? + """ + + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = tensor_dims[0].parallel_dim + for dim, tensor_dim in enumerate(tensor_dims[1:]): + # TODO: Allow more flexibility? + Assert.is_(tensor_dim.parallel_dim, parallel_dim) super().__init__( name=name, - global_size=math.prod(dim.global_size for dim in dims), - parallel_dim=parallel_dims[0][1] if parallel_dims else None, + global_size=sum(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, ) - self._dims = dims - self._parallel_dim_index = ( - sum(dim.ndim for dim in self._dims[: parallel_dims[0][0]]) - + self._dims[parallel_dims[0][0]].parallel_dim_index - if parallel_dims - else None + self._tensor_dims = tensor_dims + + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + """ + Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, + but the local size remains the same. + """ + assert self.is_parallel + return ConcatenatedTensorDim( + self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) ) - @property - def dims(self) -> tuple[TensorDim, ...]: - return self._dims + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. + """ + import torch - @property - def ndim(self) -> int: - return sum(dim.ndim for dim in self._dims) + return ( + torch.concatenate( + [ + tensor_dim.local_to_global(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) - @property - def expanded_shape(self) -> tuple[int, ...]: - return sum((dim.expanded_shape for dim in self._dims), ()) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`, + populating other ranks with `fill_value`. + """ + import torch - @property - def global_expanded_shape(self) -> tuple[int, ...]: - return sum((dim.global_expanded_shape for dim in self._dims), ()) + return ( + torch.concatenate( + [ + tensor_dim.local_to_global_partial(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) - @property - def parallel_dim_index(self) -> int | None: - return self._parallel_dim_index + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + """ + Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. + """ + if self.is_parallel and expand: + raise NotImplementedError() + import torch - def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim_index is not None - dims = list(self.dims) - dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim) - return CompositeTensorDim(self.name, tuple(dims)) + return ( + torch.concatenate( + [ + tensor_dim.global_to_local(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) class DefaultDimNames: @@ -147,21 +333,19 @@ def distributed(self) -> "Distributed": assert self._is_setup return self._distributed - def add_tensor_dim(self, dim: TensorDim) -> None: - if isinstance(dim, CompositeTensorDim): - for dim_ in dim.dims: - Assert.incl(dim_.name, self._tensor_dims) - Assert.eq(dim_, self._tensor_dims[dim_.name]) - if dim.name in self._tensor_dims: - Assert.eq(dim, self._tensor_dims[dim.name]) + def add_tensor_dim(self, tensor_dim: TensorDim) -> None: + if tensor_dim.name in self._tensor_dims: + Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) else: - if dim.parallel_dim is not None: - assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name + if tensor_dim.parallel_dim is not None: + assert ( + tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims + ), tensor_dim.parallel_dim.name Assert.eq( - dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__, + tensor_dim.parallel_dim.__dict__, + self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, ) - self._tensor_dims[dim.name] = dim + self._tensor_dims[tensor_dim.name] = tensor_dim - def get_tensor_dim(self, name: str) -> TensorDim: + def __getitem__(self, name: str) -> TensorDim: return self._tensor_dims[name] diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 5b44bf14b..be15cd37a 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight( where it is located in the shard if it exists, or -1 if it's not in the shard. Used to determine the location of each entry in a different distributed configuration. """ - - # Create an empty index for the global parameter. - index = torch.full( - parameter_meta.global_shape, - -1, - dtype=torch.int64, - device=device, - ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard begin, end = self._get_parameter_range_in_shard(parameter_name) - buffer_index = parameter_meta.global_to_local(index, expand=True) - # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible. - # In that case, we work with a separate tensor to be copied back into `buffer_index`. - try: - buffer_index_flat = buffer_index.view(-1) - is_view = True - except RuntimeError: - buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1) - is_view = False - - # Copy the shard indices at their respective positions in the flat buffer index. - buffer_index_flat[ + # Create an empty local index to hold the local shard indices. + buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device) + + # Copy the shard indices at their respective positions in the buffer index. + buffer_index.flatten()[ self._index_buffer_to_param( self._fsdp_dim.rank * self._shard_size, parameter_name ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) ].copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) - # If needed, copy the flat buffer index back into the index. - if not is_view: - buffer_index.copy_(buffer_index_flat.view_as(buffer_index)) - - return index + # Create a global index from the local one. + return parameter_meta.local_to_global_partial(buffer_index, -1) def copy_shard_overlaps( self, diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 9a8ce2092..3218a1963 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -185,8 +185,9 @@ def initialize_weights(self) -> None: # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) global_shape = meta.global_shape - if self._distributed_config.reproducible_init and ( - global_shape.numel() != parameter.numel() or not self._mode.on_device + if meta.requires_global_initialization or ( + self._distributed_config.reproducible_init + and (global_shape.numel() != parameter.numel() or not self._mode.on_device) ): # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9f32ac689..07dadbc22 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_ + from fast_llm.tensor import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, @@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> " } if self.initialization_range: mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_( - mean - self.initialization_range, mean + self.initialization_range - ) + kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) return self.module_class(**kwargs) @property diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index cd19a47a5..7249ef569 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -94,8 +94,8 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None - assert out_dim.parallel_dim is None + assert not in_dim.is_parallel + assert not out_dim.is_parallel super().__init__( in_dim, out_dim, @@ -132,7 +132,7 @@ def __init__( sequence_parallel: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None + assert not in_dim.is_parallel self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( @@ -176,7 +176,7 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert out_dim.parallel_dim is None + assert not out_dim.is_parallel self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 5f30beaef..bccc1d627 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -158,7 +158,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: @@ -242,7 +242,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 3a1966e51..08f3e535b 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -19,12 +19,12 @@ def lora_linear( ): layer.weight.requires_grad = False in_dim = layer._in_dim + assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: - assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." in_dim = TensorDim(in_dim.name, in_dim.global_size) out_dim = layer._out_dim + assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: - assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." out_dim = TensorDim(out_dim.name, out_dim.global_size) if out_channel_begin is not None or out_channel_end is not None: if out_channel_begin is None: diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 7036a1e97..f6f43d199 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -46,10 +46,10 @@ def __init__( self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - vocab_dim = tensor_space.get_tensor_dim( + hidden_dim = tensor_space[TransformerDimNames.hidden] + vocab_dim = tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size @@ -66,7 +66,7 @@ def __init__( ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim), + (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 21bf3bbd0..210cad644 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -61,7 +61,7 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] self._loss_coefficient = ( config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 @@ -108,9 +108,9 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: if self._tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), init_method=init_normal_( @@ -338,9 +338,9 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._logits_scale_factor, ) if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) + ] dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) dims[sequence_index] = ( diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index d719bef3d..c8d53a789 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -28,7 +28,7 @@ def __init__( assert config.use_absolute_position_embeddings self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: @@ -76,7 +76,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index c0ae7e781..6012f74a7 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_, init_zeros_ from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -62,14 +62,14 @@ def __init__( mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_discrete_mamba2) + td_inner = tensor_space[SSMDimNames.inner_dim] + td_state = tensor_space[SSMDimNames.state_dim] + td_model = tensor_space[SSMDimNames.model_dim] + td_conv = tensor_space[SSMDimNames.conv_dim] + td_n_qk_heads = tensor_space[SSMDimNames.qk_heads] + td_n_v_heads = tensor_space[SSMDimNames.v_heads] + td_conv_kernel = tensor_space[SSMDimNames.conv_kernel_size] + td_inner_proj = tensor_space[SSMDimNames.inner_proj_discrete_mamba2] self.d_model = td_model.size self.d_inner = td_inner.size @@ -88,7 +88,7 @@ def __init__( td_model, td_inner_proj, bias=bias, - weight_init_method=kaiming_init_(td_model.size), + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.z_bias = ( @@ -103,7 +103,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), + (td_conv, tensor_space[DefaultDimNames.scalar], td_conv_kernel), init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 @@ -126,7 +126,7 @@ def __init__( td_inner, td_model, bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + weight_init_method=init_kaiming_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 74c212add..9dfad8462 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ +from fast_llm.tensor import ParameterMeta, init_fill_, init_kaiming_, init_ones_, init_uniform_ from fast_llm.utils import get_lr_scale try: @@ -80,13 +80,13 @@ def __init__( self.config.mamba_lr_scale, layer_lr_scale ) - td_inner: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_dim) - td_state: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.state_dim) - td_model: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.model_dim) - tdt_rank: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - td_xb: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.x_proj_dim_2) - td_inner_proj: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_proj_mamba2) - td_conv_kernel: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel_size) + td_inner: TensorDim = tensor_space[SSMDimNames.inner_dim] + td_state: TensorDim = tensor_space[SSMDimNames.state_dim] + td_model: TensorDim = tensor_space[SSMDimNames.model_dim] + tdt_rank: TensorDim = tensor_space[SSMDimNames.dt_rank] + td_xb: TensorDim = tensor_space[SSMDimNames.x_proj_dim_2] + td_inner_proj: TensorDim = tensor_space[SSMDimNames.inner_proj_mamba2] + td_conv_kernel: TensorDim = tensor_space[SSMDimNames.conv_kernel_size] self.repeat_kv_before_conv = config.repeat_kv_before_conv @@ -98,7 +98,7 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), + (td_inner, tensor_space[DefaultDimNames.scalar], td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), @@ -111,7 +111,7 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), + (td_xb, tensor_space[DefaultDimNames.scalar], td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), @@ -131,14 +131,14 @@ def __init__( td_model, td_inner_proj, bias=bias, - weight_init_method=kaiming_init_(td_model.size), + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.dt_in_proj = Linear( td_model, tdt_rank, bias=config.add_bias_linear, - weight_init_method=kaiming_init_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(transformer_config.hidden_size), lr_scale=mamba_layer_lr_scale, ) # Initialize special dt projection to preserve variance at initialization @@ -185,7 +185,7 @@ def __init__( td_inner, td_model, bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + weight_init_method=init_kaiming_(td_inner.size), ) def forward(self, hidden_states, kwargs): diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 4493332ce..5e0ae786e 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import get_lr_scale try: @@ -75,15 +75,13 @@ def __init__( self.config: SSMConfig = config # Tensor dims: - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_inner_proj = tensor_space.get_tensor_dim( - SSMDimNames.inner_proj_mamba - ) # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) - td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) + td_inner = tensor_space[SSMDimNames.inner_dim] + td_inner_proj = tensor_space[SSMDimNames.inner_proj_mamba] # TensorDim("D_inner_2", self.d_inner * 2) + tdt_rank = tensor_space[SSMDimNames.dt_rank] + td_x_proj = tensor_space[SSMDimNames.x_proj_dim] + td_state = tensor_space[SSMDimNames.state_dim] + td_model = tensor_space[SSMDimNames.model_dim] + td_conv_kernel = tensor_space[SSMDimNames.conv_kernel_size] self.d_conv = td_conv_kernel.size self.d_inner = td_inner.size self.d_state = td_state.size @@ -94,12 +92,12 @@ def __init__( self.in_proj_weight = ParameterMeta.from_dims( (td_inner_proj, td_model), - init_method=kaiming_init_(td_model.size), + init_method=init_kaiming_(td_model.size), ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), - init_method=kaiming_init_(td_inner.size), + (td_inner, tensor_space[DefaultDimNames.scalar], td_conv_kernel), + init_method=init_kaiming_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) @@ -111,7 +109,7 @@ def __init__( self.x_proj = Linear( td_inner, td_x_proj, - weight_init_method=kaiming_init_(td_inner.size), + weight_init_method=init_kaiming_(td_inner.size), bias=False, lr_scale=mamba_layer_lr_scale, ) @@ -120,7 +118,7 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( (td_inner, tdt_rank), - init_method=kaiming_init_(tdt_rank.size), + init_method=init_kaiming_(tdt_rank.size), lr_scale=mamba_layer_lr_scale, ) @@ -151,7 +149,7 @@ def __init__( td_inner, td_model, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=kaiming_init_(td_model.size), + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.out_proj.weight.auto_grad_accumulation = True diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 174e19588..c59b191af 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -91,14 +91,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size + self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size + self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size + self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -106,7 +106,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space[TransformerDimNames.composite_query], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -115,7 +115,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space[TransformerDimNames.composite_key_value], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -129,7 +129,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space[TransformerDimNames.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 73f83ccf5..4fd2844d5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -63,8 +63,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space.get_tensor_dim(TransformerDimNames.hidden), - tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), + tensor_space[TransformerDimNames.hidden], + tensor_space[TransformerDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max @@ -255,7 +255,7 @@ def _debug_log( def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space.get_tensor_dim(dim_name),), + kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index efe0c5cc5..101d97ef3 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -30,8 +30,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space[TransformerDimNames.hidden] + self._intermediate_dim = tensor_space[TransformerDimNames.composite_expert_mlp] self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -46,7 +46,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space[TransformerDimNames.composite_gated_expert_mlp], bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index dc3ddeb52..3f0e14eb7 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -28,7 +28,7 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index cc83dae02..c357411b6 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -25,8 +25,8 @@ def __init__( self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4c..17b18a1ca 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -82,8 +82,8 @@ def __init__( super().__init__(config, tensor_space) self._tensor_space = tensor_space if self._tensor_space is not None: - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index d08db9a94..75d06f268 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -48,7 +48,7 @@ def _get_meta( } return TensorMeta.from_dims( tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] for dim_name in dim_names ), tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", @@ -97,7 +97,7 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index e7379e61e..20ed8e828 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -14,8 +14,8 @@ def get_init_megatron( meta: "ParameterMeta", config: TransformerConfig -) -> typing.Callable[["torch.Tensor", "Distributed"], "torch.Tensor"]: - def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): +) -> typing.Callable[["torch.Tensor", "Distributed"], None]: + def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) if "bias" in meta.tensor_name: # Generator unused. @@ -29,11 +29,11 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): elif config.num_experts > 1 and "mlp.layer_" in meta.tensor_name: tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) elif "mlp.layer_2" in meta.tensor_name: - tensor_ = _init_transposed_mlp_weight_megatron(config, meta, tensor, distributed) + tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) else: # Word embedding (override generator), layer norm (generator unused), other mlp weights. return meta.param_init_method(meta, tensor, distributed.tp_init_generator) - return tensor.copy_(tensor_.reshape_as(tensor)) + tensor.copy_(tensor_.reshape_as(tensor)) return init_megatron @@ -58,9 +58,9 @@ def _init_attention_megatron( generator = distributed.tp_init_generator state = generator.get_state() # Initialize a mock dense layer to advance the random state - dense_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + dense_tensor_ := tensor.new_empty( config.kv_channels * config.num_attention_heads, config.hidden_size, ), @@ -68,9 +68,9 @@ def _init_attention_megatron( ) # QKV is split differently. (Assuming no tensor-parallel.) heads_per_group = div(config.num_attention_heads, config.head_groups) - qkv_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + qkv_tensor_ := tensor.new_empty( config.head_groups, heads_per_group + 2, config.kv_channels, @@ -110,18 +110,19 @@ def _init_position_embeddings_megatron( # Megatron initializes the position embeddings on cpu twice. assert meta.param_init_method is not None generator = distributed.default_cpu_generator - tensor_ = meta.param_init_method(meta, torch.empty(tensor.shape, dtype=tensor.dtype), generator) - return meta.param_init_method(meta, tensor_, generator) + meta.param_init_method(meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), generator) + meta.param_init_method(meta, tensor_, generator) + return tensor_ def _init_transposed_mlp_weight_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": import torch # Megatron never transposes the mlp layer 2 weight. assert meta.param_init_method is not None - tensor_ = meta.param_init_method(meta, torch.empty_like(tensor), distributed.tp_init_generator) + meta.param_init_method(meta, tensor_ := torch.empty_like(tensor), distributed.tp_init_generator) return tensor_.view(meta.size(1), meta.size(0)).t() @@ -132,8 +133,8 @@ def _init_moe_router_megatron( # Megatron initializes the router on cpu. assert meta.param_init_method is not None - tensor_ = meta.param_init_method( - meta, torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator + meta.param_init_method( + meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator ) return tensor_ diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 4c1eab46f..49a5dcbd3 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -155,7 +155,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d780e4d6d..c17df9d0c 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,17 +1,21 @@ +import abc import functools +import logging import math import typing import torch from fast_llm.core.distributed import ReduceOp -from fast_llm.core.ops import gather_op, reduce_op +from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.triton.pointwise import triton_add, triton_copy from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class _SafeTensorSliceMeta(type): def __instancecheck__(self, instance) -> bool: @@ -146,7 +150,7 @@ def from_tensor_space( reductions: tuple[tuple[str, ReduceOp], ...] = (), **kwargs: typing.Any, ) -> typing.Self: - dims = tuple(tensor_space.get_tensor_dim(dim_name) for dim_name in dim_names) + dims = tuple(tensor_space[dim_name] for dim_name in dim_names) if reductions: # kwarg not available for ParameterMeta, so we only provide if necessary. kwargs["reductions"] = tuple( @@ -158,22 +162,23 @@ def from_tensor_space( def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global( - self, - tensor: torch.Tensor, - *, - distributed: Distributed, - ) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + """ + Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank = distributed.config.tensor_rank == 0 - modified = False - for i, dim in enumerate(self.dims): - if dim.parallel_group is not None: - tensor = gather_op( - tensor.unflatten(i, dim.expanded_shape), dim.parallel_group, i + dim.parallel_dim_index - ).flatten(i, i + len(dim.expanded_shape) - 1) - is_first_rank, modified = is_first_rank and dim.parallel_group.rank() == 0, True + is_first_rank, modified = distributed.config.tensor_rank == 0, False + + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global(tensor, dim) + is_first_rank &= tensor_dim.parallel_dim.rank == 0 + modified = True for distributed_dim, op in self._reductions: if distributed_dim.group is not None: @@ -182,28 +187,44 @@ def local_to_global( tensor = tensor.clone() tensor = reduce_op(tensor, distributed_dim.group, op=op) is_first_rank, modified = is_first_rank and distributed_dim.group.rank() == 0, True + Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank - def global_to_local( - self, - tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. - expand: bool = False, - ) -> torch.Tensor: + def local_to_global_partial(self, tensor: torch.Tensor, fill_value: float | int = -1) -> torch.Tensor: """ - Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. + Construct a tensor of shape `self.global_shape` that contains its local slice at the appropriate location, + i.e. for which `self.global_to_local(self.local_to_global_partial(tensor)) == tensor`. + Other entries are filled with `fill_value`. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) + assert not self._reductions + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global_partial(tensor, dim, fill_value) + + Assert.eq(tensor.shape, self.global_shape) + return tensor + + def global_to_local(self, tensor: torch.Tensor | SafeTensorSlice) -> torch.Tensor: + """ + Select the local slice of a global tensor. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. """ # Take a trivial slice to convert safetensor slices. - tensor_ = tensor[:] + tensor = tensor[:] assert not self._reductions + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.global_shape) - for i, dim in reversed(list(enumerate(self.dims))): - if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = tensor_.unflatten(i, dim.global_expanded_shape).chunk( - dim.parallel_dim.size, i + dim.parallel_dim_index - )[dim.parallel_dim.rank] + for dim, tensor_dim in reversed(list(enumerate(self.dims))): + tensor = tensor_dim.global_to_local(tensor, dim) - return tensor_ if expand else tensor_.reshape(self.shape) + Assert.eq(tensor.shape, self.shape) + return tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -219,8 +240,12 @@ def validate(self, tensor: torch.Tensor, device: torch.device | None = None) -> return validate_tensor(tensor, self, device) def replace_tensor_parallel_dim(self, distributed_dim: DistributedDim) -> "TensorMeta": - # Replace the tensor-parallel `DistributedDim` in `meta`. - # Note: This will turn `ParameterMeta` into `TensorMeta` + """ + Replace the tensor-parallel `DistributedDim` in `meta`, preserving the local size. + Requires for advanced tensor manipulations, + ex. turn tensor-parallel slices of a tensor into slices of a different tensor-parallel size. + Note: This will turn `ParameterMeta` into `TensorMeta` + """ if not self.is_tensor_parallel: return self dims = list(self.dims) @@ -237,7 +262,7 @@ def __init__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable[["ParameterMeta", torch.Tensor, torch.Generator], torch.Tensor] | None = None, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, weight_decay: bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. lr_scale: float | None | tuple[float | None, ...] = None, @@ -247,7 +272,11 @@ def __init__( allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) - self.param_init_method = init_method + if init_method is not None and not isinstance(init_method, Initializer): + # Support non-wrapped callables for convenience. + assert callable(init_method) + init_method = LambdaInitializer(init_method) + self.param_init_method: Initializer | None = init_method self.param_weight_decay = weight_decay self._is_param = True self.param_grad_is_zero = False @@ -272,7 +301,7 @@ def __new__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None", weight_decay: bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, @@ -293,12 +322,20 @@ def __repr__(self, *, tensor_contents=()) -> str: def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None: assert self.param_init_method is not None - if distributed.config.tensor_parallel == 1 or distributed.config.reproducible_init: + if ( + distributed.config.tensor_parallel == 1 + or distributed.config.reproducible_init + or self.param_init_method.requires_global_initialization + ): generator = distributed.pp_init_generator else: generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator self.param_init_method(self, tensor, generator) + @property + def requires_global_initialization(self) -> bool: + return self.param_init_method.requires_global_initialization + def save(self) -> dict[str, typing.Any]: return { "name": self.tensor_name, @@ -330,11 +367,32 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa -def init_fill_(value) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.fill_(value) +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + pass + + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + return self._init_method(meta, tensor, generator) + - return init_ +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) init_zeros_ = init_fill_(0.0) @@ -342,30 +400,35 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_normal_( - mean=0.0, std=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.normal_(mean, std, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def kaiming_init_(d_in): +def init_kaiming_(d_in: float) -> LambdaInitializer: return init_normal_(0.0, math.sqrt(2.0 / d_in)) def init_uniform_( - low=0.0, high=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.uniform_(low, high, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + - return init_ +def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: + return init_uniform_( + mean - high, + mean + high, + min_val=None if max_val is None else mean - max_val, + max_val=None if max_val is None else mean + max_val, + )