Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions src/ell/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class _Model:
name: str
default_client: Optional[Union[openai.Client, Any]] = None
#XXX: Deprecation in 0.1.0
#XXX: We will depreciate this when streaming is implemented.
#XXX: We will depreciate this when streaming is implemented.
# Currently we stream by default for the verbose renderer,
# but in the future we will not support streaming by default
# but in the future we will not support streaming by default
# and stream=True must be passed which will then make API providers the
# single source of truth for whether or not a model supports an api parameter.
# This makes our implementation extremely light, only requiring us to provide
Expand All @@ -44,9 +44,9 @@ def __init__(self, **data):
self._lock = threading.Lock()
self._local = threading.local()


def register_model(
self,
self,
name: str,
default_client: Optional[Union[openai.Client, Any]] = None,
supports_streaming: Optional[bool] = None
Expand Down Expand Up @@ -74,12 +74,12 @@ def model_registry_override(self, overrides: Dict[str, _Model]):
"""
if not hasattr(self._local, 'stack'):
self._local.stack = []

with self._lock:
current_registry = self._local.stack[-1] if self._local.stack else self.registry
new_registry = current_registry.copy()
new_registry.update(overrides)

self._local.stack.append(new_registry)
try:
yield
Expand Down Expand Up @@ -133,11 +133,25 @@ def get_provider_for(self, client: Union[Type[Any], Any]) -> Optional[Provider]:
"""

client_type = type(client) if not isinstance(client, type) else client
for provider_type, provider in self.providers.items():
if issubclass(client_type, provider_type) or client_type == provider_type:
return provider
# First, try to find an exact match
if client_type in self.providers:
return self.providers[client_type]

# If no exact match, look for the most specific subclass
matching_providers = [
(provider_type, provider)
for provider_type, provider in self.providers.items()
if issubclass(client_type, provider_type)
]

if matching_providers:
# Sort by inheritance depth (most derived class first)
matching_providers.sort(key=lambda x: len(x[0].mro()), reverse=True)
return matching_providers[0][1]

return None


# Single* instance
# XXX: Make a singleton
config = Config()
Expand Down Expand Up @@ -187,7 +201,7 @@ def init(
def get_store() -> Union[Store, None]:
return config.store

# Will be deprecated at 0.1.0
# Will be deprecated at 0.1.0

# You can add more helper functions here if needed
def register_provider(provider: Provider, client_type: Type[Any]) -> None:
Expand Down
3 changes: 2 additions & 1 deletion src/ell/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

logger = logging.getLogger(__name__)


def register(client: openai.Client):
"""
Register OpenAI models with the provided client.
Expand Down Expand Up @@ -92,4 +93,4 @@ def register(client: openai.Client):
pass

register(default_client)
config.default_client = default_client
config.default_client = default_client
107 changes: 107 additions & 0 deletions tests/test_provider_override.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import pytest
from ell.configurator import Config, Provider


class MockProvider(Provider):
def provider_call_function(self, *args, **kwargs):
pass

def translate_from_provider(self, *args, **kwargs):
pass

def translate_to_provider(self, *args, **kwargs):
pass


class MockOpenAIProvider(MockProvider):
pass


class MockCustomProvider(MockOpenAIProvider):
pass


class MockAnthropicProvider(MockProvider):
pass


class TestProviderOverride:
@pytest.fixture
def config(self):
return Config()

@pytest.fixture
def MockOpenAIClient(self):
return type('MockOpenAIClient', (), {})

@pytest.fixture
def MockAnthropicClient(self):
return type('MockAnthropicClient', (), {})

def test_exact_match_provider(self, config, MockOpenAIClient):
provider = MockOpenAIProvider()
config.register_provider(provider, MockOpenAIClient)

assert config.get_provider_for(MockOpenAIClient) == provider

def test_subclass_match_provider(self, config, MockOpenAIClient):
provider = MockProvider()
config.register_provider(provider, MockOpenAIClient)

CustomOpenAIClient = type('CustomOpenAIClient', (MockOpenAIClient,), {})

assert config.get_provider_for(CustomOpenAIClient) == provider

def test_most_specific_subclass_match(self, config, MockOpenAIClient):
CustomOpenAIClient = type('CustomOpenAIClient', (MockOpenAIClient,), {})

base_provider = MockProvider()
openai_provider = MockOpenAIProvider()
custom_provider = MockCustomProvider()
config.register_provider(base_provider, object)
config.register_provider(openai_provider, MockOpenAIClient)
config.register_provider(custom_provider, CustomOpenAIClient)

assert config.get_provider_for(CustomOpenAIClient) == custom_provider
assert config.get_provider_for(MockOpenAIClient) == openai_provider

def test_multiple_inheritance(self, config, MockOpenAIClient, MockAnthropicClient):
openai_provider = MockOpenAIProvider()
anthropic_provider = MockAnthropicProvider()
config.register_provider(openai_provider, MockOpenAIClient)
config.register_provider(anthropic_provider, MockAnthropicClient)

HybridClient = type('HybridClient', (MockOpenAIClient, MockAnthropicClient), {})

assert config.get_provider_for(HybridClient) == openai_provider

def test_no_match_provider(self, config, MockOpenAIClient):
provider = MockProvider()
config.register_provider(provider, MockOpenAIClient)

assert config.get_provider_for(str) is None

def test_type_and_instance_input(self, config, MockOpenAIClient):
provider = MockProvider()
config.register_provider(provider, MockOpenAIClient)

assert config.get_provider_for(MockOpenAIClient) == provider
# For testing with an instance, we need to create one
mock_instance = MockOpenAIClient()
assert config.get_provider_for(type(mock_instance)) == provider

def test_custom_provider_inheritance(self, config, MockOpenAIClient):
CustomOpenAIClient = type('CustomOpenAIClient', (MockOpenAIClient,), {})
VeryCustomOpenAIClient = type('VeryCustomOpenAIClient', (CustomOpenAIClient,), {})

base_provider = MockProvider()
openai_provider = MockOpenAIProvider()
custom_provider = MockCustomProvider()

config.register_provider(base_provider, object)
config.register_provider(openai_provider, MockOpenAIClient)
config.register_provider(custom_provider, CustomOpenAIClient)

assert config.get_provider_for(MockOpenAIClient) == openai_provider
assert config.get_provider_for(CustomOpenAIClient) == custom_provider
assert config.get_provider_for(VeryCustomOpenAIClient) == custom_provider