Skip to content

Commit ff485ab

Browse files
fix(postgresql-proxy): prevent SSL COPY stalls by draining nonblocking reads
1 parent 37a1fee commit ff485ab

1 file changed

Lines changed: 117 additions & 30 deletions

File tree

postgresql_proxy/proxy.py

Lines changed: 117 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def accept_wrapper(self, sock: socket.socket):
156156
events=events,
157157
context=context,
158158
)
159+
conn.ssl_handshake_done = not isinstance(clientsocket, ssl.SSLSocket)
159160

160161
pg_conn = self._create_pg_connection(address, context)
161162

@@ -195,9 +196,20 @@ def _handle_ssl_negotiation(
195196
Returns the SSL-wrapped socket if negotiation succeeds, or the original socket.
196197
"""
197198

198-
# Peek at the first 8 bytes to check for SSLRequest
199-
# Using MSG_PEEK so we don't consume the data if it's not SSLRequest
200-
data = client_socket.recv(8, socket.MSG_PEEK)
199+
# Peek at the first 8 bytes to check for SSLRequest.
200+
# Wait briefly to avoid race conditions where the request arrives just after accept.
201+
previous_timeout = client_socket.gettimeout()
202+
try:
203+
client_socket.settimeout(0.2)
204+
data = client_socket.recv(8, socket.MSG_PEEK)
205+
except socket.timeout:
206+
return client_socket
207+
except BlockingIOError:
208+
return client_socket
209+
except OSError:
210+
return client_socket
211+
finally:
212+
client_socket.settimeout(previous_timeout)
201213

202214
if len(data) == 8:
203215
length = int.from_bytes(data[:4], "big")
@@ -208,14 +220,68 @@ def _handle_ssl_negotiation(
208220
client_socket.recv(8)
209221
# Send 'S' to indicate SSL is supported
210222
client_socket.send(b"S")
211-
# Wrap socket with SSL
212-
ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True)
213-
LOG.debug("SSL handshake completed for PostgreSQL connection")
223+
# Wrap without immediate handshake so a stalled client can't block accept loop.
224+
ssl_socket = ssl_context.wrap_socket(
225+
client_socket,
226+
server_side=True,
227+
do_handshake_on_connect=False,
228+
)
229+
LOG.debug("SSL requested, deferring TLS handshake to selector loop")
214230
return ssl_socket
215231

216232
# Not an SSLRequest, return original socket
217233
return client_socket
218234

235+
def _set_write_interest(self, conn: connection.Connection, enabled: bool):
236+
"""Enable or disable EVENT_WRITE for a connection while preserving read interest."""
237+
try:
238+
selector_key = self.selector.get_key(conn.sock)
239+
except KeyError:
240+
return
241+
242+
current_events = selector_key.events
243+
if enabled:
244+
new_events = current_events | selectors.EVENT_WRITE
245+
else:
246+
new_events = current_events & ~selectors.EVENT_WRITE
247+
248+
if new_events != current_events:
249+
self.selector.modify(conn.sock, new_events, data=conn)
250+
conn.events = new_events
251+
252+
def _flush_outgoing(
253+
self,
254+
source_conn: connection.Connection,
255+
source_sock: socket.socket,
256+
target_conn: connection.Connection | None,
257+
):
258+
if not target_conn or not target_conn.out_bytes:
259+
return
260+
261+
try:
262+
while target_conn.out_bytes:
263+
LOG.debug('sending to %s:\n%s', target_conn.name, target_conn.out_bytes)
264+
sent = target_conn.sock.send(target_conn.out_bytes)
265+
if sent == 0:
266+
break
267+
target_conn.sent(sent)
268+
except ssl.SSLWantWriteError:
269+
self._set_write_interest(target_conn, True)
270+
return
271+
except ssl.SSLWantReadError:
272+
self._set_write_interest(target_conn, True)
273+
return
274+
except OSError:
275+
# If one side is closed, close the other one
276+
# this can happen in the case where the client disconnects, and postgres still return a response
277+
# we then read the response then close the PG side of the socket.
278+
LOG.debug('error sending to %s: connection closed', target_conn.name)
279+
self._unregister_conn(source_conn)
280+
source_sock.close()
281+
return
282+
283+
self._set_write_interest(target_conn, bool(target_conn.out_bytes))
284+
219285
def service_connection(self, key: SelectorKeyProxy, mask):
220286
"""
221287
This method proxies the messages between socket. It will use properties of the Connection object to
@@ -227,37 +293,58 @@ def service_connection(self, key: SelectorKeyProxy, mask):
227293
"""
228294
sock = key.fileobj
229295
conn = key.data
296+
297+
# Drive TLS handshake in non-blocking mode so one slow client cannot block others.
298+
if isinstance(sock, ssl.SSLSocket) and not getattr(conn, "ssl_handshake_done", False):
299+
try:
300+
sock.do_handshake()
301+
conn.ssl_handshake_done = True
302+
self._set_write_interest(conn, bool(conn.out_bytes))
303+
except ssl.SSLWantReadError:
304+
return
305+
except ssl.SSLWantWriteError:
306+
self._set_write_interest(conn, True)
307+
return
308+
except OSError as e:
309+
LOG.debug('%s SSL handshake failed %s: %s', conn.name, conn.address, e)
310+
self._unregister_conn(conn)
311+
sock.close()
312+
return
313+
230314
if mask & selectors.EVENT_READ:
231315
LOG.debug('%s can receive', conn.name)
232-
try:
233-
recv_data = sock.recv(4096) # Should be ready to read
234-
if recv_data:
235-
LOG.debug('%s received data:\n%s', conn.name, recv_data)
236-
conn.received(recv_data)
237-
else:
316+
while True:
317+
try:
318+
if recv_data := sock.recv(4096):
319+
LOG.debug('%s received data:\n%s', conn.name, recv_data)
320+
conn.received(recv_data)
321+
# Keep draining bytes in the same readiness cycle until recv indicates no immediate data.
322+
continue
323+
238324
self._unregister_conn(conn)
239325
LOG.debug('%s connection closing %s', conn.name, conn.address)
240326
# A file object shall be unregistered prior to being closed.
241327
sock.close()
242-
except OSError as e:
243-
# it means the socket was closed by peer
244-
LOG.debug('%s connection closed by peer %s: %s', conn.name, conn.address, e)
245-
self._unregister_conn(conn)
328+
return
329+
except ssl.SSLWantReadError:
330+
break
331+
except ssl.SSLWantWriteError:
332+
self._set_write_interest(conn, True)
333+
break
334+
except BlockingIOError:
335+
break
336+
except OSError as e:
337+
# it means the socket was closed by peer
338+
LOG.debug('%s connection closed by peer %s: %s', conn.name, conn.address, e)
339+
self._unregister_conn(conn)
340+
return
246341

247-
next_conn = conn.redirect_conn
248-
if next_conn and next_conn.out_bytes:
249-
try:
250-
while next_conn.out_bytes:
251-
LOG.debug('sending to %s:\n%s', next_conn.name, next_conn.out_bytes)
252-
sent = next_conn.sock.send(next_conn.out_bytes)
253-
next_conn.sent(sent)
254-
except OSError:
255-
# If one side is closed, close the other one
256-
# this can happen in the case where the client disconnects, and postgres still return a response
257-
# we then read the response then close the PG side of the socket.
258-
LOG.debug('error sending to %s: connection closed', next_conn.name)
259-
self._unregister_conn(conn)
260-
sock.close()
342+
next_conn = conn.redirect_conn
343+
if next_conn and next_conn.out_bytes:
344+
self._flush_outgoing(conn, sock, next_conn)
345+
346+
if mask & selectors.EVENT_WRITE:
347+
self._flush_outgoing(conn, sock, conn)
261348

262349
def listen(self, max_connections: int = 8):
263350
"""

0 commit comments

Comments
 (0)