File tree Expand file tree Collapse file tree 9 files changed +39
-10
lines changed Expand file tree Collapse file tree 9 files changed +39
-10
lines changed Original file line number Diff line number Diff line change 11from dataclasses import dataclass , field
22from typing import Optional , List
3+
4+ from torch .distributed .fsdp import ShardingStrategy
5+
6+
37@dataclass
48class ModelConfig :
59 file : str = "examples/aac_audiocaps/model/slam_model_aac.py:model_factory"
@@ -114,7 +118,7 @@ class FSDPConfig:
114118 mixed_precision : bool = True
115119 use_fp16 : bool = False
116120 # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
117- sharding_strategy : str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
121+ sharding_strategy : ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
118122 checkpoint_type : str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
119123 fsdp_activation_checkpointing : bool = True
120124 fsdp_cpu_offload : bool = False
Original file line number Diff line number Diff line change 11from dataclasses import dataclass , field
22from typing import Optional , List
3+
4+ from torch .distributed .fsdp import ShardingStrategy
5+
6+
37@dataclass
48class ModelConfig :
59 file : str = "examples/asr_librispeech/model/slam_model_asr.py:model_factory"
@@ -108,7 +112,7 @@ class FSDPConfig:
108112 mixed_precision : bool = True
109113 use_fp16 : bool = False
110114 # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
111- sharding_strategy : str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
115+ sharding_strategy : ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
112116 checkpoint_type : str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
113117 fsdp_activation_checkpointing : bool = True
114118 fsdp_cpu_offload : bool = False
Original file line number Diff line number Diff line change 11from dataclasses import dataclass , field
22from typing import Optional , List
3+
4+ from torch .distributed .fsdp import ShardingStrategy
5+
6+
37@dataclass
48class ModelConfig :
59 file : str = "examples/drcap_zeroshot_aac/model/slam_model_drcap.py:model_factory"
@@ -113,7 +117,7 @@ class FSDPConfig:
113117 mixed_precision : bool = True
114118 use_fp16 : bool = False
115119 # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
116- sharding_strategy : str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
120+ sharding_strategy : ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
117121 checkpoint_type : str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
118122 fsdp_activation_checkpointing : bool = True
119123 fsdp_cpu_offload : bool = False
Original file line number Diff line number Diff line change 11from dataclasses import dataclass , field
22from typing import Optional , List
3+ from torch .distributed .fsdp import ShardingStrategy
4+
5+
36@dataclass
47class ModelConfig :
58 file : str = "examples/mala_asr_slidespeech/model/slam_model_mala_asr.py:model_factory"
@@ -109,7 +112,7 @@ class FSDPConfig:
109112 mixed_precision : bool = True
110113 use_fp16 : bool = False
111114 # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
112- sharding_strategy : str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
115+ sharding_strategy : ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
113116 checkpoint_type : str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
114117 fsdp_activation_checkpointing : bool = True
115118 fsdp_cpu_offload : bool = False
Original file line number Diff line number Diff line change 11from dataclasses import dataclass , field
22from typing import Optional , List
3+
4+ from torch .distributed .fsdp import ShardingStrategy
5+
36@dataclass
47class ModelConfig :
58 file : str = "examples/mc_musiccaps/model/slam_model_mir.py:model_factory"
@@ -112,7 +115,7 @@ class FSDPConfig:
112115 mixed_precision : bool = True
113116 use_fp16 : bool = False
114117 # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
115- sharding_strategy : str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
118+ sharding_strategy : ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
116119 checkpoint_type : str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
117120 fsdp_activation_checkpointing : bool = True
118121 fsdp_cpu_offload : bool = False
Original file line number Diff line number Diff line change 11from dataclasses import dataclass , field
22from typing import Optional , List
33
4+ from torch .distributed .fsdp import ShardingStrategy
5+
6+
47@dataclass
58class ModelConfig :
69 file : str = "examples/seld_spatialsoundqa/model/slam_model_seld.py:model_factory"
@@ -97,7 +100,7 @@ class FSDPConfig:
97100 mixed_precision : bool = True
98101 use_fp16 : bool = False
99102 # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
100- sharding_strategy : str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
103+ sharding_strategy : ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
101104 checkpoint_type : str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
102105 fsdp_activation_checkpointing : bool = True
103106 fsdp_cpu_offload : bool = False
Original file line number Diff line number Diff line change 11from dataclasses import dataclass , field
22from typing import Optional , List
3+
4+ from torch .distributed .fsdp import ShardingStrategy
5+
6+
37@dataclass
48class ModelConfig :
59 llm_name : str = "vallex"
@@ -68,7 +72,7 @@ class FSDPConfig:
6872 mixed_precision : bool = True
6973 use_fp16 : bool = False
7074 # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
71- sharding_strategy : str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
75+ sharding_strategy : ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
7276 checkpoint_type : str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
7377 fsdp_activation_checkpointing : bool = True
7478 fsdp_cpu_offload : bool = False
Original file line number Diff line number Diff line change 11from dataclasses import dataclass , field
22from typing import Optional , List
3+
4+ from torch .distributed .fsdp import ShardingStrategy
5+
6+
37@dataclass
48class ModelConfig :
59 file : str = "examples/vsr_LRS3/model/slam_model_vsr.py:model_factory"
@@ -115,7 +119,7 @@ class FSDPConfig:
115119 mixed_precision : bool = True
116120 use_fp16 : bool = False
117121 # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
118- sharding_strategy : str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
122+ sharding_strategy : ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
119123 checkpoint_type : str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
120124 fsdp_activation_checkpointing : bool = True
121125 fsdp_cpu_offload : bool = False
Original file line number Diff line number Diff line change @@ -159,8 +159,8 @@ def main(kwargs: DictConfig):
159159 if not train_config .use_peft and train_config .freeze_layers :
160160
161161 freeze_transformer_layers (train_config .num_freeze_layers )
162- from torch .distributed .fsdp import ShardingStrategy
163- fsdp_config .sharding_strategy = getattr (ShardingStrategy , fsdp_config .sharding_strategy )
162+ # from torch.distributed.fsdp import ShardingStrategy
163+ # fsdp_config.sharding_strategy = getattr(ShardingStrategy, fsdp_config.sharding_strategy)
164164 mixed_precision_policy , wrapping_policy = get_policies (fsdp_config , rank )
165165 my_auto_wrapping_policy = fsdp_auto_wrap_policy (model , LlamaDecoderLayer )
166166
You can’t perform that action at this time.
0 commit comments