@@ -29,11 +29,15 @@ def create(rank_id: int):
2929 rank_id , start , stop )
3030
3131 @staticmethod
32- def set_iter ( iter_id : int ) -> bool :
32+ def should_record ( ) -> bool :
3333 if ExpertStatistic .expert_statistic_obj is not None :
34- return ExpertStatistic .expert_statistic_obj ._set_iter (iter_id )
35- else :
36- return False
34+ return ExpertStatistic .expert_statistic_obj ._should_record
35+ return False
36+
37+ @staticmethod
38+ def set_iter (iter_id : int ) -> None :
39+ if ExpertStatistic .expert_statistic_obj is not None :
40+ ExpertStatistic .expert_statistic_obj ._set_iter (iter_id )
3741
3842 @staticmethod
3943 def set_layer (layer_id : int ) -> None :
@@ -57,10 +61,10 @@ def __init__(self, rank_id: int, start: int, stop: int) -> None:
5761 self ._records = {}
5862
5963 @property
60- def should_record (self ) -> bool :
64+ def _should_record (self ) -> bool :
6165 return self .current_iter_id is not None and self .start <= self .current_iter_id < self .stop
6266
63- def _set_iter (self , iter_id : int ) -> bool :
67+ def _set_iter (self , iter_id : int ) -> None :
6468 self .current_iter_id = iter_id
6569 if iter_id == self .stop :
6670 logger .info (
@@ -74,14 +78,13 @@ def _set_iter(self, iter_id: int) -> bool:
7478 json .dump (self ._meta_info , f )
7579 safetensors .torch .save_file (
7680 self ._records , f"{ path } /rank{ self .rank_id } .safetensors" )
77- return self .should_record
7881
7982 def _set_layer (self , layer : int ) -> None :
8083 self .current_layer = layer
8184
8285 def _maybe_add_info (self , expert_count : int ,
8386 token_selected_experts : torch .Tensor ) -> None :
84- if not self .should_record :
87+ if not self ._should_record :
8588 return
8689
8790 if self ._meta_info is None :
0 commit comments