Skip to content

[Prototype] Test #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: block_interface_config
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fast_llm/engine/config_utils/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@config_class(registry=True)
class InitializationConfig(Config):
_abstract = True
has_initialization: typing.ClassVar[bool] = True
is_default: typing.ClassVar[bool] = False

@classmethod
def _from_dict(
Expand All @@ -35,7 +35,7 @@ def get_initializer(self) -> "Initializer":
class DefaultInitializationConfig(InitializationConfig):
# A placeholder indicating that the class default should be used instead.
_abstract = False
has_initialization = False
is_default = True


@config_class(dynamic_type={InitializationConfig: "fill"})
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results
from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM
from fast_llm.layers.transformer.rotary.config import NoRotaryConfig
from fast_llm.layers.attention.rotary.config import NoRotaryConfig

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from fast_llm.core.distributed import set_generator
from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim
from fast_llm.engine.config_utils.tensor_space import TensorSpace
from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.functional.autograd import wrap_forward_backward
from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs
from fast_llm.layers.block.block import BlockLayer
from fast_llm.layers.block.peft import TransformerSubLayerName
from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear
from fast_llm.layers.transformer.config import AttentionConfig, AttentionDimNames, AttentionKwargs
from fast_llm.utils import get_lr_scale
from fast_llm.layers.block.config import BlockDimNames
from fast_llm.utils import div

try:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa
Expand Down Expand Up @@ -50,80 +50,95 @@ class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]):
A self-attention layer.
"""

_QUERY_DIMS = (
AttentionDimNames.batch,
AttentionDimNames.sequence_q,
AttentionDimNames.composite_heads,
AttentionDimNames.kv_channels,
)
_KV_DIMS = (
AttentionDimNames.batch,
AttentionDimNames.sequence_q,
AttentionDimNames.head_groups,
AttentionDimNames.kv_channels,
)
_CONTEXT_DIMS = (
AttentionDimNames.batch,
AttentionDimNames.sequence_q,
AttentionDimNames.composite_dense,
)

def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str):
super().__init__(config, tensor_space, block_index, name)
self._config = config
self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config)

self._kv_channels = self._tensor_space[AttentionDimNames.kv_channels].size
self._head_groups = self._tensor_space[AttentionDimNames.head_groups].global_size
self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size
self._local_heads_per_group = self._tensor_space[AttentionDimNames.group_heads].size
def __init__(
self,
config: ConfigType,
distributed_config: DistributedConfig,
hidden_dim: TensorDim,
block_index: int,
name: str,
):
super().__init__(config, distributed_config, hidden_dim, block_index, name)
self._use_flash_attention = self._config.do_use_flash_attention(distributed_config)

head_group_dim = TensorDim(
"head_groups", self._config.head_groups, self._parallel_dim if self.head_groups > 1 else None
)
group_heads_dim = TensorDim(
"group_heads",
div(self._config.num_attention_heads, self._config.head_groups),
None if self.head_groups > 1 else self._parallel_dim,
)
self._local_head_groups = head_group_dim.size
self._local_heads_per_group = group_heads_dim.size
self._local_heads = self._local_head_groups * self._local_heads_per_group
self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power)

hidden_dim = self._tensor_space[AttentionDimNames.hidden]
kv_channels_dim = TensorDim("kv_channels", self._config.kv_channels)
query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, kv_channels_dim))
key_value_dim = ConcatenatedTensorDim(
"key_value",
(
CompositeTensorDim("key", (group_heads_dim, kv_channels_dim)),
CompositeTensorDim("value", (group_heads_dim, kv_channels_dim)),
),
)
dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, kv_channels_dim))

layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None
attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale)
self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power)

lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None

# TODO: Merge the query and key-value computations? (harder with sequence parallel.)
self.query = OutputParallelLinear(
self.query = self._config.query_layer.get_layer(
hidden_dim,
self._tensor_space[AttentionDimNames.composite_query],
query_dim,
bias=self._config.add_qkv_bias,
weight_init_method=self._config.qkv_weight_initialization_method,
bias_init_method=self._config.qkv_bias_initialization_method,
sequence_parallel=self._sequence_parallel,
lr_scale=attention_lr_scale,
lr_scale=lr_scale,
peft=self._config.block.peft,
)
self.key_value = OutputParallelLinear(
# TODO: Separate.
self.key_value = self._config.key_layer.get_layer(
hidden_dim,
self._tensor_space[AttentionDimNames.composite_key_value],
bias=self._config.add_qkv_bias,
weight_init_method=self._config.qkv_weight_initialization_method,
bias_init_method=self._config.qkv_bias_initialization_method,
key_value_dim,
sequence_parallel=self._sequence_parallel,
lr_scale=attention_lr_scale,
lr_scale=lr_scale,
peft=self._config.block.peft,
)
self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward)

# Rotary embeddings.
self._rotary = self._config.rotary.build()
self._rotary = self._config.rotary.get_layer()

# Output.
self.dense = InputParallelLinear(
self._tensor_space[AttentionDimNames.composite_dense],
self.dense = self._config.dense_layer.get_layer(
dense_dim,
hidden_dim,
bias=self._config.add_dense_bias,
weight_init_method=self._config.dense_weight_initialization_method,
bias_init_method=self._config.dense_bias_initialization_method,
sequence_parallel=self._sequence_parallel,
lr_scale=attention_lr_scale,
peft=self._config.block.peft,
lr_scale=lr_scale,
)

# PEFT.
self.query = self._config.peft.apply_linear(self.query, TransformerSubLayerName.query)
self.key_value = self._config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value)
self.dense = self._config.peft.apply_linear(self.dense, TransformerSubLayerName.dense)
if self._debug.enabled:
self._query_dims = (
BlockDimNames.batch,
BlockDimNames.sequence_q,
CompositeTensorDim("heads", (head_group_dim, group_heads_dim)),
kv_channels_dim,
)
self._kv_dims = (
BlockDimNames.batch,
BlockDimNames.sequence_q,
head_group_dim,
kv_channels_dim,
)
self._context_dims = (
BlockDimNames.batch,
BlockDimNames.sequence_q,
dense_dim,
)

def _attn_fused(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor
Expand All @@ -133,16 +148,18 @@ def _attn_fused(
sk = key.size(1)

if self._local_head_groups == 1:
query = query.view(b, sq * self._local_heads, self._kv_channels)
query = query.view(b, sq * self._local_heads, self.__config.kv_channels)
key = key.transpose(-1, -2)
else:
query = (
query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self._kv_channels))
query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self.__config.kv_channels))
.transpose(1, 2)
.reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._kv_channels)
.reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self.__config.kv_channels)
)
key = key.unflatten(-1, (self._local_head_groups, self.__config.kv_channels)).movedim(1, 3).flatten(0, 1)
value = (
value.unflatten(-1, (self._local_head_groups, self.__config.kv_channels)).transpose(1, 2).flatten(0, 1)
)
key = key.unflatten(-1, (self._local_head_groups, self._kv_channels)).movedim(1, 3).flatten(0, 1)
value = value.unflatten(-1, (self._local_head_groups, self._kv_channels)).transpose(1, 2).flatten(0, 1)

attn_weights = torch.empty(
(b * self._local_head_groups, sq * self._local_heads_per_group, sk), device=query.device, dtype=query.dtype
Expand All @@ -169,7 +186,7 @@ def _attn_fused(
return attn_output.view(b, sq, -1)
else:
return (
attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._kv_channels)
attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._config.kv_channels)
.transpose(1, 2)
.flatten(2)
)
Expand All @@ -181,7 +198,7 @@ def _query_key_value_forward(

handle = None

if self._head_groups == 1 and self._sequence_parallel:
if self._config.head_groups == 1 and self._sequence_parallel:
key_value, handle = gather_op(
key_value, group=self._tensor_space.distributed.tensor_group, dim=0, async_op=True
)
Expand Down Expand Up @@ -226,7 +243,7 @@ def _query_key_value_backward(
if handle:
handle.wait()

if self._head_groups == 1 and (group := self._tensor_space.distributed.tensor_group):
if self._config.config.head_groups == 1 and (group := self._tensor_space.distributed.tensor_group):
if self._sequence_parallel:
key_value_grad = reduce_scatter_op(key_value_grad, group=group, dim=0)
else:
Expand Down Expand Up @@ -281,15 +298,15 @@ def forward(
query = query.transpose(0, 1).contiguous()
key_value = key_value.transpose(0, 1).contiguous()

key, value = key_value.split(self._local_head_groups * self._kv_channels, dim=-1)
key, value = key_value.split(self._local_head_groups * self.__config.kv_channels, dim=-1)

query = query.view(*query.shape[:2], self._local_heads, self._kv_channels)
key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels)
value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels)
query = query.view(*query.shape[:2], self._local_heads, self._config.kv_channels)
key = key.view(*key.shape[:2], self._local_head_groups, self._config.kv_channels)
value = value.view(*value.shape[:2], self._local_head_groups, self._config.kv_channels)

if self._debug.enabled:
self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs)
self._debug(key, "key_rotary_input", self._KV_DIMS, kwargs)
self._debug(query, "query_rotary_input", self._query_dims, kwargs)
self._debug(key, "key_rotary_input", self._kv_dims, kwargs)
query, key = self._rotary(query, key, kwargs)

window_size = self._decide_window_size()
Expand Down Expand Up @@ -337,10 +354,10 @@ def forward(
)

if self._debug.enabled:
self._debug(query, "query", self._QUERY_DIMS, kwargs)
self._debug(key, "key", self._KV_DIMS, kwargs)
self._debug(value, "value", self._KV_DIMS, kwargs)
self._debug(input_, "context", self._CONTEXT_DIMS, kwargs)
self._debug(query, "query", self._query_dims, kwargs)
self._debug(key, "key", self._kv_dims, kwargs)
self._debug(value, "value", self._kv_dims, kwargs)
self._debug(input_, "context", self._context_dims, kwargs)

if sequence_first:
# TODO: Optimize (is contiguous avoidable? Transpose dense output?)
Expand Down
Loading