From 38b56ab27d0209b767a82f76e9760d862043179c Mon Sep 17 00:00:00 2001 From: KUNJ1311 Date: Wed, 19 Nov 2025 11:27:26 +0530 Subject: [PATCH 1/4] feat: filter /v1/models endpoint by virtual key provider_configs --- docs/features/governance/virtual-keys.mdx | 30 +- tests/governance/README.md | 14 + tests/governance/test_virtual_keys_access.py | 479 ++++++++++++++++++ transports/bifrost-http/handlers/inference.go | 82 ++- transports/bifrost-http/server/server.go | 10 +- 5 files changed, 605 insertions(+), 10 deletions(-) create mode 100644 tests/governance/test_virtual_keys_access.py diff --git a/docs/features/governance/virtual-keys.mdx b/docs/features/governance/virtual-keys.mdx index bee781f83..135ef24e0 100644 --- a/docs/features/governance/virtual-keys.mdx +++ b/docs/features/governance/virtual-keys.mdx @@ -16,7 +16,7 @@ Virtual Keys are the primary governance entity in Bifrost. Users and application You can also use `Authorization` and `x-api-key` headers to pass direct keys to the provider. Read more about it in [Direct Key Bypass](../keys-management#direct-key-bypass). **Key Features:** -- **Access Control** - Model and provider filtering +- **Access Control** - Model and provider filtering (applies to `/v1/models` endpoint and inference requests) - **Cost Management** - Independent budgets (checked along with team/customer budgets if attached) - **Rate Limiting** - Token and request-based throttling (VK-level only) - **Key Restrictions** - Limit VK to specific provider API keys (if configured, VK can only use those keys) @@ -499,6 +499,34 @@ curl -X POST http://localhost:8080/v1/chat/completions \ }' ``` +The virtual key can also filter which models are returned by the `/v1/models` endpoint: + +```bash +# List models allowed for a specific virtual key +curl -X GET http://localhost:8080/v1/models \ + -H "x-bf-vk: vk-engineering-main" +``` + +**How it works:** +- **With `x-bf-vk` header**: Returns only models specified in the virtual key's `provider_configs` +- **Without `x-bf-vk` header**: Returns all available models +- **Empty `provider_configs`**: Returns all models (no filtering) +- **Empty `allowed_models` for a provider**: Returns all models from that provider + +**Example:** +```json +{ + "provider_configs": [ + { + "provider": "openai", + "allowed_models": ["gpt-4o", "gpt-4o-mini"], + "weight": 1.0 + } + ] +} +// GET /v1/models with x-bf-vk header returns only openai/gpt-4o and openai/gpt-4o-mini +``` + By default governance is optional, meaning that if the `x-bf-vk` header is not present, the request will be allowed but without any governance checks/routing. But you can make it mandatory by enforcing the governance header. diff --git a/tests/governance/README.md b/tests/governance/README.md index 1cbc0f988..d619d57aa 100644 --- a/tests/governance/README.md +++ b/tests/governance/README.md @@ -48,6 +48,14 @@ This test suite provides extensive coverage of the Bifrost governance system inc - Reset functionality - Debug and health endpoints +5. **`test_virtual_keys_access.py`** - Virtual Key Model Filtering + - `/v1/models` endpoint filtering based on virtual key `provider_configs` + - Specific models via `allowed_models` list + - Empty `allowed_models` returns all models from that provider + - Empty `provider_configs` returns all models + - Multiple provider configurations + - Invalid and inactive virtual key handling + ### Configuration Files - **`conftest.py`** - Test fixtures, utilities, and configuration @@ -162,6 +170,7 @@ The test suite uses pytest markers for categorization: - `@pytest.mark.concurrency` - Concurrency tests - `@pytest.mark.slow` - Slow running tests (>5s) - `@pytest.mark.smoke` - Quick smoke tests +- `@pytest.mark.access_control` - Access control and filtering tests ## API Endpoints Tested @@ -195,6 +204,11 @@ The test suite uses pytest markers for categorization: ### Integration Endpoints - `POST /v1/chat/completions` - Chat completion with governance headers +- `GET /v1/models` - List available models with optional virtual key filtering + +#### Model Filtering (`/v1/models`) +- **With `x-bf-vk` header**: Returns models specified in the virtual key's `provider_configs` +- **Without `x-bf-vk` header**: Returns all available models ## Test Data and Schemas diff --git a/tests/governance/test_virtual_keys_access.py b/tests/governance/test_virtual_keys_access.py new file mode 100644 index 000000000..e3d0d8489 --- /dev/null +++ b/tests/governance/test_virtual_keys_access.py @@ -0,0 +1,479 @@ +""" +Tests for /v1/models endpoint filtering based on virtual key provider_configs. + +Tests cover: +- Filtering by specific allowed_models +- Empty allowed_models (all models from provider) +- Empty provider_configs (all models) +- Non-existent providers +- Multiple providers +- Invalid and inactive virtual keys +""" + +import pytest +import requests +from conftest import BIFROST_BASE_URL, assert_response_success, generate_unique_name + + +class TestModelsFiltering: + """Test /v1/models endpoint filtering based on virtual key provider configs""" + + @pytest.mark.virtual_keys + @pytest.mark.integration + @pytest.mark.smoke + def test_list_models_with_vk_filters_by_provider_config( + self, governance_client, cleanup_tracker + ): + """Test that /v1/models filters models when VK has provider_configs with specific allowed_models""" + response = requests.get(f"{BIFROST_BASE_URL}/v1/models") + assert_response_success(response, 200) + all_models = response.json().get("data", []) + + providers_map = {} + for model in all_models: + if "/" in model["id"]: + provider, model_name = model["id"].split("/", 1) + if provider not in providers_map: + providers_map[provider] = [] + providers_map[provider].append(model_name) + + if not providers_map or not any( + len(models) >= 2 for models in providers_map.values() + ): + pytest.skip("Need at least one provider with 2+ models") + + test_provider = next( + p for p, models in providers_map.items() if len(models) >= 2 + ) + selected_models = providers_map[test_provider][:2] + + vk_data = { + "name": generate_unique_name("Test VK Specific Models"), + "provider_configs": [ + { + "provider": test_provider, + "allowed_models": selected_models, + "weight": 1.0, + } + ], + } + response = governance_client.create_virtual_key(vk_data) + assert_response_success(response, 200) + vk = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk["id"]) + + assert vk["provider_configs"] is not None + assert len(vk["provider_configs"]) == 1 + assert vk["provider_configs"][0]["provider"] == test_provider + assert set(vk["provider_configs"][0]["allowed_models"]) == set(selected_models) + + response = requests.get( + f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]} + ) + assert_response_success(response, 200) + + filtered_models = response.json().get("data", []) + filtered_ids = [m["id"] for m in filtered_models] + expected_ids = [f"{test_provider}/{m}" for m in selected_models] + + assert len(filtered_models) == len(selected_models), ( + f"Expected {len(selected_models)} models, got {len(filtered_models)}" + ) + assert set(filtered_ids) == set(expected_ids), ( + f"Model IDs mismatch. Expected: {expected_ids}, Got: {filtered_ids}" + ) + + provider_models = [ + m["id"] for m in all_models if m["id"].startswith(f"{test_provider}/") + ] + excluded = [m for m in provider_models if m not in expected_ids] + for excluded_id in excluded: + assert excluded_id not in filtered_ids + + for model in filtered_models: + assert "id" in model + assert "created" in model + + @pytest.mark.virtual_keys + @pytest.mark.integration + def test_list_models_with_empty_allowed_models_returns_all_provider_models( + self, governance_client, cleanup_tracker + ): + """Test that empty allowed_models returns all models from that provider""" + response = requests.get(f"{BIFROST_BASE_URL}/v1/models") + assert_response_success(response, 200) + all_models = response.json().get("data", []) + + providers = list( + set([m["id"].split("/")[0] for m in all_models if "/" in m["id"]]) + ) + if not providers: + pytest.skip("No providers found") + + test_provider = providers[0] + provider_model_ids = [ + m["id"] for m in all_models if m["id"].startswith(f"{test_provider}/") + ] + + vk_data = { + "name": generate_unique_name("Test VK All Provider Models"), + "provider_configs": [ + { + "provider": test_provider, + "allowed_models": [], + "weight": 1.0, + } + ], + } + response = governance_client.create_virtual_key(vk_data) + assert_response_success(response, 200) + vk = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk["id"]) + + assert vk["provider_configs"] is not None + assert len(vk["provider_configs"]) == 1 + assert vk["provider_configs"][0]["provider"] == test_provider + assert vk["provider_configs"][0]["allowed_models"] == [] + + response = requests.get( + f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]} + ) + assert_response_success(response, 200) + + filtered_models = response.json().get("data", []) + filtered_ids = [m["id"] for m in filtered_models] + + assert len(filtered_models) == len(provider_model_ids), ( + f"Expected {len(provider_model_ids)} models, got {len(filtered_models)}" + ) + assert set(filtered_ids) == set(provider_model_ids) + + other_provider_models = [ + m["id"] for m in all_models if not m["id"].startswith(f"{test_provider}/") + ] + for other_model_id in other_provider_models: + assert other_model_id not in filtered_ids + + @pytest.mark.virtual_keys + @pytest.mark.integration + def test_list_models_without_provider_configs_returns_all( + self, governance_client, cleanup_tracker + ): + """Test that VK with no provider_configs returns all models""" + response = requests.get(f"{BIFROST_BASE_URL}/v1/models") + assert_response_success(response, 200) + all_models = response.json().get("data", []) + all_model_ids = [m["id"] for m in all_models] + + if len(all_models) == 0: + pytest.skip("No models available") + + vk_data = { + "name": generate_unique_name("Test VK Unrestricted"), + "provider_configs": [], + } + response = governance_client.create_virtual_key(vk_data) + assert_response_success(response, 200) + vk = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk["id"]) + + assert vk["provider_configs"] is not None + assert len(vk["provider_configs"]) == 0 + + response = requests.get( + f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]} + ) + assert_response_success(response, 200) + + filtered_models = response.json().get("data", []) + filtered_ids = [m["id"] for m in filtered_models] + + assert len(filtered_models) == len(all_models), ( + f"Expected {len(all_models)} models, got {len(filtered_models)}" + ) + assert set(filtered_ids) == set(all_model_ids) + + @pytest.mark.virtual_keys + @pytest.mark.integration + def test_list_models_with_nonexistent_provider_returns_empty( + self, governance_client, cleanup_tracker + ): + """Test that VK with non-existent provider returns empty list""" + vk_data = { + "name": generate_unique_name("Test VK Nonexistent Provider"), + "provider_configs": [ + { + "provider": "nonexistent-provider-xyz-123", + "allowed_models": [], + "weight": 1.0, + } + ], + } + response = governance_client.create_virtual_key(vk_data) + assert_response_success(response, 200) + vk = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk["id"]) + + assert vk["provider_configs"] is not None + assert len(vk["provider_configs"]) == 1 + assert vk["provider_configs"][0]["provider"] == "nonexistent-provider-xyz-123" + + response = requests.get( + f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]} + ) + assert_response_success(response, 200) + + filtered_models = response.json().get("data", []) + assert len(filtered_models) == 0 + + @pytest.mark.virtual_keys + @pytest.mark.integration + def test_list_models_with_multiple_providers( + self, governance_client, cleanup_tracker + ): + """Test that VK with multiple providers returns models from all configured providers""" + response = requests.get(f"{BIFROST_BASE_URL}/v1/models") + assert_response_success(response, 200) + all_models = response.json().get("data", []) + + available_providers = list( + set([m["id"].split("/")[0] for m in all_models if "/" in m["id"]]) + ) + + if len(available_providers) < 2: + pytest.skip(f"Need at least 2 providers, have: {available_providers}") + + provider1 = available_providers[0] + provider2 = available_providers[1] + + provider1_models = [ + m["id"].split("/")[1] + for m in all_models + if m["id"].startswith(f"{provider1}/") + ] + provider2_models = [ + m["id"].split("/")[1] + for m in all_models + if m["id"].startswith(f"{provider2}/") + ] + + if len(provider1_models) == 0 or len(provider2_models) == 0: + pytest.skip("Each provider needs at least one model") + + selected_model1 = provider1_models[0] + selected_model2 = provider2_models[0] + + vk_data = { + "name": generate_unique_name("Test VK Multi Provider"), + "provider_configs": [ + { + "provider": provider1, + "allowed_models": [selected_model1], + "weight": 1.0, + }, + { + "provider": provider2, + "allowed_models": [selected_model2], + "weight": 1.0, + }, + ], + } + response = governance_client.create_virtual_key(vk_data) + assert_response_success(response, 200) + vk = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk["id"]) + + assert vk["provider_configs"] is not None + assert len(vk["provider_configs"]) == 2 + vk_providers = [pc["provider"] for pc in vk["provider_configs"]] + assert provider1 in vk_providers + assert provider2 in vk_providers + + response = requests.get( + f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]} + ) + assert_response_success(response, 200) + + filtered_models = response.json().get("data", []) + filtered_ids = [m["id"] for m in filtered_models] + expected_ids = [ + f"{provider1}/{selected_model1}", + f"{provider2}/{selected_model2}", + ] + + assert len(filtered_models) == 2, ( + f"Expected 2 models, got {len(filtered_models)}" + ) + assert set(filtered_ids) == set(expected_ids) + + provider1_found = any( + m["id"].startswith(f"{provider1}/") for m in filtered_models + ) + provider2_found = any( + m["id"].startswith(f"{provider2}/") for m in filtered_models + ) + assert provider1_found + assert provider2_found + + @pytest.mark.virtual_keys + @pytest.mark.integration + def test_list_models_without_vk_header_returns_all(self, governance_client): + """Test that requesting models without VK header returns all models""" + response = requests.get(f"{BIFROST_BASE_URL}/v1/models") + assert_response_success(response, 200) + + all_models = response.json().get("data", []) + assert len(all_models) > 0 + + @pytest.mark.virtual_keys + @pytest.mark.integration + def test_list_models_with_invalid_vk_header(self, governance_client): + """Test that invalid VK header returns error or all models without filtering""" + response_without_vk = requests.get(f"{BIFROST_BASE_URL}/v1/models") + assert_response_success(response_without_vk, 200) + all_models = response_without_vk.json().get("data", []) + all_model_ids = set([m["id"] for m in all_models]) + + response = requests.get( + f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": "invalid-vk-value-xyz"} + ) + + assert response.status_code in [200, 401, 403] + + if response.status_code == 200: + models = response.json().get("data", []) + model_ids = set([m["id"] for m in models]) + assert isinstance(models, list) + assert model_ids == all_model_ids, "Invalid VK should not filter models" + + @pytest.mark.virtual_keys + @pytest.mark.integration + def test_list_models_with_inactive_vk(self, governance_client, cleanup_tracker): + """Test that inactive VK is rejected or returns all models without filtering""" + response = requests.get(f"{BIFROST_BASE_URL}/v1/models") + assert_response_success(response, 200) + all_models = response.json().get("data", []) + all_model_ids = set([m["id"] for m in all_models]) + + if len(all_models) == 0: + pytest.skip("No models available") + + vk_data = { + "name": generate_unique_name("Test VK Inactive"), + "is_active": False, + "provider_configs": [ + { + "provider": "test-provider", + "allowed_models": ["test-model"], + "weight": 1.0, + } + ], + } + response = governance_client.create_virtual_key(vk_data) + assert_response_success(response, 200) + vk = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk["id"]) + + if vk["is_active"] is True: + update_response = governance_client.update_virtual_key( + vk["id"], {"is_active": False} + ) + if update_response.status_code == 200: + vk = update_response.json()["virtual_key"] + else: + pytest.skip("Cannot deactivate VK") + + assert vk["is_active"] is False + + response = requests.get( + f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]} + ) + + assert response.status_code in [200, 401, 403] + + if response.status_code == 200: + filtered_models = response.json().get("data", []) + filtered_ids = set([m["id"] for m in filtered_models]) + + assert filtered_ids == all_model_ids, ( + f"Inactive VK should not filter models. Expected {len(all_models)} models, " + f"got {len(filtered_models)}" + ) + + @pytest.mark.virtual_keys + @pytest.mark.integration + def test_list_models_with_mixed_provider_configs( + self, governance_client, cleanup_tracker + ): + """Test VK with one provider having specific models and another with all models""" + response = requests.get(f"{BIFROST_BASE_URL}/v1/models") + assert_response_success(response, 200) + all_models = response.json().get("data", []) + + providers = list( + set([m["id"].split("/")[0] for m in all_models if "/" in m["id"]]) + ) + + if len(providers) < 2: + pytest.skip(f"Need at least 2 providers, have: {providers}") + + provider1 = providers[0] + provider2 = providers[1] + + provider1_models = [ + m["id"].split("/")[1] + for m in all_models + if m["id"].startswith(f"{provider1}/") + ] + provider2_all_models = [ + m["id"] for m in all_models if m["id"].startswith(f"{provider2}/") + ] + + if len(provider1_models) < 2 or len(provider2_all_models) == 0: + pytest.skip("Need sufficient models") + + selected_provider1_model = provider1_models[0] + + vk_data = { + "name": generate_unique_name("Test VK Mixed Configs"), + "provider_configs": [ + { + "provider": provider1, + "allowed_models": [selected_provider1_model], + "weight": 1.0, + }, + { + "provider": provider2, + "allowed_models": [], + "weight": 1.0, + }, + ], + } + response = governance_client.create_virtual_key(vk_data) + assert_response_success(response, 200) + vk = response.json()["virtual_key"] + cleanup_tracker.add_virtual_key(vk["id"]) + + response = requests.get( + f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]} + ) + assert_response_success(response, 200) + + filtered_models = response.json().get("data", []) + filtered_ids = [m["id"] for m in filtered_models] + + expected_count = 1 + len(provider2_all_models) + assert len(filtered_models) == expected_count, ( + f"Expected {expected_count} models, got {len(filtered_models)}" + ) + + provider1_filtered = [ + m["id"] for m in filtered_models if m["id"].startswith(f"{provider1}/") + ] + assert len(provider1_filtered) == 1 + assert f"{provider1}/{selected_provider1_model}" in provider1_filtered + + provider2_filtered = [ + m["id"] for m in filtered_models if m["id"].startswith(f"{provider2}/") + ] + assert set(provider2_filtered) == set(provider2_all_models) diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index c029f85d6..4b64f32e4 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -18,23 +18,26 @@ import ( "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) // CompletionHandler manages HTTP requests for completion operations type CompletionHandler struct { - client *bifrost.Bifrost - handlerStore lib.HandlerStore - config *lib.Config + client *bifrost.Bifrost + handlerStore lib.HandlerStore + config *lib.Config + governanceStore *governance.GovernanceStore } // NewInferenceHandler creates a new completion handler instance -func NewInferenceHandler(client *bifrost.Bifrost, config *lib.Config) *CompletionHandler { +func NewInferenceHandler(client *bifrost.Bifrost, config *lib.Config, governanceStore *governance.GovernanceStore) *CompletionHandler { return &CompletionHandler{ - client: client, - handlerStore: config, - config: config, + client: client, + handlerStore: config, + config: config, + governanceStore: governanceStore, } } @@ -255,6 +258,66 @@ func extractExtraParams(data []byte, knownFields map[string]bool) (map[string]in return extraParams, nil } +// FilterModelsByVirtualKey filters models based on virtual key governance rules. +func filterModelsByVirtualKey(governanceStore *governance.GovernanceStore, bifrostCtx *context.Context, models []schemas.Model) []schemas.Model { + if governanceStore == nil || bifrostCtx == nil { + return models + } + + vkValue := (*bifrostCtx).Value(schemas.BifrostContextKeyVirtualKey) + if vkValue == nil { + return models + } + + vkString, ok := vkValue.(string) + if !ok || vkString == "" { + return models + } + + vk, exists := governanceStore.GetVirtualKey(vkString) + if !exists || !vk.IsActive { + return models + } + + filteredModels := []schemas.Model{} + + for _, model := range models { + modelProvider, modelName := schemas.ParseModelString(model.ID, "") + + providerAllowed := len(vk.ProviderConfigs) == 0 + var allowedModelsForProvider []string + + for _, pc := range vk.ProviderConfigs { + if pc.Provider == string(modelProvider) { + providerAllowed = true + allowedModelsForProvider = pc.AllowedModels + break + } + } + + if !providerAllowed { + continue + } + + if len(allowedModelsForProvider) > 0 { + modelAllowed := false + for _, allowedModel := range allowedModelsForProvider { + if allowedModel == modelName || allowedModel == model.ID { + modelAllowed = true + break + } + } + if !modelAllowed { + continue + } + } + + filteredModels = append(filteredModels, model) + } + + return filteredModels +} + const ( // Maximum file size (25MB) MaxFileSize = 25 * 1024 * 1024 @@ -339,6 +402,11 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { return } + // Filter models based on virtual key + if h.config.ClientConfig.EnableGovernance { + resp.Data = filterModelsByVirtualKey(h.governanceStore, bifrostCtx, resp.Data) + } + // Add pricing data to the response if len(resp.Data) > 0 && h.config.PricingManager != nil { for i, modelEntry := range resp.Data { diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 15d478f7a..d52e0ff75 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -842,7 +842,13 @@ func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, name string) error // RegisterInferenceRoutes initializes the routes for the inference handler func (s *BifrostHTTPServer) RegisterInferenceRoutes(ctx context.Context, middlewares ...lib.BifrostHTTPMiddleware) error { - inferenceHandler := handlers.NewInferenceHandler(s.Client, s.Config) + var governanceStore *governance.GovernanceStore + governancePlugin, _ := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + if governancePlugin != nil { + governanceStore = governancePlugin.GetGovernanceStore() + } + + inferenceHandler := handlers.NewInferenceHandler(s.Client, s.Config, governanceStore) integrationHandler := handlers.NewIntegrationHandler(s.Client, s.Config) integrationHandler.RegisterRoutes(s.Router, middlewares...) inferenceHandler.RegisterRoutes(s.Router, middlewares...) @@ -1077,7 +1083,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { s.Server = &fasthttp.Server{ Handler: handlers.CorsMiddleware(s.Config)(handlers.TransportInterceptorMiddleware(s.Config)(s.Router.Handler)), MaxRequestBodySize: s.Config.ClientConfig.MaxRequestBodySizeMB * 1024 * 1024, - ReadBufferSize: 1024 * 16, // 16kb + ReadBufferSize: 1024 * 16, // 16kb } return nil } From d4e870d1965351abe7d2c7d36f358e3365dd14d0 Mon Sep 17 00:00:00 2001 From: KUNJ1311 Date: Wed, 19 Nov 2025 11:41:42 +0530 Subject: [PATCH 2/4] fix: remove unused parameters and variables in test_virtual_keys_access - Remove unused governance_client fixture from tests that don't need it - Remove unused filtered_ids variable in test_list_models_with_mixed_provider_configs --- tests/governance/test_virtual_keys_access.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/governance/test_virtual_keys_access.py b/tests/governance/test_virtual_keys_access.py index e3d0d8489..981c5926a 100644 --- a/tests/governance/test_virtual_keys_access.py +++ b/tests/governance/test_virtual_keys_access.py @@ -317,7 +317,7 @@ def test_list_models_with_multiple_providers( @pytest.mark.virtual_keys @pytest.mark.integration - def test_list_models_without_vk_header_returns_all(self, governance_client): + def test_list_models_without_vk_header_returns_all(self): """Test that requesting models without VK header returns all models""" response = requests.get(f"{BIFROST_BASE_URL}/v1/models") assert_response_success(response, 200) @@ -327,7 +327,7 @@ def test_list_models_without_vk_header_returns_all(self, governance_client): @pytest.mark.virtual_keys @pytest.mark.integration - def test_list_models_with_invalid_vk_header(self, governance_client): + def test_list_models_with_invalid_vk_header(self): """Test that invalid VK header returns error or all models without filtering""" response_without_vk = requests.get(f"{BIFROST_BASE_URL}/v1/models") assert_response_success(response_without_vk, 200) @@ -460,7 +460,6 @@ def test_list_models_with_mixed_provider_configs( assert_response_success(response, 200) filtered_models = response.json().get("data", []) - filtered_ids = [m["id"] for m in filtered_models] expected_count = 1 + len(provider2_all_models) assert len(filtered_models) == expected_count, ( From 76beab828322412554e7b7dde316b100360f3a80 Mon Sep 17 00:00:00 2001 From: KUNJ1311 Date: Thu, 20 Nov 2025 19:08:33 +0530 Subject: [PATCH 3/4] refactor: optimize /v1/models filtering to fetch only allowed providers --- core/bifrost.go | 49 ++++--- transports/bifrost-http/handlers/inference.go | 128 +++++++++++++----- 2 files changed, 124 insertions(+), 53 deletions(-) diff --git a/core/bifrost.go b/core/bifrost.go index 16b71dde2..d30d97db3 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -297,24 +297,12 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr return response, nil } -// ListAllModels lists all models from all configured providers. -// It accumulates responses from all providers with a limit of 1000 per provider to get all results. -func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +// ListModelsFromProviders fetches models concurrently from specified providers +func (bifrost *Bifrost) ListModelsFromProviders(ctx context.Context, providers []schemas.ModelProvider, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if request == nil { request = &schemas.BifrostListModelsRequest{} } - providerKeys, err := bifrost.GetConfiguredProviders() - if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: &schemas.ErrorField{ - Message: err.Error(), - Error: err, - }, - } - } - startTime := time.Now() // Result structure for collecting provider responses @@ -323,11 +311,11 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr err *schemas.BifrostError } - results := make(chan providerResult, len(providerKeys)) + results := make(chan providerResult, len(providers)) var wg sync.WaitGroup - // Launch concurrent requests for all providers - for _, providerKey := range providerKeys { + // Launch concurrent requests for specified providers + for _, providerKey := range providers { if strings.TrimSpace(string(providerKey)) == "" { continue } @@ -427,6 +415,33 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr }, } + return response, nil +} + +// ListAllModels lists all models from all configured providers. +// It accumulates responses from all providers with a limit of 1000 per provider to get all results. +func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + if request == nil { + request = &schemas.BifrostListModelsRequest{} + } + + providerKeys, err := bifrost.GetConfiguredProviders() + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: err.Error(), + Error: err, + }, + } + } + + // Use the helper function to fetch from all configured providers + response, bifrostErr := bifrost.ListModelsFromProviders(ctx, providerKeys, request) + if bifrostErr != nil { + return nil, bifrostErr + } + response = response.ApplyPagination(request.PageSize, request.PageToken) return response, nil diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index 4b64f32e4..79388acab 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -258,61 +258,67 @@ func extractExtraParams(data []byte, knownFields map[string]bool) (map[string]in return extraParams, nil } -// FilterModelsByVirtualKey filters models based on virtual key governance rules. -func filterModelsByVirtualKey(governanceStore *governance.GovernanceStore, bifrostCtx *context.Context, models []schemas.Model) []schemas.Model { +// getVirtualKeyProviders returns allowed providers and models from virtual key +func getVirtualKeyProviders(governanceStore *governance.GovernanceStore, bifrostCtx *context.Context) map[schemas.ModelProvider][]string { if governanceStore == nil || bifrostCtx == nil { - return models + return nil } vkValue := (*bifrostCtx).Value(schemas.BifrostContextKeyVirtualKey) if vkValue == nil { - return models + return nil } vkString, ok := vkValue.(string) if !ok || vkString == "" { - return models + return nil } vk, exists := governanceStore.GetVirtualKey(vkString) if !exists || !vk.IsActive { + return nil + } + + // If no provider configs, all providers allowed + if len(vk.ProviderConfigs) == 0 { + return make(map[schemas.ModelProvider][]string) + } + + allowedProviders := make(map[schemas.ModelProvider][]string, len(vk.ProviderConfigs)) + for _, pc := range vk.ProviderConfigs { + provider := schemas.ModelProvider(pc.Provider) + allowedProviders[provider] = pc.AllowedModels + } + + return allowedProviders +} + +// filterModelsByAllowedModels filters models based on allowed models list +func filterModelsByAllowedModels(allowedProviders map[schemas.ModelProvider][]string, models []schemas.Model) []schemas.Model { + if allowedProviders == nil || len(allowedProviders) == 0 { return models } filteredModels := []schemas.Model{} - for _, model := range models { modelProvider, modelName := schemas.ParseModelString(model.ID, "") - providerAllowed := len(vk.ProviderConfigs) == 0 - var allowedModelsForProvider []string - - for _, pc := range vk.ProviderConfigs { - if pc.Provider == string(modelProvider) { - providerAllowed = true - allowedModelsForProvider = pc.AllowedModels - break - } + allowedModels, providerAllowed := allowedProviders[modelProvider] + if !providerAllowed { + continue } - if !providerAllowed { + if len(allowedModels) == 0 { + filteredModels = append(filteredModels, model) continue } - if len(allowedModelsForProvider) > 0 { - modelAllowed := false - for _, allowedModel := range allowedModelsForProvider { - if allowedModel == modelName || allowedModel == model.ID { - modelAllowed = true - break - } - } - if !modelAllowed { - continue + for _, allowedModel := range allowedModels { + if allowedModel == modelName || allowedModel == model.ID { + filteredModels = append(filteredModels, model) + break } } - - filteredModels = append(filteredModels, model) } return filteredModels @@ -361,6 +367,12 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { return } + // Check virtual key and get allowed providers + var allowedProviders map[schemas.ModelProvider][]string + if h.config.ClientConfig.EnableGovernance { + allowedProviders = getVirtualKeyProviders(h.governanceStore, bifrostCtx) + } + var resp *schemas.BifrostListModelsResponse var bifrostErr *schemas.BifrostError @@ -390,29 +402,73 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { bifrostListModelsReq.ExtraParams = extraParams } - // If provider is empty, list all models from all providers - if provider == "" { - resp, bifrostErr = h.client.ListAllModels(*bifrostCtx, bifrostListModelsReq) + // Determine which providers to query based on VK restrictions + var providersToFetch []schemas.ModelProvider + + if allowedProviders != nil && len(allowedProviders) > 0 { + // Virtual key has provider restrictions + if provider != "" { + // Check if specific provider is allowed by VK + if _, providerAllowed := allowedProviders[schemas.ModelProvider(provider)]; !providerAllowed { + SendError(ctx, fasthttp.StatusForbidden, "Provider not allowed by virtual key") + return + } + providersToFetch = []schemas.ModelProvider{schemas.ModelProvider(provider)} + } else { + // Use all VK-allowed providers + providersToFetch = make([]schemas.ModelProvider, 0, len(allowedProviders)) + for p := range allowedProviders { + providersToFetch = append(providersToFetch, p) + } + } } else { + // No VK restrictions - use provider query param or all configured providers + if provider != "" { + providersToFetch = []schemas.ModelProvider{schemas.ModelProvider(provider)} + } else { + configuredProviders, err := h.client.GetConfiguredProviders() + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get configured providers: %v", err)) + return + } + providersToFetch = configuredProviders + } + } + + // Fetch models from providers + if len(providersToFetch) == 1 { + // Ensure provider is set in request + bifrostListModelsReq.Provider = providersToFetch[0] resp, bifrostErr = h.client.ListModelsRequest(*bifrostCtx, bifrostListModelsReq) + } else { + resp, bifrostErr = h.client.ListModelsFromProviders(*bifrostCtx, providersToFetch, bifrostListModelsReq) } if bifrostErr != nil { + if bifrostErr.Error != nil && bifrostErr.Error.Message == "provider not found for list models request" { + SendJSON(ctx, &schemas.BifrostListModelsResponse{ + Data: []schemas.Model{}, + }) + return + } SendBifrostError(ctx, bifrostErr) return } - // Filter models based on virtual key - if h.config.ClientConfig.EnableGovernance { - resp.Data = filterModelsByVirtualKey(h.governanceStore, bifrostCtx, resp.Data) + // Apply model-level filtering based on VK's allowed_models + if allowedProviders != nil && len(allowedProviders) > 0 { + resp.Data = filterModelsByAllowedModels(allowedProviders, resp.Data) } - // Add pricing data to the response + // Attach pricing data if not already present if len(resp.Data) > 0 && h.config.PricingManager != nil { for i, modelEntry := range resp.Data { + if modelEntry.Pricing != nil { + continue + } provider, modelName := schemas.ParseModelString(modelEntry.ID, "") pricingEntry := h.config.PricingManager.GetPricingEntryForModel(modelName, provider) - if pricingEntry != nil && modelEntry.Pricing == nil { + if pricingEntry != nil { pricing := &schemas.Pricing{ Prompt: bifrost.Ptr(fmt.Sprintf("%f", pricingEntry.InputCostPerToken)), Completion: bifrost.Ptr(fmt.Sprintf("%f", pricingEntry.OutputCostPerToken)), From 4ede283d03de07f8aca313b01826d908365adf67 Mon Sep 17 00:00:00 2001 From: KUNJ1311 Date: Thu, 20 Nov 2025 19:21:16 +0530 Subject: [PATCH 4/4] feat: add ExtraParams to BifrostListModelsRequest and apply pagination in listModels handler --- core/bifrost.go | 5 +++-- transports/bifrost-http/handlers/inference.go | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/core/bifrost.go b/core/bifrost.go index d30d97db3..aebd230fe 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -329,8 +329,9 @@ func (bifrost *Bifrost) ListModelsFromProviders(ctx context.Context, providers [ // Create request for this provider with limit of 1000 providerRequest := &schemas.BifrostListModelsRequest{ - Provider: providerKey, - PageSize: schemas.DefaultPageSize, + Provider: providerKey, + PageSize: schemas.DefaultPageSize, + ExtraParams: request.ExtraParams, } iterations := 0 diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index 79388acab..f47b8e5a5 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -442,6 +442,9 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { resp, bifrostErr = h.client.ListModelsRequest(*bifrostCtx, bifrostListModelsReq) } else { resp, bifrostErr = h.client.ListModelsFromProviders(*bifrostCtx, providersToFetch, bifrostListModelsReq) + if bifrostErr == nil && resp != nil { + resp = resp.ApplyPagination(bifrostListModelsReq.PageSize, bifrostListModelsReq.PageToken) + } } if bifrostErr != nil {