From 7c513572d97d00b005a92b9a1fef2bce2c9e2182 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Sch=C3=B6nfeld?= Date: Tue, 15 Mar 2016 11:41:22 +0100 Subject: [PATCH 1/2] Use poll() instead of select() if available select() doesn't work with file-descriptiors > FD_SETSIZE, which is 1024 per default. In long-running processes that use a lot of FDs, this limit can be reached quickly, causing the calls to select() to fail. poll() doesn't have this restriction, while it is available on most systems. Therefore, it is used over select() if possible. On systems where it isn't available (esp. older versions of windows), the code falls back to utilize select(). --- apns.py | 33 +++++++++++++++++++++++++-------- tests.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 8 deletions(-) diff --git a/apns.py b/apns.py index ec52fd2..04c157b 100644 --- a/apns.py +++ b/apns.py @@ -88,10 +88,28 @@ WAIT_WRITE_TIMEOUT_SEC = 10 WAIT_READ_TIMEOUT_SEC = 10 WRITE_RETRY = 3 +WAIT_READ = 1 +WAIT_WRITE = 2 ER_STATUS = 'status' ER_IDENTIFER = 'identifier' + +def _wait_for_socket(sock, direction, timeout=None): + try: + poll = select.poll() + poll.register(sock, select.POLLIN if direction == WAIT_READ else select.POLLOUT) + if timeout: + timeout *= 1000 + events = poll.poll(timeout) + return bool(events) + except AttributeError: # fallback for systems not supporting poll() + rlist = [sock] if direction == WAIT_READ else [] + wlist = [sock] if direction == WAIT_WRITE else [] + rlist, wlist, _ = select.select(rlist, wlist, [], timeout) + return bool(rlist or wlist) + + class APNs(object): """A class representing an Apple Push Notification service connection""" @@ -217,9 +235,9 @@ def _connect(self): break except ssl.SSLError as err: if ssl.SSL_ERROR_WANT_READ == err.args[0]: - select.select([self._ssl], [], []) + _wait_for_socket(self._ssl, WAIT_READ) elif ssl.SSL_ERROR_WANT_WRITE == err.args[0]: - select.select([], [self._ssl], []) + _wait_for_socket(self._ssl, WAIT_WRITE) else: raise @@ -260,9 +278,9 @@ def read(self, n=None): def write(self, string): if self.enhanced: # nonblocking socket self._last_activity_time = time.time() - _, wlist, _ = select.select([], [self._connection()], [], WAIT_WRITE_TIMEOUT_SEC) - - if len(wlist) > 0: + writeable = _wait_for_socket(self._connection(), WAIT_WRITE, WAIT_WRITE_TIMEOUT_SEC) + + if writeable: length = self._connection().sendall(string) if length == 0: _logger.debug("sent length: %d" % length) #DEBUG @@ -595,9 +613,8 @@ def run(self): continue try: - rlist, _, _ = select.select([self._apns_connection._connection()], [], [], WAIT_READ_TIMEOUT_SEC) - - if len(rlist) > 0: # there's some data from APNs + readable = _wait_for_socket(self._apns_connection._connection(), WAIT_READ, WAIT_READ_TIMEOUT_SEC) + if readable: # there's some data from APNs with self._apns_connection._send_lock: buff = self._apns_connection.read(ERROR_RESPONSE_LENGTH) if len(buff) == ERROR_RESPONSE_LENGTH: diff --git a/tests.py b/tests.py index fb17a54..d66ef96 100644 --- a/tests.py +++ b/tests.py @@ -1,8 +1,12 @@ #!/usr/bin/env python # coding: utf-8 +from contextlib import contextmanager + from apns import * +from apns import _wait_for_socket from binascii import a2b_hex from random import random +import socket import hashlib import os @@ -209,5 +213,54 @@ def testPayloadTooLargeError(self): self.assertRaises(PayloadTooLargeError, Payload, u'\u0100' * (int(max_raw_payload_bytes / 2) + 1)) + def testWaitForSocket(self): + @contextmanager + def assert_timing(expected, delta): + start = time.time() + yield + end = time.time() + self.assertAlmostEqual(expected, end - start, delta=delta) + + socket1, socket2 = socket.socketpair() + socket1.setblocking(False) + socket2.setblocking(False) + + # Nothing was written, therefore waiting for reading should time out + with assert_timing(1, 0.1): + result = _wait_for_socket(socket1, WAIT_READ, 1) + self.assertFalse(result) + + # Send-buffer is empty, waiting for write shouldn't block + with assert_timing(0, 0.1): + result = _wait_for_socket(socket1, WAIT_WRITE, 5) + self.assertTrue(result) + socket2.send('test') + + # We just sent something, reading on the other ending shouldn't block now + with assert_timing(0, 0.1): + result = _wait_for_socket(socket1, WAIT_READ, 5) + self.assertTrue(result) + self.assertEquals(socket1.recv(1024), 'test') + + # Fill up the write-buffer + try: + while socket1.send(1024 * 'a') == 1024: + continue + except socket.error: + pass + + # Waiting for write should block now + with assert_timing(1, 0.1): + result = _wait_for_socket(socket1, WAIT_WRITE, 1) + self.assertFalse(result) + + # Closed socket returns being readable + socket2.close() + with assert_timing(0, 0.1): + result = _wait_for_socket(socket1, WAIT_READ) + self.assertTrue(result) + + socket1.close() + if __name__ == '__main__': unittest.main() From f3e0e28da0637674b9451814a9a4738312a7f0e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Sch=C3=B6nfeld?= Date: Thu, 24 Mar 2016 10:06:48 +0100 Subject: [PATCH 2/2] fix tests on python 2.6 --- tests.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests.py b/tests.py index d66ef96..3b15235 100644 --- a/tests.py +++ b/tests.py @@ -219,7 +219,9 @@ def assert_timing(expected, delta): start = time.time() yield end = time.time() - self.assertAlmostEqual(expected, end - start, delta=delta) + took = end - start + self.assertTrue(expected > took - delta / 2) + self.assertTrue(expected < took + delta / 2) socket1, socket2 = socket.socketpair() socket1.setblocking(False)