Skip to content

Commit 58cf5a6

Browse files
authored
Restore non-aio grpc driver (#2077)
* restore non-aio grpc driver * improve grpc drivers * changed to log_debug for job hb * fix f-str
1 parent f208127 commit 58cf5a6

File tree

16 files changed

+502
-90
lines changed

16 files changed

+502
-90
lines changed

nvflare/fuel/f3/comm_config.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,42 @@ class VarName:
3333
SUBNET_HEARTBEAT_INTERVAL = "subnet_heartbeat_interval"
3434
SUBNET_TROUBLE_THRESHOLD = "subnet_trouble_threshold"
3535
COMM_DRIVER_PATH = "comm_driver_path"
36+
USE_AIO_GRPC_VAR_NAME = "use_aio_grpc"
3637

3738

3839
class CommConfigurator:
40+
_config_loaded = False
41+
_configuration = None
42+
3943
def __init__(self):
44+
# only load once!
4045
self.logger = logging.getLogger(self.__class__.__name__)
41-
config = None
42-
for file_name in _comm_config_files:
43-
try:
44-
config = ConfigService.load_json(file_name)
45-
if config:
46-
break
47-
except FileNotFoundError:
48-
self.logger.debug(f"config file {file_name} not found from config path")
49-
config = None
50-
except Exception as ex:
51-
self.logger.error(f"failed to load config file {file_name}: {secure_format_exception(ex)}")
52-
config = None
53-
self.config = config
46+
if not CommConfigurator._config_loaded:
47+
config = None
48+
for file_name in _comm_config_files:
49+
try:
50+
config = ConfigService.load_json(file_name)
51+
if config:
52+
break
53+
except FileNotFoundError:
54+
self.logger.debug(f"config file {file_name} not found from config path")
55+
config = None
56+
except Exception as ex:
57+
self.logger.error(f"failed to load config file {file_name}: {secure_format_exception(ex)}")
58+
config = None
59+
60+
CommConfigurator._configuration = config
61+
CommConfigurator._config_loaded = True
62+
self.config = CommConfigurator._configuration
63+
64+
@staticmethod
65+
def reset():
66+
"""Reset the configurator to allow reloading config files.
67+
68+
Returns:
69+
70+
"""
71+
CommConfigurator._config_loaded = False
5472

5573
def get_config(self):
5674
return self.config
@@ -78,3 +96,18 @@ def get_subnet_trouble_threshold(self, default):
7896

7997
def get_comm_driver_path(self, default):
8098
return ConfigService.get_str_var(VarName.COMM_DRIVER_PATH, self.config, default=default)
99+
100+
def use_aio_grpc(self, default):
101+
return ConfigService.get_bool_var(VarName.USE_AIO_GRPC_VAR_NAME, self.config, default)
102+
103+
def get_int_var(self, name: str, default=None):
104+
return ConfigService.get_int_var(name, self.config, default=default)
105+
106+
def get_float_var(self, name: str, default=None):
107+
return ConfigService.get_float_var(name, self.config, default=default)
108+
109+
def get_bool_var(self, name: str, default=None):
110+
return ConfigService.get_bool_var(name, self.config, default=default)
111+
112+
def get_str_var(self, name: str, default=None):
113+
return ConfigService.get_str_var(name, self.config, default=default)

nvflare/fuel/f3/drivers/aio_grpc_driver.py

Lines changed: 51 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import asyncio
15+
import random
1516
import threading
1617
import time
1718
from typing import Any, Dict, List
@@ -34,6 +35,7 @@
3435
from .base_driver import BaseDriver
3536
from .driver_params import DriverCap, DriverParams
3637
from .grpc.streamer_pb2 import Frame
38+
from .grpc.utils import get_grpc_client_credentials, get_grpc_server_credentials, use_aio_grpc
3739
from .net_utils import MAX_FRAME_SIZE, get_address, get_tcp_urls, ssl_required
3840

3941
GRPC_DEFAULT_OPTIONS = [
@@ -65,6 +67,18 @@ def __init__(self, aio_ctx: AioContext, connector: ConnectorInfo, conn_props: di
6567
self.channel = channel # for client side
6668
self.lock = threading.Lock()
6769

70+
conf = CommConfigurator()
71+
if conf.get_bool_var("simulate_unstable_network", default=False):
72+
self.disconn = threading.Thread(target=self._disconnect, daemon=True)
73+
self.disconn.start()
74+
75+
def _disconnect(self):
76+
t = random.randint(10, 60)
77+
self.logger.info(f"will close connection after {t} secs")
78+
time.sleep(t)
79+
self.logger.info(f"close connection now after {t} secs")
80+
self.close()
81+
6882
def get_conn_properties(self) -> dict:
6983
return self.conn_props
7084

@@ -101,18 +115,18 @@ async def read_loop(self, msg_iter):
101115
except grpc.aio.AioRpcError as error:
102116
if not self.closing:
103117
if error.code() == grpc.StatusCode.CANCELLED:
104-
self.logger.debug(f"Connection {self} is closed by peer")
118+
self.logger.info(f"Connection {self} is closed by peer")
105119
else:
106-
self.logger.debug(f"Connection {self} Error: {error.details()}")
120+
self.logger.info(f"Connection {self} Error: {error.details()}")
107121
self.logger.debug(secure_format_traceback())
108122
else:
109-
self.logger.debug(f"Connection {self} is closed locally")
123+
self.logger.info(f"Connection {self} is closed locally")
110124
except Exception as ex:
111125
if not self.closing:
112-
self.logger.debug(f"{self}: exception {type(ex)} in read_loop: {secure_format_exception(ex)}")
126+
self.logger.info(f"{self}: exception {type(ex)} in read_loop: {secure_format_exception(ex)}")
113127
self.logger.debug(secure_format_traceback())
114128

115-
self.logger.debug(f"{self}: in {ct.name}: done read_loop")
129+
self.logger.info(f"{self}: in {ct.name}: done read_loop")
116130

117131
async def generate_output(self):
118132
ct = threading.current_thread()
@@ -123,11 +137,10 @@ async def generate_output(self):
123137
yield item
124138
except Exception as ex:
125139
if self.closing:
126-
self.logger.debug(f"{self}: connection closed by {type(ex)}: {secure_format_exception(ex)}")
140+
self.logger.info(f"{self}: connection closed by {type(ex)}: {secure_format_exception(ex)}")
127141
else:
128-
self.logger.debug(f"{self}: generate_output exception {type(ex)}: {secure_format_exception(ex)}")
142+
self.logger.info(f"{self}: generate_output exception {type(ex)}: {secure_format_exception(ex)}")
129143
self.logger.debug(secure_format_traceback())
130-
131144
self.logger.debug(f"{self}: done generate_output")
132145

133146

@@ -137,20 +150,10 @@ def __init__(self, server, aio_ctx: AioContext):
137150
self.aio_ctx = aio_ctx
138151
self.logger = get_logger(self)
139152

140-
async def _write_loop(self, connection, grpc_context):
141-
self.logger.debug("started _write_loop")
142-
try:
143-
while True:
144-
f = await connection.oq.get()
145-
await grpc_context.write(f)
146-
except Exception as ex:
147-
self.logger.debug(f"_write_loop except: {type(ex)}: {secure_format_exception(ex)}")
148-
self.logger.debug("finished _write_loop")
149-
150153
async def Stream(self, request_iterator, context):
151154
connection = None
152-
ct = threading.current_thread()
153155
try:
156+
ct = threading.current_thread()
154157
self.logger.debug(f"SERVER started Stream CB in thread {ct.name}")
155158
conn_props = {
156159
DriverParams.PEER_ADDR.value: context.peer(),
@@ -169,23 +172,22 @@ async def Stream(self, request_iterator, context):
169172
)
170173
self.logger.debug(f"SERVER created connection in thread {ct.name}")
171174
self.server.driver.add_connection(connection)
172-
try:
173-
await asyncio.gather(self._write_loop(connection, context), connection.read_loop(request_iterator))
174-
except asyncio.CancelledError:
175-
self.logger.debug("SERVER: RPC cancelled")
176-
except Exception as ex:
177-
self.logger.debug(f"await gather except: {type(ex)}: {secure_format_exception(ex)}")
178-
self.logger.debug(f"SERVER: done await gather in thread {ct.name}")
179-
175+
self.aio_ctx.run_coro(connection.read_loop(request_iterator))
176+
while True:
177+
item = await connection.oq.get()
178+
yield item
179+
except asyncio.CancelledError:
180+
self.logger.info("SERVER: RPC cancelled")
180181
except Exception as ex:
181-
self.logger.debug(f"Connection closed due to error: {secure_format_exception(ex)}")
182+
if connection:
183+
self.logger.info(f"{connection}: connection exception: {secure_format_exception(ex)}")
184+
self.logger.debug(secure_format_traceback())
182185
finally:
183186
if connection:
184-
with connection.lock:
185-
connection.context = None
186-
self.logger.debug(f"SERVER: closing connection {connection.name}")
187+
connection.close()
188+
self.logger.info(f"SERVER: closed connection {connection.name}")
187189
self.server.driver.close_connection(connection)
188-
self.logger.debug(f"SERVER: cleanly finished Stream CB in thread {ct.name}")
190+
self.logger.info("SERVER: finished Stream CB")
189191

190192

191193
class Server:
@@ -207,10 +209,12 @@ def __init__(self, driver, connector, aio_ctx: AioContext, options, conn_ctx: _C
207209

208210
secure = ssl_required(params)
209211
if secure:
210-
credentials = AioGrpcDriver.get_grpc_server_credentials(params)
212+
credentials = get_grpc_server_credentials(params)
211213
self.grpc_server.add_secure_port(addr, server_credentials=credentials)
214+
self.logger.info(f"added secure port at {addr}")
212215
else:
213216
self.grpc_server.add_insecure_port(addr)
217+
self.logger.info(f"added insecure port at {addr}")
214218
except Exception as ex:
215219
conn_ctx.error = f"cannot listen on {addr}: {type(ex)}: {secure_format_exception(ex)}"
216220
self.logger.debug(conn_ctx.error)
@@ -251,7 +255,10 @@ def __init__(self):
251255

252256
@staticmethod
253257
def supported_transports() -> List[str]:
254-
return ["grpc", "grpcs"]
258+
if use_aio_grpc():
259+
return ["grpc", "grpcs"]
260+
else:
261+
return ["agrpc", "agrpcs"]
255262

256263
@staticmethod
257264
def capabilities() -> Dict[str, Any]:
@@ -280,9 +287,9 @@ def listen(self, connector: ConnectorInfo):
280287
time.sleep(0.1)
281288
if conn_ctx.error:
282289
raise CommError(code=CommError.ERROR, message=conn_ctx.error)
283-
self.logger.debug("SERVER: waiting for server to finish")
290+
self.logger.info(f"SERVER: listening on {connector}")
284291
conn_ctx.waiter.wait()
285-
self.logger.debug("SERVER: server is done")
292+
self.logger.info(f"SERVER: server is done listening on {connector}")
286293

287294
async def _start_connect(self, connector: ConnectorInfo, aio_ctx: AioContext, conn_ctx: _ConnCtx):
288295
self.logger.debug("Started _start_connect coro")
@@ -295,10 +302,12 @@ async def _start_connect(self, connector: ConnectorInfo, aio_ctx: AioContext, co
295302
secure = ssl_required(params)
296303
if secure:
297304
grpc_channel = grpc.aio.secure_channel(
298-
address, options=self.options, credentials=self.get_grpc_client_credentials(params)
305+
address, options=self.options, credentials=get_grpc_client_credentials(params)
299306
)
307+
self.logger.info(f"created secure channel at {address}")
300308
else:
301309
grpc_channel = grpc.aio.insecure_channel(address, options=self.options)
310+
self.logger.info(f"created insecure channel at {address}")
302311

303312
async with grpc_channel as channel:
304313
self.logger.debug(f"CLIENT: connected to {address}")
@@ -358,6 +367,7 @@ def connect(self, connector: ConnectorInfo):
358367
self.add_connection(conn_ctx.conn)
359368
conn_ctx.waiter.wait()
360369
self.close_connection(conn_ctx.conn)
370+
self.logger.info(f"CLIENT: connection {conn_ctx.conn} closed")
361371

362372
def shutdown(self):
363373
if self.closing:
@@ -374,38 +384,9 @@ def shutdown(self):
374384
def get_urls(scheme: str, resources: dict) -> (str, str):
375385
secure = resources.get(DriverParams.SECURE)
376386
if secure:
377-
scheme = "grpcs"
387+
if use_aio_grpc():
388+
scheme = "grpcs"
389+
else:
390+
scheme = "agrpcs"
378391

379392
return get_tcp_urls(scheme, resources)
380-
381-
@staticmethod
382-
def get_grpc_client_credentials(params: dict):
383-
384-
root_cert = AioGrpcDriver.read_file(params.get(DriverParams.CA_CERT.value))
385-
cert_chain = AioGrpcDriver.read_file(params.get(DriverParams.CLIENT_CERT))
386-
private_key = AioGrpcDriver.read_file(params.get(DriverParams.CLIENT_KEY))
387-
388-
return grpc.ssl_channel_credentials(
389-
certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert
390-
)
391-
392-
@staticmethod
393-
def get_grpc_server_credentials(params: dict):
394-
395-
root_cert = AioGrpcDriver.read_file(params.get(DriverParams.CA_CERT.value))
396-
cert_chain = AioGrpcDriver.read_file(params.get(DriverParams.SERVER_CERT))
397-
private_key = AioGrpcDriver.read_file(params.get(DriverParams.SERVER_KEY))
398-
399-
return grpc.ssl_server_credentials(
400-
[(private_key, cert_chain)],
401-
root_certificates=root_cert,
402-
require_client_auth=True,
403-
)
404-
405-
@staticmethod
406-
def read_file(file_name: str):
407-
if not file_name:
408-
return None
409-
410-
with open(file_name, "rb") as f:
411-
return f.read()

nvflare/fuel/f3/drivers/grpc/qq.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import queue
17+
18+
19+
class QueueClosed(Exception):
20+
pass
21+
22+
23+
class QQ:
24+
def __init__(self):
25+
self.q = queue.Queue()
26+
self.closed = False
27+
self.logger = logging.getLogger(self.__class__.__name__)
28+
29+
def close(self):
30+
self.closed = True
31+
32+
def append(self, i):
33+
if self.closed:
34+
raise QueueClosed("queue stopped")
35+
self.q.put_nowait(i)
36+
37+
def __iter__(self):
38+
return self
39+
40+
def __next__(self):
41+
if self.closed:
42+
raise StopIteration()
43+
while True:
44+
try:
45+
return self.q.get(block=True, timeout=0.1)
46+
except queue.Empty:
47+
if self.closed:
48+
self.logger.debug("Queue closed - stop iteration")
49+
raise StopIteration()
50+
except Exception as e:
51+
self.logger.error(f"queue exception {type(e)}")
52+
raise e

0 commit comments

Comments
 (0)