diff --git a/core/bifrost.go b/core/bifrost.go
index 16b71dde2..aebd230fe 100644
--- a/core/bifrost.go
+++ b/core/bifrost.go
@@ -297,24 +297,12 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
return response, nil
}
-// ListAllModels lists all models from all configured providers.
-// It accumulates responses from all providers with a limit of 1000 per provider to get all results.
-func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
+// ListModelsFromProviders fetches models concurrently from specified providers
+func (bifrost *Bifrost) ListModelsFromProviders(ctx context.Context, providers []schemas.ModelProvider, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
if request == nil {
request = &schemas.BifrostListModelsRequest{}
}
- providerKeys, err := bifrost.GetConfiguredProviders()
- if err != nil {
- return nil, &schemas.BifrostError{
- IsBifrostError: false,
- Error: &schemas.ErrorField{
- Message: err.Error(),
- Error: err,
- },
- }
- }
-
startTime := time.Now()
// Result structure for collecting provider responses
@@ -323,11 +311,11 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr
err *schemas.BifrostError
}
- results := make(chan providerResult, len(providerKeys))
+ results := make(chan providerResult, len(providers))
var wg sync.WaitGroup
- // Launch concurrent requests for all providers
- for _, providerKey := range providerKeys {
+ // Launch concurrent requests for specified providers
+ for _, providerKey := range providers {
if strings.TrimSpace(string(providerKey)) == "" {
continue
}
@@ -341,8 +329,9 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr
// Create request for this provider with limit of 1000
providerRequest := &schemas.BifrostListModelsRequest{
- Provider: providerKey,
- PageSize: schemas.DefaultPageSize,
+ Provider: providerKey,
+ PageSize: schemas.DefaultPageSize,
+ ExtraParams: request.ExtraParams,
}
iterations := 0
@@ -427,6 +416,33 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr
},
}
+ return response, nil
+}
+
+// ListAllModels lists all models from all configured providers.
+// It accumulates responses from all providers with a limit of 1000 per provider to get all results.
+func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
+ if request == nil {
+ request = &schemas.BifrostListModelsRequest{}
+ }
+
+ providerKeys, err := bifrost.GetConfiguredProviders()
+ if err != nil {
+ return nil, &schemas.BifrostError{
+ IsBifrostError: false,
+ Error: &schemas.ErrorField{
+ Message: err.Error(),
+ Error: err,
+ },
+ }
+ }
+
+ // Use the helper function to fetch from all configured providers
+ response, bifrostErr := bifrost.ListModelsFromProviders(ctx, providerKeys, request)
+ if bifrostErr != nil {
+ return nil, bifrostErr
+ }
+
response = response.ApplyPagination(request.PageSize, request.PageToken)
return response, nil
diff --git a/docs/features/governance/virtual-keys.mdx b/docs/features/governance/virtual-keys.mdx
index bee781f83..135ef24e0 100644
--- a/docs/features/governance/virtual-keys.mdx
+++ b/docs/features/governance/virtual-keys.mdx
@@ -16,7 +16,7 @@ Virtual Keys are the primary governance entity in Bifrost. Users and application
You can also use `Authorization` and `x-api-key` headers to pass direct keys to the provider. Read more about it in [Direct Key Bypass](../keys-management#direct-key-bypass).
**Key Features:**
-- **Access Control** - Model and provider filtering
+- **Access Control** - Model and provider filtering (applies to `/v1/models` endpoint and inference requests)
- **Cost Management** - Independent budgets (checked along with team/customer budgets if attached)
- **Rate Limiting** - Token and request-based throttling (VK-level only)
- **Key Restrictions** - Limit VK to specific provider API keys (if configured, VK can only use those keys)
@@ -499,6 +499,34 @@ curl -X POST http://localhost:8080/v1/chat/completions \
}'
```
+The virtual key can also filter which models are returned by the `/v1/models` endpoint:
+
+```bash
+# List models allowed for a specific virtual key
+curl -X GET http://localhost:8080/v1/models \
+ -H "x-bf-vk: vk-engineering-main"
+```
+
+**How it works:**
+- **With `x-bf-vk` header**: Returns only models specified in the virtual key's `provider_configs`
+- **Without `x-bf-vk` header**: Returns all available models
+- **Empty `provider_configs`**: Returns all models (no filtering)
+- **Empty `allowed_models` for a provider**: Returns all models from that provider
+
+**Example:**
+```json
+{
+ "provider_configs": [
+ {
+ "provider": "openai",
+ "allowed_models": ["gpt-4o", "gpt-4o-mini"],
+ "weight": 1.0
+ }
+ ]
+}
+// GET /v1/models with x-bf-vk header returns only openai/gpt-4o and openai/gpt-4o-mini
+```
+
By default governance is optional, meaning that if the `x-bf-vk` header is not present, the request will be allowed but without any governance checks/routing. But you can make it mandatory by enforcing the governance header.
diff --git a/tests/governance/README.md b/tests/governance/README.md
index 1cbc0f988..d619d57aa 100644
--- a/tests/governance/README.md
+++ b/tests/governance/README.md
@@ -48,6 +48,14 @@ This test suite provides extensive coverage of the Bifrost governance system inc
- Reset functionality
- Debug and health endpoints
+5. **`test_virtual_keys_access.py`** - Virtual Key Model Filtering
+ - `/v1/models` endpoint filtering based on virtual key `provider_configs`
+ - Specific models via `allowed_models` list
+ - Empty `allowed_models` returns all models from that provider
+ - Empty `provider_configs` returns all models
+ - Multiple provider configurations
+ - Invalid and inactive virtual key handling
+
### Configuration Files
- **`conftest.py`** - Test fixtures, utilities, and configuration
@@ -162,6 +170,7 @@ The test suite uses pytest markers for categorization:
- `@pytest.mark.concurrency` - Concurrency tests
- `@pytest.mark.slow` - Slow running tests (>5s)
- `@pytest.mark.smoke` - Quick smoke tests
+- `@pytest.mark.access_control` - Access control and filtering tests
## API Endpoints Tested
@@ -195,6 +204,11 @@ The test suite uses pytest markers for categorization:
### Integration Endpoints
- `POST /v1/chat/completions` - Chat completion with governance headers
+- `GET /v1/models` - List available models with optional virtual key filtering
+
+#### Model Filtering (`/v1/models`)
+- **With `x-bf-vk` header**: Returns models specified in the virtual key's `provider_configs`
+- **Without `x-bf-vk` header**: Returns all available models
## Test Data and Schemas
diff --git a/tests/governance/test_virtual_keys_access.py b/tests/governance/test_virtual_keys_access.py
new file mode 100644
index 000000000..981c5926a
--- /dev/null
+++ b/tests/governance/test_virtual_keys_access.py
@@ -0,0 +1,478 @@
+"""
+Tests for /v1/models endpoint filtering based on virtual key provider_configs.
+
+Tests cover:
+- Filtering by specific allowed_models
+- Empty allowed_models (all models from provider)
+- Empty provider_configs (all models)
+- Non-existent providers
+- Multiple providers
+- Invalid and inactive virtual keys
+"""
+
+import pytest
+import requests
+from conftest import BIFROST_BASE_URL, assert_response_success, generate_unique_name
+
+
+class TestModelsFiltering:
+ """Test /v1/models endpoint filtering based on virtual key provider configs"""
+
+ @pytest.mark.virtual_keys
+ @pytest.mark.integration
+ @pytest.mark.smoke
+ def test_list_models_with_vk_filters_by_provider_config(
+ self, governance_client, cleanup_tracker
+ ):
+ """Test that /v1/models filters models when VK has provider_configs with specific allowed_models"""
+ response = requests.get(f"{BIFROST_BASE_URL}/v1/models")
+ assert_response_success(response, 200)
+ all_models = response.json().get("data", [])
+
+ providers_map = {}
+ for model in all_models:
+ if "/" in model["id"]:
+ provider, model_name = model["id"].split("/", 1)
+ if provider not in providers_map:
+ providers_map[provider] = []
+ providers_map[provider].append(model_name)
+
+ if not providers_map or not any(
+ len(models) >= 2 for models in providers_map.values()
+ ):
+ pytest.skip("Need at least one provider with 2+ models")
+
+ test_provider = next(
+ p for p, models in providers_map.items() if len(models) >= 2
+ )
+ selected_models = providers_map[test_provider][:2]
+
+ vk_data = {
+ "name": generate_unique_name("Test VK Specific Models"),
+ "provider_configs": [
+ {
+ "provider": test_provider,
+ "allowed_models": selected_models,
+ "weight": 1.0,
+ }
+ ],
+ }
+ response = governance_client.create_virtual_key(vk_data)
+ assert_response_success(response, 200)
+ vk = response.json()["virtual_key"]
+ cleanup_tracker.add_virtual_key(vk["id"])
+
+ assert vk["provider_configs"] is not None
+ assert len(vk["provider_configs"]) == 1
+ assert vk["provider_configs"][0]["provider"] == test_provider
+ assert set(vk["provider_configs"][0]["allowed_models"]) == set(selected_models)
+
+ response = requests.get(
+ f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]}
+ )
+ assert_response_success(response, 200)
+
+ filtered_models = response.json().get("data", [])
+ filtered_ids = [m["id"] for m in filtered_models]
+ expected_ids = [f"{test_provider}/{m}" for m in selected_models]
+
+ assert len(filtered_models) == len(selected_models), (
+ f"Expected {len(selected_models)} models, got {len(filtered_models)}"
+ )
+ assert set(filtered_ids) == set(expected_ids), (
+ f"Model IDs mismatch. Expected: {expected_ids}, Got: {filtered_ids}"
+ )
+
+ provider_models = [
+ m["id"] for m in all_models if m["id"].startswith(f"{test_provider}/")
+ ]
+ excluded = [m for m in provider_models if m not in expected_ids]
+ for excluded_id in excluded:
+ assert excluded_id not in filtered_ids
+
+ for model in filtered_models:
+ assert "id" in model
+ assert "created" in model
+
+ @pytest.mark.virtual_keys
+ @pytest.mark.integration
+ def test_list_models_with_empty_allowed_models_returns_all_provider_models(
+ self, governance_client, cleanup_tracker
+ ):
+ """Test that empty allowed_models returns all models from that provider"""
+ response = requests.get(f"{BIFROST_BASE_URL}/v1/models")
+ assert_response_success(response, 200)
+ all_models = response.json().get("data", [])
+
+ providers = list(
+ set([m["id"].split("/")[0] for m in all_models if "/" in m["id"]])
+ )
+ if not providers:
+ pytest.skip("No providers found")
+
+ test_provider = providers[0]
+ provider_model_ids = [
+ m["id"] for m in all_models if m["id"].startswith(f"{test_provider}/")
+ ]
+
+ vk_data = {
+ "name": generate_unique_name("Test VK All Provider Models"),
+ "provider_configs": [
+ {
+ "provider": test_provider,
+ "allowed_models": [],
+ "weight": 1.0,
+ }
+ ],
+ }
+ response = governance_client.create_virtual_key(vk_data)
+ assert_response_success(response, 200)
+ vk = response.json()["virtual_key"]
+ cleanup_tracker.add_virtual_key(vk["id"])
+
+ assert vk["provider_configs"] is not None
+ assert len(vk["provider_configs"]) == 1
+ assert vk["provider_configs"][0]["provider"] == test_provider
+ assert vk["provider_configs"][0]["allowed_models"] == []
+
+ response = requests.get(
+ f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]}
+ )
+ assert_response_success(response, 200)
+
+ filtered_models = response.json().get("data", [])
+ filtered_ids = [m["id"] for m in filtered_models]
+
+ assert len(filtered_models) == len(provider_model_ids), (
+ f"Expected {len(provider_model_ids)} models, got {len(filtered_models)}"
+ )
+ assert set(filtered_ids) == set(provider_model_ids)
+
+ other_provider_models = [
+ m["id"] for m in all_models if not m["id"].startswith(f"{test_provider}/")
+ ]
+ for other_model_id in other_provider_models:
+ assert other_model_id not in filtered_ids
+
+ @pytest.mark.virtual_keys
+ @pytest.mark.integration
+ def test_list_models_without_provider_configs_returns_all(
+ self, governance_client, cleanup_tracker
+ ):
+ """Test that VK with no provider_configs returns all models"""
+ response = requests.get(f"{BIFROST_BASE_URL}/v1/models")
+ assert_response_success(response, 200)
+ all_models = response.json().get("data", [])
+ all_model_ids = [m["id"] for m in all_models]
+
+ if len(all_models) == 0:
+ pytest.skip("No models available")
+
+ vk_data = {
+ "name": generate_unique_name("Test VK Unrestricted"),
+ "provider_configs": [],
+ }
+ response = governance_client.create_virtual_key(vk_data)
+ assert_response_success(response, 200)
+ vk = response.json()["virtual_key"]
+ cleanup_tracker.add_virtual_key(vk["id"])
+
+ assert vk["provider_configs"] is not None
+ assert len(vk["provider_configs"]) == 0
+
+ response = requests.get(
+ f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]}
+ )
+ assert_response_success(response, 200)
+
+ filtered_models = response.json().get("data", [])
+ filtered_ids = [m["id"] for m in filtered_models]
+
+ assert len(filtered_models) == len(all_models), (
+ f"Expected {len(all_models)} models, got {len(filtered_models)}"
+ )
+ assert set(filtered_ids) == set(all_model_ids)
+
+ @pytest.mark.virtual_keys
+ @pytest.mark.integration
+ def test_list_models_with_nonexistent_provider_returns_empty(
+ self, governance_client, cleanup_tracker
+ ):
+ """Test that VK with non-existent provider returns empty list"""
+ vk_data = {
+ "name": generate_unique_name("Test VK Nonexistent Provider"),
+ "provider_configs": [
+ {
+ "provider": "nonexistent-provider-xyz-123",
+ "allowed_models": [],
+ "weight": 1.0,
+ }
+ ],
+ }
+ response = governance_client.create_virtual_key(vk_data)
+ assert_response_success(response, 200)
+ vk = response.json()["virtual_key"]
+ cleanup_tracker.add_virtual_key(vk["id"])
+
+ assert vk["provider_configs"] is not None
+ assert len(vk["provider_configs"]) == 1
+ assert vk["provider_configs"][0]["provider"] == "nonexistent-provider-xyz-123"
+
+ response = requests.get(
+ f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]}
+ )
+ assert_response_success(response, 200)
+
+ filtered_models = response.json().get("data", [])
+ assert len(filtered_models) == 0
+
+ @pytest.mark.virtual_keys
+ @pytest.mark.integration
+ def test_list_models_with_multiple_providers(
+ self, governance_client, cleanup_tracker
+ ):
+ """Test that VK with multiple providers returns models from all configured providers"""
+ response = requests.get(f"{BIFROST_BASE_URL}/v1/models")
+ assert_response_success(response, 200)
+ all_models = response.json().get("data", [])
+
+ available_providers = list(
+ set([m["id"].split("/")[0] for m in all_models if "/" in m["id"]])
+ )
+
+ if len(available_providers) < 2:
+ pytest.skip(f"Need at least 2 providers, have: {available_providers}")
+
+ provider1 = available_providers[0]
+ provider2 = available_providers[1]
+
+ provider1_models = [
+ m["id"].split("/")[1]
+ for m in all_models
+ if m["id"].startswith(f"{provider1}/")
+ ]
+ provider2_models = [
+ m["id"].split("/")[1]
+ for m in all_models
+ if m["id"].startswith(f"{provider2}/")
+ ]
+
+ if len(provider1_models) == 0 or len(provider2_models) == 0:
+ pytest.skip("Each provider needs at least one model")
+
+ selected_model1 = provider1_models[0]
+ selected_model2 = provider2_models[0]
+
+ vk_data = {
+ "name": generate_unique_name("Test VK Multi Provider"),
+ "provider_configs": [
+ {
+ "provider": provider1,
+ "allowed_models": [selected_model1],
+ "weight": 1.0,
+ },
+ {
+ "provider": provider2,
+ "allowed_models": [selected_model2],
+ "weight": 1.0,
+ },
+ ],
+ }
+ response = governance_client.create_virtual_key(vk_data)
+ assert_response_success(response, 200)
+ vk = response.json()["virtual_key"]
+ cleanup_tracker.add_virtual_key(vk["id"])
+
+ assert vk["provider_configs"] is not None
+ assert len(vk["provider_configs"]) == 2
+ vk_providers = [pc["provider"] for pc in vk["provider_configs"]]
+ assert provider1 in vk_providers
+ assert provider2 in vk_providers
+
+ response = requests.get(
+ f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]}
+ )
+ assert_response_success(response, 200)
+
+ filtered_models = response.json().get("data", [])
+ filtered_ids = [m["id"] for m in filtered_models]
+ expected_ids = [
+ f"{provider1}/{selected_model1}",
+ f"{provider2}/{selected_model2}",
+ ]
+
+ assert len(filtered_models) == 2, (
+ f"Expected 2 models, got {len(filtered_models)}"
+ )
+ assert set(filtered_ids) == set(expected_ids)
+
+ provider1_found = any(
+ m["id"].startswith(f"{provider1}/") for m in filtered_models
+ )
+ provider2_found = any(
+ m["id"].startswith(f"{provider2}/") for m in filtered_models
+ )
+ assert provider1_found
+ assert provider2_found
+
+ @pytest.mark.virtual_keys
+ @pytest.mark.integration
+ def test_list_models_without_vk_header_returns_all(self):
+ """Test that requesting models without VK header returns all models"""
+ response = requests.get(f"{BIFROST_BASE_URL}/v1/models")
+ assert_response_success(response, 200)
+
+ all_models = response.json().get("data", [])
+ assert len(all_models) > 0
+
+ @pytest.mark.virtual_keys
+ @pytest.mark.integration
+ def test_list_models_with_invalid_vk_header(self):
+ """Test that invalid VK header returns error or all models without filtering"""
+ response_without_vk = requests.get(f"{BIFROST_BASE_URL}/v1/models")
+ assert_response_success(response_without_vk, 200)
+ all_models = response_without_vk.json().get("data", [])
+ all_model_ids = set([m["id"] for m in all_models])
+
+ response = requests.get(
+ f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": "invalid-vk-value-xyz"}
+ )
+
+ assert response.status_code in [200, 401, 403]
+
+ if response.status_code == 200:
+ models = response.json().get("data", [])
+ model_ids = set([m["id"] for m in models])
+ assert isinstance(models, list)
+ assert model_ids == all_model_ids, "Invalid VK should not filter models"
+
+ @pytest.mark.virtual_keys
+ @pytest.mark.integration
+ def test_list_models_with_inactive_vk(self, governance_client, cleanup_tracker):
+ """Test that inactive VK is rejected or returns all models without filtering"""
+ response = requests.get(f"{BIFROST_BASE_URL}/v1/models")
+ assert_response_success(response, 200)
+ all_models = response.json().get("data", [])
+ all_model_ids = set([m["id"] for m in all_models])
+
+ if len(all_models) == 0:
+ pytest.skip("No models available")
+
+ vk_data = {
+ "name": generate_unique_name("Test VK Inactive"),
+ "is_active": False,
+ "provider_configs": [
+ {
+ "provider": "test-provider",
+ "allowed_models": ["test-model"],
+ "weight": 1.0,
+ }
+ ],
+ }
+ response = governance_client.create_virtual_key(vk_data)
+ assert_response_success(response, 200)
+ vk = response.json()["virtual_key"]
+ cleanup_tracker.add_virtual_key(vk["id"])
+
+ if vk["is_active"] is True:
+ update_response = governance_client.update_virtual_key(
+ vk["id"], {"is_active": False}
+ )
+ if update_response.status_code == 200:
+ vk = update_response.json()["virtual_key"]
+ else:
+ pytest.skip("Cannot deactivate VK")
+
+ assert vk["is_active"] is False
+
+ response = requests.get(
+ f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]}
+ )
+
+ assert response.status_code in [200, 401, 403]
+
+ if response.status_code == 200:
+ filtered_models = response.json().get("data", [])
+ filtered_ids = set([m["id"] for m in filtered_models])
+
+ assert filtered_ids == all_model_ids, (
+ f"Inactive VK should not filter models. Expected {len(all_models)} models, "
+ f"got {len(filtered_models)}"
+ )
+
+ @pytest.mark.virtual_keys
+ @pytest.mark.integration
+ def test_list_models_with_mixed_provider_configs(
+ self, governance_client, cleanup_tracker
+ ):
+ """Test VK with one provider having specific models and another with all models"""
+ response = requests.get(f"{BIFROST_BASE_URL}/v1/models")
+ assert_response_success(response, 200)
+ all_models = response.json().get("data", [])
+
+ providers = list(
+ set([m["id"].split("/")[0] for m in all_models if "/" in m["id"]])
+ )
+
+ if len(providers) < 2:
+ pytest.skip(f"Need at least 2 providers, have: {providers}")
+
+ provider1 = providers[0]
+ provider2 = providers[1]
+
+ provider1_models = [
+ m["id"].split("/")[1]
+ for m in all_models
+ if m["id"].startswith(f"{provider1}/")
+ ]
+ provider2_all_models = [
+ m["id"] for m in all_models if m["id"].startswith(f"{provider2}/")
+ ]
+
+ if len(provider1_models) < 2 or len(provider2_all_models) == 0:
+ pytest.skip("Need sufficient models")
+
+ selected_provider1_model = provider1_models[0]
+
+ vk_data = {
+ "name": generate_unique_name("Test VK Mixed Configs"),
+ "provider_configs": [
+ {
+ "provider": provider1,
+ "allowed_models": [selected_provider1_model],
+ "weight": 1.0,
+ },
+ {
+ "provider": provider2,
+ "allowed_models": [],
+ "weight": 1.0,
+ },
+ ],
+ }
+ response = governance_client.create_virtual_key(vk_data)
+ assert_response_success(response, 200)
+ vk = response.json()["virtual_key"]
+ cleanup_tracker.add_virtual_key(vk["id"])
+
+ response = requests.get(
+ f"{BIFROST_BASE_URL}/v1/models", headers={"x-bf-vk": vk["value"]}
+ )
+ assert_response_success(response, 200)
+
+ filtered_models = response.json().get("data", [])
+
+ expected_count = 1 + len(provider2_all_models)
+ assert len(filtered_models) == expected_count, (
+ f"Expected {expected_count} models, got {len(filtered_models)}"
+ )
+
+ provider1_filtered = [
+ m["id"] for m in filtered_models if m["id"].startswith(f"{provider1}/")
+ ]
+ assert len(provider1_filtered) == 1
+ assert f"{provider1}/{selected_provider1_model}" in provider1_filtered
+
+ provider2_filtered = [
+ m["id"] for m in filtered_models if m["id"].startswith(f"{provider2}/")
+ ]
+ assert set(provider2_filtered) == set(provider2_all_models)
diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go
index c029f85d6..f47b8e5a5 100644
--- a/transports/bifrost-http/handlers/inference.go
+++ b/transports/bifrost-http/handlers/inference.go
@@ -18,23 +18,26 @@ import (
"github.com/fasthttp/router"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
+ "github.com/maximhq/bifrost/plugins/governance"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// CompletionHandler manages HTTP requests for completion operations
type CompletionHandler struct {
- client *bifrost.Bifrost
- handlerStore lib.HandlerStore
- config *lib.Config
+ client *bifrost.Bifrost
+ handlerStore lib.HandlerStore
+ config *lib.Config
+ governanceStore *governance.GovernanceStore
}
// NewInferenceHandler creates a new completion handler instance
-func NewInferenceHandler(client *bifrost.Bifrost, config *lib.Config) *CompletionHandler {
+func NewInferenceHandler(client *bifrost.Bifrost, config *lib.Config, governanceStore *governance.GovernanceStore) *CompletionHandler {
return &CompletionHandler{
- client: client,
- handlerStore: config,
- config: config,
+ client: client,
+ handlerStore: config,
+ config: config,
+ governanceStore: governanceStore,
}
}
@@ -255,6 +258,72 @@ func extractExtraParams(data []byte, knownFields map[string]bool) (map[string]in
return extraParams, nil
}
+// getVirtualKeyProviders returns allowed providers and models from virtual key
+func getVirtualKeyProviders(governanceStore *governance.GovernanceStore, bifrostCtx *context.Context) map[schemas.ModelProvider][]string {
+ if governanceStore == nil || bifrostCtx == nil {
+ return nil
+ }
+
+ vkValue := (*bifrostCtx).Value(schemas.BifrostContextKeyVirtualKey)
+ if vkValue == nil {
+ return nil
+ }
+
+ vkString, ok := vkValue.(string)
+ if !ok || vkString == "" {
+ return nil
+ }
+
+ vk, exists := governanceStore.GetVirtualKey(vkString)
+ if !exists || !vk.IsActive {
+ return nil
+ }
+
+ // If no provider configs, all providers allowed
+ if len(vk.ProviderConfigs) == 0 {
+ return make(map[schemas.ModelProvider][]string)
+ }
+
+ allowedProviders := make(map[schemas.ModelProvider][]string, len(vk.ProviderConfigs))
+ for _, pc := range vk.ProviderConfigs {
+ provider := schemas.ModelProvider(pc.Provider)
+ allowedProviders[provider] = pc.AllowedModels
+ }
+
+ return allowedProviders
+}
+
+// filterModelsByAllowedModels filters models based on allowed models list
+func filterModelsByAllowedModels(allowedProviders map[schemas.ModelProvider][]string, models []schemas.Model) []schemas.Model {
+ if allowedProviders == nil || len(allowedProviders) == 0 {
+ return models
+ }
+
+ filteredModels := []schemas.Model{}
+ for _, model := range models {
+ modelProvider, modelName := schemas.ParseModelString(model.ID, "")
+
+ allowedModels, providerAllowed := allowedProviders[modelProvider]
+ if !providerAllowed {
+ continue
+ }
+
+ if len(allowedModels) == 0 {
+ filteredModels = append(filteredModels, model)
+ continue
+ }
+
+ for _, allowedModel := range allowedModels {
+ if allowedModel == modelName || allowedModel == model.ID {
+ filteredModels = append(filteredModels, model)
+ break
+ }
+ }
+ }
+
+ return filteredModels
+}
+
const (
// Maximum file size (25MB)
MaxFileSize = 25 * 1024 * 1024
@@ -298,6 +367,12 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) {
return
}
+ // Check virtual key and get allowed providers
+ var allowedProviders map[schemas.ModelProvider][]string
+ if h.config.ClientConfig.EnableGovernance {
+ allowedProviders = getVirtualKeyProviders(h.governanceStore, bifrostCtx)
+ }
+
var resp *schemas.BifrostListModelsResponse
var bifrostErr *schemas.BifrostError
@@ -327,24 +402,76 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) {
bifrostListModelsReq.ExtraParams = extraParams
}
- // If provider is empty, list all models from all providers
- if provider == "" {
- resp, bifrostErr = h.client.ListAllModels(*bifrostCtx, bifrostListModelsReq)
+ // Determine which providers to query based on VK restrictions
+ var providersToFetch []schemas.ModelProvider
+
+ if allowedProviders != nil && len(allowedProviders) > 0 {
+ // Virtual key has provider restrictions
+ if provider != "" {
+ // Check if specific provider is allowed by VK
+ if _, providerAllowed := allowedProviders[schemas.ModelProvider(provider)]; !providerAllowed {
+ SendError(ctx, fasthttp.StatusForbidden, "Provider not allowed by virtual key")
+ return
+ }
+ providersToFetch = []schemas.ModelProvider{schemas.ModelProvider(provider)}
+ } else {
+ // Use all VK-allowed providers
+ providersToFetch = make([]schemas.ModelProvider, 0, len(allowedProviders))
+ for p := range allowedProviders {
+ providersToFetch = append(providersToFetch, p)
+ }
+ }
} else {
+ // No VK restrictions - use provider query param or all configured providers
+ if provider != "" {
+ providersToFetch = []schemas.ModelProvider{schemas.ModelProvider(provider)}
+ } else {
+ configuredProviders, err := h.client.GetConfiguredProviders()
+ if err != nil {
+ SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get configured providers: %v", err))
+ return
+ }
+ providersToFetch = configuredProviders
+ }
+ }
+
+ // Fetch models from providers
+ if len(providersToFetch) == 1 {
+ // Ensure provider is set in request
+ bifrostListModelsReq.Provider = providersToFetch[0]
resp, bifrostErr = h.client.ListModelsRequest(*bifrostCtx, bifrostListModelsReq)
+ } else {
+ resp, bifrostErr = h.client.ListModelsFromProviders(*bifrostCtx, providersToFetch, bifrostListModelsReq)
+ if bifrostErr == nil && resp != nil {
+ resp = resp.ApplyPagination(bifrostListModelsReq.PageSize, bifrostListModelsReq.PageToken)
+ }
}
if bifrostErr != nil {
+ if bifrostErr.Error != nil && bifrostErr.Error.Message == "provider not found for list models request" {
+ SendJSON(ctx, &schemas.BifrostListModelsResponse{
+ Data: []schemas.Model{},
+ })
+ return
+ }
SendBifrostError(ctx, bifrostErr)
return
}
- // Add pricing data to the response
+ // Apply model-level filtering based on VK's allowed_models
+ if allowedProviders != nil && len(allowedProviders) > 0 {
+ resp.Data = filterModelsByAllowedModels(allowedProviders, resp.Data)
+ }
+
+ // Attach pricing data if not already present
if len(resp.Data) > 0 && h.config.PricingManager != nil {
for i, modelEntry := range resp.Data {
+ if modelEntry.Pricing != nil {
+ continue
+ }
provider, modelName := schemas.ParseModelString(modelEntry.ID, "")
pricingEntry := h.config.PricingManager.GetPricingEntryForModel(modelName, provider)
- if pricingEntry != nil && modelEntry.Pricing == nil {
+ if pricingEntry != nil {
pricing := &schemas.Pricing{
Prompt: bifrost.Ptr(fmt.Sprintf("%f", pricingEntry.InputCostPerToken)),
Completion: bifrost.Ptr(fmt.Sprintf("%f", pricingEntry.OutputCostPerToken)),
diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go
index 15d478f7a..d52e0ff75 100644
--- a/transports/bifrost-http/server/server.go
+++ b/transports/bifrost-http/server/server.go
@@ -842,7 +842,13 @@ func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, name string) error
// RegisterInferenceRoutes initializes the routes for the inference handler
func (s *BifrostHTTPServer) RegisterInferenceRoutes(ctx context.Context, middlewares ...lib.BifrostHTTPMiddleware) error {
- inferenceHandler := handlers.NewInferenceHandler(s.Client, s.Config)
+ var governanceStore *governance.GovernanceStore
+ governancePlugin, _ := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName)
+ if governancePlugin != nil {
+ governanceStore = governancePlugin.GetGovernanceStore()
+ }
+
+ inferenceHandler := handlers.NewInferenceHandler(s.Client, s.Config, governanceStore)
integrationHandler := handlers.NewIntegrationHandler(s.Client, s.Config)
integrationHandler.RegisterRoutes(s.Router, middlewares...)
inferenceHandler.RegisterRoutes(s.Router, middlewares...)
@@ -1077,7 +1083,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error {
s.Server = &fasthttp.Server{
Handler: handlers.CorsMiddleware(s.Config)(handlers.TransportInterceptorMiddleware(s.Config)(s.Router.Handler)),
MaxRequestBodySize: s.Config.ClientConfig.MaxRequestBodySizeMB * 1024 * 1024,
- ReadBufferSize: 1024 * 16, // 16kb
+ ReadBufferSize: 1024 * 16, // 16kb
}
return nil
}