File tree Expand file tree Collapse file tree 1 file changed +1
-10
lines changed Expand file tree Collapse file tree 1 file changed +1
-10
lines changed Original file line number Diff line number Diff line change @@ -103,16 +103,7 @@ def run_hf_model_on_pytorch(self, model_hf):
103
103
:numpy.ndarray: Generated output tokens
104
104
"""
105
105
input_ids = self .input_handler .tokenizer .encode (self .input_handler .prompt [0 ], return_tensors = "pt" )
106
-
107
- input_ids_len = len (input_ids [0 ])
108
-
109
- for _ in range (self .gen_len ):
110
- outputs = model_hf (input_ids )
111
- logits = outputs .logits [:, - 1 , :]
112
- predicted_token_id = torch .argmax (logits , dim = - 1 )
113
- input_ids = torch .cat ([input_ids , predicted_token_id .unsqueeze (1 )], dim = - 1 )
114
-
115
- generated_ids = input_ids [0 ][input_ids_len :].detach ().numpy ()
106
+ generated_ids = model_hf .generate (input_ids , max_new_tokens = self .gen_len , do_sample = False )[0 ][len (input_ids [0 ]:]
116
107
generated_text = self .input_handler .tokenizer .decode (generated_ids , skip_special_tokens = True )
117
108
print ("Original HF Model Outputs (Torch CPU): \n " )
118
109
print ("Prompt:" , repr (self .input_handler .prompt ))
You can’t perform that action at this time.
0 commit comments