@@ -1051,47 +1051,46 @@ def prepare_context_mla_with_cached_kv(self,
10511051 def update_spec_dec_param (
10521052 self ,
10531053 is_spec_decoding_enabled ,
1054- is_spec_dec_tree ,
1055- is_spec_dec_dynamic_tree ,
1056- max_draft_tokens ,
1054+ spec_metadata ,
1055+ spec_tree_manager ,
1056+ max_draft_len ,
1057+ max_total_draft_tokens ,
10571058 spec_decoding_tensor : Optional ['SpecDecodingTensor' ] = None ,
10581059 ):
10591060
10601061 if spec_decoding_tensor is not None :
1061- spec_decoding_position_offsets = spec_decoding_tensor .position_offsets
1062- spec_decoding_packed_mask = spec_decoding_tensor .packed_mask
1063- spec_decoding_generation_lengths = spec_decoding_tensor .generation_lengths
1062+ spec_decoding_tensor .position_offsets
1063+ spec_decoding_tensor .packed_mask
1064+ spec_decoding_tensor .generation_lengths
10641065 else :
1065- spec_decoding_position_offsets = None
1066- spec_decoding_packed_mask = None
1067- spec_decoding_generation_lengths = None
1066+ pass
10681067 # spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
10691068 self .is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version (
10701069 ) < 100
10711070
1071+ self .is_spec_dec_tree = False if spec_tree_manager is None else True
1072+ self .is_spec_dec_dynamic_tree = False if spec_tree_manager is None else spec_tree_manager .use_dynamic_tree
1073+
10721074 if get_sm_version () >= 100 :
1073- if is_spec_dec_tree or is_spec_dec_dynamic_tree :
1074- assert not is_spec_dec_tree , "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
1075- assert not is_spec_dec_dynamic_tree , "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."
1075+ if self . is_spec_dec_tree or self . is_spec_dec_dynamic_tree :
1076+ assert not self . is_spec_dec_tree , "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
1077+ assert not self . is_spec_dec_dynamic_tree , "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."
10761078
10771079 # use_spec_decoding is default to true by default, change in runtime by layers / requests
10781080 self .use_spec_decoding = self .is_spec_decoding_enabled
10791081
1080- self .is_spec_dec_tree = is_spec_dec_tree
1081- self .is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree
1082-
10831082 # Parameters can be fixed and not changed during runtime if the
10841083 if self .is_spec_decoding_enabled :
10851084 self .spec_decoding_position_offsets = torch .empty (
1086- [self .max_num_requests , max_draft_tokens + 1 ],
1085+ [self .max_num_requests , max_total_draft_tokens + 1 ],
10871086 dtype = torch .int ,
10881087 device = 'cuda' ,
10891088 )
10901089
10911090 self .spec_decoding_packed_mask = torch .empty (
10921091 [
1093- self .max_num_requests , max_draft_tokens + 1 ,
1094- math .ceil ((max_draft_tokens + 1 ) / 32 )
1092+ self .max_num_requests , max_total_draft_tokens + 1 ,
1093+ math .ceil ((max_total_draft_tokens + 1 ) / 32 )
10951094 ],
10961095 dtype = torch .int ,
10971096 device = 'cuda' ,
@@ -1103,30 +1102,41 @@ def update_spec_dec_param(
11031102 device = 'cuda' ,
11041103 )
11051104
1106- if self .is_spec_dec_dynamic_tree :
1107- assert spec_decoding_position_offsets is not None , "spec_decoding_position_offsets is required for dynamic tree"
1108- assert spec_decoding_packed_mask is not None , "spec_decoding_packed_mask is required for dynamic tree"
1109- self .spec_decoding_position_offsets .copy_ (
1110- spec_decoding_position_offsets , non_blocking = True )
1111- self .spec_decoding_packed_mask .copy_ (spec_decoding_packed_mask ,
1112- non_blocking = True )
1113- if spec_decoding_generation_lengths is not None :
1114- self .spec_decoding_generation_lengths .copy_ (
1115- spec_decoding_generation_lengths , non_blocking = True )
1105+ # Prepare the spec-dec mask, position offset and generation length for static tree of dynamic tree.
1106+ # We only prepare the spec-dec mask, position offset and generation length for the target model here.
1107+ # For the drafter model, we will prepare them in the drafting loops.
1108+ is_target_model = not spec_metadata .is_draft_model
1109+ is_using_tree = self .is_spec_dec_tree or self .is_spec_dec_dynamic_tree
1110+ if is_target_model and is_using_tree :
1111+ assert spec_metadata .spec_dec_mode .is_eagle3 (
1112+ ), "Tree decoding is only supported for Eagle3 now"
1113+ # If is the dynamic tree
1114+ if self .is_spec_dec_dynamic_tree :
1115+ # TODO: add dynamic tree logic
1116+ assert False , "Dynamic tree is not supported yet"
1117+ # If is the static tree
11161118 else :
1117- self .generate_spec_decoding_generation_length (
1118- max_draft_tokens = max_draft_tokens )
1119+ self .spec_decoding_position_offsets [
1120+ :,
1121+ ].copy_ (spec_tree_manager .spec_dec_position_offsets [0 , :],
1122+ non_blocking = True )
1123+ self .spec_decoding_packed_mask [:, :, :].copy_ (
1124+ spec_tree_manager .spec_dec_packed_mask [0 , :, :],
1125+ non_blocking = True )
1126+ self .spec_decoding_generation_lengths [:].fill_ (
1127+ spec_tree_manager .max_total_draft_tokens + 1 )
11191128 else :
1129+ # Prepare for the linear-tree.
11201130 # Populate the mask that won't change during inference phase.
11211131 self .generate_spec_decoding_position_offsets (
1122- max_draft_tokens = max_draft_tokens )
1132+ max_total_draft_tokens = max_total_draft_tokens )
11231133 self .generate_spec_decoding_packed_mask (
1124- max_draft_tokens = max_draft_tokens )
1134+ max_total_draft_tokens = max_total_draft_tokens )
11251135 self .generate_spec_decoding_generation_length (
1126- max_draft_tokens = max_draft_tokens )
1136+ max_total_draft_tokens = max_total_draft_tokens )
11271137
1128- def generate_spec_decoding_position_offsets (self , max_draft_tokens ):
1129- position_offset = torch .arange (max_draft_tokens + 1 ,
1138+ def generate_spec_decoding_position_offsets (self , max_total_draft_tokens ):
1139+ position_offset = torch .arange (max_total_draft_tokens + 1 ,
11301140 dtype = torch .int ,
11311141 device = 'cpu' ,
11321142 pin_memory = True )
@@ -1135,15 +1145,17 @@ def generate_spec_decoding_position_offsets(self, max_draft_tokens):
11351145 self .spec_decoding_position_offsets .copy_ (position_offset ,
11361146 non_blocking = True )
11371147
1138- def generate_spec_decoding_packed_mask (self , max_draft_tokens ):
1139- dummy_idx = torch .arange (max_draft_tokens + 1 )
1148+ def generate_spec_decoding_packed_mask (self , max_total_draft_tokens ):
1149+ # TODO: fix this limitation
1150+ assert max_total_draft_tokens < 32 , "max_total_draft_tokens should be less than 32, will be fixed later"
1151+ dummy_idx = torch .arange (max_total_draft_tokens + 1 )
11401152 spec_decoding_packed_mask = torch .pow (2 , dummy_idx + 1 ) - 1
11411153 self .spec_decoding_packed_mask [:, :, 0 ].copy_ (spec_decoding_packed_mask ,
11421154 non_blocking = True )
11431155
1144- def generate_spec_decoding_generation_length (self , max_draft_tokens ):
1156+ def generate_spec_decoding_generation_length (self , max_total_draft_tokens ):
11451157 spec_decoding_generation_length = torch .full ((self .max_num_requests , ),
1146- max_draft_tokens + 1 )
1158+ max_total_draft_tokens + 1 )
11471159 self .spec_decoding_generation_lengths [:self .max_num_requests ].copy_ (
11481160 spec_decoding_generation_length , non_blocking = True )
11491161
0 commit comments