@@ -189,7 +189,9 @@ def __init__(
189189 num_blocks , max_batch_tokens = memory_handler .infer_num_blocks_and_max_batch_tokens (
190190 num_blocks = getattr (generation_config , "num_blocks" , None ),
191191 max_batch_tokens = getattr (generation_config , "max_batch_tokens" , None ),
192- max_memory_percent = getattr (generation_config , "max_memory" , 0.8 ),
192+ max_memory_percent = getattr (
193+ generation_config , "max_memory" , 0.8
194+ ), # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
193195 cache_dtype = self .dtype ,
194196 )
195197
@@ -414,7 +416,7 @@ def infer_num_blocks_and_max_batch_tokens(
414416 self ,
415417 num_blocks : Optional [int ] = None ,
416418 max_batch_tokens : Optional [int ] = None ,
417- max_memory_percent : float = 0.9 ,
419+ max_memory_percent : float = 0.8 , # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
418420 cache_dtype : torch .dtype = torch .float16 ,
419421 ) -> tuple [int , int ]:
420422 """Determine optimal number of blocks and maximum number of tokens per batch based on available memory and
@@ -454,7 +456,7 @@ def infer_num_blocks_and_max_batch_tokens(
454456
455457 def compute_num_blocks_and_max_batch_tokens (
456458 self ,
457- max_memory_percent : float = 0.9 ,
459+ max_memory_percent : float ,
458460 cache_dtype : torch .dtype = torch .float16 ,
459461 m : float = 0.01 ,
460462 ) -> tuple [int , int ]:
@@ -503,7 +505,7 @@ def compute_num_blocks_and_max_batch_tokens(
503505 def compute_max_batch_tokens (
504506 self ,
505507 num_blocks : int ,
506- max_memory_percent : float = 0.9 ,
508+ max_memory_percent : float ,
507509 cache_dtype : torch .dtype = torch .float16 ,
508510 ) -> int :
509511 """Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by:
@@ -531,7 +533,7 @@ def compute_max_batch_tokens(
531533 def compute_num_blocks (
532534 self ,
533535 max_batch_tokens : int ,
534- max_memory_percent : float = 0.9 ,
536+ max_memory_percent : float ,
535537 cache_dtype : torch .dtype = torch .float16 ,
536538 ) -> int :
537539 """Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by:
0 commit comments