Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions megatron/core/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Megatron optimizer."""

import copy
import logging
import math
import warnings
from abc import ABC, abstractmethod
Expand Down
8 changes: 7 additions & 1 deletion megatron/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2088,13 +2088,19 @@ def maybe_cat(a, b, dim=0, *, required=False):
return xs[0] if len(xs) == 1 else torch.cat(xs, dim=dim)


_ASYNC_IO_LOOP: asyncio.AbstractEventLoop | None = None


def get_asyncio_loop(loop: asyncio.AbstractEventLoop | None = None) -> asyncio.AbstractEventLoop:
"""Creates an asyncio loop if necessary and then returns the current asyncio loop."""
global _ASYNC_IO_LOOP
if loop is None:
if _ASYNC_IO_LOOP is not None:
return _ASYNC_IO_LOOP
try:
loop = asyncio.get_running_loop()
except RuntimeError as e:
loop = asyncio.new_event_loop()
_ASYNC_IO_LOOP = loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop

Expand Down
84 changes: 48 additions & 36 deletions megatron/rl/inference/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from argparse import Namespace

from pydantic import PrivateAttr
import torch.distributed as dist

from megatron.core import parallel_state
from megatron.core.inference.inference_client import InferenceClient
from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext
from megatron.core.inference.coordinator import DynamicEngineCoordinator
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.engines.dynamic_engine import DynamicInferenceEngine
from megatron.core.inference.engines.mcore_engine import MCoreEngine
Expand All @@ -27,6 +28,7 @@
from megatron.core.transformer.module import MegatronModule
from megatron.core.utils import get_mamba_inference_state_config_from_model, log_single_rank
from megatron.training.global_vars import get_args, get_tokenizer
from megatron.training import get_wandb_writer

from ..inference.inference_interface import (
ChatInferenceInterface,
Expand Down Expand Up @@ -103,9 +105,7 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
"""
tokenizer = get_tokenizer()

num_cuda_graphs = None
if args.enable_cuda_graph:
num_cuda_graphs = args.inference_dynamic_batching_num_cuda_graphs
enable_cuda_graph = args.cuda_graph_impl == "local"

mamba_inference_state_config = get_mamba_inference_state_config_from_model(model)

Expand All @@ -118,17 +118,23 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
args.num_query_groups if args.group_query_attention else args.num_attention_heads
),
max_sequence_length=args.inference_max_seq_length,
num_cuda_graphs=num_cuda_graphs,
num_cuda_graphs=(
args.inference_dynamic_batching_num_cuda_graphs
if enable_cuda_graph
else None
),
block_size_tokens=args.inference_dynamic_batching_block_size,
buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb,
buffer_guaranteed_fraction=args.inference_dynamic_batching_buffer_guaranteed_fraction,
chunk_size_tokens=args.inference_dynamic_batching_chunk_size,
buffer_overflow_factor=args.inference_dynamic_batching_buffer_overflow_factor,
max_requests_override=args.inference_dynamic_batching_max_requests_override,
max_tokens_override=args.inference_dynamic_batching_max_tokens_override,
max_tokens=args.inference_dynamic_batching_max_tokens,
tensor_model_parallel_size=args.tensor_model_parallel_size,
materialize_only_last_token_logits=True,
unified_memory_kvcache=args.inference_dynamic_batching_unified_memory_kvcache,
mamba_inference_state_config=mamba_inference_state_config,
cache_mla_latent=args.multi_latent_attention and args.cache_mla_latents,
kv_lora_rank=args.kv_lora_rank if args.multi_latent_attention else None,
qk_pos_emb_head_dim=args.qk_pos_emb_head_dim,
use_cuda_graphs_for_non_decode_steps=not args.decode_only_cuda_graphs,
use_flashinfer_fused_rope=None,
unified_memory_level=args.inference_dynamic_batching_unified_memory_level,
metrics_writer=metrics_writer,
)

Expand All @@ -145,7 +151,7 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
return DynamicInferenceEngine(
controller=text_generation_controller,
context=inference_context,
enable_cuda_graph=args.enable_cuda_graph,
enable_cuda_graph=enable_cuda_graph,
random_seed=args.seed,
inference_logging_step_interval=inference_logging_step_interval,
)
Expand All @@ -154,9 +160,8 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
class MegatronLocal(InferenceServer, ReturnsTokens, ReturnsRaw):
"""Interface to use MCoreEngine directly as an inference engine."""

_coordinator: DynamicEngineCoordinator = PrivateAttr(None)
_engine_task: asyncio.Task = PrivateAttr(None)
_kill_engine: bool = PrivateAttr(False)
_client: InferenceClient = PrivateAttr(None)
_inference_engine: DynamicInferenceEngine = PrivateAttr(None)

async def base_generate(self, request: InferenceRequest):

Expand All @@ -169,24 +174,27 @@ async def base_generate(self, request: InferenceRequest):
isinstance(p, str) for p in request.prompt
), "MegatronLocal only supports string prompts."

assert self._client is not None, "Client is not initialized"

tokenizer = get_tokenizer()

sampling_params = SamplingParams(
num_tokens_to_generate=request.generation_args.max_tokens or 1024,
num_tokens_to_generate=None,
num_tokens_total=request.generation_args.max_tokens,
temperature=request.generation_args.temperature or 1.0,
top_k=request.generation_args.top_k or 0,
top_p=request.generation_args.top_p or 0.0,
termination_id=self._coordinator.engine.controller.tokenizer.eod,
termination_id=self._inference_engine.controller.tokenizer.eod,
return_log_probs=True,
skip_prompt_log_probs=True,
add_BOS=tokenizer.bos is not None,
)
request_ids = [
self._coordinator.schedule_request(prompt=prompt, sampling_params=sampling_params)
requests = [
self._client.add_request(prompt=prompt, sampling_params=sampling_params)
for prompt in request.prompt
]
responses = await asyncio.gather(
*[self._coordinator.get_response(id) for id in request_ids]
*requests
)
return [
InferenceResponse(
Expand Down Expand Up @@ -224,28 +232,32 @@ async def launch(cls, model: GPTModel, **kwargs):
"wandb module is available. Inference logging will be disabled.")

inference_engine: DynamicInferenceEngine = get_dynamic_inference_engine(args, model, inference_logging_step_interval, metrics_writer)
coordinator = DynamicEngineCoordinator(
inference_engine,
inference_max_requests=inference_engine.context.max_requests,
log_level=0,
)
await inference_engine.start_listening_to_data_parallel_coordinator(inference_coordinator_port=41521, launch_inference_coordinator=True)
if dist.get_rank() == 0:
# TODO: We have to do this only on the rank 0 process, should be fixed in the future when we have support for multiple inference clients. !2278
client = InferenceClient(inference_coordinator_port=41521)
await client.start()
else:
client = None
launched_server = cls(**kwargs)
launched_server._coordinator = coordinator

loop = asyncio.get_running_loop()

coordinator.startup(loop)
launched_server._client = client
launched_server._inference_engine = inference_engine

return launched_server

async def kill(self):
await self._coordinator.shutdown()
if dist.get_rank() == 0:
await self._client.stop_engines()
await self._inference_engine.stopped.wait()

async def suspend(self):
await self._coordinator.suspend_engine()

def resume(self):
self._coordinator.resume_engine()

if dist.get_rank() == 0:
self._client.pause_engines()
#await self._inference_engine.paused.wait()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we will need these uncommented in the final version. Right Teo?


async def resume(self):
if dist.get_rank() == 0:
self._client.unpause_engines()
#await self._inference_engine.running.wait()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?


class MegatronChatLocal(ChatInferenceInterface, MegatronLocal): ...
Loading
Loading