From 3861996acde1edc2fb3335fbb61e2569a1aa2dc4 Mon Sep 17 00:00:00 2001
From: Eli Courtwright <eli@courtwright.org>
Date: Sun, 21 Oct 2018 23:23:08 -0400
Subject: [PATCH] added preliminary Python 3 support

---
 .gitignore    |  2 ++
 collectd.py   | 39 +++++++++++++++++++++++++++------------
 setup.py      |  2 +-
 unit_tests.py | 19 ++++++++++---------
 4 files changed, 40 insertions(+), 22 deletions(-)
 create mode 100644 .gitignore

diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..a295864
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+*.pyc
+__pycache__
diff --git a/collectd.py b/collectd.py
index 2fee145..57561a6 100644
--- a/collectd.py
+++ b/collectd.py
@@ -1,14 +1,19 @@
 import re
+import sys
 import time
 import socket
 import struct
 import logging
 import traceback
 from functools import wraps
-from Queue import Queue, Empty
 from collections import defaultdict
 from threading import RLock, Thread, Semaphore
 
+try:
+    from Queue import Queue, Empty  # Python 2
+except ImportError:
+    from queue import Queue, Empty  # Python 3
+
 
 __all__ = ["Connection", "start_threads"]
 
@@ -17,6 +22,8 @@
 
 logger = logging.getLogger("collectd")
 
+StringTypes = (type(b""), type(u""))
+
 SEND_INTERVAL = 10      # seconds
 MAX_PACKET_SIZE = 1024  # bytes
 
@@ -46,20 +53,24 @@
 
 
 def pack_numeric(type_code, number):
-    return struct.pack("!HHq", type_code, 12, number)
+    return struct.pack("!HHq", type_code, 12, int(number))
 
 def pack_string(type_code, string):
-    return struct.pack("!HH", type_code, 5 + len(string)) + string + "\0"
+    if isinstance(string, type(u"")):
+        string = string.encode("UTF-8")
+    return struct.pack("!HH", type_code, 5 + len(string)) + string + b"\0"
 
 def pack_value(name, value):
-    return "".join([
+    if isinstance(value, type(u"")):
+        value = value.encode("UTF-8")
+    return b"".join([
         pack(TYPE_TYPE_INSTANCE, name),
         struct.pack("!HHH", TYPE_VALUES, 15, 1),
         struct.pack("<Bd", VALUE_GAUGE, value)
     ])
 
 def pack(id, value):
-    if isinstance(id, basestring):
+    if isinstance(id, StringTypes):
         return pack_value(id, value)
     elif id in LONG_INT_CODES:
         return pack_numeric(id, value)
@@ -69,7 +80,7 @@ def pack(id, value):
         raise AssertionError("invalid type code " + str(id))
 
 def message_start(when=None, host=socket.gethostname(), plugin_inst="", plugin_name="any"):
-    return "".join([
+    return b"".join([
         pack(TYPE_HOST, host),
         pack(TYPE_TIME, when or time.time()),
         pack(TYPE_PLUGIN, plugin_name),
@@ -87,16 +98,18 @@ def messages(counts, when=None, host=socket.gethostname(), plugin_inst="", plugi
         curr, curr_len = [start], len(start)
         for part in parts:
             if curr_len + len(part) > MAX_PACKET_SIZE:
-                packets.append("".join(curr))
+                packets.append(b"".join(curr))
                 curr, curr_len = [start], len(start)
             curr.append(part)
             curr_len += len(part)
-        packets.append("".join(curr))
+        packets.append(b"".join(curr))
     return packets
 
 
 
 def sanitize(s):
+    if sys.version_info.major == 3 and isinstance(s, bytes):
+        s = s.decode("UTF-8")
     return re.sub(r"[^a-zA-Z0-9]+", "_", s).strip("_")
 
 def swallow_errors(func):
@@ -127,18 +140,20 @@ def __init__(self, category):
     @swallow_errors
     @synchronized
     def record(self, *args, **kwargs):
-        for specific in list(args) + [""]:
-            assert isinstance(specific, basestring)
+        for specific in list(args) + [b""]:
+            assert isinstance(specific, StringTypes), str(type(specific))
+            if isinstance(specific, type(u"")):
+                specific = specific.encode("UTF-8")
             for stat, value in kwargs.items():
                 assert isinstance(value, (int, float))
-                self.counts[str(specific)][str(stat)] += value
+                self.counts[specific][stat] += value
     
     @swallow_errors
     @synchronized
     def set_exact(self, **kwargs):
         for stat, value in kwargs.items():
             assert isinstance(value, (int, float))
-            self.counts[""][str(stat)] = value
+            self.counts[b""][str(stat)] = value
     
     @synchronized
     def snapshot(self):
diff --git a/setup.py b/setup.py
index bec646e..7fa5ff1 100644
--- a/setup.py
+++ b/setup.py
@@ -2,7 +2,7 @@
 
 setup(
     name = "collectd",
-    version = "1.0.2",
+    version = "1.0.3",
     py_modules = ["collectd"],
     
     author = "Eli Courtwright",
diff --git a/unit_tests.py b/unit_tests.py
index 88303b7..e8b3a7b 100644
--- a/unit_tests.py
+++ b/unit_tests.py
@@ -22,14 +22,15 @@ def assertValidPacket(self, expected_type_count, s):
                 self.assertEqual(size, 12)
                 struct.unpack("!q", s[4:12])
             elif type_code in collectd.STRING_CODES:
-                self.assertEqual(s[size-1], "\0")
+                self.assertIn(s[size-1], [0, "\0"])
                 struct.unpack(str(size-4) + "s", s[4:size])
             else:
                 self.assertEqual(type_code, collectd.TYPE_VALUES)
                 values = s[6:size]
                 count = 0
                 while values:
-                    value_code = struct.unpack("B", values[0])[0]
+                    val = values[0]
+                    value_code = struct.unpack("B", val)[0] if isinstance(val, collectd.StringTypes) else val
                     self.assertTrue(value_code in collectd.VALUE_CODES)
                     struct.unpack(collectd.VALUE_CODES[value_code], values[1:9])
                     values = values[9:]
@@ -277,22 +278,22 @@ def test_single(self):
         self.send_and_recv(foo = 5)
     
     def test_multiple(self):
-        stats = {"foo": 345352, "bar": -5023123}
+        stats = {u"foo": 345352, u"bar": -5023123}
         packet = self.send_and_recv(**stats)
         for name, value in stats.items():
-            self.assertTrue(name + "\0" in packet)
+            self.assertTrue(name.encode("UTF-8") + b"\0" in packet)
             self.assertTrue(struct.pack("<d", value) in packet)
             self.assertTrue(collectd.pack("test-"+name, value) in packet)
     
     def test_plugin_name(self):
         conn = collectd.Connection(collectd_port = self.TEST_PORT,
                                    plugin_name = "dckx")
-        self.assertTrue("dckx" in self.send_and_recv(conn, foo=5))
+        self.assertTrue(b"dckx" in self.send_and_recv(conn, foo=5))
 
     def test_plugin_inst(self):
         conn = collectd.Connection(collectd_port = self.TEST_PORT,
                                    plugin_inst = "xkcd")
-        self.assertTrue("xkcd" in self.send_and_recv(conn, foo=5))
+        self.assertTrue(b"xkcd" in self.send_and_recv(conn, foo=5))
     
     def test_unicode(self):
         self.send_and_recv(self.conn, u"foo.bar", hits = 1)
@@ -305,7 +306,7 @@ def test_too_large(self):
         collectd.send_stats(raise_on_empty = True)
         for name,val in stats:
             packet = self.server.recv(collectd.MAX_PACKET_SIZE)
-            self.assertTrue(name + "\0" in packet)
+            self.assertTrue(name.encode("UTF-8") + b"\0" in packet)
             self.assertTrue(struct.pack("<d", val) in packet)
             self.assertValidPacket(8, packet)
     
@@ -319,9 +320,9 @@ def test_too_many(self):
         for packet in packets:
             self.assertValidPacket(8, packet)
         
-        data = "".join(packets)
+        data = b"".join(packets)
         for name,val in stats:
-            self.assertTrue(name + "\0" in data)
+            self.assertTrue(name.encode("UTF-8") + b"\0" in data)
             self.assertTrue(struct.pack("<d", val) in data)