diff --git a/LICENSE b/LICENSE index 8ebf500497..96bbfb9f76 100644 --- a/LICENSE +++ b/LICENSE @@ -270,3 +270,34 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +-------------------------------------------------------------------------------- +LICENSE FOR +Meta Platforms, Inc. and affiliates. + +BSD License + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Meta nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py b/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py index 7869002fff..cf295b5b7e 100644 --- a/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py +++ b/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py @@ -201,10 +201,6 @@ async def main( args_defaults={'no_load_rng': True, 'no_load_optim': True}, ) - # Start Nsight profiler. - if os.environ.get("NSIGHT_PREFIX"): - torch.cuda.cudart().cudaProfilerStart() - args = get_args() tokenizer = get_tokenizer() @@ -251,6 +247,10 @@ async def main( print(setup_prefix) print("~~~") + # Start Nsight profiler. + if os.environ.get("NSIGHT_PREFIX"): + torch.cuda.cudart().cudaProfilerStart() + asyncio.run( main( engine, @@ -258,3 +258,7 @@ async def main( args.inference_coordinator_port, ) ) + + # Stop Nsight profiler. + if os.environ.get("NSIGHT_PREFIX"): + torch.cuda.cudart().cudaProfilerStop() diff --git a/mamba_builders.py b/mamba_builders.py index 0ccfc29b86..a39dfb84b3 100644 --- a/mamba_builders.py +++ b/mamba_builders.py @@ -6,7 +6,7 @@ from megatron.core.transformer.spec_utils import import_module from megatron.training import print_rank_0 from megatron.training.arguments import core_transformer_config_from_args - +from megatron.core.models.mamba.mamba_layer_specs import mamba_inference_stack_spec def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None): print_rank_0('building MAMBA model ...') @@ -14,7 +14,9 @@ def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None): config = core_transformer_config_from_args(args, TransformerConfig) assert args.use_legacy_models is False, "Mamba only supported in Mcore!" - if args.spec is not None: + if config.transformer_impl == "inference_optimized": + mamba_stack_spec = mamba_inference_stack_spec + elif args.spec is not None: mamba_stack_spec = import_module(args.spec) else: raise ValueError("You must provide a valid Mamba layer spec via --spec") diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 59082ea7dc..637092b468 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -170,6 +170,9 @@ def _calculate_cuda_graph_token_counts( # Make sure divisible by TP size cuda_graph_step_size = math.ceil(cuda_graph_step_size / tp_size) * tp_size + # round down cuda graph max tokens to be multiple of TP size + cuda_graph_max_tokens = (cuda_graph_max_tokens // tp_size) * tp_size + # Cuda graph token counts. if num_cuda_graphs == 1: cuda_graph_token_counts = [cuda_graph_max_tokens] diff --git a/megatron/core/inference/communication/torch_symm_triton/__init__.py b/megatron/core/inference/communication/torch_symm_triton/__init__.py new file mode 100644 index 0000000000..17e42a6776 --- /dev/null +++ b/megatron/core/inference/communication/torch_symm_triton/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from .collectives import multimem_all_gather, multimem_reduce_scatter diff --git a/megatron/core/inference/communication/torch_symm_triton/barrier.py b/megatron/core/inference/communication/torch_symm_triton/barrier.py new file mode 100644 index 0000000000..d26b094828 --- /dev/null +++ b/megatron/core/inference/communication/torch_symm_triton/barrier.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Adapted from: https://github.com/meta-pytorch/kraken.git + +from unittest.mock import MagicMock + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl +except ImportError: + triton = MagicMock() + tl = MagicMock() + triton.jit = null_decorator + +from .utils import get_flat_bid, get_flat_tid + + +@triton.jit +def _send_signal(addrs, sem: tl.constexpr): + tl.inline_asm_elementwise( + f""" + {{ + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + send_signal: + atom.global.{sem}.sys.cas.b32 %tmp32_0, [$1], 0, 1; + setp.eq.u32 %p0, %tmp32_0, 0; + @!%p0 bra send_signal; + }} + """, + "=r, l", + [addrs], + dtype=addrs.dtype, + is_pure=False, + pack=1, + ) + + +@triton.jit +def _wait_signal(addrs, sem: tl.constexpr): + tl.inline_asm_elementwise( + f""" + {{ + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + wait_signal: + atom.global.sys.{sem}.cas.b32 %tmp32_0, [$1], 1, 0; + setp.eq.u32 %p0, %tmp32_0, 1; + @!%p0 bra wait_signal; + }} + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + + +@triton.jit +def symm_mem_sync( + signal_pad_ptrs, + block_id, + rank: tl.constexpr, + world_size: tl.constexpr, + hasPreviousMemAccess: tl.constexpr = False, + hasSubsequentMemAccess: tl.constexpr = False, +): + """ + Synchronizes blocks with matching block_id across participating devices. + + Note: the function itself is not a system level barrier/fence. It is a + building block for expressing different synchronization patterns. + + Pattern 0: Ensures that all writes to symm_mem buffers from previous + kernels across all devices are visible to the current kernel: + + symm_mem_sync(..., hasPreviousMemAccess=False, hasSubsequentMemAccess=True) + + Pattern 1: Ensures that all writes to symm_mem buffers from the current + block are visible to all remote blocks with matching blockIdx: + + symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequentMemAccess=True) + + Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe + for writing by subsequent kernels across all devices. + + symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequentMemAccess=False) + + CUDA graph friendliness: + + This barrier operates through atomic operations on a zero-filled signal + pad, which resets to a zero-filled state after each successful + synchronization. This design eliminates the need for incrementing a + flag from host. + """ + if block_id is None: + block_id = get_flat_bid() + flat_tid = get_flat_tid() + + remote_ranks = tl.arange(0, world_size) + signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64)) + remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(tl.pointer_type(tl.uint32)) + send_addrs = remote_signal_pad_addrs + block_id * world_size + rank + + local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(tl.pointer_type(tl.uint32)) + wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks + + if flat_tid < world_size: + _send_signal(send_addrs, "release" if hasPreviousMemAccess else "relaxed") + _wait_signal(wait_addrs, "acquire" if hasSubsequentMemAccess else "relaxed") diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py new file mode 100644 index 0000000000..4bc4dbde42 --- /dev/null +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -0,0 +1,231 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from unittest.mock import MagicMock + +import torch + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True +except ImportError: + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + HAVE_TRITON = False +try: + from torch._C._distributed_c10d import _SymmetricMemory +except ImportError: + _SymmetricMemory = MagicMock() + +from .barrier import symm_mem_sync +from .multimem_asm import ld_128, st_128 +from .utils import get_flat_tid, sync_threads + + +@triton.jit +def _multimem_all_gather_kernel( + local_ptr, + multicast_ptr, + signal_pad_ptrs, + numel, + BLOCK_SIZE: tl.constexpr, + NUMEL_PER_THREAD: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, +): + """ + Triton kernel to perform multicast all-gather over nvlink using multimem instructions. + """ + # an all-gather is simply a multicast store operation + # we only need a barrier at the end to ensure visibility of writes + + pid = tl.program_id(axis=0) + tid = get_flat_tid() + + # From this point on, we pretend each element is 128-bit + numel = numel // NUMEL_PER_THREAD + numel_per_rank = tl.cdiv(numel, WORLD_SIZE) + block_start = pid * BLOCK_SIZE + + while block_start < numel_per_rank: + offsets = block_start + tid + mask = offsets < numel_per_rank + + # Each pointer points to a 128-bit bit pack + # RANK * numel_per_rank -> brings us to the start of our rank's segment + # offsets -> brings us to the right offset within our rank's segment + multicast_ptrs = ( + multicast_ptr.to(tl.pointer_type(tl.uint64)) + (RANK * numel_per_rank + offsets) * 2 + ) + local_ptrs = local_ptr.to(tl.pointer_type(tl.uint64)) + offsets * 2 + (x, y, z, w) = ld_128(local_ptrs, mask=mask, multicast_op=False) + st_128(multicast_ptrs, x, y, z, w, mask=mask, multicast_op=True) + + block_start += tl.num_programs(axis=0) * BLOCK_SIZE + + sync_threads() + symm_mem_sync( + signal_pad_ptrs, + None, + RANK, + WORLD_SIZE, + hasPreviousMemAccess=True, + hasSubsequentMemAccess=True, + ) + + +def multimem_all_gather( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + symm_mem_hdl: _SymmetricMemory, + **kwargs, +) -> torch.Tensor: + """ + Calls a multicast all-gather triton kernel on the given tensor. + Output tensor must be a symmetric memory buffer. + Input tensor can be a regular torch tensor + Arguments: + output_tensor: torch.Tensor - output tensor to be all-gathered into + input_tensor: torch.Tensor - input tensor to be all-gathered from + symm_mem_hdl: _SymmetricMemory - handle to the symmetric memory buffer for output_tensor + Returns: + torch.Tensor - all-gathered tensor, which is output_tensor + """ + assert HAVE_TRITON, "Triton is required for multimem all-gather." + + config = { + "max_num_blocks": kwargs.get("max_num_blocks", 24), + "num_warps": kwargs.get("num_warps", 32), + "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), + } + assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + numel_per_thread = 128 // (input_tensor.element_size() * 8) + + assert ( + output_tensor.numel() % numel_per_thread == 0 + ), "The number of elements must be 128-bit aligned." + + num_threads = triton.cdiv(output_tensor.numel() // numel_per_thread, symm_mem_hdl.world_size) + num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) + + _multimem_all_gather_kernel[(num_blocks, 1, 1)]( + input_tensor.data_ptr(), + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + numel=output_tensor.numel(), + BLOCK_SIZE=config["BLOCK_SIZE"], + NUMEL_PER_THREAD=numel_per_thread, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + num_warps=config["num_warps"], + ) + + return output_tensor + + +@triton.jit +def _multimem_reduce_scatter_kernel( + local_ptr, + multicast_ptr, + signal_pad_ptrs, + numel, + BLOCK_SIZE: tl.constexpr, + NUMEL_PER_THREAD: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, +): + """ + Triton kernel to perform multicast reduce-scatter over nvlink using multimem instructions. + """ + symm_mem_sync( + signal_pad_ptrs, + None, + RANK, + WORLD_SIZE, + hasPreviousMemAccess=False, + hasSubsequentMemAccess=False, + ) + sync_threads() + + pid = tl.program_id(axis=0) + tid = get_flat_tid() + + # From this point on, we pretend each element is 128-bit + numel = numel // NUMEL_PER_THREAD + numel_per_rank = tl.cdiv(numel, WORLD_SIZE) + block_start = pid * BLOCK_SIZE + + while block_start < numel_per_rank: + offsets = block_start + tid + mask = offsets < numel_per_rank + + # Each pointer points to a 128-bit bit pack + multicast_ptrs = ( + multicast_ptr.to(tl.pointer_type(tl.uint64)) + (RANK * numel_per_rank + offsets) * 2 + ) + local_ptrs = local_ptr.to(tl.pointer_type(tl.uint64)) + offsets * 2 + (x, y, z, w) = ld_128(multicast_ptrs, mask=mask, multicast_op=True) + st_128(local_ptrs, x, y, z, w, mask=mask, multicast_op=False) + + block_start += tl.num_programs(axis=0) * BLOCK_SIZE + + +def multimem_reduce_scatter( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + symm_mem_hdl: _SymmetricMemory, + **kwargs, +) -> torch.Tensor: + """ + Calls a multicast reduce-scatter triton kernel on the given tensor. + Input tensor must be a symmetric memory buffer. + Output tensor can be a regular torch tensor + Arguments: + output_tensor: torch.Tensor - output tensor to be reduce-scattered into + input_tensor: torch.Tensor - input tensor to be reduce-scattered from + symm_mem_hdl: _SymmetricMemory - handle to the symmetric memory buffer for input_tensor + **kwargs: Additional keyword arguments for kernel configuration: + max_num_blocks (int, optional): The maximum number of blocks to launch. + num_warps (int, optional): The number of warps per block. + BLOCK_SIZE (int, optional): The BLOCK_SIZE parameter for the kernel. + Returns: + torch.Tensor - reduce-scattered tensor, which is output_tensor + """ + + assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." + + config = { + "max_num_blocks": kwargs.get("max_num_blocks", 24), + "num_warps": kwargs.get("num_warps", 32), + "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), + } + + assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + numel_per_thread = 128 // (output_tensor.element_size() * 8) + + assert ( + input_tensor.numel() % numel_per_thread == 0 + ), "The number of elements must be 128-bit aligned." + + num_threads = triton.cdiv(input_tensor.numel() // numel_per_thread, symm_mem_hdl.world_size) + num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) + + _multimem_reduce_scatter_kernel[(num_blocks, 1, 1)]( + output_tensor.data_ptr(), + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + numel=input_tensor.numel(), + BLOCK_SIZE=config["BLOCK_SIZE"], + NUMEL_PER_THREAD=numel_per_thread, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + num_warps=config["num_warps"], + ) + + return output_tensor diff --git a/megatron/core/inference/communication/torch_symm_triton/multimem_asm.py b/megatron/core/inference/communication/torch_symm_triton/multimem_asm.py new file mode 100644 index 0000000000..cf85ce57f6 --- /dev/null +++ b/megatron/core/inference/communication/torch_symm_triton/multimem_asm.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Adapted from https://github.com/yifuwang/symm-mem-recipes.git + +from unittest.mock import MagicMock + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl +except ImportError: + triton = MagicMock() + tl = MagicMock() + triton.jit = null_decorator + + +@triton.jit +def ld_128(ptr, mask, multicast_op: tl.constexpr): + """ + Loads 128 bits (8 x bf16) from memory into registers. + + This function abstracts two distinct hardware behaviors based on `multicast_op`: + + 1. **Standard Load (`multicast_op=False`)**: + - **Semantics:** Local Global Memory Load. + - **Action:** Reads 128 bits from `ptr` in global memory into the local register file. + - **Use Case:** Standard tensor processing. + + 2. **Multicast Reduce-Load (`multicast_op=True`)**: + - **Semantics:** "Pull" Reduction over NVLink. + - **Action:** Simultaneously reads 128 bits from the *same* address across all peer GPUs + in the multicast group, sums them (add reduction), and loads the result into the + local register file. + - **Hardware:** Uses `multimem.ld_reduce` (Hopper+). + - **Use Case:** The "Reduce" step in collective operations. + + Args: + ptr: Memory pointer to the source buffer. + mask: Boolean predicate. If False, the operation is skipped (no-op). + multicast_op (tl.constexpr): Toggles between standard load (False) + and multicast-reduce (True). + + Returns: + Four 32-bit registers (tl.uint32), representing 128 bits of loaded data. + Note: When interpreting as bf16, this equates to 8 values (2 per register). + """ + # PTX Assembly Logic: + # 1. @$5: Predication. Only execute if argument 5 (mask) is True (1). + # 2. Opcode Selection: + # - 'multimem.ld_reduce...add.v4.bf16x2': Hardware-accelerated reduction across peers. + # - 'ld.global...v4.u32': Standard 128-bit memory read. + # 3. Operands: + # - {$0, $1, $2, $3}: Destination registers (Output). + # - [$4]: Source memory address (Input). + if multicast_op: + return tl.inline_asm_elementwise( + """ + { + .reg .pred %p0; + setp.ne.s32 %p0, $5, 1; + @%p0 bra end; + multimem.ld_reduce.relaxed.sys.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4]; + end: + } + """, + "=r,=r,=r,=r,l,r", + args=[ptr, mask.to(tl.int32)], + dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + else: + return tl.inline_asm_elementwise( + """ + { + .reg .pred %p0; + setp.ne.s32 %p0, $5, 1; + @%p0 bra end; + ld.global.v4.u32 {$0, $1, $2, $3}, [$4]; + end: + } + """, + "=r,=r,=r,=r,l,r", + args=[ptr, mask.to(tl.int32)], + dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + + +@triton.jit +def st_128(ptr, x, y, z, w, mask, multicast_op): + """ + Stores 128 bits (8 x bf16) from registers to memory. + + This function abstracts two distinct hardware behaviors based on `multicast_op`: + + 1. **Standard Store (`multicast_op=False`)**: + - **Semantics:** Local Global Memory Store. + - **Action:** Writes 128 bits from local registers to `ptr` in global memory. + + 2. **Multicast Store (`multicast_op=True`)**: + - **Semantics:** "Push" Broadcast over NVLink. + - **Action:** Writes 128 bits from local registers to the `ptr` address in + the global memory of **all** peer GPUs in the multicast group simultaneously. + - **Hardware:** Uses `multimem.st` (Hopper+). + - **Use Case:** The "Broadcast" or "All-Gather" step in collective operations. + + Args: + ptr: Memory pointer to the destination buffer. + x, y, z, w: Four 32-bit registers containing the data to store. + mask: Boolean predicate. If False, the store is skipped. + multicast_op (tl.constexpr): Toggles between standard store (False) + and multicast broadcast (True). + """ + # PTX Assembly Logic: + # 1. @$6: Predication. Only execute if argument 6 (mask) is True. + # 2. Opcode Selection: + # - 'multimem.st...v4.f32': Broadcasts data to all peers. + # (Note: .f32 type used for bit-movement, equivalent to .u32 for storage). + # - 'st.global...v4.u32': Standard 128-bit memory write. + # 3. Operands: + # - [$1]: Destination memory address. + # - {$2, $3, $4, $5}: Source registers containing data. + if multicast_op: + return tl.inline_asm_elementwise( + """ + { + .reg .pred %p0; + setp.ne.s32 %p0, $6, 1; + @%p0 bra end; + multimem.st.relaxed.sys.global.v4.f32 [$1], {$2, $3, $4, $5}; + end: + } + """, + "=r,l,r,r,r,r,r", + args=[ptr, x, y, z, w, mask.to(tl.int32)], + dtype=(tl.uint32), + is_pure=False, + pack=1, + ) + else: + return tl.inline_asm_elementwise( + """ + { + .reg .pred %p0; + setp.ne.s32 %p0, $6, 1; + @%p0 bra end; + st.global.v4.f32 [$1], {$2, $3, $4, $5}; + end: + } + """, + "=r,l,r,r,r,r,r", + args=[ptr, x, y, z, w, mask.to(tl.int32)], + dtype=(tl.uint32), + is_pure=False, + pack=1, + ) diff --git a/megatron/core/inference/communication/torch_symm_triton/utils.py b/megatron/core/inference/communication/torch_symm_triton/utils.py new file mode 100644 index 0000000000..785481dfba --- /dev/null +++ b/megatron/core/inference/communication/torch_symm_triton/utils.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Adapted from: https://github.com/meta-pytorch/kraken.git + +from unittest.mock import MagicMock + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl +except ImportError: + triton = MagicMock() + tl = MagicMock() + triton.jit = null_decorator + + +@triton.jit +def get_tid(): + """ + Returns the thread IDs in x, y, z dimensions. + """ + return tl.inline_asm_elementwise( + """ + mov.u32 $0, %tid.x; + mov.u32 $1, %tid.y; + mov.u32 $2, %tid.z; + """, + "=r,=r,=r", + [], + dtype=(tl.uint32, tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + + +@triton.jit +def get_ntid(): + """ + Returns the number of threads in x, y, z dimensions. + """ + return tl.inline_asm_elementwise( + """ + mov.u32 $0, %ntid.x; + mov.u32 $1, %ntid.y; + mov.u32 $2, %ntid.z; + """, + "=r,=r,=r", + [], + dtype=(tl.uint32, tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + + +@triton.jit +def get_flat_tid(): + """ + Calculates a unique, one-dimensional ID for each thread within its thread block. + """ + tid_x, tid_y, tid_z = get_tid() + ntid_x, ntid_y, _ = get_ntid() + return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x + + +@triton.jit +def get_flat_bid(): + """ + Calculates a unique, one-dimensional ID for each block within the grid.""" + return ( + tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0) + + tl.program_id(1) * tl.num_programs(0) + + tl.program_id(0) + ) + + +@triton.jit +def sync_threads(): + """ + Synchronize all threads within a block. + """ + tl.inline_asm_elementwise("bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1) diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index bfe38c2bbc..f83275ed9c 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -12,6 +12,10 @@ from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules from megatron.core.ssm.mlp_layer import MLPLayer +from megatron.core.tensor_parallel import ( + InferenceLayerNormColumnParallelLinear, + InferenceRowParallelLinear, +) from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.mlp import MLP, MLPSubmodules @@ -82,3 +86,63 @@ ), ), ) + +mamba_inference_stack_spec = ModuleSpec( + module=MambaStack, + submodules=MambaStackSubmodules( + mamba_layer=ModuleSpec( + module=MambaLayer, + submodules=MambaLayerSubmodules( + mixer=ModuleSpec( + module=MambaMixer, + submodules=MambaMixerSubmodules( + in_proj=InferenceLayerNormColumnParallelLinear, + out_proj=InferenceRowParallelLinear, + ), + ), + mamba_bda=get_bias_dropout_add, + ), + ), + # Started with spec from gpt_layer_specs.py (with MLP removed) + # Using the TE spec because we had problems getting the non-TE spec + # working + attention_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=InferenceLayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=InferenceRowParallelLinear, + ), + ), + self_attn_bda=get_bias_dropout_add, + ), + ), + # Started with spec from gpt_layer_specs.py + # Using the TE spec because we had problems getting the non-TE spec + # working + mlp_layer=ModuleSpec( + module=MLPLayer, + submodules=TransformerLayerSubmodules( + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=InferenceLayerNormColumnParallelLinear, + linear_fc2=InferenceRowParallelLinear, + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + moe_layer=ModuleSpec( + # TODO (rwaleffe): change this to be an "MoELayer" to work with CudaGraphs? + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add + ), + ), + ), +) diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 1916bfff07..e7c2136224 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -11,7 +11,7 @@ import numpy as np import torch -from .utils import GlobalMemoryBuffer, is_torch_min_version +from .utils import GlobalMemoryBuffer, GlobalSymmetricMemoryBuffer, is_torch_min_version logger = logging.getLogger(__name__) @@ -132,6 +132,9 @@ # Memory buffers to avoid dynamic memory allocation _GLOBAL_MEMORY_BUFFER = None +# Global symmetric memory buffer for inference +_GLOBAL_SYMMETRIC_MEMORY_BUFFER = None + # List of all process groups # Used for updating the timeout for all process groups # None represents the default process group @@ -1285,6 +1288,9 @@ def initialize_model_parallel( # we could stick it there _set_global_memory_buffer() + # initialize global symmetric memory buffer + _set_global_symmetric_memory_buffer() + def is_initialized(): """Useful for code segments that may be accessed with or without mpu initialization""" @@ -1927,18 +1933,43 @@ def _set_global_memory_buffer(): _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() +def _set_global_symmetric_memory_buffer(): + """Initialize global buffer.""" + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER + assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER is None, "global memory buffer is already initialized" + + _GLOBAL_SYMMETRIC_MEMORY_BUFFER = GlobalSymmetricMemoryBuffer( + size_in_mb=256, # todo: set from an argument? + process_group=get_tensor_model_parallel_group(), + ) + + def get_global_memory_buffer(): """Return the global GlobalMemoryBuffer object""" assert _GLOBAL_MEMORY_BUFFER is not None, "global memory buffer is not initialized" return _GLOBAL_MEMORY_BUFFER +def get_global_symmetric_memory_buffer(): + """Return the global GlobalSymmetricMemoryBuffer object""" + assert ( + _GLOBAL_SYMMETRIC_MEMORY_BUFFER is not None + ), "global symmetric memory buffer is not initialized" + return _GLOBAL_SYMMETRIC_MEMORY_BUFFER + + def destroy_global_memory_buffer(): """Sets the global memory buffer to None""" global _GLOBAL_MEMORY_BUFFER _GLOBAL_MEMORY_BUFFER = None +def destroy_global_symmetric_memory_buffer(): + """Sets the global symmetric memory buffer to None""" + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER + _GLOBAL_SYMMETRIC_MEMORY_BUFFER = None + + def get_all_ranks(): """Get caller's rank in tensor-model-parallel, data-parallel, context-parallel, pipeline-model-parallel and expert-model-parallel groups.""" @@ -2014,6 +2045,9 @@ def destroy_model_parallel(): global _GLOBAL_MEMORY_BUFFER _GLOBAL_MEMORY_BUFFER = None + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER + _GLOBAL_SYMMETRIC_MEMORY_BUFFER = None + global _DATA_PARALLEL_GROUP_GLOO if ( _DATA_PARALLEL_GROUP_GLOO is not None diff --git a/megatron/core/tensor_parallel/__init__.py b/megatron/core/tensor_parallel/__init__.py index afa53bdc6e..98cc5efec8 100644 --- a/megatron/core/tensor_parallel/__init__.py +++ b/megatron/core/tensor_parallel/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from .cross_entropy import vocab_parallel_cross_entropy from .data import broadcast_data +from .inference_layers import InferenceLayerNormColumnParallelLinear, InferenceRowParallelLinear from .layers import ( ColumnParallelLinear, RowParallelLinear, diff --git a/megatron/core/tensor_parallel/inference_layers.py b/megatron/core/tensor_parallel/inference_layers.py index 05f7b88d09..ddba196104 100644 --- a/megatron/core/tensor_parallel/inference_layers.py +++ b/megatron/core/tensor_parallel/inference_layers.py @@ -10,7 +10,12 @@ TELayerNormColumnParallelLinear, TERowParallelLinear, ) +from megatron.core.inference.communication.torch_symm_triton import ( + multimem_all_gather, + multimem_reduce_scatter, +) from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.parallel_state import get_global_symmetric_memory_buffer from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -85,14 +90,45 @@ def __init__( config.sequence_parallel ), "--transformer-impl=inference_optimized requires --sequence-parallel" + def _all_gather(self, x: torch.Tensor) -> None: + """ + Attempt an NVLS all-gather into symmetric memory. If not possible, + revert to torch dist (NCCL) all-gather. + """ + if self.tp_size == 1: + return x + + # 1. check if bf16 + is_bf16 = x.dtype == torch.bfloat16 + # 2. check if hopper or newer + is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 + # 3. attempt to ask for symmetric memory + symm_mem_buffer_dims = list(x.size()) + symm_mem_buffer_dims[0] *= self.tp_size + symm_mem_buffer = get_global_symmetric_memory_buffer().maybe_get_tensor( + symm_mem_buffer_dims, dtype=x.dtype + ) + has_enough_symmetric_memory = symm_mem_buffer["handle"] is not None + can_use_custom_nvls_collectives = ( + is_bf16 and is_hopper_or_newer and has_enough_symmetric_memory + ) + + if can_use_custom_nvls_collectives: + # do multimem all gather + multimem_all_gather(symm_mem_buffer["tensor"], x, symm_mem_buffer["handle"]) + return symm_mem_buffer["tensor"] + else: + # revert to torch dist (NCCL) all gather + x, _ = gather_along_first_dim(x, process_group=self.tp_group) + return x + @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass. """ x = _te_rms_norm_kernel(x=x, weight=self.layer_norm_weight, eps=self.eps) - if self.tp_size > 1: - x, _ = gather_along_first_dim(x, process_group=self.tp_group) + x = self._all_gather(x) x = torch.matmul(x, self.weight.t()) return x, None @@ -140,12 +176,51 @@ def __init__( config.sequence_parallel ), "--transformer-impl=inference_optimized requires --sequence-parallel" + def _matmul_reduce_scatter(self, x): + """ + Multiplies x by the weight matrix and performs a reduce-scatter. + It will first try to write the matmul output to symmetric memory + and perform an NVLS multicast reduce-scatter. If that is not possible, + it will revert to torch.dist (NCCL) reduce-scatter. + """ + # 1. check if bf16 + is_bf16 = x.dtype == torch.bfloat16 + # 2. check if hopper + is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 + # 3. attempt to ask for symmetric memory + symm_mem_buffer_dims = list(x.size()) + symm_mem_buffer_dims[-1] = self.weight.size(0) + symm_mem_buffer = get_global_symmetric_memory_buffer().maybe_get_tensor( + symm_mem_buffer_dims, dtype=x.dtype + ) + has_enough_symmetric_memory = symm_mem_buffer["handle"] is not None + can_use_custom_nvls_collectives = ( + is_bf16 and is_hopper_or_newer and has_enough_symmetric_memory + ) + if can_use_custom_nvls_collectives: + # Write output of matmul directly onto the symmetric memory buffer + torch.matmul(x, self.weight.t(), out=symm_mem_buffer["tensor"]) + x = symm_mem_buffer["tensor"] + # perform nvls reduce-scatter + output_dims = list(x.size()) + output_dims[0] = x.size(0) // self.tp_size + output = torch.empty(output_dims, dtype=x.dtype, device=x.device) + multimem_reduce_scatter(output, x, symm_mem_buffer["handle"]) + return output + else: + # revert to torch dist (NCCL) reduce-scatter + x = torch.matmul(x, self.weight.t()) + x, _ = reduce_scatter_along_first_dim(x, tp_group=self.tp_group) + return x + @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass. """ - x = torch.matmul(x, self.weight.t()) - if self.tp_size > 1: - x, _ = reduce_scatter_along_first_dim(x, tp_group=self.tp_group) - return x, None + if self.tp_size == 1: + x = torch.matmul(x, self.weight.t()) + return x, None + else: + x = self._matmul_reduce_scatter(x) + return x, None diff --git a/megatron/core/utils.py b/megatron/core/utils.py index c4129bd5d2..0d6f417293 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -29,6 +29,20 @@ import numpy import torch +try: + import torch.distributed._symmetric_memory as symm_mem + + HAVE_TORCH_SYMM_MEM = True +except ImportError: + HAVE_TORCH_SYMM_MEM = False + +try: + import triton # pylint: disable=unused-import + + HAVE_TRITON = True +except ImportError: + HAVE_TRITON = False + from megatron.core import config from megatron.core.package_info import __version__ as mcore_version @@ -616,6 +630,65 @@ def get_tensor(self, tensor_shape, dtype, name, mem_alloc_context: Optional[Call return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) +class GlobalSymmetricMemoryBuffer: + """ + Global symmetric memory buffer used in inference. + This buffer is used by mcore-inference's low-latency + NVLS all-gather and reduce-scatter collectives. + """ + + def __init__(self, size_in_mb, process_group): + if not HAVE_TORCH_SYMM_MEM or not HAVE_TRITON: + # This should be hit if the user is running an older + # version of torch, or if they do not have triton + # installed. + self.symm_buffer = None + self.symm_mem_hdl = None + else: + numel = int(size_in_mb * 1024 * 1024) # size in bytes + try: + symm_mem.enable_symm_mem_for_group(process_group.group_name) + self.symm_buffer = symm_mem.empty(numel, dtype=torch.uint8, device='cuda') + self.symm_mem_hdl = symm_mem.rendezvous(self.symm_buffer, process_group) + except RuntimeError as e: + # If symmetric memory initialization fails, set buffer and handle to None + # This should happen if the process group is not contained within NVlink + self.symm_buffer = None + self.symm_mem_hdl = None + + def _can_allocate(self, numel, dtype) -> bool: + """ + Returns whether enough symmetric memory is available + for the given tensor shape and dtype. + """ + if self.symm_mem_hdl is None: + return False + size_of_dtype = torch.tensor([], dtype=dtype).element_size() + required_len = numel * size_of_dtype + return required_len <= self.symm_buffer.numel() + + def _allocate(self, numel, dtype) -> torch.Tensor: + """ + Allocates a sub-tensor from the self.symm_buffer for the given numel and dtype""" + required_bytes = numel * torch.tensor([], dtype=dtype).element_size() + return self.symm_buffer[0:required_bytes].view(dtype).view(numel) + + def maybe_get_tensor(self, tensor_shape, dtype): + """ + Returns (potentially) a sub-tensor from the self.symm_buffer for the given shape. + If enough symmetric memory is not available, returns None. + """ + if self.symm_mem_hdl is None: + return {"tensor": None, "handle": None} + numel = reduce(operator.mul, tensor_shape, 1) + if not self._can_allocate(numel, dtype): + return {"tensor": None, "handle": None} + return { + "tensor": self._allocate(numel, dtype).view(*tensor_shape), + "handle": self.symm_mem_hdl, + } + + def _kernel_make_viewless_tensor(inp, requires_grad): """Make a viewless tensor. diff --git a/megatron/training/training.py b/megatron/training/training.py index eb7a903561..c3d68e2860 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -93,6 +93,7 @@ from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper from megatron.core.parallel_state import ( destroy_global_memory_buffer, + destroy_global_symmetric_memory_buffer, destroy_model_parallel, update_pg_timeout ) @@ -145,6 +146,7 @@ def destroy_global_state(): destroy_global_vars() destroy_num_microbatches_calculator() destroy_global_memory_buffer() + destroy_global_symmetric_memory_buffer() destroy_model_parallel() destroy_rerun_state_machine() diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/golden_values_dev_dgx_h100.json index 4ebaf72f5e..944863ce00 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/golden_values_dev_dgx_h100.json @@ -34,125 +34,125 @@ 1278, 2362 ], - "latency": 42.63835311005823, + "latency": 23.35220137424767, "logprobs": [ - -9.358713150024414, - -2.724055767059326, - -4.5792131423950195, - -1.4844143390655518, - -0.6546129584312439, - -1.7303215265274048, - -2.4795279502868652, - -2.0776171684265137, - -2.4553134441375732, - -6.219150066375732, - -1.566371202468872, - -3.486889362335205, - -4.418787479400635, - -3.8580172061920166, - -2.0664010047912598, - -1.843908667564392, - -3.744598627090454, - -6.82543420791626, - -0.2880207300186157, - -0.9257857799530029, - -6.612694263458252, - -7.218401908874512, - -12.827808380126953, - -2.1861495971679688, - -3.8218231201171875, - -0.5008565187454224, - -4.383245468139648, - -0.06934759020805359, - -0.09667497128248215, - -3.2640299797058105, - -10.102912902832031, - -1.1498218774795532, - -5.979549407958984, - -5.0192108154296875, - -3.8367133140563965, - -2.581653356552124, - -3.4087462425231934, - -5.545716285705566, - -1.6541939973831177, - -5.547749996185303, - -12.21850872039795, - -12.582784652709961, - -0.09534379839897156, - -2.522055149078369, - -1.4054086208343506, - -2.8758127689361572, - -1.1866405010223389, - -0.005799253936856985, - -3.3871712684631348, - -13.193516731262207, - -4.389392852783203, - -2.520228862762451, - -6.023908615112305, - -0.7408540844917297, - -0.04526234790682793, - -1.5508661270141602, - -1.1332746744155884, - -5.653256416320801, - -0.4028852581977844, - -4.9457244873046875, - -0.618165135383606, - -0.6616490483283997, - -2.36385178565979, - -13.6455078125, - -0.08668932318687439, - -3.5266754627227783, - -1.3801541328430176, - -6.351947784423828, - -0.5434023141860962, - -3.5673093795776367, - -0.871107816696167, - -1.618450403213501, - -5.378700256347656, - -17.17119026184082, - -6.662005424499512, - -0.9221409559249878, - -4.141905784606934, - -1.2047083377838135, - -2.227570056915283, - -1.7645721435546875, - -0.21892313659191132, - -9.296550750732422, - -0.11995092779397964, - -7.402207851409912, - -2.512965679168701, - -4.100971221923828, - -3.580245018005371, - -1.9462040662765503, - -2.347074031829834, - -1.5288957357406616, - -2.4033043384552, - -1.7311294078826904, - -1.1686863899230957, - -2.938558340072632, - -0.5278136730194092, - -0.4748117923736572, - -1.749883770942688, - -0.8397680521011353, - -0.4109693169593811, - -0.9552587270736694, - -1.5238327980041504, - -0.4656376838684082, - -1.6448218822479248, - -0.5414345264434814, - -1.2422380447387695, - -1.1426063776016235, - -0.002245525596663356, - -1.252556562423706, - -0.007873333990573883, - -0.7185167670249939, - -0.7521701455116272, - -0.042445242404937744, - -0.8852499723434448, - -0.02266514115035534, - -2.0951969623565674, - -1.348037838935852, - -0.8296748399734497 + -9.35879135131836, + -2.7352774143218994, + -4.542932987213135, + -1.4809632301330566, + -0.6577711701393127, + -1.7310287952423096, + -2.5016393661499023, + -2.054267168045044, + -2.4450795650482178, + -6.180659294128418, + -1.568453073501587, + -3.404385805130005, + -4.357839584350586, + -3.9313418865203857, + -2.001478672027588, + -1.8802878856658936, + -3.8159995079040527, + -6.879362106323242, + -0.28638726472854614, + -0.9805830717086792, + -6.659268856048584, + -7.184902667999268, + -12.831036567687988, + -2.2628769874572754, + -3.80989933013916, + -0.5026318430900574, + -4.312714576721191, + -0.06652869284152985, + -0.10383106768131256, + -3.221609354019165, + -10.062438011169434, + -1.19387686252594, + -5.972838401794434, + -5.059903621673584, + -3.794962167739868, + -2.58512020111084, + -3.407836675643921, + -5.576328277587891, + -1.6389069557189941, + -5.498246669769287, + -12.218515396118164, + -12.583944320678711, + -0.09274326264858246, + -2.500924587249756, + -1.370800256729126, + -2.858417510986328, + -1.1951555013656616, + -0.006517108529806137, + -3.3397316932678223, + -13.183527946472168, + -4.315248966217041, + -2.4844048023223877, + -6.052038192749023, + -0.7679911851882935, + -0.05106499418616295, + -1.5119061470031738, + -1.148835301399231, + -5.648500442504883, + -0.42955976724624634, + -4.942170143127441, + -0.6178378462791443, + -0.7215086221694946, + -2.4680683612823486, + -13.656073570251465, + -0.09046748280525208, + -3.528261184692383, + -1.3840829133987427, + -6.3916826248168945, + -0.590160071849823, + -3.512652635574341, + -0.8600459694862366, + -1.6373299360275269, + -5.384238243103027, + -17.205631256103516, + -6.648115634918213, + -0.890762984752655, + -4.155974388122559, + -1.1969019174575806, + -2.251375675201416, + -1.7827272415161133, + -0.21727021038532257, + -9.323517799377441, + -0.11923929303884506, + -7.317551136016846, + -2.5149247646331787, + -4.099612236022949, + -3.5964670181274414, + -1.9214924573898315, + -2.305270195007324, + -1.5137361288070679, + -2.3835322856903076, + -1.7124545574188232, + -1.1756497621536255, + -3.0433411598205566, + -0.5281094312667847, + -0.4586932063102722, + -1.7248739004135132, + -0.8336725831031799, + -0.4110657572746277, + -0.9216307401657104, + -1.4833365678787231, + -0.4625704288482666, + -1.636054277420044, + -0.5516311526298523, + -1.2232449054718018, + -1.2100636959075928, + -0.002353756921365857, + -1.1664479970932007, + -0.007350543048232794, + -0.7310623526573181, + -0.7930303812026978, + -0.049882158637046814, + -0.8908950686454773, + -0.019804010167717934, + -2.044306755065918, + -1.3121578693389893, + -0.8065381050109863 ] } -} +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/model_config.yaml index 551ba8115c..ddb560715e 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/model_config.yaml @@ -42,9 +42,6 @@ MODEL_ARGS: --top_k: 1 --return-log-probs: true --num-tokens-to-generate: 30 - --inference-dynamic-batching-max-requests-override: 8 # hardcode decode padding tokens to 7 for reproducibility - --inference-dynamic-batching-buffer-guaranteed-fraction: 0 - --inference-dynamic-batching-buffer-overflow-factor: 0.2 --inference-dynamic-batching-buffer-size-gb: 20 --dist-ckpt-strictness: log_unexpected --inference-ckpt-non-strict: true # To handle the extra_state errors