11import unittest
2- from unittest .mock import Mock , patch
3- from skyflow .generated .rest import V1BatchRecord , V1FieldRecords , V1DetokenizeRecordRequest , V1TokenizeRecordRequest
2+ from unittest .mock import Mock , patch , MagicMock
3+ from skyflow .generated .rest import V1BatchRecord , V1FieldRecords , V1DetokenizeRecordRequest , V1TokenizeRecordRequest , \
4+ UnauthorizedError
5+ from skyflow .generated .rest .errors import ForbiddenError
46from skyflow .utils .enums import RedactionType , TokenMode
57from skyflow .vault .controller import Vault
68from skyflow .vault .data import InsertRequest , InsertResponse , UpdateResponse , UpdateRequest , DeleteResponse , \
@@ -138,6 +140,45 @@ def test_insert_with_continue_on_error_false(self, mock_parse_response, mock_val
138140 self .assertEqual (result .inserted_fields , expected_inserted_fields )
139141 self .assertEqual (result .errors , []) # No errors expected
140142
143+ @patch ("skyflow.vault.controller._vault.validate_insert_request" )
144+ def test_insert_handles_generic_error (self , mock_validate ):
145+ request = InsertRequest (table_name = "test_table" , values = [{"column_name" : "value" }], return_tokens = False ,
146+ upsert = False ,
147+ homogeneous = False , continue_on_error = False , token_mode = Mock ())
148+ records_api = self .vault_client .get_records_api .return_value
149+ records_api .record_service_insert_record .side_effect = Exception ("Generic Exception" )
150+
151+ with self .assertRaises (Exception ):
152+ self .vault .insert (request )
153+
154+ records_api .record_service_insert_record .assert_called_once ()
155+
156+ @patch ("skyflow.vault.controller._vault.validate_insert_request" )
157+ def test_insert_handles_forbidden_error (self , mock_validate ):
158+ request = InsertRequest (table_name = "test_table" , values = [{"column_name" : "value" }], return_tokens = False ,
159+ upsert = False ,
160+ homogeneous = False , continue_on_error = False , token_mode = Mock ())
161+ records_api = self .vault_client .get_records_api .return_value
162+ records_api .record_service_insert_record .side_effect = ForbiddenError ("ForbiddenError" )
163+
164+ with self .assertRaises (Exception ):
165+ self .vault .insert (request )
166+
167+ records_api .record_service_insert_record .assert_called_once ()
168+
169+ @patch ("skyflow.vault.controller._vault.validate_insert_request" )
170+ def test_insert_handles_unauthorized_error (self , mock_validate ):
171+ request = InsertRequest (table_name = "test_table" , values = [{"column_name" : "value" }], return_tokens = False ,
172+ upsert = False ,
173+ homogeneous = False , continue_on_error = False , token_mode = Mock ())
174+ records_api = self .vault_client .get_records_api .return_value
175+ records_api .record_service_insert_record .side_effect = UnauthorizedError ("Unauthorized" )
176+
177+ with self .assertRaises (Exception ):
178+ self .vault .insert (request )
179+
180+ records_api .record_service_insert_record .assert_called_once ()
181+
141182 @patch ("skyflow.vault.controller._vault.validate_insert_request" )
142183 @patch ("skyflow.vault.controller._vault.parse_insert_response" )
143184 def test_insert_with_continue_on_error_false_when_tokens_are_not_none (self , mock_parse_response , mock_validate ):
@@ -241,6 +282,42 @@ def test_update_successful(self, mock_parse_response, mock_validate):
241282 self .assertEqual (result .updated_field , expected_updated_field )
242283 self .assertEqual (result .errors , []) # No errors expected
243284
285+ @patch ("skyflow.vault.controller._vault.validate_update_request" )
286+ def test_update_handles_generic_error (self , mock_validate ):
287+ request = UpdateRequest (table = "test_table" , data = {"skyflow_id" : "123" , "field" : "value" },
288+ return_tokens = False )
289+ records_api = self .vault_client .get_records_api .return_value
290+ records_api .record_service_update_record .side_effect = Exception ("Generic Exception" )
291+
292+ with self .assertRaises (Exception ):
293+ self .vault .update (request )
294+
295+ records_api .record_service_update_record .assert_called_once ()
296+
297+ @patch ("skyflow.vault.controller._vault.validate_update_request" )
298+ def test_update_handles_unauthorized_error (self , mock_validate ):
299+ request = UpdateRequest (table = "test_table" , data = {"skyflow_id" : "123" , "field" : "value" },
300+ return_tokens = False )
301+ records_api = self .vault_client .get_records_api .return_value
302+ records_api .record_service_update_record .side_effect = UnauthorizedError ("UnauthorizedError" )
303+
304+ with self .assertRaises (Exception ):
305+ self .vault .update (request )
306+
307+ records_api .record_service_update_record .assert_called_once ()
308+
309+ @patch ("skyflow.vault.controller._vault.validate_update_request" )
310+ def test_update_handles_forbidden_error (self , mock_validate ):
311+ request = UpdateRequest (table = "test_table" , data = {"skyflow_id" : "123" , "field" : "value" },
312+ return_tokens = False )
313+ records_api = self .vault_client .get_records_api .return_value
314+ records_api .record_service_update_record .side_effect = ForbiddenError ("ForbiddenError" )
315+
316+ with self .assertRaises (Exception ):
317+ self .vault .update (request )
318+
319+ records_api .record_service_update_record .assert_called_once ()
320+
244321 @patch ("skyflow.vault.controller._vault.validate_delete_request" )
245322 @patch ("skyflow.vault.controller._vault.parse_delete_response" )
246323 def test_delete_successful (self , mock_parse_response , mock_validate ):
@@ -284,6 +361,39 @@ def test_delete_successful(self, mock_parse_response, mock_validate):
284361 self .assertEqual (result .deleted_ids , expected_deleted_ids )
285362 self .assertEqual (result .errors , []) # No errors expected
286363
364+ @patch ("skyflow.vault.controller._vault.validate_delete_request" )
365+ def test_delete_handles_generic_exception (self , mock_validate ):
366+ request = DeleteRequest (table = "test_table" , ids = ["id1" , "id2" ])
367+ records_api = self .vault_client .get_records_api .return_value
368+ records_api .record_service_bulk_delete_record .side_effect = Exception ("Generic Error" )
369+
370+ with self .assertRaises (Exception ):
371+ self .vault .delete (request )
372+
373+ records_api .record_service_bulk_delete_record .assert_called_once ()
374+
375+ @patch ("skyflow.vault.controller._vault.validate_delete_request" )
376+ def test_delete_handles_unauthorized_error (self , mock_validate ):
377+ request = DeleteRequest (table = "test_table" , ids = ["id1" , "id2" ])
378+ records_api = self .vault_client .get_records_api .return_value
379+ records_api .record_service_bulk_delete_record .side_effect = UnauthorizedError ("Unauthorized" )
380+
381+ with self .assertRaises (Exception ):
382+ self .vault .delete (request )
383+
384+ records_api .record_service_bulk_delete_record .assert_called_once ()
385+
386+ @patch ("skyflow.vault.controller._vault.validate_delete_request" )
387+ def test_delete_handles_forbidden_error (self , mock_validate ):
388+ request = DeleteRequest (table = "test_table" , ids = ["id1" , "id2" ])
389+ records_api = self .vault_client .get_records_api .return_value
390+ records_api .record_service_bulk_delete_record .side_effect = ForbiddenError ("Forbidden" )
391+
392+ with self .assertRaises (Exception ):
393+ self .vault .delete (request )
394+
395+ records_api .record_service_bulk_delete_record .assert_called_once ()
396+
287397 @patch ("skyflow.vault.controller._vault.validate_get_request" )
288398 @patch ("skyflow.vault.controller._vault.parse_get_response" )
289399 def test_get_successful (self , mock_parse_response , mock_validate ):
@@ -402,6 +512,39 @@ def test_get_successful_with_column_values(self, mock_parse_response, mock_valid
402512 self .assertEqual (result .data , expected_data )
403513 self .assertEqual (result .errors , []) # No errors expected
404514
515+ @patch ("skyflow.vault.controller._vault.validate_get_request" )
516+ def test_get_handles_generic_error (self , mock_validate ):
517+ request = GetRequest (table = "test_table" , ids = ["id1" , "id2" ])
518+ records_api = self .vault_client .get_records_api .return_value
519+ records_api .record_service_bulk_get_record .side_effect = Exception ("Generic Exception" )
520+
521+ with self .assertRaises (Exception ):
522+ self .vault .get (request )
523+
524+ records_api .record_service_bulk_get_record .assert_called_once ()
525+
526+ @patch ("skyflow.vault.controller._vault.validate_get_request" )
527+ def test_get_handles_unauthorized_error (self , mock_validate ):
528+ request = GetRequest (table = "test_table" , ids = ["id1" , "id2" ])
529+ records_api = self .vault_client .get_records_api .return_value
530+ records_api .record_service_bulk_get_record .side_effect = UnauthorizedError ("UnauthorizedError" )
531+
532+ with self .assertRaises (Exception ):
533+ self .vault .get (request )
534+
535+ records_api .record_service_bulk_get_record .assert_called_once ()
536+
537+ @patch ("skyflow.vault.controller._vault.validate_get_request" )
538+ def test_get_handles_forbidden_error (self , mock_validate ):
539+ request = GetRequest (table = "test_table" , ids = ["id1" , "id2" ])
540+ records_api = self .vault_client .get_records_api .return_value
541+ records_api .record_service_bulk_get_record .side_effect = ForbiddenError ("ForbiddenError" )
542+
543+ with self .assertRaises (Exception ):
544+ self .vault .get (request )
545+
546+ records_api .record_service_bulk_get_record .assert_called_once ()
547+
405548 @patch ("skyflow.vault.controller._vault.validate_query_request" )
406549 @patch ("skyflow.vault.controller._vault.parse_query_response" )
407550 def test_query_successful (self , mock_parse_response , mock_validate ):
@@ -445,6 +588,39 @@ def test_query_successful(self, mock_parse_response, mock_validate):
445588 self .assertEqual (result .fields , expected_fields )
446589 self .assertEqual (result .errors , []) # No errors expected
447590
591+ @patch ("skyflow.vault.controller._vault.validate_query_request" )
592+ def test_query_handles_generic_error (self , mock_validate ):
593+ request = QueryRequest (query = "SELECT * from table_name" )
594+ query_api = self .vault_client .get_query_api .return_value
595+ query_api .query_service_execute_query .side_effect = Exception ("Generic Exception" )
596+
597+ with self .assertRaises (Exception ):
598+ self .vault .query (request )
599+
600+ query_api .query_service_execute_query .assert_called_once ()
601+
602+ @patch ("skyflow.vault.controller._vault.validate_query_request" )
603+ def test_query_handles_unauthorized_error (self , mock_validate ):
604+ request = QueryRequest (query = "SELECT * from table_name" )
605+ query_api = self .vault_client .get_query_api .return_value
606+ query_api .query_service_execute_query .side_effect = UnauthorizedError ("UnauthorizedError" )
607+
608+ with self .assertRaises (Exception ):
609+ self .vault .query (request )
610+
611+ query_api .query_service_execute_query .assert_called_once ()
612+
613+ @patch ("skyflow.vault.controller._vault.validate_query_request" )
614+ def test_query_handles_forbidden_error (self , mock_validate ):
615+ request = QueryRequest (query = "SELECT * from table_name" )
616+ query_api = self .vault_client .get_query_api .return_value
617+ query_api .query_service_execute_query .side_effect = ForbiddenError ("ForbiddenError" )
618+
619+ with self .assertRaises (Exception ):
620+ self .vault .query (request )
621+
622+ query_api .query_service_execute_query .assert_called_once ()
623+
448624 @patch ("skyflow.vault.controller._vault.validate_detokenize_request" )
449625 @patch ("skyflow.vault.controller._vault.parse_detokenize_response" )
450626 def test_detokenize_successful (self , mock_parse_response , mock_validate ):
@@ -502,6 +678,75 @@ def test_detokenize_successful(self, mock_parse_response, mock_validate):
502678 self .assertEqual (result .detokenized_fields , expected_fields )
503679 self .assertEqual (result .errors , []) # No errors expected
504680
681+ @patch ("skyflow.vault.controller._vault.validate_detokenize_request" )
682+ def test_detokenize_handles_generic_error (self , mock_validate ):
683+ request = DetokenizeRequest (
684+ data = [
685+ {
686+ 'token' : 'token1' ,
687+ 'redaction' : RedactionType .PLAIN_TEXT
688+ },
689+ {
690+ 'token' : 'token2' ,
691+ 'redaction' : RedactionType .PLAIN_TEXT
692+ }
693+ ],
694+ continue_on_error = False
695+ )
696+ tokens_api = self .vault_client .get_tokens_api .return_value
697+ tokens_api .record_service_detokenize .side_effect = Exception ("Generic Error" )
698+
699+ with self .assertRaises (Exception ):
700+ self .vault .detokenize (request )
701+
702+ tokens_api .record_service_detokenize .assert_called_once ()
703+
704+ @patch ("skyflow.vault.controller._vault.validate_detokenize_request" )
705+ def test_detokenize_handles_unauthorized_error (self , mock_validate ):
706+ request = DetokenizeRequest (
707+ data = [
708+ {
709+ 'token' : 'token1' ,
710+ 'redaction' : RedactionType .PLAIN_TEXT
711+ },
712+ {
713+ 'token' : 'token2' ,
714+ 'redaction' : RedactionType .PLAIN_TEXT
715+ }
716+ ],
717+ continue_on_error = False
718+ )
719+ tokens_api = self .vault_client .get_tokens_api .return_value
720+ tokens_api .record_service_detokenize .side_effect = UnauthorizedError ("UnauthorizedError" )
721+
722+ with self .assertRaises (Exception ):
723+ self .vault .detokenize (request )
724+
725+ tokens_api .record_service_detokenize .assert_called_once ()
726+
727+ @patch ("skyflow.vault.controller._vault.validate_detokenize_request" )
728+ def test_detokenize_handles_forbidden_error (self , mock_validate ):
729+ request = DetokenizeRequest (
730+ data = [
731+ {
732+ 'token' : 'token1' ,
733+ 'redaction' : RedactionType .PLAIN_TEXT
734+ },
735+ {
736+ 'token' : 'token2' ,
737+ 'redaction' : RedactionType .PLAIN_TEXT
738+ }
739+ ],
740+ continue_on_error = False
741+ )
742+ tokens_api = self .vault_client .get_tokens_api .return_value
743+ tokens_api .record_service_detokenize .side_effect = ForbiddenError ("ForbiddenError" )
744+
745+ with self .assertRaises (Exception ):
746+ self .vault .detokenize (request )
747+
748+ tokens_api .record_service_detokenize .assert_called_once ()
749+
505750 @patch ("skyflow.vault.controller._vault.validate_tokenize_request" )
506751 @patch ("skyflow.vault.controller._vault.parse_tokenize_response" )
507752 def test_tokenize_successful (self , mock_parse_response , mock_validate ):
@@ -551,4 +796,52 @@ def test_tokenize_successful(self, mock_parse_response, mock_validate):
551796 mock_parse_response .assert_called_once_with (mock_api_response )
552797
553798 # Check that the result matches the expected TokenizeResponse
554- self .assertEqual (result .tokenized_fields , expected_fields )
799+ self .assertEqual (result .tokenized_fields , expected_fields )
800+
801+ @patch ("skyflow.vault.controller._vault.validate_tokenize_request" )
802+ def test_tokenize_handles_generic_error (self , mock_validate ):
803+ request = TokenizeRequest (
804+ values = [
805+ {"value" : "value1" , "column_group" : "group1" },
806+ {"value" : "value2" , "column_group" : "group2" }
807+ ]
808+ )
809+ tokens_api = self .vault_client .get_tokens_api .return_value
810+ tokens_api .record_service_tokenize .side_effect = Exception ("Generic Error" )
811+
812+ with self .assertRaises (Exception ):
813+ self .vault .tokenize (request )
814+
815+ tokens_api .record_service_tokenize .assert_called_once ()
816+
817+ @patch ("skyflow.vault.controller._vault.validate_tokenize_request" )
818+ def test_tokenize_handles_unauthorized_error (self , mock_validate ):
819+ request = TokenizeRequest (
820+ values = [
821+ {"value" : "value1" , "column_group" : "group1" },
822+ {"value" : "value2" , "column_group" : "group2" }
823+ ]
824+ )
825+ tokens_api = self .vault_client .get_tokens_api .return_value
826+ tokens_api .record_service_tokenize .side_effect = UnauthorizedError ("UnauthorizedError" )
827+
828+ with self .assertRaises (Exception ):
829+ self .vault .tokenize (request )
830+
831+ tokens_api .record_service_tokenize .assert_called_once ()
832+
833+ @patch ("skyflow.vault.controller._vault.validate_tokenize_request" )
834+ def test_tokenize_handles_forbidden_error (self , mock_validate ):
835+ request = TokenizeRequest (
836+ values = [
837+ {"value" : "value1" , "column_group" : "group1" },
838+ {"value" : "value2" , "column_group" : "group2" }
839+ ]
840+ )
841+ tokens_api = self .vault_client .get_tokens_api .return_value
842+ tokens_api .record_service_tokenize .side_effect = ForbiddenError ("Forbidden Error" )
843+
844+ with self .assertRaises (Exception ):
845+ self .vault .tokenize (request )
846+
847+ tokens_api .record_service_tokenize .assert_called_once ()
0 commit comments