@@ -1889,9 +1889,6 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
18891889
18901890 for req in scheduled_batch .generation_requests :
18911891 if req .is_disagg_generation_transmission_complete :
1892- print (
1893- "[PyExecutor::_prepare_disagg_gen_transmission_complete]: TRANSMISSION COMPLETE for request ID: " ,
1894- req .py_request_id )
18951892 req .state = LlmRequestState .GENERATION_IN_PROGRESS
18961893 req .context_current_position = req .prompt_len
18971894 req .decoding_iter = 1
@@ -1903,9 +1900,6 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
19031900 beam_width = req .sampling_config .beam_width
19041901
19051902 for beam in range (0 , beam_width ):
1906- print (
1907- f"[PyExecutor::_prepare_disagg_gen_transmission_complete]: Adding new token { torch .tensor (first_gen_tokens [beam ])} for beam { beam } ."
1908- )
19091903 req .add_new_token (first_gen_tokens [beam ], beam )
19101904
19111905 @nvtx_range ("_recv_disagg_gen_cache" )
@@ -2001,24 +1995,12 @@ def _forward_step(self,
20011995 )
20021996 def forward (scheduled_requests , resource_manager , new_tensors_device ,
20031997 gather_context_logits , cache_indirection_buffer ):
2004- # iter_begin = time.time()
2005- result = self .model_engine .forward (
1998+ return self .model_engine .forward (
20061999 scheduled_requests ,
20072000 resource_manager ,
20082001 new_tensors_device ,
20092002 gather_context_logits = gather_context_logits ,
20102003 cache_indirection_buffer = cache_indirection_buffer )
2011- # torch.cuda.synchronize()
2012- # iter_end = time.time()
2013- # iter_latency_ms = (iter_end - iter_begin) * 1e3
2014- # if self.model_engine.iter_counter > 10 and self.dist.rank == 0:
2015- # logger.info(f"[PyExecutor::_forward_step] CUSTOM LOG: iter={self.model_engine.iter_counter}, "
2016- # f"rank={self.dist.rank}, "
2017- # f"active_requests={len(self.active_requests)}, "
2018- # f"scheduled_generation_requests={len(scheduled_requests.generation_requests)}, "
2019- # f"scheduled_batch_size={scheduled_requests.batch_size}, "
2020- # f"iter_latency_ms={iter_latency_ms}ms")
2021- return result
20222004
20232005 try :
20242006 gather_context_logits = any (
@@ -2085,8 +2067,7 @@ def _update_request_states_star_attention(
20852067 @nvtx_range ("_update_request_states" )
20862068 def _update_request_states (self , scheduled_requests : ScheduledRequests ):
20872069 cp_config = self .dist .cp_config
2088- # note: helix parallelism uses the same logic as tp parallelism here
2089- if 'cp_type' in cp_config and cp_config ['cp_type' ] != CpType .HELIX :
2070+ if 'cp_type' in cp_config :
20902071 cp_type = cp_config ['cp_type' ]
20912072 if cp_type == CpType .STAR :
20922073 self ._update_request_states_star_attention (scheduled_requests )
0 commit comments