Skip to content

Commit 86f3522

Browse files
SK-2131 minor fix for remote disconnect error
1 parent c117285 commit 86f3522

File tree

1 file changed

+47
-24
lines changed

1 file changed

+47
-24
lines changed

skyflow/vault/_client.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import types
66
import requests
77
import asyncio
8+
from requests.adapters import HTTPAdapter
89
from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse
910
from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody
1011
from skyflow.vault._config import Configuration, ConnectionConfig, DeleteOptions, DetokenizeOptions, GetOptions, InsertOptions, UpdateOptions, QueryOptions
@@ -36,49 +37,71 @@ def __init__(self, config: Configuration):
3637
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.TOKEN_PROVIDER_ERROR.value % (
3738
str(type(config.tokenProvider))), interface=interface)
3839

40+
self._create_session()
3941
self.vaultID = config.vaultID
4042
self.vaultURL = config.vaultURL.rstrip('/')
4143
self.tokenProvider = config.tokenProvider
4244
self.storedToken = ''
4345
log_info(InfoMessages.CLIENT_INITIALIZED.value, interface=interface)
46+
47+
def _create_session(self):
48+
self.session = requests.Session()
49+
adapter = HTTPAdapter(pool_connections=1, pool_maxsize=25, pool_block=True)
50+
self.session.mount("https://", adapter)
51+
52+
def __del__(self):
53+
if (self.session is not None):
54+
log_info(InfoMessages.CLOSING_SESSION.value, interface=InterfaceName.CLIENT.value)
55+
self.session.close()
56+
self.session = None
57+
58+
def _get_session(self):
59+
if (self.session is None):
60+
self._create_session()
61+
return self.session
4462

4563
def insert(self, records: dict, options: InsertOptions = InsertOptions()):
64+
max_retries = 1
4665
interface = InterfaceName.INSERT.value
4766
log_info(InfoMessages.INSERT_TRIGGERED.value, interface=interface)
4867
self._checkConfig(interface)
49-
5068
jsonBody = getInsertRequestBody(records, options)
5169
requestURL = self._get_complete_vault_url()
52-
self.storedToken = tokenProviderWrapper(
53-
self.storedToken, self.tokenProvider, interface)
54-
headers = {
55-
"Authorization": "Bearer " + self.storedToken,
56-
"sky-metadata": json.dumps(getMetrics())
57-
}
58-
max_retries = 3
59-
# Use for-loop for retry logic, avoid code repetition
60-
for attempt in range(max_retries+1):
70+
71+
for attempt in range(max_retries + 1):
6172
try:
62-
# If jsonBody is a dict, use json=, else use data=
63-
response = requests.post(requestURL, data=jsonBody, headers=headers)
73+
self.storedToken = tokenProviderWrapper(
74+
self.storedToken, self.tokenProvider, interface)
75+
headers = {
76+
"Authorization": "Bearer " + self.storedToken,
77+
"sky-metadata": json.dumps(getMetrics()),
78+
}
79+
response = self._get_session().post(
80+
requestURL,
81+
data=jsonBody,
82+
headers=headers,
83+
)
6484
processedResponse = processResponse(response)
6585
result, partial = convertResponse(records, processedResponse, options)
6686
if partial:
6787
log_error(SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, interface)
68-
raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, result, interface=interface)
69-
if 'records' not in result:
88+
elif 'records' not in result:
7089
log_error(SkyflowErrorMessages.BATCH_INSERT_FAILURE.value, interface)
71-
raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, SkyflowErrorMessages.BATCH_INSERT_FAILURE.value, result, interface=interface)
72-
log_info(InfoMessages.INSERT_DATA_SUCCESS.value, interface)
90+
else:
91+
log_info(InfoMessages.INSERT_DATA_SUCCESS.value, interface)
7392
return result
74-
except Exception as err:
93+
except requests.exceptions.ConnectionError as err:
7594
if attempt < max_retries:
76-
continue
77-
else:
78-
if isinstance(err, SkyflowError):
79-
raise err
80-
else:
81-
raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, f"Error occurred: {err}", interface=interface)
95+
continue
96+
raise SkyflowError(
97+
SkyflowErrorCodes.SERVER_ERROR,
98+
SkyflowErrorMessages.NETWORK_ERROR.value % str(err),
99+
interface=interface
100+
)
101+
except SkyflowError as err:
102+
if err.code != SkyflowErrorCodes.SERVER_ERROR or attempt >= max_retries:
103+
raise err
104+
continue
82105

83106
def detokenize(self, records: dict, options: DetokenizeOptions = DetokenizeOptions()):
84107
interface = InterfaceName.DETOKENIZE.value
@@ -292,4 +315,4 @@ def delete(self, records: dict, options: DeleteOptions = DeleteOptions()):
292315

293316
else:
294317
log_info(InfoMessages.DELETE_DATA_SUCCESS.value, interface)
295-
return result
318+
return result

0 commit comments

Comments
 (0)