Skip to content

feat(serve.py): add api_path parameter to cli options to allow custom API endpoint configuration #2080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ def __init__(
top_p: float = 1.0,
max_new_tokens: int = 50,
devices: int = 1,
api_path: Optional[str] = None,
) -> None:
if not _LITSERVE_AVAILABLE:
raise ImportError(str(_LITSERVE_AVAILABLE))

super().__init__()
super().__init__(api_path=api_path)

self.checkpoint_dir = checkpoint_dir
self.quantize = quantize
self.precision = precision
Expand All @@ -61,12 +63,11 @@ def setup(self, device: str) -> None:
accelerator=accelerator,
quantize=self.quantize,
precision=self.precision,
generate_strategy="sequential" if self.devices is not None and self.devices > 1 else None,
generate_strategy=("sequential" if self.devices is not None and self.devices > 1 else None),
)
print("Model successfully initialized.", file=sys.stderr)

def decode_request(self, request: Dict[str, Any]) -> Any:
# Convert the request payload to your model input.
prompt = str(request["prompt"])
return prompt

Expand All @@ -82,15 +83,30 @@ def __init__(
top_p: float = 1.0,
max_new_tokens: int = 50,
devices: int = 1,
api_path: Optional[str] = None,
):
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices)
super().__init__(
checkpoint_dir,
quantize,
precision,
temperature,
top_k,
top_p,
max_new_tokens,
devices,
api_path=api_path,
)

def setup(self, device: str):
super().setup(device)

def predict(self, inputs: str) -> Any:
output = self.llm.generate(
inputs, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, max_new_tokens=self.max_new_tokens
inputs,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
max_new_tokens=self.max_new_tokens,
)
return output

Expand All @@ -110,14 +126,24 @@ def __init__(
top_p: float = 1.0,
max_new_tokens: int = 50,
devices: int = 1,
api_path: Optional[str] = None,
):
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices)
super().__init__(
checkpoint_dir,
quantize,
precision,
temperature,
top_k,
top_p,
max_new_tokens,
devices,
api_path=api_path,
)

def setup(self, device: str):
super().setup(device)

def predict(self, inputs: torch.Tensor) -> Any:
# Run the model on the input and return the output.
yield from self.llm.generate(
inputs,
temperature=self.temperature,
Expand All @@ -143,8 +169,19 @@ def __init__(
top_p: float = 1.0,
max_new_tokens: int = 50,
devices: int = 1,
api_path: Optional[str] = None,
):
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices)
super().__init__(
checkpoint_dir,
quantize,
precision,
temperature,
top_k,
top_p,
max_new_tokens,
devices,
api_path=api_path,
)

def setup(self, device: str):
super().setup(device)
Expand Down Expand Up @@ -178,7 +215,12 @@ def predict(self, inputs: str, context: dict) -> Any:

# Run the model on the input and return the output.
yield from self.llm.generate(
inputs, temperature=temperature, top_k=self.top_k, top_p=top_p, max_new_tokens=max_new_tokens, stream=True
inputs,
temperature=temperature,
top_k=self.top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
stream=True,
)


Expand All @@ -196,6 +238,7 @@ def run_server(
stream: bool = False,
openai_spec: bool = False,
access_token: Optional[str] = None,
api_path: Optional[str] = "/predict",
) -> None:
"""Serve a LitGPT model using LitServe.

Expand Down Expand Up @@ -237,11 +280,13 @@ def run_server(
`/v1/chat/completions` endpoints that work with the OpenAI SDK and other OpenAI-compatible clients,
making it easy to integrate with existing applications that use the OpenAI API.
access_token: Optional API token to access models with restrictions.
api_path: The custom API path for the endpoint (e.g., "/my_api/classify").
"""
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())

api_class = OpenAISpecLitAPI if openai_spec else StreamLitAPI if stream else SimpleLitAPI

server = LitServer(
api_class(
checkpoint_dir=checkpoint_dir,
Expand All @@ -252,6 +297,7 @@ def run_server(
top_p=top_p,
max_new_tokens=max_new_tokens,
devices=devices,
api_path=api_path,
),
spec=OpenAISpec() if openai_spec else None,
accelerator=accelerator,
Expand Down