Skip to content

Commit b189e18

Browse files
committed
Add CAII endpoint description retrieval and update related models
1 parent 4bcbddd commit b189e18

File tree

5 files changed

+68
-22
lines changed

5 files changed

+68
-22
lines changed

llm-service/app/routers/index/models/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141

4242
from .... import exceptions
4343
from ....services import models
44-
from ....services.caii.types import ModelResponse
44+
from ....services.caii.caii import describe_endpoint, build_model_response
45+
from ....services.caii.types import ModelResponse, Endpoint
4546

4647
router = APIRouter(prefix="/models", tags=["Models"])
4748

@@ -72,6 +73,17 @@ def get_model() -> models.ModelSource:
7273
return models.get_model_source()
7374

7475

76+
@router.get(path="/caii/endpoint/{endpoint_name}", summary="Get CAII endpoint details.")
77+
@exceptions.propagates
78+
def get_endpoint_description(endpoint_name: str) -> ModelResponse:
79+
"""
80+
Get the details of a specific CAII endpoint by its name.
81+
"""
82+
endpoint = describe_endpoint(endpoint_name)
83+
model_response = build_model_response(endpoint)
84+
return model_response
85+
86+
7587
@router.get("/llm/{model_name}/test", summary="Test LLM Inference model.")
7688
@exceptions.propagates
7789
def llm_model_test(model_name: str) -> Literal["ok"]:

llm-service/app/services/caii/CaiiEmbeddingModel.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,23 @@
4242
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
4343
from pydantic import Field
4444

45-
from .types import Endpoint
45+
from .types import Endpoint, ListEndpointEntry
4646
from .utils import build_auth_headers
4747

4848

4949
class CaiiEmbeddingModel(BaseEmbedding):
50-
endpoint: Endpoint = Field(
51-
Endpoint, description="The endpoint to use for embeddings"
50+
endpoint: ListEndpointEntry = Field(
51+
ListEndpointEntry, description="The endpoint to use for embeddings"
52+
)
53+
http_client: httpx.Client = Field(
54+
httpx.Client, description="The http client to use for requests"
5255
)
53-
http_client: httpx.Client = Field(httpx.Client, description="The http client to use for requests")
5456

55-
def __init__(self, endpoint: Endpoint, http_client: httpx.Client | None = httpx.Client()):
57+
def __init__(
58+
self,
59+
endpoint: ListEndpointEntry,
60+
http_client: httpx.Client | None = httpx.Client(),
61+
):
5662
super().__init__()
5763
self.endpoint = endpoint
5864
self.http_client = http_client or httpx.Client()
@@ -86,7 +92,9 @@ def _get_embedding(self, query: str, input_type: str) -> Embedding:
8692
def make_embedding_request(self, body: str) -> Any:
8793
headers = build_auth_headers()
8894
headers["Content-Type"] = "application/json"
89-
response = self.http_client.post(url=self.endpoint.url, content=body, headers=headers)
95+
response = self.http_client.post(
96+
url=self.endpoint.url, content=body, headers=headers
97+
)
9098
res = response.content
9199
json_response = res.decode("utf-8")
92100
structured_response = json.loads(json_response)

llm-service/app/services/caii/caii.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,30 +62,37 @@
6262
logger = logging.getLogger(__name__)
6363

6464

65-
def describe_endpoint(
65+
def describe_endpoint_entry(
6666
endpoint_name: str, endpoints: Optional[List[ListEndpointEntry]] = None
67-
) -> Endpoint:
67+
) -> ListEndpointEntry:
6868
if not endpoints:
69-
logger.info("Fetching endpoint details from CAII REST API")
70-
domain = settings.caii_domain
71-
headers = build_auth_headers()
72-
describe_url = f"https://{domain}/api/v1alpha1/describeEndpoint"
73-
desc_json = {"name": endpoint_name, "namespace": DEFAULT_NAMESPACE}
74-
75-
response = requests.post(describe_url, headers=headers, json=desc_json)
76-
raise_for_http_error(response)
77-
return Endpoint(**body_to_json(response))
69+
endpoint = describe_endpoint(endpoint_name)
70+
return ListEndpointEntry(**endpoint.model_dump())
7871

7972
logger.info("Fetching endpoint details from cached list of endpoints")
8073
for endpoint in endpoints:
8174
if endpoint.name == endpoint_name:
82-
return Endpoint(**endpoint.model_dump())
75+
return ListEndpointEntry(**endpoint.model_dump())
8376

8477
raise HTTPException(
8578
status_code=404, detail=f"Endpoint '{endpoint_name}' not found."
8679
)
8780

8881

82+
def describe_endpoint(
83+
endpoint_name: str,
84+
) -> Endpoint:
85+
logger.info("Fetching endpoint details from CAII REST API")
86+
domain = settings.caii_domain
87+
headers = build_auth_headers()
88+
describe_url = f"https://{domain}/api/v1alpha1/describeEndpoint"
89+
desc_json = {"name": endpoint_name, "namespace": DEFAULT_NAMESPACE}
90+
91+
response = requests.post(describe_url, headers=headers, json=desc_json)
92+
raise_for_http_error(response)
93+
return Endpoint(**body_to_json(response))
94+
95+
8996
def list_endpoints() -> list[ListEndpointEntry]:
9097
try:
9198
import cmlapi
@@ -123,7 +130,7 @@ def list_endpoints() -> list[ListEndpointEntry]:
123130

124131

125132
def get_reranking_model(model_name: str, top_n: int) -> BaseNodePostprocessor:
126-
endpoint = describe_endpoint(endpoint_name=model_name)
133+
endpoint = describe_endpoint_entry(endpoint_name=model_name)
127134
token = get_caii_access_token()
128135
return CaiiRerankingModel(
129136
model=endpoint.model_name,
@@ -167,7 +174,7 @@ def get_llm(
167174

168175
def get_embedding_model(model_name: str) -> BaseEmbedding:
169176
endpoint_name = model_name
170-
endpoint = describe_endpoint(endpoint_name=endpoint_name)
177+
endpoint = describe_endpoint_entry(endpoint_name=endpoint_name)
171178

172179
if os.path.exists("/etc/ssl/certs/ca-certificates.crt"):
173180
http_client = httpx.Client(verify="/etc/ssl/certs/ca-certificates.crt")
@@ -223,7 +230,10 @@ def get_caii_embedding_models() -> List[ModelResponse]:
223230
def get_models_with_task(task_type: str) -> List[Endpoint]:
224231
endpoints = list_endpoints()
225232
endpoint_details = list(
226-
map(lambda endpoint: describe_endpoint(endpoint.name, endpoints), endpoints)
233+
map(
234+
lambda endpoint: describe_endpoint_entry(endpoint.name, endpoints),
235+
endpoints,
236+
)
227237
)
228238
llm_endpoints = list(
229239
filter(

ui/src/api/modelsApi.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,21 @@ const getModelSource = async (): Promise<ModelSource> => {
147147
return await getRequest(`${llmServicePath}/models/model_source`);
148148
};
149149

150+
export const useGetCAIIModelStatus = (endpoint_name: string) => {
151+
return useQuery({
152+
queryKey: [QueryKeys.getCAIIModelStatus, { endpoint_name }],
153+
queryFn: async () => {
154+
return await getCAIIModelStatus(endpoint_name);
155+
},
156+
});
157+
};
158+
159+
const getCAIIModelStatus = async (endpoint_name: string): Promise<Model> => {
160+
return await getRequest(
161+
`${llmServicePath}/models/caii/endpoint/${endpoint_name}`,
162+
);
163+
};
164+
150165
export const useTestLlmModel = ({
151166
onSuccess,
152167
onError,

ui/src/api/utils.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ export enum QueryKeys {
114114
"getAmpConfig" = "getAmpConfig",
115115
"getTools" = "getTools",
116116
"getPollingAmpConfig" = "getPollingAmpConfig",
117+
"getCAIIModelStatus" = "getCAIIModelStatus",
117118
}
118119

119120
export const commonHeaders = {

0 commit comments

Comments
 (0)