From 255cb996f4e55e08a343d90c510d78f109a66a61 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Sat, 21 Jun 2025 02:09:05 +0100 Subject: [PATCH 1/4] feat(serve.py): add api_path parameter to BaseLitAPI and derived classes to allow custom API endpoint configuration refactor(serve.py): improve code readability by formatting long lines and restructuring some method calls for clarity style(serve.py): format code for improved readability by removing unnecessary line breaks and aligning parameters fix(serve.py): resolve api_path to default value when not using OpenAI spec or stream to ensure server functionality fix(serve.py): set default api_path to "/predict" for run_server function to ensure consistent behavior when no path is provided refactor(serve.py): simplify api_path handling by removing unnecessary resolved_api_path variable --- litgpt/deploy/serve.py | 70 ++++++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 16 deletions(-) diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index 44aa0b576f..a3ee86e1b4 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -1,4 +1,3 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import json import sys from pathlib import Path @@ -31,11 +30,10 @@ def __init__( top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, + api_path: Optional[str] = None, ) -> None: - if not _LITSERVE_AVAILABLE: - raise ImportError(str(_LITSERVE_AVAILABLE)) + super().__init__(api_path=api_path) - super().__init__() self.checkpoint_dir = checkpoint_dir self.quantize = quantize self.precision = precision @@ -61,12 +59,11 @@ def setup(self, device: str) -> None: accelerator=accelerator, quantize=self.quantize, precision=self.precision, - generate_strategy="sequential" if self.devices is not None and self.devices > 1 else None, + generate_strategy=("sequential" if self.devices is not None and self.devices > 1 else None), ) print("Model successfully initialized.", file=sys.stderr) def decode_request(self, request: Dict[str, Any]) -> Any: - # Convert the request payload to your model input. prompt = str(request["prompt"]) return prompt @@ -82,20 +79,34 @@ def __init__( top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, + api_path: Optional[str] = None, ): - super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices) + super().__init__( + checkpoint_dir, + quantize, + precision, + temperature, + top_k, + top_p, + max_new_tokens, + devices, + api_path=api_path, + ) def setup(self, device: str): super().setup(device) def predict(self, inputs: str) -> Any: output = self.llm.generate( - inputs, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, max_new_tokens=self.max_new_tokens + inputs, + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p, + max_new_tokens=self.max_new_tokens, ) return output def encode_response(self, output: str) -> Dict[str, Any]: - # Convert the model output to a response payload. return {"output": output} @@ -110,14 +121,24 @@ def __init__( top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, + api_path: Optional[str] = None, ): - super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices) + super().__init__( + checkpoint_dir, + quantize, + precision, + temperature, + top_k, + top_p, + max_new_tokens, + devices, + api_path=api_path, + ) def setup(self, device: str): super().setup(device) def predict(self, inputs: torch.Tensor) -> Any: - # Run the model on the input and return the output. yield from self.llm.generate( inputs, temperature=self.temperature, @@ -143,8 +164,19 @@ def __init__( top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, + api_path: Optional[str] = None, ): - super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices) + super().__init__( + checkpoint_dir, + quantize, + precision, + temperature, + top_k, + top_p, + max_new_tokens, + devices, + api_path=api_path, + ) def setup(self, device: str): super().setup(device) @@ -167,18 +199,20 @@ def setup(self, device: str): self.template = Template(self.chat_template) def decode_request(self, request: "ChatCompletionRequest") -> Any: - # Apply chat template to request messages return self.template.render(messages=request.messages) def predict(self, inputs: str, context: dict) -> Any: - # Extract parameters from context with fallback to instance attributes temperature = context.get("temperature") or self.temperature top_p = context.get("top_p", self.top_p) or self.top_p max_new_tokens = context.get("max_completion_tokens") or self.max_new_tokens - # Run the model on the input and return the output. yield from self.llm.generate( - inputs, temperature=temperature, top_k=self.top_k, top_p=top_p, max_new_tokens=max_new_tokens, stream=True + inputs, + temperature=temperature, + top_k=self.top_k, + top_p=top_p, + max_new_tokens=max_new_tokens, + stream=True, ) @@ -196,6 +230,7 @@ def run_server( stream: bool = False, openai_spec: bool = False, access_token: Optional[str] = None, + api_path: Optional[str] = "/predict", ) -> None: """Serve a LitGPT model using LitServe. @@ -235,11 +270,13 @@ def run_server( stream: Whether to stream the responses. openai_spec: Whether to use the OpenAISpec. access_token: Optional API token to access models with restrictions. + api_path: The custom API path for the endpoint (e.g., "/my_api/classify"). """ checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) api_class = OpenAISpecLitAPI if openai_spec else StreamLitAPI if stream else SimpleLitAPI + server = LitServer( api_class( checkpoint_dir=checkpoint_dir, @@ -250,6 +287,7 @@ def run_server( top_p=top_p, max_new_tokens=max_new_tokens, devices=devices, + api_path=api_path, ), spec=OpenAISpec() if openai_spec else None, accelerator=accelerator, From 143544ee884600ac89a878de8c6bccc257d2096f Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Tue, 22 Jul 2025 15:30:02 +0100 Subject: [PATCH 2/4] chore: add back what llm removed --- litgpt/deploy/serve.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index 3ea4870ffb..fa03aa531a 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -1,3 +1,4 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import json import sys from pathlib import Path @@ -32,6 +33,9 @@ def __init__( devices: int = 1, api_path: Optional[str] = None, ) -> None: + if not _LITSERVE_AVAILABLE: + raise ImportError(str(_LITSERVE_AVAILABLE)) + super().__init__(api_path=api_path) self.checkpoint_dir = checkpoint_dir From e93dc0d66a9798f8a9d50302d6fb1aae34652395 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Tue, 22 Jul 2025 15:31:23 +0100 Subject: [PATCH 3/4] chore: add back comments --- litgpt/deploy/serve.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index fa03aa531a..9afd8f9ec4 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -203,13 +203,17 @@ def setup(self, device: str): self.template = Template(self.chat_template) def decode_request(self, request: "ChatCompletionRequest") -> Any: + # Apply chat template to request messages return self.template.render(messages=request.messages) def predict(self, inputs: str, context: dict) -> Any: + # Extract parameters from context with fallback to instance attributes temperature = context.get("temperature") or self.temperature top_p = context.get("top_p", self.top_p) or self.top_p max_new_tokens = context.get("max_completion_tokens") or self.max_new_tokens + # Run the model on the input and return the output. + yield from self.llm.generate( inputs, temperature=temperature, From 87b5c5def00baf1701c64d0d5a5437ab5c8a73b8 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Tue, 22 Jul 2025 15:32:00 +0100 Subject: [PATCH 4/4] chore: add back comments --- litgpt/deploy/serve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index 9afd8f9ec4..21f14fec65 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -111,6 +111,7 @@ def predict(self, inputs: str) -> Any: return output def encode_response(self, output: str) -> Dict[str, Any]: + # Convert the model output to a response payload. return {"output": output} @@ -213,7 +214,6 @@ def predict(self, inputs: str, context: dict) -> Any: max_new_tokens = context.get("max_completion_tokens") or self.max_new_tokens # Run the model on the input and return the output. - yield from self.llm.generate( inputs, temperature=temperature,