Skip to content

Commit 4d0adf6

Browse files
committed
Add mypy type checks
1 parent 9b7c35c commit 4d0adf6

File tree

5 files changed

+148
-85
lines changed

5 files changed

+148
-85
lines changed

.github/workflows/lint.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
run: uv run ruff format --check
4646
- name: Run ruff check
4747
run: uv run ruff check
48-
# - name: Run mypy
49-
# run: uv run mypy .
48+
- name: Run mypy
49+
run: uv run mypy .
5050
- name: Minimize uv cache
5151
run: uv cache prune --ci

langchain/langchain_vectorize/retrievers.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,18 @@
44

55
from typing import TYPE_CHECKING, Any, Literal, Optional
66

7-
import vectorize_client
87
from langchain_core.documents import Document
98
from langchain_core.retrievers import BaseRetriever
109
from typing_extensions import override
11-
from vectorize_client import (
12-
ApiClient,
13-
Configuration,
14-
PipelinesApi,
15-
RetrieveDocumentsRequest,
16-
)
10+
from vectorize_client.api.pipelines_api import PipelinesApi
11+
from vectorize_client.api_client import ApiClient
12+
from vectorize_client.configuration import Configuration
13+
from vectorize_client.models.retrieve_documents_request import RetrieveDocumentsRequest
1714

1815
if TYPE_CHECKING:
1916
from langchain_core.callbacks import CallbackManagerForRetrieverRun
2017
from langchain_core.runnables import RunnableConfig
18+
from vectorize_client.models.document import Document as VectorizeDocument
2119

2220
_METADATA_FIELDS = {
2321
"relevancy",
@@ -122,7 +120,7 @@ def format_docs(docs):
122120
metadata_filters: list[dict[str, Any]] = []
123121
"""The metadata filters to apply when retrieving the documents."""
124122

125-
_pipelines: PipelinesApi | None = None
123+
_pipelines: PipelinesApi = _NOT_SET # type: ignore[assignment]
126124

127125
@override
128126
def model_post_init(self, /, context: Any) -> None:
@@ -146,7 +144,7 @@ def model_post_init(self, /, context: Any) -> None:
146144
self._pipelines = PipelinesApi(api)
147145

148146
@staticmethod
149-
def _convert_document(document: vectorize_client.models.Document) -> Document:
147+
def _convert_document(document: VectorizeDocument) -> Document:
150148
metadata = {field: getattr(document, field) for field in _METADATA_FIELDS}
151149
return Document(id=document.id, page_content=document.text, metadata=metadata)
152150

@@ -162,14 +160,29 @@ def _get_relevant_documents(
162160
rerank: bool | None = None,
163161
metadata_filters: list[dict[str, Any]] | None = None,
164162
) -> list[Document]:
165-
request = RetrieveDocumentsRequest(
163+
request = RetrieveDocumentsRequest( # type: ignore[call-arg]
166164
question=query,
167165
num_results=num_results or self.num_results,
168166
rerank=rerank or self.rerank,
169167
metadata_filters=metadata_filters or self.metadata_filters,
170168
)
169+
organization_ = organization or self.organization
170+
if not organization_:
171+
msg = (
172+
"Organization must be set either at initialization "
173+
"or in the invoke method."
174+
)
175+
raise ValueError(msg)
176+
pipeline_id_ = pipeline_id or self.pipeline_id
177+
if not pipeline_id_:
178+
msg = (
179+
"Pipeline ID must be set either at initialization "
180+
"or in the invoke method."
181+
)
182+
raise ValueError(msg)
183+
171184
response = self._pipelines.retrieve_documents(
172-
organization or self.organization, pipeline_id or self.pipeline_id, request
185+
organization_, pipeline_id_, request
173186
)
174187
return [self._convert_document(doc) for doc in response.documents]
175188

@@ -181,9 +194,10 @@ def invoke(
181194
*,
182195
organization: str = "",
183196
pipeline_id: str = "",
184-
num_results: int = _NOT_SET,
185-
rerank: bool = _NOT_SET,
186-
metadata_filters: list[dict[str, Any]] = _NOT_SET,
197+
num_results: int = _NOT_SET, # type: ignore[assignment]
198+
rerank: bool = _NOT_SET, # type: ignore[assignment]
199+
metadata_filters: list[dict[str, Any]] = _NOT_SET, # type: ignore[assignment]
200+
**_kwargs: Any,
187201
) -> list[Document]:
188202
"""Invoke the retriever to get relevant documents.
189203
@@ -218,16 +232,15 @@ def invoke(
218232
query = "what year was breath of the wild released?"
219233
docs = retriever.invoke(query, num_results=2)
220234
"""
221-
kwargs = {}
222235
if organization:
223-
kwargs["organization"] = organization
236+
_kwargs["organization"] = organization
224237
if pipeline_id:
225-
kwargs["pipeline_id"] = pipeline_id
238+
_kwargs["pipeline_id"] = pipeline_id
226239
if num_results is not _NOT_SET:
227-
kwargs["num_results"] = num_results
240+
_kwargs["num_results"] = num_results
228241
if rerank is not _NOT_SET:
229-
kwargs["rerank"] = rerank
242+
_kwargs["rerank"] = rerank
230243
if metadata_filters is not _NOT_SET:
231-
kwargs["metadata_filters"] = metadata_filters
244+
_kwargs["metadata_filters"] = metadata_filters
232245

233-
return super().invoke(input, config, **kwargs)
246+
return super().invoke(input, config, **_kwargs)

langchain/pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Issues = "https://github.com/vectorize-io/integrations-python/issues"
3131

3232
[dependency-groups]
3333
dev = [
34-
"mypy>=1.13.0",
34+
"mypy>=1.17.1,<1.18",
3535
"pytest>=8.3.3",
3636
"ruff>=0.9.0,<0.10",
3737
]
@@ -59,6 +59,8 @@ flake8-annotations.mypy-init-return = true
5959

6060
[tool.mypy]
6161
strict = true
62+
strict_bytes = true
63+
enable_error_code = "deprecated"
6264
warn_unreachable = true
6365
pretty = true
6466
show_error_codes = true

langchain/tests/test_retrievers.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,42 @@
88

99
import pytest
1010
import urllib3
11-
import vectorize_client as v
12-
from vectorize_client import ApiClient
11+
from vectorize_client.api.ai_platform_connectors_api import AIPlatformConnectorsApi
12+
from vectorize_client.api.destination_connectors_api import DestinationConnectorsApi
13+
from vectorize_client.api.pipelines_api import PipelinesApi
14+
from vectorize_client.api.source_connectors_api import SourceConnectorsApi
15+
from vectorize_client.api.uploads_api import UploadsApi
16+
from vectorize_client.api_client import ApiClient
17+
from vectorize_client.configuration import Configuration
18+
from vectorize_client.models.ai_platform_config_schema import AIPlatformConfigSchema
19+
from vectorize_client.models.ai_platform_type_for_pipeline import (
20+
AIPlatformTypeForPipeline,
21+
)
22+
from vectorize_client.models.create_source_connector_request import (
23+
CreateSourceConnectorRequest,
24+
)
25+
from vectorize_client.models.destination_connector_type_for_pipeline import (
26+
DestinationConnectorTypeForPipeline,
27+
)
28+
from vectorize_client.models.file_upload import FileUpload
29+
from vectorize_client.models.pipeline_ai_platform_connector_schema import (
30+
PipelineAIPlatformConnectorSchema,
31+
)
32+
from vectorize_client.models.pipeline_configuration_schema import (
33+
PipelineConfigurationSchema,
34+
)
35+
from vectorize_client.models.pipeline_destination_connector_schema import (
36+
PipelineDestinationConnectorSchema,
37+
)
38+
from vectorize_client.models.pipeline_source_connector_schema import (
39+
PipelineSourceConnectorSchema,
40+
)
41+
from vectorize_client.models.schedule_schema import ScheduleSchema
42+
from vectorize_client.models.schedule_schema_type import ScheduleSchemaType
43+
from vectorize_client.models.source_connector_type import SourceConnectorType
44+
from vectorize_client.models.start_file_upload_to_connector_request import (
45+
StartFileUploadToConnectorRequest,
46+
)
1347

1448
from langchain_vectorize.retrievers import VectorizeRetriever
1549

@@ -38,7 +72,7 @@ def environment() -> Literal["prod", "dev", "local", "staging"]:
3872
if env not in ["prod", "dev", "local", "staging"]:
3973
msg = "Invalid VECTORIZE_ENV environment variable."
4074
raise ValueError(msg)
41-
return env
75+
return env # type: ignore[return-value]
4276

4377

4478
@pytest.fixture(scope="session")
@@ -56,33 +90,31 @@ def api_client(api_token: str, environment: str) -> Iterator[ApiClient]:
5690
else:
5791
host = "https://api-staging.vectorize.io/v1"
5892

59-
with v.ApiClient(
60-
v.Configuration(host=host, access_token=api_token, debug=True),
93+
with ApiClient(
94+
Configuration(host=host, access_token=api_token, debug=True),
6195
header_name,
6296
header_value,
6397
) as api:
6498
yield api
6599

66100

67101
@pytest.fixture(scope="session")
68-
def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
69-
pipelines = v.PipelinesApi(api_client)
102+
def pipeline_id(api_client: ApiClient, org_id: str) -> Iterator[str]:
103+
pipelines = PipelinesApi(api_client)
70104

71-
connectors_api = v.SourceConnectorsApi(api_client)
105+
connectors_api = SourceConnectorsApi(api_client)
72106
response = connectors_api.create_source_connector(
73107
org_id,
74-
v.CreateSourceConnectorRequest(
75-
v.FileUpload(name="from api", type="FILE_UPLOAD")
76-
),
108+
CreateSourceConnectorRequest(FileUpload(name="from api", type="FILE_UPLOAD")),
77109
)
78110
source_connector_id = response.connector.id
79111
logging.info("Created source connector %s", source_connector_id)
80112

81-
uploads_api = v.UploadsApi(api_client)
113+
uploads_api = UploadsApi(api_client)
82114
upload_response = uploads_api.start_file_upload_to_connector(
83115
org_id,
84116
source_connector_id,
85-
v.StartFileUploadToConnectorRequest(
117+
StartFileUploadToConnectorRequest( # type: ignore[call-arg]
86118
name="research.pdf",
87119
content_type="application/pdf",
88120
metadata=json.dumps({"created-from-api": True}),
@@ -109,44 +141,44 @@ def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
109141
else:
110142
logging.info("Upload successful")
111143

112-
ai_platforms = v.AIPlatformConnectorsApi(api_client).get_ai_platform_connectors(
144+
ai_platforms = AIPlatformConnectorsApi(api_client).get_ai_platform_connectors(
113145
org_id
114146
)
115147
builtin_ai_platform = next(
116148
c.id for c in ai_platforms.ai_platform_connectors if c.type == "VECTORIZE"
117149
)
118150
logging.info("Using AI platform %s", builtin_ai_platform)
119151

120-
vector_databases = v.DestinationConnectorsApi(
121-
api_client
122-
).get_destination_connectors(org_id)
152+
vector_databases = DestinationConnectorsApi(api_client).get_destination_connectors(
153+
org_id
154+
)
123155
builtin_vector_db = next(
124156
c.id for c in vector_databases.destination_connectors if c.type == "VECTORIZE"
125157
)
126158
logging.info("Using destination connector %s", builtin_vector_db)
127159

128160
pipeline_response = pipelines.create_pipeline(
129161
org_id,
130-
v.PipelineConfigurationSchema(
162+
PipelineConfigurationSchema( # type: ignore[call-arg]
131163
source_connectors=[
132-
v.PipelineSourceConnectorSchema(
164+
PipelineSourceConnectorSchema(
133165
id=source_connector_id,
134-
type=v.SourceConnectorType.FILE_UPLOAD,
166+
type=SourceConnectorType.FILE_UPLOAD,
135167
config={},
136168
)
137169
],
138-
destination_connector=v.PipelineDestinationConnectorSchema(
170+
destination_connector=PipelineDestinationConnectorSchema(
139171
id=builtin_vector_db,
140-
type="VECTORIZE",
172+
type=DestinationConnectorTypeForPipeline.VECTORIZE,
141173
config={},
142174
),
143-
ai_platform_connector=v.PipelineAIPlatformConnectorSchema(
175+
ai_platform_connector=PipelineAIPlatformConnectorSchema(
144176
id=builtin_ai_platform,
145-
type="VECTORIZE",
146-
config={},
177+
type=AIPlatformTypeForPipeline.VECTORIZE,
178+
config=AIPlatformConfigSchema(),
147179
),
148180
pipeline_name="Test pipeline",
149-
schedule=v.ScheduleSchema(type="manual"),
181+
schedule=ScheduleSchema(type=ScheduleSchemaType.MANUAL),
150182
),
151183
)
152184
pipeline_id = pipeline_response.data.id

0 commit comments

Comments
 (0)