Skip to content

Commit 1a5b4b5

Browse files
authored
feat: add support for AddToCell in Data Client (#1147)
1 parent c3e3eb0 commit 1a5b4b5

File tree

6 files changed

+337
-13
lines changed

6 files changed

+337
-13
lines changed

google/cloud/bigtable/data/mutations.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,14 @@ def _from_dict(cls, input_dict: dict[str, Any]) -> Mutation:
123123
instance = DeleteAllFromFamily(details["family_name"])
124124
elif "delete_from_row" in input_dict:
125125
instance = DeleteAllFromRow()
126+
elif "add_to_cell" in input_dict:
127+
details = input_dict["add_to_cell"]
128+
instance = AddToCell(
129+
details["family_name"],
130+
details["column_qualifier"]["raw_value"],
131+
details["input"]["int_value"],
132+
details["timestamp"]["raw_timestamp_micros"],
133+
)
126134
except KeyError as e:
127135
raise ValueError("Invalid mutation dictionary") from e
128136
if instance is None:
@@ -276,6 +284,75 @@ def _to_dict(self) -> dict[str, Any]:
276284
}
277285

278286

287+
@dataclass
288+
class AddToCell(Mutation):
289+
"""
290+
Adds an int64 value to an aggregate cell. The column family must be an
291+
aggregate family and have an "int64" input type or this mutation will be
292+
rejected.
293+
294+
Note: The timestamp values are in microseconds but must match the
295+
granularity of the table (defaults to `MILLIS`). Therefore, the given value
296+
must be a multiple of 1000 (millisecond granularity). For example:
297+
`1571902339435000`.
298+
299+
Args:
300+
family: The name of the column family to which the cell belongs.
301+
qualifier: The column qualifier of the cell.
302+
value: The value to be accumulated into the cell.
303+
timestamp_micros: The timestamp of the cell. Must be provided for
304+
cell aggregation to work correctly.
305+
306+
307+
Raises:
308+
TypeError: If `qualifier` is not `bytes` or `str`.
309+
TypeError: If `value` is not `int`.
310+
TypeError: If `timestamp_micros` is not `int`.
311+
ValueError: If `value` is out of bounds for a 64-bit signed int.
312+
ValueError: If `timestamp_micros` is less than 0.
313+
"""
314+
315+
def __init__(
316+
self,
317+
family: str,
318+
qualifier: bytes | str,
319+
value: int,
320+
timestamp_micros: int,
321+
):
322+
qualifier = qualifier.encode() if isinstance(qualifier, str) else qualifier
323+
if not isinstance(qualifier, bytes):
324+
raise TypeError("qualifier must be bytes or str")
325+
if not isinstance(value, int):
326+
raise TypeError("value must be int")
327+
if not isinstance(timestamp_micros, int):
328+
raise TypeError("timestamp_micros must be int")
329+
if abs(value) > _MAX_INCREMENT_VALUE:
330+
raise ValueError(
331+
"int values must be between -2**63 and 2**63 (64-bit signed int)"
332+
)
333+
334+
if timestamp_micros < 0:
335+
raise ValueError("timestamp must be non-negative")
336+
337+
self.family = family
338+
self.qualifier = qualifier
339+
self.value = value
340+
self.timestamp = timestamp_micros
341+
342+
def _to_dict(self) -> dict[str, Any]:
343+
return {
344+
"add_to_cell": {
345+
"family_name": self.family,
346+
"column_qualifier": {"raw_value": self.qualifier},
347+
"timestamp": {"raw_timestamp_micros": self.timestamp},
348+
"input": {"int_value": self.value},
349+
}
350+
}
351+
352+
def is_idempotent(self) -> bool:
353+
return False
354+
355+
279356
class RowMutationEntry:
280357
"""
281358
A single entry in a `MutateRows` request.

tests/system/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616

1717
TEST_FAMILY = "test-family"
1818
TEST_FAMILY_2 = "test-family-2"
19+
TEST_AGGREGATE_FAMILY = "test-aggregate-family"

tests/system/data/setup_fixtures.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import os
2121
import uuid
2222

23-
from . import TEST_FAMILY, TEST_FAMILY_2
23+
from . import TEST_FAMILY, TEST_FAMILY_2, TEST_AGGREGATE_FAMILY
2424

2525
# authorized view subset to allow all qualifiers
2626
ALLOW_ALL = ""
@@ -183,6 +183,7 @@ def authorized_view_id(
183183
"family_subsets": {
184184
TEST_FAMILY: ALL_QUALIFIERS,
185185
TEST_FAMILY_2: ALL_QUALIFIERS,
186+
TEST_AGGREGATE_FAMILY: ALL_QUALIFIERS,
186187
},
187188
},
188189
},

tests/system/data/test_system_async.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from google.cloud.bigtable.data._cross_sync import CrossSync
2929

30-
from . import TEST_FAMILY, TEST_FAMILY_2
30+
from . import TEST_FAMILY, TEST_FAMILY_2, TEST_AGGREGATE_FAMILY
3131

3232

3333
__CROSS_SYNC_OUTPUT__ = "tests.system.data.test_system_autogen"
@@ -76,6 +76,27 @@ async def add_row(
7676
await self.target.client._gapic_client.mutate_row(request)
7777
self.rows.append(row_key)
7878

79+
@CrossSync.convert
80+
async def add_aggregate_row(
81+
self, row_key, *, family=TEST_AGGREGATE_FAMILY, qualifier=b"q", input=0
82+
):
83+
request = {
84+
"table_name": self.target.table_name,
85+
"row_key": row_key,
86+
"mutations": [
87+
{
88+
"add_to_cell": {
89+
"family_name": family,
90+
"column_qualifier": {"raw_value": qualifier},
91+
"timestamp": {"raw_timestamp_micros": 0},
92+
"input": {"int_value": input},
93+
}
94+
}
95+
],
96+
}
97+
await self.target.client._gapic_client.mutate_row(request)
98+
self.rows.append(row_key)
99+
79100
@CrossSync.convert
80101
async def delete_rows(self):
81102
if self.rows:
@@ -132,7 +153,17 @@ def column_family_config(self):
132153
"""
133154
from google.cloud.bigtable_admin_v2 import types
134155

135-
return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()}
156+
int_aggregate_type = types.Type.Aggregate(
157+
input_type=types.Type(int64_type={"encoding": {"big_endian_bytes": {}}}),
158+
sum={},
159+
)
160+
return {
161+
TEST_FAMILY: types.ColumnFamily(),
162+
TEST_FAMILY_2: types.ColumnFamily(),
163+
TEST_AGGREGATE_FAMILY: types.ColumnFamily(
164+
value_type=types.Type(aggregate_type=int_aggregate_type)
165+
),
166+
}
136167

137168
@pytest.fixture(scope="session")
138169
def init_table_id(self):
@@ -281,6 +312,37 @@ async def test_mutation_set_cell(self, target, temp_rows):
281312
# ensure cell is updated
282313
assert (await self._retrieve_cell_value(target, row_key)) == new_value
283314

315+
@CrossSync.pytest
316+
@pytest.mark.usefixtures("target")
317+
@CrossSync.Retry(
318+
predicate=retry.if_exception_type(ClientError), initial=1, maximum=5
319+
)
320+
async def test_mutation_add_to_cell(self, target, temp_rows):
321+
"""
322+
Test add to cell mutation
323+
"""
324+
from google.cloud.bigtable.data.mutations import AddToCell
325+
326+
row_key = b"add_to_cell"
327+
family = TEST_AGGREGATE_FAMILY
328+
qualifier = b"test-qualifier"
329+
# add row to temp_rows, for future deletion
330+
await temp_rows.add_aggregate_row(row_key, family=family, qualifier=qualifier)
331+
# set and check cell value
332+
await target.mutate_row(
333+
row_key, AddToCell(family, qualifier, 1, timestamp_micros=0)
334+
)
335+
encoded_result = await self._retrieve_cell_value(target, row_key)
336+
int_result = int.from_bytes(encoded_result, byteorder="big")
337+
assert int_result == 1
338+
# update again
339+
await target.mutate_row(
340+
row_key, AddToCell(family, qualifier, 9, timestamp_micros=0)
341+
)
342+
encoded_result = await self._retrieve_cell_value(target, row_key)
343+
int_result = int.from_bytes(encoded_result, byteorder="big")
344+
assert int_result == 10
345+
284346
@pytest.mark.skipif(
285347
bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits"
286348
)
@@ -1123,7 +1185,7 @@ async def test_execute_query_simple(self, client, table_id, instance_id):
11231185
predicate=retry.if_exception_type(ClientError), initial=1, maximum=5
11241186
)
11251187
async def test_execute_against_target(
1126-
self, client, instance_id, table_id, temp_rows
1188+
self, client, instance_id, table_id, temp_rows, column_family_config
11271189
):
11281190
await temp_rows.add_row(b"row_key_1")
11291191
result = await client.execute_query(
@@ -1138,14 +1200,19 @@ async def test_execute_against_target(
11381200
assert family_map[b"q"] == b"test-value"
11391201
assert len(rows[0][TEST_FAMILY_2]) == 0
11401202
md = result.metadata
1141-
assert len(md) == 3
1203+
# we expect it to fetch each column family, plus _key
1204+
# add additional families here if column_family_config changes
1205+
assert len(md) == len(column_family_config) + 1
11421206
assert md["_key"].column_type == SqlType.Bytes()
11431207
assert md[TEST_FAMILY].column_type == SqlType.Map(
11441208
SqlType.Bytes(), SqlType.Bytes()
11451209
)
11461210
assert md[TEST_FAMILY_2].column_type == SqlType.Map(
11471211
SqlType.Bytes(), SqlType.Bytes()
11481212
)
1213+
assert md[TEST_AGGREGATE_FAMILY].column_type == SqlType.Map(
1214+
SqlType.Bytes(), SqlType.Int64()
1215+
)
11491216

11501217
@pytest.mark.skipif(
11511218
bool(os.environ.get(BIGTABLE_EMULATOR)),
@@ -1248,7 +1315,7 @@ async def test_execute_query_params(self, client, table_id, instance_id):
12481315
predicate=retry.if_exception_type(ClientError), initial=1, maximum=5
12491316
)
12501317
async def test_execute_metadata_on_empty_response(
1251-
self, client, instance_id, table_id, temp_rows
1318+
self, client, instance_id, table_id, temp_rows, column_family_config
12521319
):
12531320
await temp_rows.add_row(b"row_key_1")
12541321
result = await client.execute_query(
@@ -1258,11 +1325,16 @@ async def test_execute_metadata_on_empty_response(
12581325

12591326
assert len(rows) == 0
12601327
md = result.metadata
1261-
assert len(md) == 3
1328+
# we expect it to fetch each column family, plus _key
1329+
# add additional families here if column_family_config change
1330+
assert len(md) == len(column_family_config) + 1
12621331
assert md["_key"].column_type == SqlType.Bytes()
12631332
assert md[TEST_FAMILY].column_type == SqlType.Map(
12641333
SqlType.Bytes(), SqlType.Bytes()
12651334
)
12661335
assert md[TEST_FAMILY_2].column_type == SqlType.Map(
12671336
SqlType.Bytes(), SqlType.Bytes()
12681337
)
1338+
assert md[TEST_AGGREGATE_FAMILY].column_type == SqlType.Map(
1339+
SqlType.Bytes(), SqlType.Int64()
1340+
)

0 commit comments

Comments
 (0)