diff --git a/src/ell/configurator.py b/src/ell/configurator.py index 57e281fe1..b62300a97 100644 --- a/src/ell/configurator.py +++ b/src/ell/configurator.py @@ -16,9 +16,9 @@ class _Model: name: str default_client: Optional[Union[openai.Client, Any]] = None #XXX: Deprecation in 0.1.0 - #XXX: We will depreciate this when streaming is implemented. + #XXX: We will depreciate this when streaming is implemented. # Currently we stream by default for the verbose renderer, - # but in the future we will not support streaming by default + # but in the future we will not support streaming by default # and stream=True must be passed which will then make API providers the # single source of truth for whether or not a model supports an api parameter. # This makes our implementation extremely light, only requiring us to provide @@ -44,9 +44,9 @@ def __init__(self, **data): self._lock = threading.Lock() self._local = threading.local() - + def register_model( - self, + self, name: str, default_client: Optional[Union[openai.Client, Any]] = None, supports_streaming: Optional[bool] = None @@ -74,12 +74,12 @@ def model_registry_override(self, overrides: Dict[str, _Model]): """ if not hasattr(self._local, 'stack'): self._local.stack = [] - + with self._lock: current_registry = self._local.stack[-1] if self._local.stack else self.registry new_registry = current_registry.copy() new_registry.update(overrides) - + self._local.stack.append(new_registry) try: yield @@ -133,11 +133,25 @@ def get_provider_for(self, client: Union[Type[Any], Any]) -> Optional[Provider]: """ client_type = type(client) if not isinstance(client, type) else client - for provider_type, provider in self.providers.items(): - if issubclass(client_type, provider_type) or client_type == provider_type: - return provider + # First, try to find an exact match + if client_type in self.providers: + return self.providers[client_type] + + # If no exact match, look for the most specific subclass + matching_providers = [ + (provider_type, provider) + for provider_type, provider in self.providers.items() + if issubclass(client_type, provider_type) + ] + + if matching_providers: + # Sort by inheritance depth (most derived class first) + matching_providers.sort(key=lambda x: len(x[0].mro()), reverse=True) + return matching_providers[0][1] + return None + # Single* instance # XXX: Make a singleton config = Config() @@ -187,7 +201,7 @@ def init( def get_store() -> Union[Store, None]: return config.store -# Will be deprecated at 0.1.0 +# Will be deprecated at 0.1.0 # You can add more helper functions here if needed def register_provider(provider: Provider, client_type: Type[Any]) -> None: diff --git a/src/ell/models/openai.py b/src/ell/models/openai.py index a53941fd0..dff37e92a 100644 --- a/src/ell/models/openai.py +++ b/src/ell/models/openai.py @@ -30,6 +30,7 @@ logger = logging.getLogger(__name__) + def register(client: openai.Client): """ Register OpenAI models with the provided client. @@ -92,4 +93,4 @@ def register(client: openai.Client): pass register(default_client) -config.default_client = default_client \ No newline at end of file +config.default_client = default_client diff --git a/tests/test_provider_override.py b/tests/test_provider_override.py new file mode 100644 index 000000000..9096825b0 --- /dev/null +++ b/tests/test_provider_override.py @@ -0,0 +1,107 @@ +import pytest +from ell.configurator import Config, Provider + + +class MockProvider(Provider): + def provider_call_function(self, *args, **kwargs): + pass + + def translate_from_provider(self, *args, **kwargs): + pass + + def translate_to_provider(self, *args, **kwargs): + pass + + +class MockOpenAIProvider(MockProvider): + pass + + +class MockCustomProvider(MockOpenAIProvider): + pass + + +class MockAnthropicProvider(MockProvider): + pass + + +class TestProviderOverride: + @pytest.fixture + def config(self): + return Config() + + @pytest.fixture + def MockOpenAIClient(self): + return type('MockOpenAIClient', (), {}) + + @pytest.fixture + def MockAnthropicClient(self): + return type('MockAnthropicClient', (), {}) + + def test_exact_match_provider(self, config, MockOpenAIClient): + provider = MockOpenAIProvider() + config.register_provider(provider, MockOpenAIClient) + + assert config.get_provider_for(MockOpenAIClient) == provider + + def test_subclass_match_provider(self, config, MockOpenAIClient): + provider = MockProvider() + config.register_provider(provider, MockOpenAIClient) + + CustomOpenAIClient = type('CustomOpenAIClient', (MockOpenAIClient,), {}) + + assert config.get_provider_for(CustomOpenAIClient) == provider + + def test_most_specific_subclass_match(self, config, MockOpenAIClient): + CustomOpenAIClient = type('CustomOpenAIClient', (MockOpenAIClient,), {}) + + base_provider = MockProvider() + openai_provider = MockOpenAIProvider() + custom_provider = MockCustomProvider() + config.register_provider(base_provider, object) + config.register_provider(openai_provider, MockOpenAIClient) + config.register_provider(custom_provider, CustomOpenAIClient) + + assert config.get_provider_for(CustomOpenAIClient) == custom_provider + assert config.get_provider_for(MockOpenAIClient) == openai_provider + + def test_multiple_inheritance(self, config, MockOpenAIClient, MockAnthropicClient): + openai_provider = MockOpenAIProvider() + anthropic_provider = MockAnthropicProvider() + config.register_provider(openai_provider, MockOpenAIClient) + config.register_provider(anthropic_provider, MockAnthropicClient) + + HybridClient = type('HybridClient', (MockOpenAIClient, MockAnthropicClient), {}) + + assert config.get_provider_for(HybridClient) == openai_provider + + def test_no_match_provider(self, config, MockOpenAIClient): + provider = MockProvider() + config.register_provider(provider, MockOpenAIClient) + + assert config.get_provider_for(str) is None + + def test_type_and_instance_input(self, config, MockOpenAIClient): + provider = MockProvider() + config.register_provider(provider, MockOpenAIClient) + + assert config.get_provider_for(MockOpenAIClient) == provider + # For testing with an instance, we need to create one + mock_instance = MockOpenAIClient() + assert config.get_provider_for(type(mock_instance)) == provider + + def test_custom_provider_inheritance(self, config, MockOpenAIClient): + CustomOpenAIClient = type('CustomOpenAIClient', (MockOpenAIClient,), {}) + VeryCustomOpenAIClient = type('VeryCustomOpenAIClient', (CustomOpenAIClient,), {}) + + base_provider = MockProvider() + openai_provider = MockOpenAIProvider() + custom_provider = MockCustomProvider() + + config.register_provider(base_provider, object) + config.register_provider(openai_provider, MockOpenAIClient) + config.register_provider(custom_provider, CustomOpenAIClient) + + assert config.get_provider_for(MockOpenAIClient) == openai_provider + assert config.get_provider_for(CustomOpenAIClient) == custom_provider + assert config.get_provider_for(VeryCustomOpenAIClient) == custom_provider