Skip to content

Commit 7ffb5b5

Browse files
authored
Merge pull request #697 from LuckySting/issue-696-fix-async-query-client
Fix: convert gRPC stream termination to YDB errors in async query client (issue #696)
2 parents dd9e1c9 + 1effa95 commit 7ffb5b5

File tree

7 files changed

+90
-7
lines changed

7 files changed

+90
-7
lines changed

tests/aio/query/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
from unittest import mock
2+
3+
import grpc
14
import pytest
5+
from grpc._cython import cygrpc
6+
27
from ydb.aio.query.session import QuerySession
38
from ydb.aio.query.pool import QuerySessionPool
49

@@ -32,3 +37,22 @@ async def tx(session):
3237
async def pool(driver):
3338
async with QuerySessionPool(driver) as pool:
3439
yield pool
40+
41+
42+
@pytest.fixture
43+
async def ydb_terminates_streams_with_unavailable():
44+
async def _patch(self):
45+
message = await self._read() # Read the first message
46+
while message is not cygrpc.EOF: # While the message is not empty, continue reading the stream
47+
yield message
48+
message = await self._read()
49+
50+
# Emulate stream termination
51+
raise grpc.aio.AioRpcError(
52+
code=grpc.StatusCode.UNAVAILABLE,
53+
initial_metadata=await self.initial_metadata(),
54+
trailing_metadata=await self.trailing_metadata(),
55+
)
56+
57+
with mock.patch.object(grpc.aio._call._StreamResponseMixin, "_fetch_stream_responses", _patch):
58+
yield

tests/aio/query/test_query_session.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import pytest
2+
3+
import ydb
24
from ydb.aio.query.session import QuerySession
35

46

@@ -113,3 +115,13 @@ async def test_two_results(self, session: QuerySession):
113115

114116
assert res == [[1], [2]]
115117
assert counter == 2
118+
119+
@pytest.mark.asyncio
120+
@pytest.mark.usefixtures("ydb_terminates_streams_with_unavailable")
121+
async def test_terminated_stream_raises_ydb_error(self, session: QuerySession):
122+
await session.create()
123+
124+
with pytest.raises(ydb.Unavailable):
125+
async with await session.execute("select 1") as results:
126+
async for _ in results:
127+
pass

tests/aio/query/test_query_transaction.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22

3+
import ydb
34
from ydb.aio.query.transaction import QueryTxContext
45
from ydb.query.transaction import QueryTxStateEnum
56

@@ -107,3 +108,13 @@ async def test_execute_two_results(self, tx: QueryTxContext):
107108

108109
assert res == [[1], [2]]
109110
assert counter == 2
111+
112+
@pytest.mark.asyncio
113+
@pytest.mark.usefixtures("ydb_terminates_streams_with_unavailable")
114+
async def test_terminated_stream_raises_ydb_error(self, tx: QueryTxContext):
115+
await tx.begin()
116+
117+
with pytest.raises(ydb.Unavailable):
118+
async with await tx.execute("select 1") as results:
119+
async for _ in results:
120+
pass

ydb/_errors.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from dataclasses import dataclass
2-
from typing import Optional
2+
from typing import Optional, Union
3+
4+
import grpc
35

46
from . import issues
57

@@ -52,3 +54,26 @@ def check_retriable_error(err, retry_settings, attempt):
5254
class ErrorRetryInfo:
5355
is_retriable: bool
5456
sleep_timeout_seconds: Optional[float]
57+
58+
59+
def stream_error_converter(exc: BaseException) -> Union[issues.Error, BaseException]:
60+
"""Converts gRPC stream errors to appropriate YDB exception types.
61+
62+
This function takes a base exception and converts specific gRPC aio stream errors
63+
to their corresponding YDB exception types for better error handling and semantic
64+
clarity.
65+
66+
Args:
67+
exc (BaseException): The original exception to potentially convert.
68+
69+
Returns:
70+
BaseException: Either a converted YDB exception or the original exception
71+
if no specific conversion rule applies.
72+
"""
73+
if isinstance(exc, (grpc.RpcError, grpc.aio.AioRpcError)):
74+
if exc.code() == grpc.StatusCode.UNAVAILABLE:
75+
return issues.Unavailable(exc.details() or "")
76+
if exc.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
77+
return issues.DeadlineExceed("Deadline exceeded on request")
78+
return issues.Error("Stream has been terminated. Original exception: {}".format(str(exc.details())))
79+
return exc

ydb/aio/_utilities.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33

44
class AsyncResponseIterator(object):
5-
def __init__(self, it, wrapper):
5+
def __init__(self, it, wrapper, error_converter=None):
66
self.it = it.__aiter__()
77
self.wrapper = wrapper
8+
self.error_converter = error_converter
89

910
def cancel(self):
1011
self.it.cancel()
@@ -17,7 +18,13 @@ def __aiter__(self):
1718
return self
1819

1920
async def _next(self):
20-
res = self.wrapper(await self.it.__anext__())
21+
try:
22+
res = self.wrapper(await self.it.__anext__())
23+
except BaseException as e:
24+
if self.error_converter:
25+
raise self.error_converter(e) from e
26+
raise e
27+
2128
if res is not None:
2229
return res
2330
return await self._next()

ydb/aio/query/session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020

2121
from ..._constants import DEFAULT_INITIAL_RESPONSE_TIMEOUT
22+
from ..._errors import stream_error_converter
2223

2324

2425
class QuerySession(BaseQuerySession):
@@ -151,12 +152,13 @@ async def execute(
151152
)
152153

153154
return AsyncResponseContextIterator(
154-
stream_it,
155-
lambda resp: base.wrap_execute_query_response(
155+
it=stream_it,
156+
wrapper=lambda resp: base.wrap_execute_query_response(
156157
rpc_state=None,
157158
response_pb=resp,
158159
session_state=self._state,
159160
session=self,
160161
settings=self._settings,
161162
),
163+
error_converter=stream_error_converter,
162164
)

ydb/aio/query/transaction.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
BaseQueryTxContext,
1212
QueryTxStateEnum,
1313
)
14+
from ..._errors import stream_error_converter
1415

1516
logger = logging.getLogger(__name__)
1617

@@ -181,14 +182,15 @@ async def execute(
181182
)
182183

183184
self._prev_stream = AsyncResponseContextIterator(
184-
stream_it,
185-
lambda resp: base.wrap_execute_query_response(
185+
it=stream_it,
186+
wrapper=lambda resp: base.wrap_execute_query_response(
186187
rpc_state=None,
187188
response_pb=resp,
188189
session_state=self._session_state,
189190
tx=self,
190191
commit_tx=commit_tx,
191192
settings=self.session._settings,
192193
),
194+
error_converter=stream_error_converter,
193195
)
194196
return self._prev_stream

0 commit comments

Comments
 (0)