Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightly_train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from lightly_train._commands.predict_task import predict_semantic_segmentation
from lightly_train._commands.train import pretrain, train
from lightly_train._commands.train_task import (
train_depth_estimation,
train_image_classification,
train_image_classification_multihead,
train_instance_segmentation,
Expand Down Expand Up @@ -74,6 +75,7 @@
"ModelPart",
"predict_semantic_segmentation",
"pretrain",
"train_depth_estimation",
"train_image_classification",
"train_image_classification_multihead",
"train_instance_segmentation",
Expand Down
174 changes: 174 additions & 0 deletions src/lightly_train/_commands/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from lightly_train._data.coco_object_detection_dataset import (
COCOObjectDetectionDataArgs,
)
from lightly_train._data.depth_estimation_dataset import (
DepthEstimationDataArgs,
)
from lightly_train._data.image_classification_dataset import (
ImageClassificationMulticlassDataArgs,
ImageClassificationMultilabelDataArgs,
Expand Down Expand Up @@ -74,6 +77,171 @@
logger = logging.getLogger(__name__)


def train_depth_estimation(
*,
out: PathLike,
data: dict[str, Any] | str,
model: str,
steps: int | Literal["auto"] = "auto",
batch_size: int | Literal["auto"] = "auto",
num_workers: int | Literal["auto"] = "auto",
devices: int | str | list[int] = "auto",
num_nodes: int = 1,
resume_interrupted: bool = False,
checkpoint: PathLike | None = None,
overwrite: bool = False,
accelerator: str = "auto",
strategy: str = "auto",
precision: _PRECISION_INPUT = "bf16-mixed",
float32_matmul_precision: Literal["auto", "highest", "high", "medium"] = "auto",
seed: int = 0,
logger_args: dict[str, Any] | None = None,
model_args: dict[str, Any] | None = None,
transform_args: dict[str, Any] | None = None,
metric_args: dict[str, Any] | None = None,
loader_args: dict[str, Any] | None = None,
save_checkpoint_args: dict[str, Any] | None = None,
torch_compile_args: dict[str, Any] | None = None,
gradient_accumulation_steps: int | Literal["auto"] = "auto",
) -> None:
"""Train a depth estimation model.

Fine-tunes a Depth Anything V3 model by distilling depth and sky pseudo-labels
(e.g. generated with the V3 ViT-L teacher) into a smaller backbone.

The training process can be monitored with TensorBoard:

.. code-block:: bash

tensorboard --logdir out

After training, the last model checkpoint is saved in the out directory to:
``out/checkpoints/last.ckpt`` and also exported to
``out/exported_models/exported_last.pt``.

Args:
out:
The output directory where the model checkpoints and logs are saved.
data:
The dataset configuration or path to a YAML file with the configuration.
Each split points to a directory of RGB images and directories of depth and
sky pseudo-labels (``.npy`` files matched to each image by filename stem)::

data={
"train": {"images": ".../train/images",
"depth": ".../train/depth",
"sky": ".../train/sky"},
"val": {"images": ".../val/images",
"depth": ".../val/depth",
"sky": ".../val/sky"},
}

Depth pixels with value ``<= 0`` are treated as invalid and ignored by the
loss and metrics.
model:
The model to train. For example, "dinov2/dav3-relative-small",
"dinov3/dav3-relative-tiny", "dinov3/dav3-relative-tiny-plus", or a
path to a local model checkpoint.

If you want to resume training from an interrupted or crashed run, use the
``resume_interrupted`` parameter.
steps:
The number of training steps.
batch_size:
Global batch size. The batch size per device/GPU is inferred from this value
and the number of devices and nodes.
num_workers:
Number of workers for the dataloader per device/GPU. 'auto' automatically
sets the number of workers based on the available CPU cores.
devices:
Number of devices/GPUs for training. 'auto' automatically selects all
available devices. The device type is determined by the ``accelerator``
parameter.
num_nodes:
Number of nodes for distributed training.
checkpoint:
Use this parameter to further fine-tune a model from a previous fine-tuned
checkpoint. The checkpoint must be a path to a checkpoint file, for example
"checkpoints/model.ckpt". This will only load the model weights from the
previous run. All other training state (e.g. optimizer state, epochs) from
the previous run are not loaded.

This option is equivalent to setting ``model="<path_to_checkpoint>"``.

If you want to resume training from an interrupted or crashed run, use the
``resume_interrupted`` parameter instead.
resume_interrupted:
Set this to True if you want to resume training from an **interrupted or
crashed** training run. This will pick up exactly where the training left
off, including the optimizer state and the current step.

- You must use the same ``out`` directory as the interrupted run.
- You must **NOT** change any training parameters (e.g., learning rate, batch size, data, etc.).
- This is intended for continuing the same run without modification.
overwrite:
Overwrite the output directory if it already exists. Warning, this might
overwrite existing files in the directory!
accelerator:
Hardware accelerator. Can be one of ['cpu', 'gpu', 'mps', 'auto'].
'auto' will automatically select the best accelerator available.
strategy:
Training strategy. For example 'ddp' or 'auto'. 'auto' automatically
selects the best strategy available.
precision:
Training precision. Select '16-mixed' for mixed 16-bit precision, '32-true'
for full 32-bit precision, or 'bf16-mixed' for mixed bfloat16 precision.
float32_matmul_precision:
Precision for float32 matrix multiplication. Can be one of ['auto',
'highest', 'high', 'medium']. See https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
for more information.
seed:
Random seed for reproducibility.
logger_args:
Logger arguments. Either None or a dictionary of logger names to either
None or a dictionary of logger arguments. None uses the default loggers.
To disable a logger, set it to None: ``logger_args={"tensorboard": None}``.
model_args:
Model training arguments. Either None or a dictionary of model arguments.
transform_args:
Transform arguments. Either None or a dictionary of transform arguments.
The image size and normalization parameters can be set with
``transform_args={"image_size": (height, width), "normalize": {"mean": (r, g, b), "std": (r, g, b)}}``
metric_args:
Metric arguments. Either None or a dictionary of metric arguments.
Set ``metric_args={"train": True}`` to also compute depth metrics on the
training data. Set ``metric_args={"watch_metric": "val_metric/rmse"}`` to
configure the metric used to select the best checkpoint.
loader_args:
Arguments for the PyTorch DataLoader. Should only be used in special cases
as default values are automatically set. Prefer to use the `batch_size` and
`num_workers` arguments instead. For details, see:
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
save_checkpoint_args:
Arguments to configure the saving of checkpoints. The checkpoint frequency
can be set with ``save_checkpoint_args={"save_every_num_steps": 100}``.
torch_compile_args:
Arguments to configure model compilation with torch.compile. The arguments
are directly passed to torch.compile. Set
``torch_compile_args={"disable": True}`` to disable it if you encounter any
issues.
gradient_accumulation_steps:
Number of gradient accumulation steps. 'auto' automatically enables
gradient accumulation when batch_size is smaller than the model's default
batch size, using ``max(1, default_batch_size // batch_size)`` steps to
keep the effective batch size and learning rate close to the model defaults.
Set to 1 to explicitly disable gradient accumulation.
"""
tracker.track_training_started(
task_type="depth_estimation",
model=model,
method="depth_anything",
batch_size=batch_size,
devices=devices,
steps=steps,
)
return _train_task(config_cls=DepthEstimationTrainTaskConfig, **locals())


def train_image_classification(
*,
out: PathLike,
Expand Down Expand Up @@ -1974,6 +2142,7 @@ class TrainTaskConfig(PydanticConfig):
data: TaskDataArgs
model: str
task: Literal[
"depth_estimation",
"image_classification",
"image_classification_multihead",
"instance_segmentation",
Expand Down Expand Up @@ -2017,6 +2186,11 @@ def _load_yaml_if_path(cls, v: Any) -> Any:
)


class DepthEstimationTrainTaskConfig(TrainTaskConfig):
data: DepthEstimationDataArgs
task: Literal["depth_estimation"] = "depth_estimation"


class ImageClassificationMulticlassTrainTaskConfig(TrainTaskConfig):
data: ImageClassificationMulticlassDataArgs
task: Literal["image_classification"] = "image_classification"
Expand Down
9 changes: 9 additions & 0 deletions src/lightly_train/_commands/train_task_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
)
from lightly_train._task_checkpoint import TaskSaveCheckpointArgs
from lightly_train._task_models import task_model_helpers
from lightly_train._task_models.depth_estimation.train_model import (
DepthEstimationTrain,
)
from lightly_train._task_models.dinov2_eomt_instance_segmentation.train_model import (
DINOv2EoMTInstanceSegmentationTrain,
)
Expand Down Expand Up @@ -116,6 +119,7 @@


TASK_TRAIN_MODEL_CLASSES: list[type[TrainModel]] = [
DepthEstimationTrain,
ImageClassificationTrain,
ImageClassificationMultiheadTrain,
DINOv2EoMTInstanceSegmentationTrain,
Expand All @@ -133,6 +137,11 @@

# TODO(Thomas, 10/25): Create a type for the metrics.
TASK_TO_METRICS: dict[str, dict[str, str]] = {
"depth_estimation": {
"val_metric/rmse": "Val RMSE",
"val_metric/abs_rel": "Val AbsRel",
"val_metric/delta1": "Val delta<1.25",
},
"instance_segmentation": {
"val_metric/map": "Val mAP@0.5:0.95",
"val_metric/map_50": "Val mAP@0.5",
Expand Down
Loading
Loading