Skip to content

Commit 26038d3

Browse files
committed
Update run utils
Signed-off-by: Mamta Singh <[email protected]>
1 parent b400ff2 commit 26038d3

File tree

1 file changed

+1
-10
lines changed

1 file changed

+1
-10
lines changed

QEfficient/utils/run_utils.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,7 @@ def run_hf_model_on_pytorch(self, model_hf):
103103
:numpy.ndarray: Generated output tokens
104104
"""
105105
input_ids = self.input_handler.tokenizer.encode(self.input_handler.prompt[0], return_tensors="pt")
106-
107-
input_ids_len = len(input_ids[0])
108-
109-
for _ in range(self.gen_len):
110-
outputs = model_hf(input_ids)
111-
logits = outputs.logits[:, -1, :]
112-
predicted_token_id = torch.argmax(logits, dim=-1)
113-
input_ids = torch.cat([input_ids, predicted_token_id.unsqueeze(1)], dim=-1)
114-
115-
generated_ids = input_ids[0][input_ids_len:].detach().numpy()
106+
generated_ids = model_hf.generate(input_ids, max_new_tokens=self.gen_len, do_sample=False)[0][len(input_ids[0]:]
116107
generated_text = self.input_handler.tokenizer.decode(generated_ids, skip_special_tokens=True)
117108
print("Original HF Model Outputs (Torch CPU): \n")
118109
print("Prompt:", repr(self.input_handler.prompt))

0 commit comments

Comments
 (0)