Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions litellm/llms/vertex_ai/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,16 @@ def get_vertex_location_from_url(url: str) -> Optional[str]:
return match.group(1) if match else None


def get_vertex_model_id_from_url(url: str) -> Optional[str]:
"""
Get the vertex model id from the url

`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
"""
match = re.search(r"/models/([^/:]+)", url)
return match.group(1) if match else None


def replace_project_and_location_in_route(
requested_route: str, vertex_project: str, vertex_location: str
) -> str:
Expand Down Expand Up @@ -696,6 +706,15 @@ def construct_target_url(
if "cachedContent" in requested_route:
vertex_version = "v1beta1"

# Check if the requested route starts with a version
# e.g. /v1beta1/publishers/google/models/gemini-3-pro-preview:streamGenerateContent
if requested_route.startswith("/v1/"):
vertex_version = "v1"
requested_route = requested_route.replace("/v1/", "/", 1)
elif requested_route.startswith("/v1beta1/"):
vertex_version = "v1beta1"
requested_route = requested_route.replace("/v1beta1/", "/", 1)

base_requested_route = "{}/projects/{}/locations/{}".format(
vertex_version, vertex_project, vertex_location
)
Expand Down
16 changes: 16 additions & 0 deletions litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,7 @@ async def _base_vertex_proxy_route(
from litellm.llms.vertex_ai.common_utils import (
construct_target_url,
get_vertex_location_from_url,
get_vertex_model_id_from_url,
get_vertex_project_id_from_url,
)

Expand Down Expand Up @@ -1576,6 +1577,21 @@ async def _base_vertex_proxy_route(
vertex_location=vertex_location,
)

if vertex_project is None or vertex_location is None:
# Check if model is in router config
model_id = get_vertex_model_id_from_url(endpoint)
if model_id:
from litellm.proxy.proxy_server import llm_router

if llm_router:
deployments = llm_router.get_model_list(model_name=model_id)
if deployments:
deployment = deployments[0]
litellm_params = deployment.get("litellm_params", {})
if litellm_params.get("use_in_pass_through"):
vertex_project = litellm_params.get("vertex_project")
vertex_location = litellm_params.get("vertex_location")

vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(
project_id=vertex_project,
location=vertex_location,
Expand Down
57 changes: 50 additions & 7 deletions tests/test_litellm/llms/vertex_ai/test_vertex_ai_common_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import sys
from typing import Any, Dict
from unittest.mock import MagicMock, call, patch
from unittest.mock import patch

import pytest

Expand All @@ -11,7 +10,6 @@
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path

import litellm
from litellm.llms.vertex_ai.common_utils import (
_get_vertex_url,
convert_anyof_null_to_nullable,
Expand Down Expand Up @@ -798,9 +796,54 @@ def test_fix_enum_empty_strings():
assert "mobile" in enum_values
assert "tablet" in enum_values

# 3. Other properties preserved
assert input_schema["properties"]["user_agent_type"]["type"] == "string"
assert input_schema["properties"]["user_agent_type"]["description"] == "Device type for user agent"

def test_get_vertex_model_id_from_url():
"""Test get_vertex_model_id_from_url with various URLs"""
from litellm.llms.vertex_ai.common_utils import get_vertex_model_id_from_url

# Test with valid URL
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
model_id = get_vertex_model_id_from_url(url)
assert model_id == "gemini-pro"

# Test with invalid URL
url = "https://invalid-url.com"
model_id = get_vertex_model_id_from_url(url)
assert model_id is None


def test_construct_target_url_with_version_prefix():
"""Test construct_target_url with version prefixes"""
from litellm.llms.vertex_ai.common_utils import construct_target_url

# Test with /v1/ prefix
url = "/v1/publishers/google/models/gemini-pro:streamGenerateContent"
vertex_project = "test-project"
vertex_location = "us-central1"
base_url = "https://us-central1-aiplatform.googleapis.com"

target_url = construct_target_url(
base_url=base_url,
requested_route=url,
vertex_project=vertex_project,
vertex_location=vertex_location,
)

expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
assert str(target_url) == expected_url

# Test with /v1beta1/ prefix
url = "/v1beta1/publishers/google/models/gemini-pro:streamGenerateContent"

target_url = construct_target_url(
base_url=base_url,
requested_route=url,
vertex_project=vertex_project,
vertex_location=vertex_location,
)

expected_url = "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
assert str(target_url) == expected_url


def test_fix_enum_types():
Expand Down Expand Up @@ -862,7 +905,7 @@ def test_fix_enum_types():
"truncateMode": {
"enum": ["auto", "none", "start", "end"], # Kept - string type
"type": "string",
"description": "How to truncate content"
"description": "How to truncate content",
},
"maxLength": { # enum removed
"type": "integer",
Expand Down
Loading