Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093))


- Set `_DeviceDtypeModuleMixin._device` from torch's default device function ([#21164](https://github.com/Lightning-AI/pytorch-lightning/pull/21164))


### Fixed

- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/fabric/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
from torch.nn import Module
from typing_extensions import Self, override

from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3


class _DeviceDtypeModuleMixin(Module):
__jit_unused_properties__: list[str] = ["device", "dtype"]

def __init__(self) -> None:
super().__init__()
self._dtype: Union[str, torch.dtype] = torch.get_default_dtype()
self._device = torch.device("cpu")
# Workarounds from the original pytorch issue:
# https://github.com/pytorch/pytorch/issues/115333#issuecomment-1848449687
self._device = torch.get_default_device() if _TORCH_GREATER_EQUAL_2_3 else torch.empty(0).device

@property
def dtype(self) -> Union[str, torch.dtype]:
Expand Down
25 changes: 25 additions & 0 deletions tests/tests_fabric/utilities/test_device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch import nn as nn

from lightning.fabric.plugins.precision.utils import _DtypeContextManager
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from tests_fabric.helpers.runif import RunIf

Expand Down Expand Up @@ -50,6 +51,30 @@ def test_submodules_device_and_dtype(dst_device_str, dst_type):
assert model.dtype == model.module.module.dtype == dst_type


@pytest.mark.parametrize(
"dst_device_str",
[
"cpu",
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
pytest.param("mps:0", marks=RunIf(mps=True)),
],
)
@pytest.mark.parametrize(
"dst_type",
[
torch.float,
pytest.param(torch.half, marks=RunIf(mps=False)),
pytest.param(torch.double, marks=RunIf(mps=False)),
],
)
def test_submodules_context_device_and_dtype(dst_device_str, dst_type):
dst_device = torch.device(dst_device_str)
with _DtypeContextManager(dst_type), dst_device:
model = TopModule()
assert model.device == dst_device
assert model.dtype == dst_type


@pytest.mark.parametrize(
"device",
[
Expand Down
16 changes: 16 additions & 0 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,22 @@ def test_init_module_context(monkeypatch):
strategy.tensor_init_context.reset_mock()


@pytest.mark.parametrize(
("target_device", "accelerator", "devices"),
[
("cpu", "cpu", "auto"),
pytest.param("cuda:0", "gpu", [0], marks=RunIf(min_cuda_gpus=1)),
pytest.param("cuda:1", "gpu", [1], marks=RunIf(min_cuda_gpus=2)),
],
)
def test_init_module_device_type(target_device, accelerator, devices):
"""Test that the strategy returns the context manager for initializing the module."""
trainer = Trainer(accelerator=accelerator, devices=devices)
with trainer.init_module():
model = BoringModel()
assert model.device == torch.device(target_device)


def test_expand_home_trainer():
"""Test that the dirpath gets expanded if it contains `~`."""
home_root = Path.home()
Expand Down
Loading