@@ -28,6 +28,7 @@ def __init__(
2828 device = torch .device ("cpu" ),
2929 device_gpt = torch .device ("cpu" ),
3030 logger = logging .getLogger (__name__ ),
31+ enable_cache = False ,
3132 ):
3233 super ().__init__ ()
3334
@@ -36,6 +37,8 @@ def __init__(
3637 self .device = device
3738 self .device_gpt = device_gpt
3839
40+ self .enable_cache = enable_cache
41+
3942 self .generator = torch .Generator (device = device )
4043
4144 self .num_vq = int (gpt_config ["num_vq" ])
@@ -142,7 +145,6 @@ def prepare(self, compile=False):
142145 class _GenerationInputs :
143146 position_ids : torch .Tensor
144147 cache_position : torch .Tensor
145- use_cache : bool
146148 input_ids : Optional [torch .Tensor ] = None
147149 past_key_values : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
148150 attention_mask : Optional [torch .Tensor ] = None
@@ -167,7 +169,6 @@ def _prepare_generation_inputs(
167169 inputs_embeds : Optional [torch .Tensor ] = None ,
168170 cache_position : Optional [torch .Tensor ] = None ,
169171 position_ids : Optional [torch .Tensor ] = None ,
170- use_cache = True ,
171172 ) -> _GenerationInputs :
172173 # With static cache, the `past_key_values` is None
173174 # TODO joao: standardize interface for the different Cache classes and remove of this if
@@ -230,8 +231,7 @@ def _prepare_generation_inputs(
230231 and attention_mask is not None
231232 and cache_length + input_ids .shape [1 ] > max_cache_length
232233 ):
233- start_pos = attention_mask .shape [1 ] - max_cache_length
234- attention_mask = attention_mask .narrow (1 , start_pos , max_cache_length )
234+ attention_mask = attention_mask .narrow (1 , - max_cache_length , max_cache_length )
235235
236236 if attention_mask is not None and position_ids is None :
237237 # create position_ids on the fly for batch generation
@@ -258,7 +258,6 @@ def _prepare_generation_inputs(
258258 model_inputs = self ._GenerationInputs (
259259 position_ids = position_ids ,
260260 cache_position = cache_position ,
261- use_cache = use_cache ,
262261 )
263262
264263 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
@@ -399,7 +398,6 @@ def generate(
399398 inputs_ids ,
400399 past_key_values ,
401400 attention_mask_cache .narrow (1 , 0 , inputs_ids .shape [1 ]),
402- use_cache = not self .is_te_llama ,
403401 )
404402
405403 if i > 0 :
@@ -423,7 +421,7 @@ def generate(
423421 position_ids = model_input .position_ids ,
424422 past_key_values = model_input .past_key_values ,
425423 inputs_embeds = model_input .inputs_embeds ,
426- use_cache = model_input . use_cache ,
424+ use_cache = not self . is_te_llama and self . enable_cache ,
427425 output_attentions = return_attn ,
428426 cache_position = model_input .cache_position ,
429427 )
0 commit comments