diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index e949a9fac5a9b..25499f40b729c 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093)) +- Set `_DeviceDtypeModuleMixin._device` from torch's default device function ([#21164](https://github.com/Lightning-AI/pytorch-lightning/pull/21164)) + + ### Fixed - Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105)) diff --git a/src/lightning/fabric/utilities/device_dtype_mixin.py b/src/lightning/fabric/utilities/device_dtype_mixin.py index ff5a0949e4207..527ed90203e46 100644 --- a/src/lightning/fabric/utilities/device_dtype_mixin.py +++ b/src/lightning/fabric/utilities/device_dtype_mixin.py @@ -18,6 +18,8 @@ from torch.nn import Module from typing_extensions import Self, override +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 + class _DeviceDtypeModuleMixin(Module): __jit_unused_properties__: list[str] = ["device", "dtype"] @@ -25,7 +27,9 @@ class _DeviceDtypeModuleMixin(Module): def __init__(self) -> None: super().__init__() self._dtype: Union[str, torch.dtype] = torch.get_default_dtype() - self._device = torch.device("cpu") + # Workarounds from the original pytorch issue: + # https://github.com/pytorch/pytorch/issues/115333#issuecomment-1848449687 + self._device = torch.get_default_device() if _TORCH_GREATER_EQUAL_2_3 else torch.empty(0).device @property def dtype(self) -> Union[str, torch.dtype]: diff --git a/tests/tests_fabric/utilities/test_device_dtype_mixin.py b/tests/tests_fabric/utilities/test_device_dtype_mixin.py index 1261ca5e0accb..bf24490b0cf32 100644 --- a/tests/tests_fabric/utilities/test_device_dtype_mixin.py +++ b/tests/tests_fabric/utilities/test_device_dtype_mixin.py @@ -2,6 +2,7 @@ import torch from torch import nn as nn +from lightning.fabric.plugins.precision.utils import _DtypeContextManager from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from tests_fabric.helpers.runif import RunIf @@ -50,6 +51,30 @@ def test_submodules_device_and_dtype(dst_device_str, dst_type): assert model.dtype == model.module.module.dtype == dst_type +@pytest.mark.parametrize( + "dst_device_str", + [ + "cpu", + pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps:0", marks=RunIf(mps=True)), + ], +) +@pytest.mark.parametrize( + "dst_type", + [ + torch.float, + pytest.param(torch.half, marks=RunIf(mps=False)), + pytest.param(torch.double, marks=RunIf(mps=False)), + ], +) +def test_submodules_context_device_and_dtype(dst_device_str, dst_type): + dst_device = torch.device(dst_device_str) + with _DtypeContextManager(dst_type), dst_device: + model = TopModule() + assert model.device == dst_device + assert model.dtype == dst_type + + @pytest.mark.parametrize( "device", [ diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index da79e2fdc411b..ea3c31a370fce 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -2107,6 +2107,22 @@ def test_init_module_context(monkeypatch): strategy.tensor_init_context.reset_mock() +@pytest.mark.parametrize( + ("target_device", "accelerator", "devices"), + [ + ("cpu", "cpu", "auto"), + pytest.param("cuda:0", "gpu", [0], marks=RunIf(min_cuda_gpus=1)), + pytest.param("cuda:1", "gpu", [1], marks=RunIf(min_cuda_gpus=2)), + ], +) +def test_init_module_device_type(target_device, accelerator, devices): + """Test that the strategy returns the context manager for initializing the module.""" + trainer = Trainer(accelerator=accelerator, devices=devices) + with trainer.init_module(): + model = BoringModel() + assert model.device == torch.device(target_device) + + def test_expand_home_trainer(): """Test that the dirpath gets expanded if it contains `~`.""" home_root = Path.home()