1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import asyncio
15+ import random
1516import threading
1617import time
1718from typing import Any, Dict, List
3435from .base_driver import BaseDriver
3536from .driver_params import DriverCap, DriverParams
3637from .grpc.streamer_pb2 import Frame
38+ from .grpc.utils import get_grpc_client_credentials, get_grpc_server_credentials, use_aio_grpc
3739from .net_utils import MAX_FRAME_SIZE, get_address, get_tcp_urls, ssl_required
3840
3941GRPC_DEFAULT_OPTIONS = [
@@ -65,6 +67,18 @@ def __init__(self, aio_ctx: AioContext, connector: ConnectorInfo, conn_props: di
6567 self.channel = channel # for client side
6668 self.lock = threading.Lock()
6769
70+ conf = CommConfigurator()
71+ if conf.get_bool_var("simulate_unstable_network", default=False):
72+ self.disconn = threading.Thread(target=self._disconnect, daemon=True)
73+ self.disconn.start()
74+
75+ def _disconnect(self):
76+ t = random.randint(10, 60)
77+ self.logger.info(f"will close connection after {t} secs")
78+ time.sleep(t)
79+ self.logger.info(f"close connection now after {t} secs")
80+ self.close()
81+
6882 def get_conn_properties(self) -> dict:
6983 return self.conn_props
7084
@@ -101,18 +115,18 @@ async def read_loop(self, msg_iter):
101115 except grpc.aio.AioRpcError as error:
102116 if not self.closing:
103117 if error.code() == grpc.StatusCode.CANCELLED:
104- self.logger.debug (f"Connection {self} is closed by peer")
118+ self.logger.info (f"Connection {self} is closed by peer")
105119 else:
106- self.logger.debug (f"Connection {self} Error: {error.details()}")
120+ self.logger.info (f"Connection {self} Error: {error.details()}")
107121 self.logger.debug(secure_format_traceback())
108122 else:
109- self.logger.debug (f"Connection {self} is closed locally")
123+ self.logger.info (f"Connection {self} is closed locally")
110124 except Exception as ex:
111125 if not self.closing:
112- self.logger.debug (f"{self}: exception {type(ex)} in read_loop: {secure_format_exception(ex)}")
126+ self.logger.info (f"{self}: exception {type(ex)} in read_loop: {secure_format_exception(ex)}")
113127 self.logger.debug(secure_format_traceback())
114128
115- self.logger.debug (f"{self}: in {ct.name}: done read_loop")
129+ self.logger.info (f"{self}: in {ct.name}: done read_loop")
116130
117131 async def generate_output(self):
118132 ct = threading.current_thread()
@@ -123,11 +137,10 @@ async def generate_output(self):
123137 yield item
124138 except Exception as ex:
125139 if self.closing:
126- self.logger.debug (f"{self}: connection closed by {type(ex)}: {secure_format_exception(ex)}")
140+ self.logger.info (f"{self}: connection closed by {type(ex)}: {secure_format_exception(ex)}")
127141 else:
128- self.logger.debug (f"{self}: generate_output exception {type(ex)}: {secure_format_exception(ex)}")
142+ self.logger.info (f"{self}: generate_output exception {type(ex)}: {secure_format_exception(ex)}")
129143 self.logger.debug(secure_format_traceback())
130-
131144 self.logger.debug(f"{self}: done generate_output")
132145
133146
@@ -137,20 +150,10 @@ def __init__(self, server, aio_ctx: AioContext):
137150 self.aio_ctx = aio_ctx
138151 self.logger = get_logger(self)
139152
140- async def _write_loop(self, connection, grpc_context):
141- self.logger.debug("started _write_loop")
142- try:
143- while True:
144- f = await connection.oq.get()
145- await grpc_context.write(f)
146- except Exception as ex:
147- self.logger.debug(f"_write_loop except: {type(ex)}: {secure_format_exception(ex)}")
148- self.logger.debug("finished _write_loop")
149-
150153 async def Stream(self, request_iterator, context):
151154 connection = None
152- ct = threading.current_thread()
153155 try:
156+ ct = threading.current_thread()
154157 self.logger.debug(f"SERVER started Stream CB in thread {ct.name}")
155158 conn_props = {
156159 DriverParams.PEER_ADDR.value: context.peer(),
@@ -169,23 +172,22 @@ async def Stream(self, request_iterator, context):
169172 )
170173 self.logger.debug(f"SERVER created connection in thread {ct.name}")
171174 self.server.driver.add_connection(connection)
172- try:
173- await asyncio.gather(self._write_loop(connection, context), connection.read_loop(request_iterator))
174- except asyncio.CancelledError:
175- self.logger.debug("SERVER: RPC cancelled")
176- except Exception as ex:
177- self.logger.debug(f"await gather except: {type(ex)}: {secure_format_exception(ex)}")
178- self.logger.debug(f"SERVER: done await gather in thread {ct.name}")
179-
175+ self.aio_ctx.run_coro(connection.read_loop(request_iterator))
176+ while True:
177+ item = await connection.oq.get()
178+ yield item
179+ except asyncio.CancelledError:
180+ self.logger.info("SERVER: RPC cancelled")
180181 except Exception as ex:
181- self.logger.debug(f"Connection closed due to error: {secure_format_exception(ex)}")
182+ if connection:
183+ self.logger.info(f"{connection}: connection exception: {secure_format_exception(ex)}")
184+ self.logger.debug(secure_format_traceback())
182185 finally:
183186 if connection:
184- with connection.lock:
185- connection.context = None
186- self.logger.debug(f"SERVER: closing connection {connection.name}")
187+ connection.close()
188+ self.logger.info(f"SERVER: closed connection {connection.name}")
187189 self.server.driver.close_connection(connection)
188- self.logger.debug(f "SERVER: cleanly finished Stream CB in thread {ct.name} ")
190+ self.logger.info( "SERVER: finished Stream CB")
189191
190192
191193class Server:
@@ -207,10 +209,12 @@ def __init__(self, driver, connector, aio_ctx: AioContext, options, conn_ctx: _C
207209
208210 secure = ssl_required(params)
209211 if secure:
210- credentials = AioGrpcDriver. get_grpc_server_credentials(params)
212+ credentials = get_grpc_server_credentials(params)
211213 self.grpc_server.add_secure_port(addr, server_credentials=credentials)
214+ self.logger.info(f"added secure port at {addr}")
212215 else:
213216 self.grpc_server.add_insecure_port(addr)
217+ self.logger.info(f"added insecure port at {addr}")
214218 except Exception as ex:
215219 conn_ctx.error = f"cannot listen on {addr}: {type(ex)}: {secure_format_exception(ex)}"
216220 self.logger.debug(conn_ctx.error)
@@ -251,7 +255,10 @@ def __init__(self):
251255
252256 @staticmethod
253257 def supported_transports() -> List[str]:
254- return ["grpc", "grpcs"]
258+ if use_aio_grpc():
259+ return ["grpc", "grpcs"]
260+ else:
261+ return ["agrpc", "agrpcs"]
255262
256263 @staticmethod
257264 def capabilities() -> Dict[str, Any]:
@@ -280,9 +287,9 @@ def listen(self, connector: ConnectorInfo):
280287 time.sleep(0.1)
281288 if conn_ctx.error:
282289 raise CommError(code=CommError.ERROR, message=conn_ctx.error)
283- self.logger.debug( "SERVER: waiting for server to finish ")
290+ self.logger.info(f "SERVER: listening on {connector} ")
284291 conn_ctx.waiter.wait()
285- self.logger.debug( "SERVER: server is done")
292+ self.logger.info(f "SERVER: server is done listening on {connector} ")
286293
287294 async def _start_connect(self, connector: ConnectorInfo, aio_ctx: AioContext, conn_ctx: _ConnCtx):
288295 self.logger.debug("Started _start_connect coro")
@@ -295,10 +302,12 @@ async def _start_connect(self, connector: ConnectorInfo, aio_ctx: AioContext, co
295302 secure = ssl_required(params)
296303 if secure:
297304 grpc_channel = grpc.aio.secure_channel(
298- address, options=self.options, credentials=self. get_grpc_client_credentials(params)
305+ address, options=self.options, credentials=get_grpc_client_credentials(params)
299306 )
307+ self.logger.info(f"created secure channel at {address}")
300308 else:
301309 grpc_channel = grpc.aio.insecure_channel(address, options=self.options)
310+ self.logger.info(f"created insecure channel at {address}")
302311
303312 async with grpc_channel as channel:
304313 self.logger.debug(f"CLIENT: connected to {address}")
@@ -358,6 +367,7 @@ def connect(self, connector: ConnectorInfo):
358367 self.add_connection(conn_ctx.conn)
359368 conn_ctx.waiter.wait()
360369 self.close_connection(conn_ctx.conn)
370+ self.logger.info(f"CLIENT: connection {conn_ctx.conn} closed")
361371
362372 def shutdown(self):
363373 if self.closing:
@@ -374,38 +384,9 @@ def shutdown(self):
374384 def get_urls(scheme: str, resources: dict) -> (str, str):
375385 secure = resources.get(DriverParams.SECURE)
376386 if secure:
377- scheme = "grpcs"
387+ if use_aio_grpc():
388+ scheme = "grpcs"
389+ else:
390+ scheme = "agrpcs"
378391
379392 return get_tcp_urls(scheme, resources)
380-
381- @staticmethod
382- def get_grpc_client_credentials(params: dict):
383-
384- root_cert = AioGrpcDriver.read_file(params.get(DriverParams.CA_CERT.value))
385- cert_chain = AioGrpcDriver.read_file(params.get(DriverParams.CLIENT_CERT))
386- private_key = AioGrpcDriver.read_file(params.get(DriverParams.CLIENT_KEY))
387-
388- return grpc.ssl_channel_credentials(
389- certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert
390- )
391-
392- @staticmethod
393- def get_grpc_server_credentials(params: dict):
394-
395- root_cert = AioGrpcDriver.read_file(params.get(DriverParams.CA_CERT.value))
396- cert_chain = AioGrpcDriver.read_file(params.get(DriverParams.SERVER_CERT))
397- private_key = AioGrpcDriver.read_file(params.get(DriverParams.SERVER_KEY))
398-
399- return grpc.ssl_server_credentials(
400- [(private_key, cert_chain)],
401- root_certificates=root_cert,
402- require_client_auth=True,
403- )
404-
405- @staticmethod
406- def read_file(file_name: str):
407- if not file_name:
408- return None
409-
410- with open(file_name, "rb") as f:
411- return f.read()
0 commit comments