Skip to content

Commit b20ccbf

Browse files
[2/3] sdks/python: sink data with Milvus Search I/O connector (#36729)
* sdks/python: add milvus sink integration * CHANGES.md: update release notes * sdks/python: fix py docs formatting issues * sdks/python: fix linting issues * sdks/python: delegate auto-flushing to milvus backend * sdks/python: address gemini comments
1 parent 7f49978 commit b20ccbf

File tree

5 files changed

+1143
-0
lines changed

5 files changed

+1143
-0
lines changed

CHANGES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
7676
* Python examples added for Milvus search enrichment handler on [Beam Website](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment-milvus/)
7777
including jupyter notebook example (Python) ([#36176](https://github.com/apache/beam/issues/36176)).
78+
* Milvus sink I/O connector added (Python) ([#36702](https://github.com/apache/beam/issues/36702)).
79+
Now Beam has full support for Milvus integration including Milvus enrichment and sink operations.
7880

7981
## Breaking Changes
8082

Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import logging
18+
from dataclasses import dataclass
19+
from dataclasses import field
20+
from typing import Any
21+
from typing import Callable
22+
from typing import Dict
23+
from typing import List
24+
from typing import Optional
25+
26+
from pymilvus import MilvusClient
27+
from pymilvus.exceptions import MilvusException
28+
29+
import apache_beam as beam
30+
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
31+
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
32+
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec
33+
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder
34+
from apache_beam.ml.rag.types import Chunk
35+
from apache_beam.ml.rag.utils import DEFAULT_WRITE_BATCH_SIZE
36+
from apache_beam.ml.rag.utils import MilvusConnectionParameters
37+
from apache_beam.ml.rag.utils import MilvusHelpers
38+
from apache_beam.ml.rag.utils import retry_with_backoff
39+
from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs
40+
from apache_beam.transforms import DoFn
41+
42+
_LOGGER = logging.getLogger(__name__)
43+
44+
45+
@dataclass
46+
class MilvusWriteConfig:
47+
"""Configuration parameters for writing data to Milvus collections.
48+
49+
This class defines the parameters needed to write data to a Milvus collection,
50+
including collection targeting, batching behavior, and operation timeouts.
51+
52+
Args:
53+
collection_name: Name of the target Milvus collection to write data to.
54+
Must be a non-empty string.
55+
partition_name: Name of the specific partition within the collection to
56+
write to. If empty, writes to the default partition.
57+
timeout: Maximum time in seconds to wait for write operations to complete.
58+
If None, uses the client's default timeout.
59+
write_config: Configuration for write operations including batch size and
60+
other write-specific settings.
61+
kwargs: Additional keyword arguments for write operations. Enables forward
62+
compatibility with future Milvus client parameters.
63+
"""
64+
collection_name: str
65+
partition_name: str = ""
66+
timeout: Optional[float] = None
67+
write_config: WriteConfig = field(default_factory=WriteConfig)
68+
kwargs: Dict[str, Any] = field(default_factory=dict)
69+
70+
def __post_init__(self):
71+
if not self.collection_name:
72+
raise ValueError("Collection name must be provided")
73+
74+
@property
75+
def write_batch_size(self):
76+
"""Returns the batch size for write operations.
77+
78+
Returns:
79+
The configured batch size, or DEFAULT_WRITE_BATCH_SIZE if not specified.
80+
"""
81+
return self.write_config.write_batch_size or DEFAULT_WRITE_BATCH_SIZE
82+
83+
84+
@dataclass
85+
class MilvusVectorWriterConfig(VectorDatabaseWriteConfig):
86+
"""Configuration for writing vector data to Milvus collections.
87+
88+
This class extends VectorDatabaseWriteConfig to provide Milvus-specific
89+
configuration for ingesting vector embeddings and associated metadata.
90+
It defines how Apache Beam chunks are converted to Milvus records and
91+
handles the write operation parameters.
92+
93+
The configuration includes connection parameters, write settings, and
94+
column specifications that determine how chunk data is mapped to Milvus
95+
fields.
96+
97+
Args:
98+
connection_params: Configuration for connecting to the Milvus server,
99+
including URI, credentials, and connection options.
100+
write_config: Configuration for write operations including collection name,
101+
partition, batch size, and timeouts.
102+
column_specs: List of column specifications defining how chunk fields are
103+
mapped to Milvus collection fields. Defaults to standard RAG fields
104+
(id, embedding, sparse_embedding, content, metadata).
105+
106+
Example:
107+
config = MilvusVectorWriterConfig(
108+
connection_params=MilvusConnectionParameters(
109+
uri="http://localhost:19530"),
110+
write_config=MilvusWriteConfig(collection_name="my_collection"),
111+
column_specs=MilvusVectorWriterConfig.default_column_specs())
112+
"""
113+
connection_params: MilvusConnectionParameters
114+
write_config: MilvusWriteConfig
115+
column_specs: List[ColumnSpec] = field(
116+
default_factory=lambda: MilvusVectorWriterConfig.default_column_specs())
117+
118+
def create_converter(self) -> Callable[[Chunk], Dict[str, Any]]:
119+
"""Creates a function to convert Apache Beam Chunks to Milvus records.
120+
121+
Returns:
122+
A function that takes a Chunk and returns a dictionary representing
123+
a Milvus record with fields mapped according to column_specs.
124+
"""
125+
def convert(chunk: Chunk) -> Dict[str, Any]:
126+
result = {}
127+
for col in self.column_specs:
128+
result[col.column_name] = col.value_fn(chunk)
129+
return result
130+
131+
return convert
132+
133+
def create_write_transform(self) -> beam.PTransform:
134+
"""Creates the Apache Beam transform for writing to Milvus.
135+
136+
Returns:
137+
A PTransform that can be applied to a PCollection of Chunks to write
138+
them to the configured Milvus collection.
139+
"""
140+
return _WriteToMilvusVectorDatabase(self)
141+
142+
@staticmethod
143+
def default_column_specs() -> List[ColumnSpec]:
144+
"""Returns default column specifications for RAG use cases.
145+
146+
Creates column mappings for standard RAG fields: id, dense embedding,
147+
sparse embedding, content text, and metadata. These specifications
148+
define how Chunk fields are converted to Milvus-compatible formats.
149+
150+
Returns:
151+
List of ColumnSpec objects defining the default field mappings.
152+
"""
153+
column_specs = ColumnSpecsBuilder()
154+
return column_specs\
155+
.with_id_spec()\
156+
.with_embedding_spec(convert_fn=lambda values: list(values))\
157+
.with_sparse_embedding_spec(conv_fn=MilvusHelpers.sparse_embedding)\
158+
.with_content_spec()\
159+
.with_metadata_spec(convert_fn=lambda values: dict(values))\
160+
.build()
161+
162+
163+
class _WriteToMilvusVectorDatabase(beam.PTransform):
164+
"""Apache Beam PTransform for writing vector data to Milvus.
165+
166+
This transform handles the conversion of Apache Beam Chunks to Milvus records
167+
and coordinates the write operations. It applies the configured converter
168+
function and uses a DoFn for batched writes to optimize performance.
169+
170+
Args:
171+
config: MilvusVectorWriterConfig containing all necessary parameters for
172+
the write operation.
173+
"""
174+
def __init__(self, config: MilvusVectorWriterConfig):
175+
self.config = config
176+
177+
def expand(self, pcoll: beam.PCollection[Chunk]):
178+
"""Expands the PTransform to convert chunks and write to Milvus.
179+
180+
Args:
181+
pcoll: PCollection of Chunk objects to write to Milvus.
182+
183+
Returns:
184+
PCollection of dictionaries representing the records written to Milvus.
185+
"""
186+
return (
187+
pcoll
188+
| "Convert to Records" >> beam.Map(self.config.create_converter())
189+
| beam.ParDo(
190+
_WriteMilvusFn(
191+
self.config.connection_params, self.config.write_config)))
192+
193+
194+
class _WriteMilvusFn(DoFn):
195+
"""DoFn that handles batched writes to Milvus.
196+
197+
This DoFn accumulates records in batches and flushes them to Milvus when
198+
the batch size is reached or when the bundle finishes. This approach
199+
optimizes performance by reducing the number of individual write operations.
200+
201+
Args:
202+
connection_params: Configuration for connecting to the Milvus server.
203+
write_config: Configuration for write operations including batch size
204+
and collection details.
205+
"""
206+
def __init__(
207+
self,
208+
connection_params: MilvusConnectionParameters,
209+
write_config: MilvusWriteConfig):
210+
self._connection_params = connection_params
211+
self._write_config = write_config
212+
self.batch = []
213+
214+
def process(self, element, *args, **kwargs):
215+
"""Processes individual records, batching them for efficient writes.
216+
217+
Args:
218+
element: A dictionary representing a Milvus record to write.
219+
*args: Additional positional arguments.
220+
**kwargs: Additional keyword arguments.
221+
222+
Yields:
223+
The original element after adding it to the batch.
224+
"""
225+
_ = args, kwargs # Unused parameters
226+
self.batch.append(element)
227+
if len(self.batch) >= self._write_config.write_batch_size:
228+
self._flush()
229+
yield element
230+
231+
def finish_bundle(self):
232+
"""Called when a bundle finishes processing.
233+
234+
Flushes any remaining records in the batch to ensure all data is written.
235+
"""
236+
self._flush()
237+
238+
def _flush(self):
239+
"""Flushes the current batch of records to Milvus.
240+
241+
Creates a MilvusSink connection and writes all batched records,
242+
then clears the batch for the next set of records.
243+
"""
244+
if len(self.batch) == 0:
245+
return
246+
with _MilvusSink(self._connection_params, self._write_config) as sink:
247+
sink.write(self.batch)
248+
self.batch = []
249+
250+
def display_data(self):
251+
"""Returns display data for monitoring and debugging.
252+
253+
Returns:
254+
Dictionary containing database, collection, and batch size information
255+
for display in the Apache Beam monitoring UI.
256+
"""
257+
res = super().display_data()
258+
res["database"] = self._connection_params.db_name
259+
res["collection"] = self._write_config.collection_name
260+
res["batch_size"] = self._write_config.write_batch_size
261+
return res
262+
263+
264+
class _MilvusSink:
265+
"""Low-level sink for writing data directly to Milvus.
266+
267+
This class handles the direct interaction with the Milvus client for
268+
upsert operations. It manages the connection lifecycle and provides
269+
context manager support for proper resource cleanup.
270+
271+
Args:
272+
connection_params: Configuration for connecting to the Milvus server.
273+
write_config: Configuration for write operations including collection
274+
and partition targeting.
275+
"""
276+
def __init__(
277+
self,
278+
connection_params: MilvusConnectionParameters,
279+
write_config: MilvusWriteConfig):
280+
self._connection_params = connection_params
281+
self._write_config = write_config
282+
self._client = None
283+
284+
def write(self, documents):
285+
"""Writes a batch of documents to the Milvus collection.
286+
287+
Performs an upsert operation to insert new documents or update existing
288+
ones based on primary key. After the upsert, flushes the collection to
289+
ensure data persistence.
290+
291+
Args:
292+
documents: List of dictionaries representing Milvus records to write.
293+
Each dictionary should contain fields matching the collection schema.
294+
"""
295+
self._client = MilvusClient(
296+
**unpack_dataclass_with_kwargs(self._connection_params))
297+
298+
resp = self._client.upsert(
299+
collection_name=self._write_config.collection_name,
300+
partition_name=self._write_config.partition_name,
301+
data=documents,
302+
timeout=self._write_config.timeout,
303+
**self._write_config.kwargs)
304+
305+
_LOGGER.debug(
306+
"Upserted into Milvus: upsert_count=%d, cost=%d",
307+
resp.get("upsert_count", 0),
308+
resp.get("cost", 0))
309+
310+
def __enter__(self):
311+
"""Enters the context manager and establishes Milvus connection.
312+
313+
Returns:
314+
Self, enabling use in 'with' statements.
315+
"""
316+
if not self._client:
317+
connection_params = unpack_dataclass_with_kwargs(self._connection_params)
318+
319+
# Extract retry parameters from connection_params.
320+
max_retries = connection_params.pop('max_retries', 3)
321+
retry_delay = connection_params.pop('retry_delay', 1.0)
322+
retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0)
323+
324+
def create_client():
325+
return MilvusClient(**connection_params)
326+
327+
self._client = retry_with_backoff(
328+
create_client,
329+
max_retries=max_retries,
330+
retry_delay=retry_delay,
331+
retry_backoff_factor=retry_backoff_factor,
332+
operation_name="Milvus connection",
333+
exception_types=(MilvusException, ))
334+
return self
335+
336+
def __exit__(self, exc_type, exc_val, exc_tb):
337+
"""Exits the context manager and closes the Milvus connection.
338+
339+
Args:
340+
exc_type: Exception type if an exception was raised.
341+
exc_val: Exception value if an exception was raised.
342+
exc_tb: Exception traceback if an exception was raised.
343+
"""
344+
_ = exc_type, exc_val, exc_tb # Unused parameters
345+
if self._client:
346+
self._client.close()

0 commit comments

Comments
 (0)