Skip to content

Block interface: refactor #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: tp_mamba
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Megatron-LM
30 changes: 15 additions & 15 deletions docs/developer_guide/conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,21 @@ Continuing our `AwesomeModel` handler example, we define:

```python
def _create_weight_converters(self) -> list[WeightConverter]:
converters = []
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
num_layers = self._model.config.base_model.transformer.num_layers

# A simple renaming example, for the word embeddings.
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))

# We usually want to loop dynamically over layers
for i in range(num_layers):
# A `SplitWeightConverter` example, splitting a weight in two.
converters.append(SplitWeightConverter(
f"layers.{i + 1}.weight",
(f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"),
))
return converters
converters = []
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
num_layers = self._model.config.base_model.transformer.num_layers

# A simple renaming example, for the word embeddings.
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))

# We usually want to loop dynamically over layers
for i in range(num_layers):
# A `SplitWeightConverter` example, splitting a weight in two.
converters.append(SplitWeightConverter(
f"layers.{i + 1}.weight",
(f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"),
))
return converters
```

And that's it! We're ready to use the new checkpoint format in Fast-LLM.
Expand Down
22 changes: 22 additions & 0 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,28 @@ def __init__(self, config: ConfigType, *args, **kwargs):
# Handle multiple inheritance.
super().__init__(*args, **kwargs)

def __init_subclass__(cls):
# Automatically set `config_class` based on the bound type.
# Make sure `ConfigType` is bound and respects class hierarchy.
try:
config_class = None
for base in types.get_original_bases(cls):
if hasattr(base, "__origin__") and issubclass(base.__origin__, Configurable):
for arg in base.__args__:
if arg.__name__ == "ConfigType":
if config_class is None:
config_class = arg.__bound__
else:
assert arg.__bound__ is config_class
assert config_class is not None
except Exception as e:
raise TypeError(
f"Could not determine the configuration class for the configurable class {cls.__name__}: {e.args}. "
"Please make sure to declare in the format "
f"`class {cls.__name__}[ConfigType: ConfigClass](BaseConfigurable[ConfigType])`.] "
)
cls.config_class = config_class

@property
def config(self) -> ConfigType:
return self._config
Expand Down
2 changes: 0 additions & 2 deletions fast_llm/data/preparator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ def _get_runnable(self) -> typing.Callable[[], None]:


class DatasetPreparator[ConfigType: DatasetPreparatorConfig](Configurable[ConfigType], abc.ABC):
config_class: typing.ClassVar[type[DatasetPreparatorConfig]] = DatasetPreparatorConfig

@abc.abstractmethod
def run(self) -> None:
raise NotImplementedError
2 changes: 0 additions & 2 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@


class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]):
config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig

_tokenizer: Tokenizer
_data_type: DataType
_text_column: str
Expand Down
58 changes: 26 additions & 32 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from fast_llm.config import Configurable
from fast_llm.engine.base_model.config import BaseModelConfig
from fast_llm.engine.config_utils.tensor_space import TensorSpace
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.tensor import ParameterMeta, TensorMeta
Expand All @@ -20,11 +19,18 @@
class Module(torch.nn.Module, abc.ABC):
""" """

def forward(self, input_, kwargs):
"""
Run a forward pass for the module, with autograd support.
"""
raise NotImplementedError()
_is_setup: bool = False
_distributed: Distributed

def __init__(self, distributed_config: DistributedConfig):
self._distributed_config = distributed_config
super().__init__()

def setup(self, distributed: Distributed) -> None:
assert not self._is_setup
distributed.check_config(self._distributed_config)
self._distributed = distributed
self._is_setup = True


class Layer(Module):
Expand All @@ -39,9 +45,9 @@ def forward(


class Sequential(Layer):
def __init__(self, layers: list[Layer]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def __init__(self, distributed_config: DistributedConfig):
super().__init__(distributed_config)
self.layers = torch.nn.ModuleList(self.get_layers())

def __getitem__(self, item):
return self.layers[item]
Expand All @@ -59,6 +65,15 @@ def forward(
input_ = layer(input_, kwargs, losses, metrics)
return input_

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass

def setup(self, distributed: Distributed) -> None:
super().setup(distributed)
for layer in self.layers:
layer.setup(distributed)


@dataclasses.dataclass()
class LossDef:
Expand All @@ -71,29 +86,14 @@ class LossDef:
dtype: torch.dtype = torch.float32


class SequentialLayers(Sequential, abc.ABC):
# Small class defined to fix the MRO of BaseModel.__init__
def __init__(self):
super().__init__(self.get_layers())

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass


class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], SequentialLayers, abc.ABC):
config_class: typing.ClassVar[type[BaseModelConfig]] = BaseModelConfig
_is_setup: bool = False
class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential):

def __init__(
self,
config: BaseModelConfig,
distributed_config: DistributedConfig,
):
self._tensor_space: TensorSpace = TensorSpace(distributed_config)
config.setup_tensor_space(self._tensor_space)

super().__init__(config)
super().__init__(config, distributed_config)

for key, value in self.named_parameters():
Assert.custom(isinstance, value, ParameterMeta)
Expand All @@ -104,12 +104,6 @@ def __init__(
# TODO: Add basic handling (preprocessor) in this class.
self._reference_models: dict[str, "InferenceRunner"] = {}

def setup(self, distributed: Distributed) -> None:
assert not self._is_setup
distributed.check_config(self._tensor_space.distributed_config)
self._tensor_space.setup(distributed)
self._is_setup = True

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass
Expand Down
7 changes: 2 additions & 5 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fast_llm.utils import compare_nested, log

if typing.TYPE_CHECKING:
from fast_llm.engine.config_utils.tensor_space import TensorSpace
import torch


@config_class()
Expand All @@ -18,9 +18,6 @@ class BaseModelConfig(Config):

_abstract = True

def setup_tensor_space(self, tensor_space: "TensorSpace") -> None:
raise NotImplementedError()

def compare_architecture(
self,
model_config: typing.Self,
Expand Down Expand Up @@ -64,5 +61,5 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
pass

@abc.abstractmethod
def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None:
def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
pass
57 changes: 57 additions & 0 deletions fast_llm/engine/config_utils/initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import abc
import typing

if typing.TYPE_CHECKING:
import torch

from fast_llm.tensor import ParameterMeta


class Initializer(abc.ABC):
@abc.abstractmethod
def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None:
pass

requires_global_initialization = False


class LambdaInitializer(Initializer):
def __init__(
self,
init_method: typing.Callable[["ParameterMeta", "torch.Tensor", "torch.Generator"], None],
requires_global_initialization: bool = False,
) -> None:
self._init_method = init_method
self.requires_global_initialization = requires_global_initialization

def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None:
return self._init_method(meta, tensor, generator)


def init_fill_(value: float) -> LambdaInitializer:
def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa
tensor.fill_(value)

return LambdaInitializer(init_)


init_zeros_ = init_fill_(0.0)
init_ones_ = init_fill_(1.0)


def init_normal_(
mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None
) -> LambdaInitializer:
def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa
tensor = tensor.normal_(mean, std, generator=generator)
if min_val is not None or max_val is not None:
tensor.clamp_(min=min_val, max=max_val)

return LambdaInitializer(init_)


def init_uniform_centered_(scale: float, mean: float = 0.0) -> LambdaInitializer:
def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa
tensor.uniform_(mean - scale, mean + scale, generator=generator)

return LambdaInitializer(init_)
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import math
import typing

from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim
from fast_llm.engine.distributed.config import DistributedDim
from fast_llm.utils import Assert, div

if typing.TYPE_CHECKING:
import torch

from fast_llm.core.distributed import ProcessGroup
from fast_llm.engine.distributed.distributed import Distributed

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -219,49 +218,4 @@ def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = F
)


class DefaultDimNames:
# Scalar
scalar = "scalar"


class TensorSpace:
_is_setup: bool = False
_distributed: "Distributed"

def __init__(self, distributed_config: DistributedConfig):
self._distributed_config = distributed_config
self._tensor_dims: dict[str, TensorDim] = {}
self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1))

def setup(self, distributed: "Distributed") -> None:
assert not self._is_setup
if distributed.config is not self._distributed_config:
distributed.config.compare(self._distributed_config, ValueError)
self._is_setup = True
self._distributed = distributed

@property
def distributed_config(self) -> DistributedConfig:
return self._distributed_config

@property
def distributed(self) -> "Distributed":
assert self._is_setup
return self._distributed

def add_tensor_dim(self, tensor_dim: TensorDim) -> None:
if tensor_dim.name in self._tensor_dims:
Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name])
else:
if tensor_dim.parallel_dim is not None:
assert (
tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims
), tensor_dim.parallel_dim.name
Assert.eq(
tensor_dim.parallel_dim.__dict__,
self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__,
)
self._tensor_dims[tensor_dim.name] = tensor_dim

def __getitem__(self, name: str) -> TensorDim:
return self._tensor_dims[name]
scalar_dim = TensorDim("scalar", 1)
3 changes: 0 additions & 3 deletions fast_llm/engine/distributed/distributed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datetime
import logging
import typing

import torch
import torch.distributed
Expand Down Expand Up @@ -146,8 +145,6 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]):
TODO: Clarify cpu support.
"""

config_class: typing.ClassVar[type[DistributedConfig]] = DistributedConfig

def __init__(self, config: DistributedConfig, use_cpu: bool = False):
super().__init__(config)
assert self._config.reference_config is None
Expand Down
4 changes: 0 additions & 4 deletions fast_llm/engine/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ class EvaluatorSamplingParameters:


class Evaluator[ConfigType: EvaluatorConfig](Configurable[ConfigType], abc.ABC):
config_class: typing.ClassVar[type[EvaluatorConfig]] = EvaluatorConfig

_is_setup: bool = False

def __init__(
Expand Down Expand Up @@ -96,8 +94,6 @@ def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None:


class LossEvaluator[ConfigType: LossEvaluatorConfig](Evaluator[ConfigType]):
config_class: typing.ClassVar[type[LossEvaluatorConfig]] = LossEvaluatorConfig

def setup(
self,
distributed: Distributed,
Expand Down
2 changes: 0 additions & 2 deletions fast_llm/engine/evaluation/lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@


class LmEvalEvaluator[ConfigType: LmEvalEvaluatorConfig](Evaluator[ConfigType]):
config_class: typing.ClassVar[type[LmEvalEvaluatorConfig]] = LmEvalEvaluatorConfig

_hf_model: "HuggingfaceBaseModelForCausalLM" = None
_flm_wrapper: "FastLLMLmEvalWrapper" = None

Expand Down
1 change: 0 additions & 1 deletion fast_llm/engine/multi_stage/fast_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


class FastLLMModel[ConfigType: FastLLMModelConfig](MultiStageModel[ConfigType]):
config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig
_is_loaded: bool = False

def save_checkpoint(
Expand Down
Loading