Skip to content

Concatenated dim #336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 246 additions & 62 deletions fast_llm/engine/config_utils/tensor_space.py

Large diffs are not rendered by default.

32 changes: 7 additions & 25 deletions fast_llm/engine/multi_stage/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight(
where it is located in the shard if it exists, or -1 if it's not in the shard.
Used to determine the location of each entry in a different distributed configuration.
"""

# Create an empty index for the global parameter.
index = torch.full(
parameter_meta.global_shape,
-1,
dtype=torch.int64,
device=device,
)
# Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard
begin, end = self._get_parameter_range_in_shard(parameter_name)

buffer_index = parameter_meta.global_to_local(index, expand=True)
# Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible.
# In that case, we work with a separate tensor to be copied back into `buffer_index`.
try:
buffer_index_flat = buffer_index.view(-1)
is_view = True
except RuntimeError:
buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1)
is_view = False

# Copy the shard indices at their respective positions in the flat buffer index.
buffer_index_flat[
# Create an empty local index to hold the local shard indices.
buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device)

# Copy the shard indices at their respective positions in the buffer index.
buffer_index.flatten()[
self._index_buffer_to_param(
self._fsdp_dim.rank * self._shard_size, parameter_name
) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name)
].copy_(torch.arange(begin, end, dtype=torch.int64, device=device))

# If needed, copy the flat buffer index back into the index.
if not is_view:
buffer_index.copy_(buffer_index_flat.view_as(buffer_index))

return index
# Create a global index from the local one.
return parameter_meta.local_to_global_partial(buffer_index, -1)

def copy_shard_overlaps(
self,
Expand Down
5 changes: 3 additions & 2 deletions fast_llm/engine/multi_stage/stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,9 @@ def initialize_weights(self) -> None:
# Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device)
global_shape = meta.global_shape

if self._distributed_config.reproducible_init and (
global_shape.numel() != parameter.numel() or not self._mode.on_device
if meta.requires_global_initialization or (
self._distributed_config.reproducible_init
and (global_shape.numel() != parameter.numel() or not self._mode.on_device)
):
# Initialize all global weights on every gpu, then select the appropriate slice if applicable.
global_param = parameter.new_empty(global_shape, device=self._distributed.device)
Expand Down
6 changes: 2 additions & 4 deletions fast_llm/layers/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig):
)

def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm":
from fast_llm.tensor import init_uniform_
from fast_llm.tensor import init_uniform_centered_

kwargs = {
"hidden_dim": hidden_dim,
Expand All @@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "
}
if self.initialization_range:
mean = 0 if self.zero_centered else 1
kwargs["weight_init_method"] = init_uniform_(
mean - self.initialization_range, mean + self.initialization_range
)
kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean)
return self.module_class(**kwargs)

@property
Expand Down
8 changes: 4 additions & 4 deletions fast_llm/layers/common/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def __init__(
transposed_weight: bool = False,
lr_scale: float | None | tuple[float | None, ...] = None,
):
assert in_dim.parallel_dim is None
assert out_dim.parallel_dim is None
assert not in_dim.is_parallel
assert not out_dim.is_parallel
super().__init__(
in_dim,
out_dim,
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(
sequence_parallel: bool = False,
lr_scale: float | None | tuple[float | None, ...] = None,
):
assert in_dim.parallel_dim is None
assert not in_dim.is_parallel
self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size
self._sequence_parallel = sequence_parallel and self._group_size > 1
super().__init__(
Expand Down Expand Up @@ -176,7 +176,7 @@ def __init__(
transposed_weight: bool = False,
lr_scale: float | None | tuple[float | None, ...] = None,
):
assert out_dim.parallel_dim is None
assert not out_dim.is_parallel
self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size
self._sequence_parallel = sequence_parallel and self._group_size > 1
super().__init__(
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/layers/common/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(
lr_scale: float | None = None,
):
super().__init__()
assert hidden_dim.parallel_dim is None
assert not hidden_dim.is_parallel
self._eps = eps
self._zero_centered = zero_centered
if implementation == NormalizationImplementation.auto:
Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(
lr_scale: float | None = None,
):
super().__init__()
assert hidden_dim.parallel_dim is None
assert not hidden_dim.is_parallel
self._eps = eps
self._zero_centered = zero_centered
if implementation == NormalizationImplementation.auto:
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/layers/common/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def lora_linear(
):
layer.weight.requires_grad = False
in_dim = layer._in_dim
assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism."
if in_dim.parallel_dim is not None:
assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism."
in_dim = TensorDim(in_dim.name, in_dim.global_size)
out_dim = layer._out_dim
assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism."
if out_dim.parallel_dim is not None:
assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism."
out_dim = TensorDim(out_dim.name, out_dim.global_size)
if out_channel_begin is not None or out_channel_end is not None:
if out_channel_begin is None:
Expand Down
8 changes: 4 additions & 4 deletions fast_llm/layers/language_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def __init__(
self._dropout_p = config.transformer.hidden_dropout
self._use_absolute_position_embeddings = config.use_absolute_position_embeddings

hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden)
vocab_dim = tensor_space.get_tensor_dim(
hidden_dim = tensor_space[TransformerDimNames.hidden]
vocab_dim = tensor_space[
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
)
]

if self._parallel_embeddings:
self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size
Expand All @@ -66,7 +66,7 @@ def __init__(
)
if self._use_absolute_position_embeddings:
self.position_embeddings_weight = ParameterMeta.from_dims(
(tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim),
(tensor_space[LanguageModelDimNames.position_embed], hidden_dim),
init_method=init_normal_(
std=config.init_method_std_embed,
min_val=config.init_method_min_embed,
Expand Down
10 changes: 5 additions & 5 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
if self._cross_entropy_splits is not None and self._sequence_parallel:
assert not self._parallel_embeddings

hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden)
hidden_dim = self._tensor_space[TransformerDimNames.hidden]

self._loss_coefficient = (
config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0
Expand Down Expand Up @@ -108,9 +108,9 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None:
if self._tie_word_embeddings or self._prediction_distance > 0:
return
# untie embedding weights
vocab_dim = self._tensor_space.get_tensor_dim(
vocab_dim = self._tensor_space[
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
)
]
self.output_weights = ParameterMeta.from_dims(
(vocab_dim, hidden_dim),
init_method=init_normal_(
Expand Down Expand Up @@ -338,9 +338,9 @@ def _logits_cross_entropy_forward_backward(
logits_scale_factor=self._logits_scale_factor,
)
if self._debug_transformer and self._cross_entropy_splits is None:
vocab_dim = self._tensor_space.get_tensor_dim(
vocab_dim = self._tensor_space[
LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp
)
]
dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim]
sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first])
dims[sequence_index] = (
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/layers/language_model/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
assert config.use_absolute_position_embeddings
self._tensor_space = tensor_space
self._distributed_config = self._tensor_space.distributed_config
self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar)
self._scalar_dim = self._tensor_space[DefaultDimNames.scalar]

def _create_tensors(self, sequence_length: int) -> None:
if sequence_length <= self._tensor_cache_max_sequence_length:
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace):
self._config = config
self._tensor_space = tensor_space
self._distributed_config = self._tensor_space.distributed_config
self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar)
self._scalar_dim = self._tensor_space[DefaultDimNames.scalar]

def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
return
Expand Down
24 changes: 12 additions & 12 deletions fast_llm/layers/ssm/discrete_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames
from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs
from fast_llm.layers.transformer.transformer import Mixer
from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_
from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_, init_zeros_
from fast_llm.utils import get_lr_scale

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,14 +62,14 @@ def __init__(
mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale)
logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}")

td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim)
td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim)
td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim)
td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim)
td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads)
td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads)
td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size)
td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_discrete_mamba2)
td_inner = tensor_space[SSMDimNames.inner_dim]
td_state = tensor_space[SSMDimNames.state_dim]
td_model = tensor_space[SSMDimNames.model_dim]
td_conv = tensor_space[SSMDimNames.conv_dim]
td_n_qk_heads = tensor_space[SSMDimNames.qk_heads]
td_n_v_heads = tensor_space[SSMDimNames.v_heads]
td_conv_kernel = tensor_space[SSMDimNames.conv_kernel_size]
td_inner_proj = tensor_space[SSMDimNames.inner_proj_discrete_mamba2]

self.d_model = td_model.size
self.d_inner = td_inner.size
Expand All @@ -88,7 +88,7 @@ def __init__(
td_model,
td_inner_proj,
bias=bias,
weight_init_method=kaiming_init_(td_model.size),
weight_init_method=init_kaiming_(td_model.size),
lr_scale=mamba_layer_lr_scale,
)
self.z_bias = (
Expand All @@ -103,7 +103,7 @@ def __init__(
)

self.conv1d_weight = ParameterMeta.from_dims(
(td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel),
(td_conv, tensor_space[DefaultDimNames.scalar], td_conv_kernel),
init_method=init_uniform_(
1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size)
), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67
Expand All @@ -126,7 +126,7 @@ def __init__(
td_inner,
td_model,
bias=bias,
weight_init_method=kaiming_init_(td_inner.size),
weight_init_method=init_kaiming_(td_inner.size),
lr_scale=mamba_layer_lr_scale,
)

Expand Down
26 changes: 13 additions & 13 deletions fast_llm/layers/ssm/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias
from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames
from fast_llm.layers.transformer.transformer import Mixer
from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_
from fast_llm.tensor import ParameterMeta, init_fill_, init_kaiming_, init_ones_, init_uniform_
from fast_llm.utils import get_lr_scale

try:
Expand Down Expand Up @@ -80,13 +80,13 @@ def __init__(
self.config.mamba_lr_scale, layer_lr_scale
)

td_inner: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_dim)
td_state: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.state_dim)
td_model: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.model_dim)
tdt_rank: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank)
td_xb: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.x_proj_dim_2)
td_inner_proj: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_proj_mamba2)
td_conv_kernel: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel_size)
td_inner: TensorDim = tensor_space[SSMDimNames.inner_dim]
td_state: TensorDim = tensor_space[SSMDimNames.state_dim]
td_model: TensorDim = tensor_space[SSMDimNames.model_dim]
tdt_rank: TensorDim = tensor_space[SSMDimNames.dt_rank]
td_xb: TensorDim = tensor_space[SSMDimNames.x_proj_dim_2]
td_inner_proj: TensorDim = tensor_space[SSMDimNames.inner_proj_mamba2]
td_conv_kernel: TensorDim = tensor_space[SSMDimNames.conv_kernel_size]

self.repeat_kv_before_conv = config.repeat_kv_before_conv

Expand All @@ -98,7 +98,7 @@ def __init__(

if self.repeat_kv_before_conv:
self.conv1d_weight = ParameterMeta.from_dims(
(td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel),
(td_inner, tensor_space[DefaultDimNames.scalar], td_conv_kernel),
init_method=init_uniform_(
-1 / math.sqrt(td_inner.size * td_conv_kernel.size),
1 / math.sqrt(td_inner.size * td_conv_kernel.size),
Expand All @@ -111,7 +111,7 @@ def __init__(
)
else:
self.conv1d_weight = ParameterMeta.from_dims(
(td_xb, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel),
(td_xb, tensor_space[DefaultDimNames.scalar], td_conv_kernel),
init_method=init_uniform_(
-1 / math.sqrt(td_xb.size * td_conv_kernel.size),
1 / math.sqrt(td_xb.size * td_conv_kernel.size),
Expand All @@ -131,14 +131,14 @@ def __init__(
td_model,
td_inner_proj,
bias=bias,
weight_init_method=kaiming_init_(td_model.size),
weight_init_method=init_kaiming_(td_model.size),
lr_scale=mamba_layer_lr_scale,
)
self.dt_in_proj = Linear(
td_model,
tdt_rank,
bias=config.add_bias_linear,
weight_init_method=kaiming_init_(transformer_config.hidden_size),
weight_init_method=init_kaiming_(transformer_config.hidden_size),
lr_scale=mamba_layer_lr_scale,
)
# Initialize special dt projection to preserve variance at initialization
Expand Down Expand Up @@ -185,7 +185,7 @@ def __init__(
td_inner,
td_model,
bias=bias,
weight_init_method=kaiming_init_(td_inner.size),
weight_init_method=init_kaiming_(td_inner.size),
)

def forward(self, hidden_states, kwargs):
Expand Down
Loading