diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index bab5324608..21f14fec65 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -31,11 +31,13 @@ def __init__( top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, + api_path: Optional[str] = None, ) -> None: if not _LITSERVE_AVAILABLE: raise ImportError(str(_LITSERVE_AVAILABLE)) - super().__init__() + super().__init__(api_path=api_path) + self.checkpoint_dir = checkpoint_dir self.quantize = quantize self.precision = precision @@ -61,12 +63,11 @@ def setup(self, device: str) -> None: accelerator=accelerator, quantize=self.quantize, precision=self.precision, - generate_strategy="sequential" if self.devices is not None and self.devices > 1 else None, + generate_strategy=("sequential" if self.devices is not None and self.devices > 1 else None), ) print("Model successfully initialized.", file=sys.stderr) def decode_request(self, request: Dict[str, Any]) -> Any: - # Convert the request payload to your model input. prompt = str(request["prompt"]) return prompt @@ -82,15 +83,30 @@ def __init__( top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, + api_path: Optional[str] = None, ): - super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices) + super().__init__( + checkpoint_dir, + quantize, + precision, + temperature, + top_k, + top_p, + max_new_tokens, + devices, + api_path=api_path, + ) def setup(self, device: str): super().setup(device) def predict(self, inputs: str) -> Any: output = self.llm.generate( - inputs, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, max_new_tokens=self.max_new_tokens + inputs, + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p, + max_new_tokens=self.max_new_tokens, ) return output @@ -110,14 +126,24 @@ def __init__( top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, + api_path: Optional[str] = None, ): - super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices) + super().__init__( + checkpoint_dir, + quantize, + precision, + temperature, + top_k, + top_p, + max_new_tokens, + devices, + api_path=api_path, + ) def setup(self, device: str): super().setup(device) def predict(self, inputs: torch.Tensor) -> Any: - # Run the model on the input and return the output. yield from self.llm.generate( inputs, temperature=self.temperature, @@ -143,8 +169,19 @@ def __init__( top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, + api_path: Optional[str] = None, ): - super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices) + super().__init__( + checkpoint_dir, + quantize, + precision, + temperature, + top_k, + top_p, + max_new_tokens, + devices, + api_path=api_path, + ) def setup(self, device: str): super().setup(device) @@ -178,7 +215,12 @@ def predict(self, inputs: str, context: dict) -> Any: # Run the model on the input and return the output. yield from self.llm.generate( - inputs, temperature=temperature, top_k=self.top_k, top_p=top_p, max_new_tokens=max_new_tokens, stream=True + inputs, + temperature=temperature, + top_k=self.top_k, + top_p=top_p, + max_new_tokens=max_new_tokens, + stream=True, ) @@ -196,6 +238,7 @@ def run_server( stream: bool = False, openai_spec: bool = False, access_token: Optional[str] = None, + api_path: Optional[str] = "/predict", ) -> None: """Serve a LitGPT model using LitServe. @@ -237,11 +280,13 @@ def run_server( `/v1/chat/completions` endpoints that work with the OpenAI SDK and other OpenAI-compatible clients, making it easy to integrate with existing applications that use the OpenAI API. access_token: Optional API token to access models with restrictions. + api_path: The custom API path for the endpoint (e.g., "/my_api/classify"). """ checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) api_class = OpenAISpecLitAPI if openai_spec else StreamLitAPI if stream else SimpleLitAPI + server = LitServer( api_class( checkpoint_dir=checkpoint_dir, @@ -252,6 +297,7 @@ def run_server( top_p=top_p, max_new_tokens=max_new_tokens, devices=devices, + api_path=api_path, ), spec=OpenAISpec() if openai_spec else None, accelerator=accelerator,