|
62 | 62 | logger = logging.getLogger(__name__)
|
63 | 63 |
|
64 | 64 |
|
65 |
| -def describe_endpoint( |
| 65 | +def describe_endpoint_entry( |
66 | 66 | endpoint_name: str, endpoints: Optional[List[ListEndpointEntry]] = None
|
67 |
| -) -> Endpoint: |
| 67 | +) -> ListEndpointEntry: |
68 | 68 | if not endpoints:
|
69 |
| - logger.info("Fetching endpoint details from CAII REST API") |
70 |
| - domain = settings.caii_domain |
71 |
| - headers = build_auth_headers() |
72 |
| - describe_url = f"https://{domain}/api/v1alpha1/describeEndpoint" |
73 |
| - desc_json = {"name": endpoint_name, "namespace": DEFAULT_NAMESPACE} |
74 |
| - |
75 |
| - response = requests.post(describe_url, headers=headers, json=desc_json) |
76 |
| - raise_for_http_error(response) |
77 |
| - return Endpoint(**body_to_json(response)) |
| 69 | + endpoint = describe_endpoint(endpoint_name) |
| 70 | + return ListEndpointEntry(**endpoint.model_dump()) |
78 | 71 |
|
79 | 72 | logger.info("Fetching endpoint details from cached list of endpoints")
|
80 | 73 | for endpoint in endpoints:
|
81 | 74 | if endpoint.name == endpoint_name:
|
82 |
| - return Endpoint(**endpoint.model_dump()) |
| 75 | + return ListEndpointEntry(**endpoint.model_dump()) |
83 | 76 |
|
84 | 77 | raise HTTPException(
|
85 | 78 | status_code=404, detail=f"Endpoint '{endpoint_name}' not found."
|
86 | 79 | )
|
87 | 80 |
|
88 | 81 |
|
| 82 | +def describe_endpoint( |
| 83 | + endpoint_name: str, |
| 84 | +) -> Endpoint: |
| 85 | + logger.info("Fetching endpoint details from CAII REST API") |
| 86 | + domain = settings.caii_domain |
| 87 | + headers = build_auth_headers() |
| 88 | + describe_url = f"https://{domain}/api/v1alpha1/describeEndpoint" |
| 89 | + desc_json = {"name": endpoint_name, "namespace": DEFAULT_NAMESPACE} |
| 90 | + |
| 91 | + response = requests.post(describe_url, headers=headers, json=desc_json) |
| 92 | + raise_for_http_error(response) |
| 93 | + return Endpoint(**body_to_json(response)) |
| 94 | + |
| 95 | + |
89 | 96 | def list_endpoints() -> list[ListEndpointEntry]:
|
90 | 97 | try:
|
91 | 98 | import cmlapi
|
@@ -123,7 +130,7 @@ def list_endpoints() -> list[ListEndpointEntry]:
|
123 | 130 |
|
124 | 131 |
|
125 | 132 | def get_reranking_model(model_name: str, top_n: int) -> BaseNodePostprocessor:
|
126 |
| - endpoint = describe_endpoint(endpoint_name=model_name) |
| 133 | + endpoint = describe_endpoint_entry(endpoint_name=model_name) |
127 | 134 | token = get_caii_access_token()
|
128 | 135 | return CaiiRerankingModel(
|
129 | 136 | model=endpoint.model_name,
|
@@ -167,7 +174,7 @@ def get_llm(
|
167 | 174 |
|
168 | 175 | def get_embedding_model(model_name: str) -> BaseEmbedding:
|
169 | 176 | endpoint_name = model_name
|
170 |
| - endpoint = describe_endpoint(endpoint_name=endpoint_name) |
| 177 | + endpoint = describe_endpoint_entry(endpoint_name=endpoint_name) |
171 | 178 |
|
172 | 179 | if os.path.exists("/etc/ssl/certs/ca-certificates.crt"):
|
173 | 180 | http_client = httpx.Client(verify="/etc/ssl/certs/ca-certificates.crt")
|
@@ -223,7 +230,10 @@ def get_caii_embedding_models() -> List[ModelResponse]:
|
223 | 230 | def get_models_with_task(task_type: str) -> List[Endpoint]:
|
224 | 231 | endpoints = list_endpoints()
|
225 | 232 | endpoint_details = list(
|
226 |
| - map(lambda endpoint: describe_endpoint(endpoint.name, endpoints), endpoints) |
| 233 | + map( |
| 234 | + lambda endpoint: describe_endpoint_entry(endpoint.name, endpoints), |
| 235 | + endpoints, |
| 236 | + ) |
227 | 237 | )
|
228 | 238 | llm_endpoints = list(
|
229 | 239 | filter(
|
|
0 commit comments