Skip to content

Commit afeb173

Browse files
committed
refactor(cb): add fixme note on default safety margin value
1 parent 7923cf6 commit afeb173

File tree

1 file changed

+7
-5
lines changed
  • src/transformers/generation/continuous_batching

1 file changed

+7
-5
lines changed

src/transformers/generation/continuous_batching/cache.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ def __init__(
189189
num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens(
190190
num_blocks=getattr(generation_config, "num_blocks", None),
191191
max_batch_tokens=getattr(generation_config, "max_batch_tokens", None),
192-
max_memory_percent=getattr(generation_config, "max_memory", 0.8),
192+
max_memory_percent=getattr(
193+
generation_config, "max_memory", 0.8
194+
), # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
193195
cache_dtype=self.dtype,
194196
)
195197

@@ -414,7 +416,7 @@ def infer_num_blocks_and_max_batch_tokens(
414416
self,
415417
num_blocks: Optional[int] = None,
416418
max_batch_tokens: Optional[int] = None,
417-
max_memory_percent: float = 0.9,
419+
max_memory_percent: float = 0.8, # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
418420
cache_dtype: torch.dtype = torch.float16,
419421
) -> tuple[int, int]:
420422
"""Determine optimal number of blocks and maximum number of tokens per batch based on available memory and
@@ -454,7 +456,7 @@ def infer_num_blocks_and_max_batch_tokens(
454456

455457
def compute_num_blocks_and_max_batch_tokens(
456458
self,
457-
max_memory_percent: float = 0.9,
459+
max_memory_percent: float,
458460
cache_dtype: torch.dtype = torch.float16,
459461
m: float = 0.01,
460462
) -> tuple[int, int]:
@@ -503,7 +505,7 @@ def compute_num_blocks_and_max_batch_tokens(
503505
def compute_max_batch_tokens(
504506
self,
505507
num_blocks: int,
506-
max_memory_percent: float = 0.9,
508+
max_memory_percent: float,
507509
cache_dtype: torch.dtype = torch.float16,
508510
) -> int:
509511
"""Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by:
@@ -531,7 +533,7 @@ def compute_max_batch_tokens(
531533
def compute_num_blocks(
532534
self,
533535
max_batch_tokens: int,
534-
max_memory_percent: float = 0.9,
536+
max_memory_percent: float,
535537
cache_dtype: torch.dtype = torch.float16,
536538
) -> int:
537539
"""Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by:

0 commit comments

Comments
 (0)