Skip to content

Commit ba938fa

Browse files
authored
fix(ci): unexpected keyword argument streaming (#42102)
* debug(ci): run `pwd` to check what we're working with * fix(ci): `ls -lR` * fix(ci): remove working directory which should not be there? * fix(cb): make sure memory is freed when calling `stop` * fix(ci): effectively clear cache * fix(ci): reduce memory safety margin * refactor(cb): add fixme note on default safety margin value
1 parent 6744ebe commit ba938fa

File tree

4 files changed

+9
-8
lines changed

4 files changed

+9
-8
lines changed

.github/workflows/benchmark.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ jobs:
4040
run: python3 -m pip install -r benchmark_v2/requirements.txt kernels
4141

4242
- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
43-
working-directory: /transformers
4443
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e ".[torch]"
4544

4645
- name: Run benchmark

benchmark_v2/framework/benchmark_runner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,6 @@ def flush_memory():
117117
# Clear CUDA cache
118118
if torch.cuda.is_available():
119119
torch.cuda.empty_cache()
120-
torch.cuda.reset_max_memory_allocated()
121-
torch.cuda.reset_peak_memory_stats()
122120
torch.cuda.synchronize()
123121
gc.collect()
124122

src/transformers/generation/continuous_batching/cache.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ def __init__(
189189
num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens(
190190
num_blocks=getattr(generation_config, "num_blocks", None),
191191
max_batch_tokens=getattr(generation_config, "max_batch_tokens", None),
192-
max_memory_percent=getattr(generation_config, "max_memory", 0.9),
192+
max_memory_percent=getattr(
193+
generation_config, "max_memory", 0.8
194+
), # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
193195
cache_dtype=self.dtype,
194196
)
195197

@@ -414,7 +416,7 @@ def infer_num_blocks_and_max_batch_tokens(
414416
self,
415417
num_blocks: Optional[int] = None,
416418
max_batch_tokens: Optional[int] = None,
417-
max_memory_percent: float = 0.9,
419+
max_memory_percent: float = 0.8, # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
418420
cache_dtype: torch.dtype = torch.float16,
419421
) -> tuple[int, int]:
420422
"""Determine optimal number of blocks and maximum number of tokens per batch based on available memory and
@@ -454,7 +456,7 @@ def infer_num_blocks_and_max_batch_tokens(
454456

455457
def compute_num_blocks_and_max_batch_tokens(
456458
self,
457-
max_memory_percent: float = 0.9,
459+
max_memory_percent: float,
458460
cache_dtype: torch.dtype = torch.float16,
459461
m: float = 0.01,
460462
) -> tuple[int, int]:
@@ -503,7 +505,7 @@ def compute_num_blocks_and_max_batch_tokens(
503505
def compute_max_batch_tokens(
504506
self,
505507
num_blocks: int,
506-
max_memory_percent: float = 0.9,
508+
max_memory_percent: float,
507509
cache_dtype: torch.dtype = torch.float16,
508510
) -> int:
509511
"""Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by:
@@ -531,7 +533,7 @@ def compute_max_batch_tokens(
531533
def compute_num_blocks(
532534
self,
533535
max_batch_tokens: int,
534-
max_memory_percent: float = 0.9,
536+
max_memory_percent: float,
535537
cache_dtype: torch.dtype = torch.float16,
536538
) -> int:
537539
"""Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by:

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,8 @@ def stop(self, block: bool = True, timeout: Optional[float] = None) -> None:
826826
if block:
827827
self.join(stop_trigger_time, timeout)
828828

829+
self.batch_processor = None
830+
829831
def join(self, stop_trigger_time: float, timeout: Optional[float] = None) -> None:
830832
"""Wait for the background thread to finish.
831833

0 commit comments

Comments
 (0)