Skip to content

Commit d5b668b

Browse files
SK-1908: Added unit tests
1 parent b176ff8 commit d5b668b

File tree

3 files changed

+298
-5
lines changed

3 files changed

+298
-5
lines changed

ci-scripts/bump_version.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ then
77

88
sed -E "s/current_version = .+/current_version = '$SEMVER'/g" setup.py > tempfile && cat tempfile > setup.py && rm -f tempfile
99
sed -E "s/SDK_VERSION = .+/SDK_VERSION = '$SEMVER'/g" skyflow/utils/_version.py > tempfile && cat tempfile > skyflow/utils/_version.py && rm -f tempfile
10+
sed -E "s/__version__ = .+/__version__ = '$SEMVER'/g" skyflow/generated/rest/version.py > tempfile && cat tempfile > skyflow/generated/rest/version.py && rm -f tempfile
1011

1112
echo --------------------------
1213
echo "Done, Package now at $1"
@@ -18,6 +19,7 @@ else
1819

1920
sed -E "s/current_version = .+/current_version = '$DEV_VERSION'/g" setup.py > tempfile && cat tempfile > setup.py && rm -f tempfile
2021
sed -E "s/SDK_VERSION = .+/SDK_VERSION = '$DEV_VERSION'/g" skyflow/utils/_version.py > tempfile && cat tempfile > skyflow/utils/_version.py && rm -f tempfile
22+
sed -E "s/__version__ = .+/__version__ = '$DEV_VERSION'/g" skyflow/generated/rest/version.py > tempfile && cat tempfile > skyflow/generated/rest/version.py && rm -f tempfile
2123

2224
echo --------------------------
2325
echo "Done, Package now at $DEV_VERSION"

skyflow/generated/rest/version.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
from importlib import metadata
2-
31
__version__ = '2.0.0b1.dev0+3d4ee51'

tests/vault/controller/test__vault.py

Lines changed: 296 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import unittest
2-
from unittest.mock import Mock, patch
3-
from skyflow.generated.rest import V1BatchRecord, V1FieldRecords, V1DetokenizeRecordRequest, V1TokenizeRecordRequest
2+
from unittest.mock import Mock, patch, MagicMock
3+
from skyflow.generated.rest import V1BatchRecord, V1FieldRecords, V1DetokenizeRecordRequest, V1TokenizeRecordRequest, \
4+
UnauthorizedError
5+
from skyflow.generated.rest.errors import ForbiddenError
46
from skyflow.utils.enums import RedactionType, TokenMode
57
from skyflow.vault.controller import Vault
68
from skyflow.vault.data import InsertRequest, InsertResponse, UpdateResponse, UpdateRequest, DeleteResponse, \
@@ -138,6 +140,45 @@ def test_insert_with_continue_on_error_false(self, mock_parse_response, mock_val
138140
self.assertEqual(result.inserted_fields, expected_inserted_fields)
139141
self.assertEqual(result.errors, []) # No errors expected
140142

143+
@patch("skyflow.vault.controller._vault.validate_insert_request")
144+
def test_insert_handles_generic_error(self, mock_validate):
145+
request = InsertRequest(table_name="test_table", values=[{"column_name": "value"}], return_tokens=False,
146+
upsert=False,
147+
homogeneous=False, continue_on_error=False, token_mode=Mock())
148+
records_api = self.vault_client.get_records_api.return_value
149+
records_api.record_service_insert_record.side_effect = Exception("Generic Exception")
150+
151+
with self.assertRaises(Exception):
152+
self.vault.insert(request)
153+
154+
records_api.record_service_insert_record.assert_called_once()
155+
156+
@patch("skyflow.vault.controller._vault.validate_insert_request")
157+
def test_insert_handles_forbidden_error(self, mock_validate):
158+
request = InsertRequest(table_name="test_table", values=[{"column_name": "value"}], return_tokens=False,
159+
upsert=False,
160+
homogeneous=False, continue_on_error=False, token_mode=Mock())
161+
records_api = self.vault_client.get_records_api.return_value
162+
records_api.record_service_insert_record.side_effect = ForbiddenError("ForbiddenError")
163+
164+
with self.assertRaises(Exception):
165+
self.vault.insert(request)
166+
167+
records_api.record_service_insert_record.assert_called_once()
168+
169+
@patch("skyflow.vault.controller._vault.validate_insert_request")
170+
def test_insert_handles_unauthorized_error(self, mock_validate):
171+
request = InsertRequest(table_name="test_table", values=[{"column_name": "value"}], return_tokens=False,
172+
upsert=False,
173+
homogeneous=False, continue_on_error=False, token_mode=Mock())
174+
records_api = self.vault_client.get_records_api.return_value
175+
records_api.record_service_insert_record.side_effect = UnauthorizedError("Unauthorized")
176+
177+
with self.assertRaises(Exception):
178+
self.vault.insert(request)
179+
180+
records_api.record_service_insert_record.assert_called_once()
181+
141182
@patch("skyflow.vault.controller._vault.validate_insert_request")
142183
@patch("skyflow.vault.controller._vault.parse_insert_response")
143184
def test_insert_with_continue_on_error_false_when_tokens_are_not_none(self, mock_parse_response, mock_validate):
@@ -241,6 +282,42 @@ def test_update_successful(self, mock_parse_response, mock_validate):
241282
self.assertEqual(result.updated_field, expected_updated_field)
242283
self.assertEqual(result.errors, []) # No errors expected
243284

285+
@patch("skyflow.vault.controller._vault.validate_update_request")
286+
def test_update_handles_generic_error(self, mock_validate):
287+
request = UpdateRequest(table="test_table", data={"skyflow_id": "123", "field": "value"},
288+
return_tokens=False)
289+
records_api = self.vault_client.get_records_api.return_value
290+
records_api.record_service_update_record.side_effect = Exception("Generic Exception")
291+
292+
with self.assertRaises(Exception):
293+
self.vault.update(request)
294+
295+
records_api.record_service_update_record.assert_called_once()
296+
297+
@patch("skyflow.vault.controller._vault.validate_update_request")
298+
def test_update_handles_unauthorized_error(self, mock_validate):
299+
request = UpdateRequest(table="test_table", data={"skyflow_id": "123", "field": "value"},
300+
return_tokens=False)
301+
records_api = self.vault_client.get_records_api.return_value
302+
records_api.record_service_update_record.side_effect = UnauthorizedError("UnauthorizedError")
303+
304+
with self.assertRaises(Exception):
305+
self.vault.update(request)
306+
307+
records_api.record_service_update_record.assert_called_once()
308+
309+
@patch("skyflow.vault.controller._vault.validate_update_request")
310+
def test_update_handles_forbidden_error(self, mock_validate):
311+
request = UpdateRequest(table="test_table", data={"skyflow_id": "123", "field": "value"},
312+
return_tokens=False)
313+
records_api = self.vault_client.get_records_api.return_value
314+
records_api.record_service_update_record.side_effect = ForbiddenError("ForbiddenError")
315+
316+
with self.assertRaises(Exception):
317+
self.vault.update(request)
318+
319+
records_api.record_service_update_record.assert_called_once()
320+
244321
@patch("skyflow.vault.controller._vault.validate_delete_request")
245322
@patch("skyflow.vault.controller._vault.parse_delete_response")
246323
def test_delete_successful(self, mock_parse_response, mock_validate):
@@ -284,6 +361,39 @@ def test_delete_successful(self, mock_parse_response, mock_validate):
284361
self.assertEqual(result.deleted_ids, expected_deleted_ids)
285362
self.assertEqual(result.errors, []) # No errors expected
286363

364+
@patch("skyflow.vault.controller._vault.validate_delete_request")
365+
def test_delete_handles_generic_exception(self, mock_validate):
366+
request = DeleteRequest(table="test_table", ids=["id1", "id2"])
367+
records_api = self.vault_client.get_records_api.return_value
368+
records_api.record_service_bulk_delete_record.side_effect = Exception("Generic Error")
369+
370+
with self.assertRaises(Exception):
371+
self.vault.delete(request)
372+
373+
records_api.record_service_bulk_delete_record.assert_called_once()
374+
375+
@patch("skyflow.vault.controller._vault.validate_delete_request")
376+
def test_delete_handles_unauthorized_error(self, mock_validate):
377+
request = DeleteRequest(table="test_table", ids=["id1", "id2"])
378+
records_api = self.vault_client.get_records_api.return_value
379+
records_api.record_service_bulk_delete_record.side_effect = UnauthorizedError("Unauthorized")
380+
381+
with self.assertRaises(Exception):
382+
self.vault.delete(request)
383+
384+
records_api.record_service_bulk_delete_record.assert_called_once()
385+
386+
@patch("skyflow.vault.controller._vault.validate_delete_request")
387+
def test_delete_handles_forbidden_error(self, mock_validate):
388+
request = DeleteRequest(table="test_table", ids=["id1", "id2"])
389+
records_api = self.vault_client.get_records_api.return_value
390+
records_api.record_service_bulk_delete_record.side_effect = ForbiddenError("Forbidden")
391+
392+
with self.assertRaises(Exception):
393+
self.vault.delete(request)
394+
395+
records_api.record_service_bulk_delete_record.assert_called_once()
396+
287397
@patch("skyflow.vault.controller._vault.validate_get_request")
288398
@patch("skyflow.vault.controller._vault.parse_get_response")
289399
def test_get_successful(self, mock_parse_response, mock_validate):
@@ -402,6 +512,39 @@ def test_get_successful_with_column_values(self, mock_parse_response, mock_valid
402512
self.assertEqual(result.data, expected_data)
403513
self.assertEqual(result.errors, []) # No errors expected
404514

515+
@patch("skyflow.vault.controller._vault.validate_get_request")
516+
def test_get_handles_generic_error(self, mock_validate):
517+
request = GetRequest(table="test_table", ids=["id1", "id2"])
518+
records_api = self.vault_client.get_records_api.return_value
519+
records_api.record_service_bulk_get_record.side_effect = Exception("Generic Exception")
520+
521+
with self.assertRaises(Exception):
522+
self.vault.get(request)
523+
524+
records_api.record_service_bulk_get_record.assert_called_once()
525+
526+
@patch("skyflow.vault.controller._vault.validate_get_request")
527+
def test_get_handles_unauthorized_error(self, mock_validate):
528+
request = GetRequest(table="test_table", ids=["id1", "id2"])
529+
records_api = self.vault_client.get_records_api.return_value
530+
records_api.record_service_bulk_get_record.side_effect = UnauthorizedError("UnauthorizedError")
531+
532+
with self.assertRaises(Exception):
533+
self.vault.get(request)
534+
535+
records_api.record_service_bulk_get_record.assert_called_once()
536+
537+
@patch("skyflow.vault.controller._vault.validate_get_request")
538+
def test_get_handles_forbidden_error(self, mock_validate):
539+
request = GetRequest(table="test_table", ids=["id1", "id2"])
540+
records_api = self.vault_client.get_records_api.return_value
541+
records_api.record_service_bulk_get_record.side_effect = ForbiddenError("ForbiddenError")
542+
543+
with self.assertRaises(Exception):
544+
self.vault.get(request)
545+
546+
records_api.record_service_bulk_get_record.assert_called_once()
547+
405548
@patch("skyflow.vault.controller._vault.validate_query_request")
406549
@patch("skyflow.vault.controller._vault.parse_query_response")
407550
def test_query_successful(self, mock_parse_response, mock_validate):
@@ -445,6 +588,39 @@ def test_query_successful(self, mock_parse_response, mock_validate):
445588
self.assertEqual(result.fields, expected_fields)
446589
self.assertEqual(result.errors, []) # No errors expected
447590

591+
@patch("skyflow.vault.controller._vault.validate_query_request")
592+
def test_query_handles_generic_error(self, mock_validate):
593+
request = QueryRequest(query="SELECT * from table_name")
594+
query_api = self.vault_client.get_query_api.return_value
595+
query_api.query_service_execute_query.side_effect = Exception("Generic Exception")
596+
597+
with self.assertRaises(Exception):
598+
self.vault.query(request)
599+
600+
query_api.query_service_execute_query.assert_called_once()
601+
602+
@patch("skyflow.vault.controller._vault.validate_query_request")
603+
def test_query_handles_unauthorized_error(self, mock_validate):
604+
request = QueryRequest(query="SELECT * from table_name")
605+
query_api = self.vault_client.get_query_api.return_value
606+
query_api.query_service_execute_query.side_effect = UnauthorizedError("UnauthorizedError")
607+
608+
with self.assertRaises(Exception):
609+
self.vault.query(request)
610+
611+
query_api.query_service_execute_query.assert_called_once()
612+
613+
@patch("skyflow.vault.controller._vault.validate_query_request")
614+
def test_query_handles_forbidden_error(self, mock_validate):
615+
request = QueryRequest(query="SELECT * from table_name")
616+
query_api = self.vault_client.get_query_api.return_value
617+
query_api.query_service_execute_query.side_effect = ForbiddenError("ForbiddenError")
618+
619+
with self.assertRaises(Exception):
620+
self.vault.query(request)
621+
622+
query_api.query_service_execute_query.assert_called_once()
623+
448624
@patch("skyflow.vault.controller._vault.validate_detokenize_request")
449625
@patch("skyflow.vault.controller._vault.parse_detokenize_response")
450626
def test_detokenize_successful(self, mock_parse_response, mock_validate):
@@ -502,6 +678,75 @@ def test_detokenize_successful(self, mock_parse_response, mock_validate):
502678
self.assertEqual(result.detokenized_fields, expected_fields)
503679
self.assertEqual(result.errors, []) # No errors expected
504680

681+
@patch("skyflow.vault.controller._vault.validate_detokenize_request")
682+
def test_detokenize_handles_generic_error(self, mock_validate):
683+
request = DetokenizeRequest(
684+
data=[
685+
{
686+
'token': 'token1',
687+
'redaction': RedactionType.PLAIN_TEXT
688+
},
689+
{
690+
'token': 'token2',
691+
'redaction': RedactionType.PLAIN_TEXT
692+
}
693+
],
694+
continue_on_error=False
695+
)
696+
tokens_api = self.vault_client.get_tokens_api.return_value
697+
tokens_api.record_service_detokenize.side_effect = Exception("Generic Error")
698+
699+
with self.assertRaises(Exception):
700+
self.vault.detokenize(request)
701+
702+
tokens_api.record_service_detokenize.assert_called_once()
703+
704+
@patch("skyflow.vault.controller._vault.validate_detokenize_request")
705+
def test_detokenize_handles_unauthorized_error(self, mock_validate):
706+
request = DetokenizeRequest(
707+
data=[
708+
{
709+
'token': 'token1',
710+
'redaction': RedactionType.PLAIN_TEXT
711+
},
712+
{
713+
'token': 'token2',
714+
'redaction': RedactionType.PLAIN_TEXT
715+
}
716+
],
717+
continue_on_error=False
718+
)
719+
tokens_api = self.vault_client.get_tokens_api.return_value
720+
tokens_api.record_service_detokenize.side_effect = UnauthorizedError("UnauthorizedError")
721+
722+
with self.assertRaises(Exception):
723+
self.vault.detokenize(request)
724+
725+
tokens_api.record_service_detokenize.assert_called_once()
726+
727+
@patch("skyflow.vault.controller._vault.validate_detokenize_request")
728+
def test_detokenize_handles_forbidden_error(self, mock_validate):
729+
request = DetokenizeRequest(
730+
data=[
731+
{
732+
'token': 'token1',
733+
'redaction': RedactionType.PLAIN_TEXT
734+
},
735+
{
736+
'token': 'token2',
737+
'redaction': RedactionType.PLAIN_TEXT
738+
}
739+
],
740+
continue_on_error=False
741+
)
742+
tokens_api = self.vault_client.get_tokens_api.return_value
743+
tokens_api.record_service_detokenize.side_effect = ForbiddenError("ForbiddenError")
744+
745+
with self.assertRaises(Exception):
746+
self.vault.detokenize(request)
747+
748+
tokens_api.record_service_detokenize.assert_called_once()
749+
505750
@patch("skyflow.vault.controller._vault.validate_tokenize_request")
506751
@patch("skyflow.vault.controller._vault.parse_tokenize_response")
507752
def test_tokenize_successful(self, mock_parse_response, mock_validate):
@@ -551,4 +796,52 @@ def test_tokenize_successful(self, mock_parse_response, mock_validate):
551796
mock_parse_response.assert_called_once_with(mock_api_response)
552797

553798
# Check that the result matches the expected TokenizeResponse
554-
self.assertEqual(result.tokenized_fields, expected_fields)
799+
self.assertEqual(result.tokenized_fields, expected_fields)
800+
801+
@patch("skyflow.vault.controller._vault.validate_tokenize_request")
802+
def test_tokenize_handles_generic_error(self, mock_validate):
803+
request = TokenizeRequest(
804+
values=[
805+
{"value": "value1", "column_group": "group1"},
806+
{"value": "value2", "column_group": "group2"}
807+
]
808+
)
809+
tokens_api = self.vault_client.get_tokens_api.return_value
810+
tokens_api.record_service_tokenize.side_effect = Exception("Generic Error")
811+
812+
with self.assertRaises(Exception):
813+
self.vault.tokenize(request)
814+
815+
tokens_api.record_service_tokenize.assert_called_once()
816+
817+
@patch("skyflow.vault.controller._vault.validate_tokenize_request")
818+
def test_tokenize_handles_unauthorized_error(self, mock_validate):
819+
request = TokenizeRequest(
820+
values=[
821+
{"value": "value1", "column_group": "group1"},
822+
{"value": "value2", "column_group": "group2"}
823+
]
824+
)
825+
tokens_api = self.vault_client.get_tokens_api.return_value
826+
tokens_api.record_service_tokenize.side_effect = UnauthorizedError("UnauthorizedError")
827+
828+
with self.assertRaises(Exception):
829+
self.vault.tokenize(request)
830+
831+
tokens_api.record_service_tokenize.assert_called_once()
832+
833+
@patch("skyflow.vault.controller._vault.validate_tokenize_request")
834+
def test_tokenize_handles_forbidden_error(self, mock_validate):
835+
request = TokenizeRequest(
836+
values=[
837+
{"value": "value1", "column_group": "group1"},
838+
{"value": "value2", "column_group": "group2"}
839+
]
840+
)
841+
tokens_api = self.vault_client.get_tokens_api.return_value
842+
tokens_api.record_service_tokenize.side_effect = ForbiddenError("Forbidden Error")
843+
844+
with self.assertRaises(Exception):
845+
self.vault.tokenize(request)
846+
847+
tokens_api.record_service_tokenize.assert_called_once()

0 commit comments

Comments
 (0)