Skip to content

Commit 7164af1

Browse files
Added Test script and also corrected the spellings mistake in import in embeddings.__init__.py
1 parent e5209ab commit 7164af1

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

src/neo4j_graphrag/embeddings/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
2020
from .sentence_transformers import SentenceTransformerEmbeddings
2121
from .vertexai import VertexAIEmbeddings
22-
from .bedrock_embeddings import BedrockEmbeddings
22+
from .bedrockembeddings import BedrockEmbeddings
2323

2424
__all__ = [
2525
"Embedder",
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from unittest.mock import patch, MagicMock
2+
import pytest
3+
import json
4+
5+
from neo4j_graphrag.embeddings.bedrockembeddings import BedrockEmbeddings
6+
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
7+
8+
9+
@patch("neo4j_graphrag.embeddings.bedrockembeddings.boto3.client")
10+
def test_bedrock_embedder_happy_path(mock_boto_client):
11+
# Mock AWS response with valid embedding
12+
fake_embedding = [0.1] * 1024
13+
fake_response = {
14+
"embedding": fake_embedding
15+
}
16+
17+
# Mock the .read() to return the fake response as JSON bytes
18+
mock_body = MagicMock()
19+
mock_body.read.return_value = json.dumps(fake_response).encode("utf-8")
20+
21+
# Mock the bedrock client
22+
mock_bedrock_client = MagicMock()
23+
mock_bedrock_client.invoke_model.return_value = {"body": mock_body}
24+
mock_boto_client.return_value = mock_bedrock_client
25+
26+
# Instantiate the embedder and run embed_query
27+
embedder = BedrockEmbeddings()
28+
result = embedder.embed_query("Hello, Bedrock!")
29+
30+
# Assertions
31+
assert isinstance(result, list)
32+
assert len(result) == 1024
33+
assert result == fake_embedding
34+
35+
36+
@patch("neo4j_graphrag.embeddings.bedrockembeddings.boto3.client")
37+
def test_bedrock_embedder_error_path(mock_boto_client):
38+
# Simulate AWS client raising an exception
39+
mock_bedrock_client = MagicMock()
40+
mock_bedrock_client.invoke_model.side_effect = Exception("AWS error")
41+
mock_boto_client.return_value = mock_bedrock_client
42+
43+
embedder = BedrockEmbeddings()
44+
45+
with pytest.raises(EmbeddingsGenerationError) as exc_info:
46+
embedder.embed_query("This will fail.")
47+
48+
assert "Issue Generating Embeddings" in str(exc_info.value)

0 commit comments

Comments
 (0)