Skip to content

Commit 1ac9df6

Browse files
committed
Merge remote-tracking branch 'lmcafee/lmcafee/dedupe-engine-coordinator' into tde/rl_4_out_of_4
2 parents b0e8bba + f3cf7b5 commit 1ac9df6

File tree

7 files changed

+229
-200
lines changed

7 files changed

+229
-200
lines changed

examples/inference/gpt/gpt_dynamic_inference_12b.sh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
3333
: ${CUDA_GRAPH_SHARE_IO_BUFFERS=1}
3434

3535
# Miscellaneous.
36+
: ${USE_COORDINATOR=0}
3637
: ${ENGINE=dynamic}
3738
: ${EXTRA_ARGS=""}
3839
# NSIGHT_PREFIX=/path/to/nsight/profile
@@ -85,7 +86,7 @@ ARGS=" \
8586
"
8687

8788
# Cuda graphs.
88-
if [ "${CUDA_GRAPH_IMPL}" = "local" ]; then
89+
if [ "${NUM_CUDA_GRAPHS}" != "0" ]; then
8990
ARGS+=" \
9091
--cuda-graph-impl local \
9192
--inference-dynamic-batching-num-cuda-graphs ${NUM_CUDA_GRAPHS} \
@@ -108,7 +109,12 @@ else
108109
fi
109110

110111
# Command.
111-
CMD="python -m examples.inference.gpt.gpt_${ENGINE}_inference ${ARGS}"
112+
if [[ "${USE_COORDINATOR}" == "0" ]]; then
113+
CMD="python -m examples.inference.gpt.gpt_${ENGINE}_inference ${ARGS}"
114+
else
115+
CMD="python -um examples.inference.gpt.gpt_${ENGINE}_inference_with_coordinator ${ARGS}"
116+
fi
117+
112118
if [[ -v NSIGHT_PREFIX ]]; then
113119
CMD="nsys profile -s none -t nvtx,cuda --cudabacktrace=all --cuda-graph-trace=node --python-backtrace=cuda --wait all -o ${NSIGHT_PREFIX} --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop ${CMD}"
114120
fi

examples/inference/gpt/gpt_dynamic_inference_357m.sh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
3434
: ${CUDA_GRAPH_SHARE_IO_BUFFERS=1}
3535

3636
# Miscellaneous.
37+
: ${USE_COORDINATOR=0}
3738
: ${ENGINE=dynamic}
3839
: ${EXTRA_ARGS=""}
3940
# NSIGHT_PREFIX=/path/to/nsight/profile
@@ -71,7 +72,7 @@ ARGS=" \
7172
"
7273

7374
# Cuda graphs.
74-
if [ "${CUDA_GRAPH_IMPL}" = "local" ]; then
75+
if [ "${NUM_CUDA_GRAPHS}" != "0" ]; then
7576
ARGS+=" \
7677
--cuda-graph-impl local \
7778
--inference-dynamic-batching-num-cuda-graphs ${NUM_CUDA_GRAPHS} \
@@ -94,7 +95,12 @@ else
9495
fi
9596

9697
# Command.
97-
CMD="python -m examples.inference.gpt.gpt_${ENGINE}_inference ${ARGS}"
98+
if [[ "${USE_COORDINATOR}" == "0" ]]; then
99+
CMD="python -m examples.inference.gpt.gpt_${ENGINE}_inference ${ARGS}"
100+
else
101+
CMD="python -um examples.inference.gpt.gpt_${ENGINE}_inference_with_coordinator ${ARGS}"
102+
fi
103+
98104
if [[ -v NSIGHT_PREFIX ]]; then
99105
CMD="nsys profile -s none -t nvtx,cuda --cudabacktrace=all --cuda-graph-trace=node --python-backtrace=cuda --wait all -o ${NSIGHT_PREFIX} --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop ${CMD}"
100106
fi

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,10 @@ def __init__(
5959
self, request_id: Optional[int], message: Optional[str] = None, *, is_transient: bool = True
6060
):
6161
request_str = '--' if request_id is None else str(request_id)
62-
message = "" if message is None else f" | {message}"
63-
super().__init__(f"request {request_str}{message}")
62+
_message = "" if message is None else f" | {message}"
63+
super().__init__(f"request {request_str}{_message}")
64+
self.request_id = request_id
65+
self.message = message
6466
self.is_transient = is_transient
6567

6668

@@ -102,6 +104,50 @@ def __init__(self, max_request_count, active_request_count):
102104
)
103105

104106

107+
class ContextErrorFactory:
108+
"""Factory class for serializing/deserializing context errors."""
109+
110+
@classmethod
111+
def serialize(cls, error: ContextOverflowError) -> dict:
112+
"""Serialize error.
113+
114+
Args:
115+
error (ContextOverflowError): Error.
116+
117+
Returns:
118+
(dict) Serialized error data.
119+
"""
120+
assert isinstance(error, ContextOverflowError)
121+
return {
122+
"type": type(error).__name__,
123+
"request_id": error.request_id,
124+
"message": error.message,
125+
"is_transient": error.is_transient,
126+
}
127+
128+
@classmethod
129+
def deserialize(cls, obj: dict) -> ContextOverflowError:
130+
"""Deserialize error.
131+
132+
Args:
133+
obj (dict): Serialized error data.
134+
135+
Returns:
136+
(ContextOverflowError) Deserialized error.
137+
"""
138+
error_cls = {
139+
"ContextOverflowError": ContextOverflowError,
140+
"RequestOverflowError": RequestOverflowError,
141+
"TokenOverflowError": TokenOverflowError,
142+
"MaxSequenceLengthOverflowError": MaxSequenceLengthOverflowError,
143+
"BlockOverflowError": BlockOverflowError,
144+
"ActiveRequestCountOverflowError": ActiveRequestCountOverflowError,
145+
}[obj["type"]]
146+
error = ContextOverflowError(**{k: v for k, v in obj.items() if k != "type"})
147+
error.__class__ = error_cls # todo (@lmcafe): better/safer alternative?
148+
return error
149+
150+
105151
class WarmupEngineMode(Enum):
106152
"""Enumeration for warmup engine modes used during cuda graph capture."""
107153

0 commit comments

Comments
 (0)