Skip to content

Commit 0045773

Browse files
authored
Merge pull request #153 from nuaalixu/main
fix #92 for fsdp training
2 parents 38d8c66 + 832bf02 commit 0045773

File tree

9 files changed

+39
-10
lines changed

9 files changed

+39
-10
lines changed

examples/aac_audiocaps/aac_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, List
3+
4+
from torch.distributed.fsdp import ShardingStrategy
5+
6+
37
@dataclass
48
class ModelConfig:
59
file: str = "examples/aac_audiocaps/model/slam_model_aac.py:model_factory"
@@ -114,7 +118,7 @@ class FSDPConfig:
114118
mixed_precision: bool = True
115119
use_fp16: bool = False
116120
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
117-
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
121+
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
118122
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
119123
fsdp_activation_checkpointing: bool = True
120124
fsdp_cpu_offload: bool = False

examples/asr_librispeech/asr_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, List
3+
4+
from torch.distributed.fsdp import ShardingStrategy
5+
6+
37
@dataclass
48
class ModelConfig:
59
file: str = "examples/asr_librispeech/model/slam_model_asr.py:model_factory"
@@ -108,7 +112,7 @@ class FSDPConfig:
108112
mixed_precision: bool = True
109113
use_fp16: bool = False
110114
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
111-
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
115+
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
112116
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
113117
fsdp_activation_checkpointing: bool = True
114118
fsdp_cpu_offload: bool = False

examples/drcap_zeroshot_aac/drcap_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, List
3+
4+
from torch.distributed.fsdp import ShardingStrategy
5+
6+
37
@dataclass
48
class ModelConfig:
59
file: str = "examples/drcap_zeroshot_aac/model/slam_model_drcap.py:model_factory"
@@ -113,7 +117,7 @@ class FSDPConfig:
113117
mixed_precision: bool = True
114118
use_fp16: bool = False
115119
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
116-
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
120+
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
117121
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
118122
fsdp_activation_checkpointing: bool = True
119123
fsdp_cpu_offload: bool = False

examples/mala_asr_slidespeech/mala_asr_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, List
3+
from torch.distributed.fsdp import ShardingStrategy
4+
5+
36
@dataclass
47
class ModelConfig:
58
file: str = "examples/mala_asr_slidespeech/model/slam_model_mala_asr.py:model_factory"
@@ -109,7 +112,7 @@ class FSDPConfig:
109112
mixed_precision: bool = True
110113
use_fp16: bool = False
111114
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
112-
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
115+
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
113116
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
114117
fsdp_activation_checkpointing: bool = True
115118
fsdp_cpu_offload: bool = False

examples/mc_musiccaps/mir_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, List
3+
4+
from torch.distributed.fsdp import ShardingStrategy
5+
36
@dataclass
47
class ModelConfig:
58
file: str = "examples/mc_musiccaps/model/slam_model_mir.py:model_factory"
@@ -112,7 +115,7 @@ class FSDPConfig:
112115
mixed_precision: bool = True
113116
use_fp16: bool = False
114117
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
115-
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
118+
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
116119
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
117120
fsdp_activation_checkpointing: bool = True
118121
fsdp_cpu_offload: bool = False

examples/seld_spatialsoundqa/seld_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, List
33

4+
from torch.distributed.fsdp import ShardingStrategy
5+
6+
47
@dataclass
58
class ModelConfig:
69
file: str = "examples/seld_spatialsoundqa/model/slam_model_seld.py:model_factory"
@@ -97,7 +100,7 @@ class FSDPConfig:
97100
mixed_precision: bool = True
98101
use_fp16: bool = False
99102
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
100-
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
103+
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
101104
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
102105
fsdp_activation_checkpointing: bool = True
103106
fsdp_cpu_offload: bool = False

examples/vallex/vallex_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, List
3+
4+
from torch.distributed.fsdp import ShardingStrategy
5+
6+
37
@dataclass
48
class ModelConfig:
59
llm_name: str = "vallex"
@@ -68,7 +72,7 @@ class FSDPConfig:
6872
mixed_precision: bool = True
6973
use_fp16: bool = False
7074
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
71-
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
75+
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
7276
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
7377
fsdp_activation_checkpointing: bool = True
7478
fsdp_cpu_offload: bool = False

examples/vsr_LRS3/vsr_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, List
3+
4+
from torch.distributed.fsdp import ShardingStrategy
5+
6+
37
@dataclass
48
class ModelConfig:
59
file: str = "examples/vsr_LRS3/model/slam_model_vsr.py:model_factory"
@@ -115,7 +119,7 @@ class FSDPConfig:
115119
mixed_precision: bool = True
116120
use_fp16: bool = False
117121
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
118-
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
122+
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
119123
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
120124
fsdp_activation_checkpointing: bool = True
121125
fsdp_cpu_offload: bool = False

src/slam_llm/pipeline/finetune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def main(kwargs: DictConfig):
159159
if not train_config.use_peft and train_config.freeze_layers:
160160

161161
freeze_transformer_layers(train_config.num_freeze_layers)
162-
from torch.distributed.fsdp import ShardingStrategy
163-
fsdp_config.sharding_strategy = getattr(ShardingStrategy, fsdp_config.sharding_strategy)
162+
# from torch.distributed.fsdp import ShardingStrategy
163+
# fsdp_config.sharding_strategy = getattr(ShardingStrategy, fsdp_config.sharding_strategy)
164164
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
165165
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
166166

0 commit comments

Comments
 (0)