Skip to content

Commit 2b6e1b1

Browse files
Merged security fix to admin conn (#688)
* Update workspace related documentation (#684) * Update workspace related documentation * Add more details to server/client workspace and add reference * Update documentation format (#685) * Cherry-picked security fix for admin conn to 2.1 branch Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
1 parent 81dd280 commit 2b6e1b1

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

nvflare/fuel/hci/conn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,20 @@
2727

2828
LINE_END = "\x03" # Indicates the end of the line (end of text)
2929
ALL_END = "\x04" # Marks the end of a complete transmission (End of Transmission)
30-
31-
3230
MAX_MSG_SIZE = 1024
31+
MAX_DATA_SIZE = 512 * 1024 * 1024
32+
MAX_IDLE_TIME = 10
3333

3434

3535
def receive_til_end(sock, end=ALL_END):
3636
total_data = []
37-
37+
data_size = 0
38+
sock.settimeout(MAX_IDLE_TIME)
3839
while True:
3940
data = str(sock.recv(1024), "utf-8")
41+
data_size += len(data)
42+
if data_size > MAX_DATA_SIZE:
43+
raise BufferError(f"Data size exceeds limit ({MAX_DATA_SIZE} bytes)")
4044
if end in data:
4145
total_data.append(data[: data.find(end)])
4246
break

nvflare/fuel/hci/server/hci.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
import socketserver
1717
import ssl
1818
import threading
19-
import traceback
2019

2120
from nvflare.fuel.hci.conn import Connection, receive_til_end
2221
from nvflare.fuel.hci.proto import validate_proto
2322
from nvflare.fuel.hci.security import get_certificate_common_name
2423

2524
from .reg import ServerCommandRegister
2625

26+
MAX_ADMIN_CONNECTIONS = 16
27+
2728

2829
class _MsgHandler(socketserver.BaseRequestHandler):
2930
"""Message handler.
@@ -32,8 +33,23 @@ class _MsgHandler(socketserver.BaseRequestHandler):
3233
ServerCommandRegister.
3334
"""
3435

36+
connections = 0
37+
lock = threading.Lock()
38+
39+
def __init__(self, request, client_address, server):
40+
# handle() is called in the constructor so logger must be initialized first
41+
self.logger = logging.getLogger(self.__class__.__name__)
42+
super().__init__(request, client_address, server)
43+
3544
def handle(self):
3645
try:
46+
with _MsgHandler.lock:
47+
_MsgHandler.connections += 1
48+
49+
self.logger.debug(f"Concurrent admin connections: {_MsgHandler.connections}")
50+
if _MsgHandler.connections > MAX_ADMIN_CONNECTIONS:
51+
raise ConnectionRefusedError(f"Admin connection limit ({MAX_ADMIN_CONNECTIONS}) reached")
52+
3753
conn = Connection(self.request, self.server)
3854

3955
if self.server.use_ssl:
@@ -68,8 +84,13 @@ def handle(self):
6884

6985
if not conn.ended:
7086
conn.close()
71-
except BaseException:
72-
traceback.print_exc()
87+
except BaseException as exc:
88+
self.logger.error(f"Admin connection terminated due to exception: {str(exc)}")
89+
if self.logger.getEffectiveLevel() <= logging.DEBUG:
90+
self.logger.exception("Admin connection error")
91+
finally:
92+
with _MsgHandler.lock:
93+
_MsgHandler.connections -= 1
7394

7495

7596
def initialize_hci():
@@ -121,7 +142,7 @@ def __init__(
121142
ctx.load_verify_locations(ca_cert)
122143
ctx.load_cert_chain(certfile=server_cert, keyfile=server_key)
123144

124-
# replace the socket with an ssl version of itself
145+
# replace the socket with an SSL version of itself
125146
self.socket = ctx.wrap_socket(self.socket, server_side=True)
126147
self.use_ssl = True
127148

0 commit comments

Comments
 (0)