diff --git a/google/cloud/bigtable/data/mutations.py b/google/cloud/bigtable/data/mutations.py index 2f4e441ed..f19b1e49e 100644 --- a/google/cloud/bigtable/data/mutations.py +++ b/google/cloud/bigtable/data/mutations.py @@ -123,6 +123,14 @@ def _from_dict(cls, input_dict: dict[str, Any]) -> Mutation: instance = DeleteAllFromFamily(details["family_name"]) elif "delete_from_row" in input_dict: instance = DeleteAllFromRow() + elif "add_to_cell" in input_dict: + details = input_dict["add_to_cell"] + instance = AddToCell( + details["family_name"], + details["column_qualifier"]["raw_value"], + details["input"]["int_value"], + details["timestamp"]["raw_timestamp_micros"], + ) except KeyError as e: raise ValueError("Invalid mutation dictionary") from e if instance is None: @@ -276,6 +284,75 @@ def _to_dict(self) -> dict[str, Any]: } +@dataclass +class AddToCell(Mutation): + """ + Adds an int64 value to an aggregate cell. The column family must be an + aggregate family and have an "int64" input type or this mutation will be + rejected. + + Note: The timestamp values are in microseconds but must match the + granularity of the table (defaults to `MILLIS`). Therefore, the given value + must be a multiple of 1000 (millisecond granularity). For example: + `1571902339435000`. + + Args: + family: The name of the column family to which the cell belongs. + qualifier: The column qualifier of the cell. + value: The value to be accumulated into the cell. + timestamp_micros: The timestamp of the cell. Must be provided for + cell aggregation to work correctly. + + + Raises: + TypeError: If `qualifier` is not `bytes` or `str`. + TypeError: If `value` is not `int`. + TypeError: If `timestamp_micros` is not `int`. + ValueError: If `value` is out of bounds for a 64-bit signed int. + ValueError: If `timestamp_micros` is less than 0. + """ + + def __init__( + self, + family: str, + qualifier: bytes | str, + value: int, + timestamp_micros: int, + ): + qualifier = qualifier.encode() if isinstance(qualifier, str) else qualifier + if not isinstance(qualifier, bytes): + raise TypeError("qualifier must be bytes or str") + if not isinstance(value, int): + raise TypeError("value must be int") + if not isinstance(timestamp_micros, int): + raise TypeError("timestamp_micros must be int") + if abs(value) > _MAX_INCREMENT_VALUE: + raise ValueError( + "int values must be between -2**63 and 2**63 (64-bit signed int)" + ) + + if timestamp_micros < 0: + raise ValueError("timestamp must be non-negative") + + self.family = family + self.qualifier = qualifier + self.value = value + self.timestamp = timestamp_micros + + def _to_dict(self) -> dict[str, Any]: + return { + "add_to_cell": { + "family_name": self.family, + "column_qualifier": {"raw_value": self.qualifier}, + "timestamp": {"raw_timestamp_micros": self.timestamp}, + "input": {"int_value": self.value}, + } + } + + def is_idempotent(self) -> bool: + return False + + class RowMutationEntry: """ A single entry in a `MutateRows` request. diff --git a/tests/system/data/__init__.py b/tests/system/data/__init__.py index f2952b2cd..2b35cea8f 100644 --- a/tests/system/data/__init__.py +++ b/tests/system/data/__init__.py @@ -16,3 +16,4 @@ TEST_FAMILY = "test-family" TEST_FAMILY_2 = "test-family-2" +TEST_AGGREGATE_FAMILY = "test-aggregate-family" diff --git a/tests/system/data/setup_fixtures.py b/tests/system/data/setup_fixtures.py index a77ffc008..169e2396b 100644 --- a/tests/system/data/setup_fixtures.py +++ b/tests/system/data/setup_fixtures.py @@ -20,7 +20,7 @@ import os import uuid -from . import TEST_FAMILY, TEST_FAMILY_2 +from . import TEST_FAMILY, TEST_FAMILY_2, TEST_AGGREGATE_FAMILY # authorized view subset to allow all qualifiers ALLOW_ALL = "" @@ -183,6 +183,7 @@ def authorized_view_id( "family_subsets": { TEST_FAMILY: ALL_QUALIFIERS, TEST_FAMILY_2: ALL_QUALIFIERS, + TEST_AGGREGATE_FAMILY: ALL_QUALIFIERS, }, }, }, diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index b59131414..0dd6e8100 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -27,7 +27,7 @@ from google.cloud.bigtable.data._cross_sync import CrossSync -from . import TEST_FAMILY, TEST_FAMILY_2 +from . import TEST_FAMILY, TEST_FAMILY_2, TEST_AGGREGATE_FAMILY __CROSS_SYNC_OUTPUT__ = "tests.system.data.test_system_autogen" @@ -76,6 +76,27 @@ async def add_row( await self.target.client._gapic_client.mutate_row(request) self.rows.append(row_key) + @CrossSync.convert + async def add_aggregate_row( + self, row_key, *, family=TEST_AGGREGATE_FAMILY, qualifier=b"q", input=0 + ): + request = { + "table_name": self.target.table_name, + "row_key": row_key, + "mutations": [ + { + "add_to_cell": { + "family_name": family, + "column_qualifier": {"raw_value": qualifier}, + "timestamp": {"raw_timestamp_micros": 0}, + "input": {"int_value": input}, + } + } + ], + } + await self.target.client._gapic_client.mutate_row(request) + self.rows.append(row_key) + @CrossSync.convert async def delete_rows(self): if self.rows: @@ -132,7 +153,17 @@ def column_family_config(self): """ from google.cloud.bigtable_admin_v2 import types - return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} + int_aggregate_type = types.Type.Aggregate( + input_type=types.Type(int64_type={"encoding": {"big_endian_bytes": {}}}), + sum={}, + ) + return { + TEST_FAMILY: types.ColumnFamily(), + TEST_FAMILY_2: types.ColumnFamily(), + TEST_AGGREGATE_FAMILY: types.ColumnFamily( + value_type=types.Type(aggregate_type=int_aggregate_type) + ), + } @pytest.fixture(scope="session") def init_table_id(self): @@ -281,6 +312,37 @@ async def test_mutation_set_cell(self, target, temp_rows): # ensure cell is updated assert (await self._retrieve_cell_value(target, row_key)) == new_value + @CrossSync.pytest + @pytest.mark.usefixtures("target") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + async def test_mutation_add_to_cell(self, target, temp_rows): + """ + Test add to cell mutation + """ + from google.cloud.bigtable.data.mutations import AddToCell + + row_key = b"add_to_cell" + family = TEST_AGGREGATE_FAMILY + qualifier = b"test-qualifier" + # add row to temp_rows, for future deletion + await temp_rows.add_aggregate_row(row_key, family=family, qualifier=qualifier) + # set and check cell value + await target.mutate_row( + row_key, AddToCell(family, qualifier, 1, timestamp_micros=0) + ) + encoded_result = await self._retrieve_cell_value(target, row_key) + int_result = int.from_bytes(encoded_result, byteorder="big") + assert int_result == 1 + # update again + await target.mutate_row( + row_key, AddToCell(family, qualifier, 9, timestamp_micros=0) + ) + encoded_result = await self._retrieve_cell_value(target, row_key) + int_result = int.from_bytes(encoded_result, byteorder="big") + assert int_result == 10 + @pytest.mark.skipif( bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" ) @@ -1123,7 +1185,7 @@ async def test_execute_query_simple(self, client, table_id, instance_id): predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) async def test_execute_against_target( - self, client, instance_id, table_id, temp_rows + self, client, instance_id, table_id, temp_rows, column_family_config ): await temp_rows.add_row(b"row_key_1") result = await client.execute_query( @@ -1138,7 +1200,9 @@ async def test_execute_against_target( assert family_map[b"q"] == b"test-value" assert len(rows[0][TEST_FAMILY_2]) == 0 md = result.metadata - assert len(md) == 3 + # we expect it to fetch each column family, plus _key + # add additional families here if column_family_config changes + assert len(md) == len(column_family_config) + 1 assert md["_key"].column_type == SqlType.Bytes() assert md[TEST_FAMILY].column_type == SqlType.Map( SqlType.Bytes(), SqlType.Bytes() @@ -1146,6 +1210,9 @@ async def test_execute_against_target( assert md[TEST_FAMILY_2].column_type == SqlType.Map( SqlType.Bytes(), SqlType.Bytes() ) + assert md[TEST_AGGREGATE_FAMILY].column_type == SqlType.Map( + SqlType.Bytes(), SqlType.Int64() + ) @pytest.mark.skipif( bool(os.environ.get(BIGTABLE_EMULATOR)), @@ -1248,7 +1315,7 @@ async def test_execute_query_params(self, client, table_id, instance_id): predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) async def test_execute_metadata_on_empty_response( - self, client, instance_id, table_id, temp_rows + self, client, instance_id, table_id, temp_rows, column_family_config ): await temp_rows.add_row(b"row_key_1") result = await client.execute_query( @@ -1258,7 +1325,9 @@ async def test_execute_metadata_on_empty_response( assert len(rows) == 0 md = result.metadata - assert len(md) == 3 + # we expect it to fetch each column family, plus _key + # add additional families here if column_family_config change + assert len(md) == len(column_family_config) + 1 assert md["_key"].column_type == SqlType.Bytes() assert md[TEST_FAMILY].column_type == SqlType.Map( SqlType.Bytes(), SqlType.Bytes() @@ -1266,3 +1335,6 @@ async def test_execute_metadata_on_empty_response( assert md[TEST_FAMILY_2].column_type == SqlType.Map( SqlType.Bytes(), SqlType.Bytes() ) + assert md[TEST_AGGREGATE_FAMILY].column_type == SqlType.Map( + SqlType.Bytes(), SqlType.Int64() + ) diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index 6b2006d7b..46e9c2215 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -26,7 +26,7 @@ from google.cloud.environment_vars import BIGTABLE_EMULATOR from google.type import date_pb2 from google.cloud.bigtable.data._cross_sync import CrossSync -from . import TEST_FAMILY, TEST_FAMILY_2 +from . import TEST_FAMILY, TEST_FAMILY_2, TEST_AGGREGATE_FAMILY TARGETS = ["table"] if not os.environ.get(BIGTABLE_EMULATOR): @@ -66,6 +66,26 @@ def add_row( self.target.client._gapic_client.mutate_row(request) self.rows.append(row_key) + def add_aggregate_row( + self, row_key, *, family=TEST_AGGREGATE_FAMILY, qualifier=b"q", input=0 + ): + request = { + "table_name": self.target.table_name, + "row_key": row_key, + "mutations": [ + { + "add_to_cell": { + "family_name": family, + "column_qualifier": {"raw_value": qualifier}, + "timestamp": {"raw_timestamp_micros": 0}, + "input": {"int_value": input}, + } + } + ], + } + self.target.client._gapic_client.mutate_row(request) + self.rows.append(row_key) + def delete_rows(self): if self.rows: request = { @@ -106,7 +126,17 @@ def column_family_config(self): """specify column families to create when creating a new test table""" from google.cloud.bigtable_admin_v2 import types - return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} + int_aggregate_type = types.Type.Aggregate( + input_type=types.Type(int64_type={"encoding": {"big_endian_bytes": {}}}), + sum={}, + ) + return { + TEST_FAMILY: types.ColumnFamily(), + TEST_FAMILY_2: types.ColumnFamily(), + TEST_AGGREGATE_FAMILY: types.ColumnFamily( + value_type=types.Type(aggregate_type=int_aggregate_type) + ), + } @pytest.fixture(scope="session") def init_table_id(self): @@ -225,6 +255,27 @@ def test_mutation_set_cell(self, target, temp_rows): target.mutate_row(row_key, mutation) assert self._retrieve_cell_value(target, row_key) == new_value + @pytest.mark.usefixtures("target") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutation_add_to_cell(self, target, temp_rows): + """Test add to cell mutation""" + from google.cloud.bigtable.data.mutations import AddToCell + + row_key = b"add_to_cell" + family = TEST_AGGREGATE_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_aggregate_row(row_key, family=family, qualifier=qualifier) + target.mutate_row(row_key, AddToCell(family, qualifier, 1, timestamp_micros=0)) + encoded_result = self._retrieve_cell_value(target, row_key) + int_result = int.from_bytes(encoded_result, byteorder="big") + assert int_result == 1 + target.mutate_row(row_key, AddToCell(family, qualifier, 9, timestamp_micros=0)) + encoded_result = self._retrieve_cell_value(target, row_key) + int_result = int.from_bytes(encoded_result, byteorder="big") + assert int_result == 10 + @pytest.mark.skipif( bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" ) @@ -915,7 +966,9 @@ def test_execute_query_simple(self, client, table_id, instance_id): @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_execute_against_target(self, client, instance_id, table_id, temp_rows): + def test_execute_against_target( + self, client, instance_id, table_id, temp_rows, column_family_config + ): temp_rows.add_row(b"row_key_1") result = client.execute_query("SELECT * FROM `" + table_id + "`", instance_id) rows = [r for r in result] @@ -926,7 +979,7 @@ def test_execute_against_target(self, client, instance_id, table_id, temp_rows): assert family_map[b"q"] == b"test-value" assert len(rows[0][TEST_FAMILY_2]) == 0 md = result.metadata - assert len(md) == 3 + assert len(md) == len(column_family_config) + 1 assert md["_key"].column_type == SqlType.Bytes() assert md[TEST_FAMILY].column_type == SqlType.Map( SqlType.Bytes(), SqlType.Bytes() @@ -934,6 +987,9 @@ def test_execute_against_target(self, client, instance_id, table_id, temp_rows): assert md[TEST_FAMILY_2].column_type == SqlType.Map( SqlType.Bytes(), SqlType.Bytes() ) + assert md[TEST_AGGREGATE_FAMILY].column_type == SqlType.Map( + SqlType.Bytes(), SqlType.Int64() + ) @pytest.mark.skipif( bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't support SQL" @@ -1023,7 +1079,7 @@ def test_execute_query_params(self, client, table_id, instance_id): predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) def test_execute_metadata_on_empty_response( - self, client, instance_id, table_id, temp_rows + self, client, instance_id, table_id, temp_rows, column_family_config ): temp_rows.add_row(b"row_key_1") result = client.execute_query( @@ -1032,7 +1088,7 @@ def test_execute_metadata_on_empty_response( rows = [r for r in result] assert len(rows) == 0 md = result.metadata - assert len(md) == 3 + assert len(md) == len(column_family_config) + 1 assert md["_key"].column_type == SqlType.Bytes() assert md[TEST_FAMILY].column_type == SqlType.Map( SqlType.Bytes(), SqlType.Bytes() @@ -1040,3 +1096,6 @@ def test_execute_metadata_on_empty_response( assert md[TEST_FAMILY_2].column_type == SqlType.Map( SqlType.Bytes(), SqlType.Bytes() ) + assert md[TEST_AGGREGATE_FAMILY].column_type == SqlType.Map( + SqlType.Bytes(), SqlType.Int64() + ) diff --git a/tests/unit/data/test_mutations.py b/tests/unit/data/test_mutations.py index 485c86e42..17050162c 100644 --- a/tests/unit/data/test_mutations.py +++ b/tests/unit/data/test_mutations.py @@ -117,6 +117,17 @@ def test_size(self, test_dict): {"delete_from_family": {"family_name": "foo"}}, ), (mutations.DeleteAllFromRow, {"delete_from_row": {}}), + ( + mutations.AddToCell, + { + "add_to_cell": { + "family_name": "foo", + "column_qualifier": {"raw_value": b"bar"}, + "timestamp": {"raw_timestamp_micros": 12345}, + "input": {"int_value": 123}, + } + }, + ), ], ) def test__from_dict(self, expected_class, input_dict): @@ -162,6 +173,7 @@ def test__from_dict_wrong_subclass(self): mutations.DeleteRangeFromColumn("foo", b"bar"), mutations.DeleteAllFromFamily("foo"), mutations.DeleteAllFromRow(), + mutations.AddToCell("foo", b"bar", 123, 456), ] for instance in subclasses: others = [other for other in subclasses if other != instance] @@ -706,3 +718,105 @@ def test__from_dict(self): assert len(instance.mutations) == 1 assert isinstance(instance.mutations[0], mutations.DeleteAllFromFamily) assert instance.mutations[0].family_to_delete == "test_family" + + +class TestAddToCell: + def _target_class(self): + from google.cloud.bigtable.data.mutations import AddToCell + + return AddToCell + + def _make_one(self, *args, **kwargs): + return self._target_class()(*args, **kwargs) + + @pytest.mark.parametrize("input_val", [2**64, -(2**64)]) + def test_ctor_large_int(self, input_val): + with pytest.raises(ValueError) as e: + self._make_one( + family="f", qualifier=b"b", value=input_val, timestamp_micros=123 + ) + assert "int values must be between" in str(e.value) + + @pytest.mark.parametrize("input_val", ["", "a", "abc", "hello world!"]) + def test_ctor_str_value(self, input_val): + with pytest.raises(TypeError) as e: + self._make_one( + family="f", qualifier=b"b", value=input_val, timestamp_micros=123 + ) + assert "value must be int" in str(e.value) + + def test_ctor(self): + """Ensure constructor sets expected values""" + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_value = 1234 + expected_timestamp = 1234567890 + instance = self._make_one( + expected_family, expected_qualifier, expected_value, expected_timestamp + ) + assert instance.family == expected_family + assert instance.qualifier == expected_qualifier + assert instance.value == expected_value + assert instance.timestamp == expected_timestamp + + def test_ctor_negative_timestamp(self): + """Only non-negative timestamps are valid""" + with pytest.raises(ValueError) as e: + self._make_one("test-family", b"test-qualifier", 1234, -2) + assert "timestamp must be non-negative" in str(e.value) + + def test__to_dict(self): + """ensure dict representation is as expected""" + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_value = 1234 + expected_timestamp = 123456789 + instance = self._make_one( + expected_family, expected_qualifier, expected_value, expected_timestamp + ) + got_dict = instance._to_dict() + assert list(got_dict.keys()) == ["add_to_cell"] + got_inner_dict = got_dict["add_to_cell"] + assert got_inner_dict["family_name"] == expected_family + assert got_inner_dict["column_qualifier"]["raw_value"] == expected_qualifier + assert got_inner_dict["timestamp"]["raw_timestamp_micros"] == expected_timestamp + assert got_inner_dict["input"]["int_value"] == expected_value + assert len(got_inner_dict.keys()) == 4 + + def test__to_pb(self): + """ensure proto representation is as expected""" + import google.cloud.bigtable_v2.types.data as data_pb + + expected_family = "test-family" + expected_qualifier = b"test-qualifier" + expected_value = 1234 + expected_timestamp = 123456789 + instance = self._make_one( + expected_family, expected_qualifier, expected_value, expected_timestamp + ) + got_pb = instance._to_pb() + assert isinstance(got_pb, data_pb.Mutation) + assert got_pb.add_to_cell.family_name == expected_family + assert got_pb.add_to_cell.column_qualifier.raw_value == expected_qualifier + assert got_pb.add_to_cell.timestamp.raw_timestamp_micros == expected_timestamp + assert got_pb.add_to_cell.input.int_value == expected_value + + @pytest.mark.parametrize( + "timestamp", + [ + (1234567890), + (1), + (0), + ], + ) + def test_is_idempotent(self, timestamp): + """is_idempotent is not based on the timestamp""" + instance = self._make_one("test-family", b"test-qualifier", 1234, timestamp) + assert not instance.is_idempotent() + + def test___str__(self): + """Str representation of mutations should be to_dict""" + instance = self._make_one("test-family", b"test-qualifier", 1234, 1234567890) + str_value = instance.__str__() + dict_value = instance._to_dict() + assert str_value == str(dict_value)