Skip to content

Commit 28dbfee

Browse files
GdoongMathewBordabhimrazySkafteNicki
authored
Set _DeviceDtypeModuleMixin _device from torch's default device function (#21164)
* fix: set `_DeviceDtypeModuleMixin` _device from torch's default device function. * add torch 2.2.2 or below compatibility * add import for `_TORCH_GREATER_EQUAL_2_3 * fix: restore torch default dtype once test_submodules_context_device_and_dtype is finished. --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Bhimraj Yadav <[email protected]> Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
1 parent 75a6bf7 commit 28dbfee

File tree

4 files changed

+49
-1
lines changed

4 files changed

+49
-1
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626
[#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093))
2727

2828

29+
- Set `_DeviceDtypeModuleMixin._device` from torch's default device function ([#21164](https://github.com/Lightning-AI/pytorch-lightning/pull/21164))
30+
31+
2932
### Fixed
3033

3134
- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))

src/lightning/fabric/utilities/device_dtype_mixin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@
1818
from torch.nn import Module
1919
from typing_extensions import Self, override
2020

21+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
22+
2123

2224
class _DeviceDtypeModuleMixin(Module):
2325
__jit_unused_properties__: list[str] = ["device", "dtype"]
2426

2527
def __init__(self) -> None:
2628
super().__init__()
2729
self._dtype: Union[str, torch.dtype] = torch.get_default_dtype()
28-
self._device = torch.device("cpu")
30+
# Workarounds from the original pytorch issue:
31+
# https://github.com/pytorch/pytorch/issues/115333#issuecomment-1848449687
32+
self._device = torch.get_default_device() if _TORCH_GREATER_EQUAL_2_3 else torch.empty(0).device
2933

3034
@property
3135
def dtype(self) -> Union[str, torch.dtype]:

tests/tests_fabric/utilities/test_device_dtype_mixin.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
from torch import nn as nn
44

5+
from lightning.fabric.plugins.precision.utils import _DtypeContextManager
56
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
67
from tests_fabric.helpers.runif import RunIf
78

@@ -50,6 +51,30 @@ def test_submodules_device_and_dtype(dst_device_str, dst_type):
5051
assert model.dtype == model.module.module.dtype == dst_type
5152

5253

54+
@pytest.mark.parametrize(
55+
"dst_device_str",
56+
[
57+
"cpu",
58+
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
59+
pytest.param("mps:0", marks=RunIf(mps=True)),
60+
],
61+
)
62+
@pytest.mark.parametrize(
63+
"dst_type",
64+
[
65+
torch.float,
66+
pytest.param(torch.half, marks=RunIf(mps=False)),
67+
pytest.param(torch.double, marks=RunIf(mps=False)),
68+
],
69+
)
70+
def test_submodules_context_device_and_dtype(dst_device_str, dst_type):
71+
dst_device = torch.device(dst_device_str)
72+
with _DtypeContextManager(dst_type), dst_device:
73+
model = TopModule()
74+
assert model.device == dst_device
75+
assert model.dtype == dst_type
76+
77+
5378
@pytest.mark.parametrize(
5479
"device",
5580
[

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,6 +2107,22 @@ def test_init_module_context(monkeypatch):
21072107
strategy.tensor_init_context.reset_mock()
21082108

21092109

2110+
@pytest.mark.parametrize(
2111+
("target_device", "accelerator", "devices"),
2112+
[
2113+
("cpu", "cpu", "auto"),
2114+
pytest.param("cuda:0", "gpu", [0], marks=RunIf(min_cuda_gpus=1)),
2115+
pytest.param("cuda:1", "gpu", [1], marks=RunIf(min_cuda_gpus=2)),
2116+
],
2117+
)
2118+
def test_init_module_device_type(target_device, accelerator, devices):
2119+
"""Test that the strategy returns the context manager for initializing the module."""
2120+
trainer = Trainer(accelerator=accelerator, devices=devices)
2121+
with trainer.init_module():
2122+
model = BoringModel()
2123+
assert model.device == torch.device(target_device)
2124+
2125+
21102126
def test_expand_home_trainer():
21112127
"""Test that the dirpath gets expanded if it contains `~`."""
21122128
home_root = Path.home()

0 commit comments

Comments
 (0)