Skip to content

Python: Remove model info check in Bedrock connectors #12395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions python/semantic_kernel/connectors/ai/bedrock/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ Not all models in Bedrock support tools. Refer to the [AWS documentation](https:

Not all models in Bedrock support streaming. You can use the boto3 client to check if a model supports streaming. Refer to the [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html) and the [Boto3 documentation](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/get_foundation_model.html) for more information.

You can also directly call the `get_foundation_model_info("model_id")` method from the Bedrock connector to check if a model supports streaming.

> Note: The bedrock connector will check if a model supports streaming before making a streaming request to the model.

## Model specific parameters

Foundation models can have specific parameters that are unique to the model or the model provider. You can refer to this [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html) for more information.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# Copyright (c) Microsoft. All rights reserved.

from abc import ABC
from functools import partial
from typing import Any, ClassVar

import boto3

from semantic_kernel.kernel_pydantic import KernelBaseModel
from semantic_kernel.utils.async_utils import run_in_executor


class BedrockBase(KernelBaseModel, ABC):
Expand Down Expand Up @@ -40,15 +38,3 @@ def __init__(
bedrock_client=client or boto3.client("bedrock"),
**kwargs,
)

async def get_foundation_model_info(self, model_id: str) -> dict[str, Any]:
"""Get the foundation model information."""
response = await run_in_executor(
None,
partial(
self.bedrock_client.get_foundation_model,
modelIdentifier=model_id,
),
)

return response.get("modelDetails")
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from semantic_kernel.contents.utils.finish_reason import FinishReason
from semantic_kernel.exceptions.service_exceptions import (
ServiceInitializationError,
ServiceInvalidRequestError,
ServiceInvalidResponseError,
)
from semantic_kernel.utils.async_utils import run_in_executor
Expand Down Expand Up @@ -127,11 +126,6 @@ async def _inner_get_streaming_chat_message_contents(
settings: "PromptExecutionSettings",
function_invoke_attempt: int = 0,
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
# Not all models support streaming: check if the model supports streaming before proceeding
model_info = await self.get_foundation_model_info(self.ai_model_id)
if not model_info.get("responseStreamingSupported"):
raise ServiceInvalidRequestError(f"The model {self.ai_model_id} does not support streaming.")

if not isinstance(settings, BedrockChatPromptExecutionSettings):
settings = self.get_prompt_execution_settings_from_settings(settings)
assert isinstance(settings, BedrockChatPromptExecutionSettings) # nosec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from semantic_kernel.connectors.ai.text_completion_client_base import TextCompletionClientBase
from semantic_kernel.contents.streaming_text_content import StreamingTextContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceInvalidRequestError
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
from semantic_kernel.utils.async_utils import run_in_executor
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import (
trace_streaming_text_completion,
Expand Down Expand Up @@ -108,11 +108,6 @@ async def _inner_get_streaming_text_contents(
prompt: str,
settings: "PromptExecutionSettings",
) -> AsyncGenerator[list[StreamingTextContent], Any]:
# Not all models support streaming: check if the model supports streaming before proceeding
model_info = await self.get_foundation_model_info(self.ai_model_id)
if not model_info.get("responseStreamingSupported"):
raise ServiceInvalidRequestError(f"The model {self.ai_model_id} does not support streaming.")

if not isinstance(settings, BedrockTextPromptExecutionSettings):
settings = self.get_prompt_execution_settings_from_settings(settings)
assert isinstance(settings, BedrockTextPromptExecutionSettings) # nosec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceInvalidRequestError
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
from semantic_kernel.utils.async_utils import run_in_executor

if TYPE_CHECKING:
Expand Down Expand Up @@ -80,13 +80,6 @@ async def generate_embeddings(
settings: "PromptExecutionSettings | None" = None,
**kwargs: Any,
) -> ndarray:
model_info = await self.get_foundation_model_info(self.ai_model_id)
if "TEXT" not in model_info.get("inputModalities", []):
# Image embedding is not supported yet in SK
raise ServiceInvalidRequestError(f"The model {self.ai_model_id} does not support text input.")
if "EMBEDDING" not in model_info.get("outputModalities", []):
raise ServiceInvalidRequestError(f"The model {self.ai_model_id} does not support embedding output.")

if not settings:
settings = BedrockEmbeddingPromptExecutionSettings()
elif not isinstance(settings, BedrockEmbeddingPromptExecutionSettings):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from semantic_kernel.contents.utils.finish_reason import FinishReason
from semantic_kernel.exceptions.service_exceptions import (
ServiceInitializationError,
ServiceInvalidRequestError,
ServiceInvalidResponseError,
)
from tests.unit.connectors.ai.bedrock.conftest import MockBedrockClient, MockBedrockRuntimeClient
Expand Down Expand Up @@ -281,30 +280,6 @@ async def test_bedrock_streaming_chat_completion(
assert response.finish_reason == FinishReason.STOP


async def test_bedrock_streaming_chat_completion_with_unsupported_model(
model_id,
chat_history: ChatHistory,
) -> None:
"""Test Amazon Bedrock Streaming Chat Completion complete method"""
with patch.object(
MockBedrockClient, "get_foundation_model", return_value={"modelDetails": {"responseStreamingSupported": False}}
):
# Setup
bedrock_chat_completion = BedrockChatCompletion(
model_id=model_id,
runtime_client=MockBedrockRuntimeClient(),
client=MockBedrockClient(),
)

# Act
settings = BedrockChatPromptExecutionSettings()
with pytest.raises(ServiceInvalidRequestError):
async for chunk in bedrock_chat_completion.get_streaming_chat_message_contents(
chat_history=chat_history, settings=settings
):
pass


@pytest.mark.parametrize(
# These are fake model ids with the supported prefixes
"model_id",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from semantic_kernel.contents.streaming_text_content import StreamingTextContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceInvalidRequestError
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
from tests.unit.connectors.ai.bedrock.conftest import MockBedrockClient, MockBedrockRuntimeClient

# region init
Expand Down Expand Up @@ -213,25 +213,4 @@ async def test_bedrock_streaming_text_completion(
assert isinstance(response.inner_content, list)


async def test_bedrock_streaming_text_completion_with_unsupported_model(
model_id,
) -> None:
"""Test Amazon Bedrock Streaming Chat Completion complete method"""
with patch.object(
MockBedrockClient, "get_foundation_model", return_value={"modelDetails": {"responseStreamingSupported": False}}
):
# Setup
bedrock_text_completion = BedrockTextCompletion(
model_id=model_id,
runtime_client=MockBedrockRuntimeClient(),
client=MockBedrockClient(),
)

# Act
settings = BedrockTextPromptExecutionSettings()
with pytest.raises(ServiceInvalidRequestError):
async for chunk in bedrock_text_completion.get_streaming_text_contents("Hello", settings=settings):
pass


# endregion
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from semantic_kernel.connectors.ai.bedrock.services.bedrock_text_embedding import BedrockTextEmbedding
from semantic_kernel.exceptions.service_exceptions import (
ServiceInitializationError,
ServiceInvalidRequestError,
ServiceInvalidResponseError,
)
from tests.unit.connectors.ai.bedrock.conftest import MockBedrockClient, MockBedrockRuntimeClient
Expand Down Expand Up @@ -149,40 +148,6 @@ async def test_bedrock_text_embedding(model_id, mock_bedrock_text_embedding_resp
assert len(response) == 2


async def test_bedrock_text_embedding_with_unsupported_model_input_modality(model_id) -> None:
"""Test Bedrock text embedding generation with unsupported model"""
with patch.object(
MockBedrockClient, "get_foundation_model", return_value={"modelDetails": {"inputModalities": ["IMAGE"]}}
):
# Setup
bedrock_text_embedding = BedrockTextEmbedding(
model_id=model_id,
runtime_client=MockBedrockRuntimeClient(),
client=MockBedrockClient(),
)

with pytest.raises(ServiceInvalidRequestError):
await bedrock_text_embedding.generate_embeddings(["hello", "world"])


async def test_bedrock_text_embedding_with_unsupported_model_output_modality(model_id) -> None:
"""Test Bedrock text embedding generation with unsupported model"""
with patch.object(
MockBedrockClient,
"get_foundation_model",
return_value={"modelDetails": {"inputModalities": ["TEXT"], "outputModalities": ["TEXT"]}},
):
# Setup
bedrock_text_embedding = BedrockTextEmbedding(
model_id=model_id,
runtime_client=MockBedrockRuntimeClient(),
client=MockBedrockClient(),
)

with pytest.raises(ServiceInvalidRequestError):
await bedrock_text_embedding.generate_embeddings(["hello", "world"])


@pytest.mark.parametrize(
# These are fake model ids with the supported prefixes
"model_id",
Expand Down
Loading