Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def __init__(
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
self.use_fp8 = use_fp8
self.param_origin_shape = {}
for name, param in module.named_parameters():
self.param_origin_shape[name] = param.shape

shardformer = ShardFormer(shard_config)
if custom_policy is not None:
Expand Down
Loading
Loading