55import types
66import requests
77import asyncio
8+ from requests .adapters import HTTPAdapter
89from skyflow .vault ._insert import getInsertRequestBody , processResponse , convertResponse
910from skyflow .vault ._update import sendUpdateRequests , createUpdateResponseBody
1011from skyflow .vault ._config import Configuration , ConnectionConfig , DeleteOptions , DetokenizeOptions , GetOptions , InsertOptions , UpdateOptions , QueryOptions
@@ -36,49 +37,71 @@ def __init__(self, config: Configuration):
3637 raise SkyflowError (SkyflowErrorCodes .INVALID_INPUT , SkyflowErrorMessages .TOKEN_PROVIDER_ERROR .value % (
3738 str (type (config .tokenProvider ))), interface = interface )
3839
40+ self ._create_session ()
3941 self .vaultID = config .vaultID
4042 self .vaultURL = config .vaultURL .rstrip ('/' )
4143 self .tokenProvider = config .tokenProvider
4244 self .storedToken = ''
4345 log_info (InfoMessages .CLIENT_INITIALIZED .value , interface = interface )
46+
47+ def _create_session (self ):
48+ self .session = requests .Session ()
49+ adapter = HTTPAdapter (pool_connections = 1 , pool_maxsize = 25 , pool_block = True )
50+ self .session .mount ("https://" , adapter )
51+
52+ def __del__ (self ):
53+ if (self .session is not None ):
54+ log_info (InfoMessages .CLOSING_SESSION .value , interface = InterfaceName .CLIENT .value )
55+ self .session .close ()
56+ self .session = None
57+
58+ def _get_session (self ):
59+ if (self .session is None ):
60+ self ._create_session ()
61+ return self .session
4462
4563 def insert (self , records : dict , options : InsertOptions = InsertOptions ()):
64+ max_retries = 1
4665 interface = InterfaceName .INSERT .value
4766 log_info (InfoMessages .INSERT_TRIGGERED .value , interface = interface )
4867 self ._checkConfig (interface )
49-
5068 jsonBody = getInsertRequestBody (records , options )
5169 requestURL = self ._get_complete_vault_url ()
52- self .storedToken = tokenProviderWrapper (
53- self .storedToken , self .tokenProvider , interface )
54- headers = {
55- "Authorization" : "Bearer " + self .storedToken ,
56- "sky-metadata" : json .dumps (getMetrics ())
57- }
58- max_retries = 3
59- # Use for-loop for retry logic, avoid code repetition
60- for attempt in range (max_retries + 1 ):
70+
71+ for attempt in range (max_retries + 1 ):
6172 try :
62- # If jsonBody is a dict, use json=, else use data=
63- response = requests .post (requestURL , data = jsonBody , headers = headers )
73+ self .storedToken = tokenProviderWrapper (
74+ self .storedToken , self .tokenProvider , interface )
75+ headers = {
76+ "Authorization" : "Bearer " + self .storedToken ,
77+ "sky-metadata" : json .dumps (getMetrics ()),
78+ }
79+ response = self ._get_session ().post (
80+ requestURL ,
81+ data = jsonBody ,
82+ headers = headers ,
83+ )
6484 processedResponse = processResponse (response )
6585 result , partial = convertResponse (records , processedResponse , options )
6686 if partial :
6787 log_error (SkyflowErrorMessages .BATCH_INSERT_PARTIAL_SUCCESS .value , interface )
68- raise SkyflowError (SkyflowErrorCodes .PARTIAL_SUCCESS , SkyflowErrorMessages .BATCH_INSERT_PARTIAL_SUCCESS .value , result , interface = interface )
69- if 'records' not in result :
88+ elif 'records' not in result :
7089 log_error (SkyflowErrorMessages .BATCH_INSERT_FAILURE .value , interface )
71- raise SkyflowError ( SkyflowErrorCodes . SERVER_ERROR , SkyflowErrorMessages . BATCH_INSERT_FAILURE . value , result , interface = interface )
72- log_info (InfoMessages .INSERT_DATA_SUCCESS .value , interface )
90+ else :
91+ log_info (InfoMessages .INSERT_DATA_SUCCESS .value , interface )
7392 return result
74- except Exception as err :
93+ except requests . exceptions . ConnectionError as err :
7594 if attempt < max_retries :
76- continue
77- else :
78- if isinstance (err , SkyflowError ):
79- raise err
80- else :
81- raise SkyflowError (SkyflowErrorCodes .SERVER_ERROR , f"Error occurred: { err } " , interface = interface )
95+ continue
96+ raise SkyflowError (
97+ SkyflowErrorCodes .SERVER_ERROR ,
98+ SkyflowErrorMessages .NETWORK_ERROR .value % str (err ),
99+ interface = interface
100+ )
101+ except SkyflowError as err :
102+ if err .code != SkyflowErrorCodes .SERVER_ERROR or attempt >= max_retries :
103+ raise err
104+ continue
82105
83106 def detokenize (self , records : dict , options : DetokenizeOptions = DetokenizeOptions ()):
84107 interface = InterfaceName .DETOKENIZE .value
@@ -292,4 +315,4 @@ def delete(self, records: dict, options: DeleteOptions = DeleteOptions()):
292315
293316 else :
294317 log_info (InfoMessages .DELETE_DATA_SUCCESS .value , interface )
295- return result
318+ return result
0 commit comments