@@ -206,23 +206,19 @@ def test_prepare_and_convert_on_llm(self, force_not_import_ipex):
206
206
model_name = "facebook/opt-125m"
207
207
model = AutoModelForCausalLM .from_pretrained (model_name )
208
208
tokenizer = AutoTokenizer .from_pretrained (model_name )
209
- input_ids = tokenizer ("Hello, my dog is cute" , return_tensors = "pt" )["input_ids" ]
209
+ model = AutoModelForCausalLM .from_pretrained (model_name )
210
+ model_config = model .config
211
+ tokenizer = AutoTokenizer .from_pretrained (model_name )
212
+ inputs = tokenizer ("Hello, my dog is cute" , return_tensors = "pt" )
210
213
# example_inputs = (input_ids,)
211
- # model = export(model, example_inputs=example_inputs)
214
+ # model = export_model_for_pt2e_quant(model, example_inputs=example_inputs)
215
+ attention_mask = inputs .attention_mask
216
+ input_ids = inputs .input_ids
217
+
218
+
219
+ from transformers .integrations .executorch import export_with_dynamic_cache
212
220
from transformers import DynamicCache
213
- example_inputs = {
214
- "input_ids" : input_ids ,
215
- "attention_mask" : None ,
216
- "past_key_values" : DynamicCache (),
217
- "use_cache" : True ,
218
- }
219
- with torch .no_grad ():
220
- ep = torch .export .export_for_training (
221
- model ,
222
- (),
223
- example_inputs ,
224
- strict = False ,
225
- )
221
+ ep = export_with_dynamic_cache (model , input_ids , attention_mask )
226
222
model = ep .module ()
227
223
model ._exported = True
228
224
model .dynamic_shapes = None
@@ -232,15 +228,25 @@ def test_prepare_and_convert_on_llm(self, force_not_import_ipex):
232
228
prepare_model = prepare (model , quant_config )
233
229
# calibrate
234
230
for i in range (2 ):
235
- prepare_model (** example_inputs )
231
+ prepare_model (
232
+ input_ids = input_ids ,
233
+ attention_mask = attention_mask ,
234
+ past_key_values = DynamicCache (config = model_config ),
235
+ use_cache = True ,
236
+ )
236
237
# convert
237
238
converted_model = convert (prepare_model )
238
239
# inference
239
240
from torch ._inductor import config
240
241
241
242
config .freezing = True
242
243
opt_model = torch .compile (converted_model )
243
- out = opt_model (** example_inputs )
244
+ out = opt_model (
245
+ input_ids = input_ids ,
246
+ attention_mask = attention_mask ,
247
+ past_key_values = DynamicCache (config = model_config ),
248
+ use_cache = True ,
249
+ )
244
250
assert out .logits is not None
245
251
246
252
@staticmethod
0 commit comments