From e1efafa809557a92d3a4955c1f1475727bfa76a3 Mon Sep 17 00:00:00 2001 From: Arthur Gautier Date: Fri, 18 Oct 2024 16:05:40 -0700 Subject: [PATCH] adds a method to return expiration of the records This adds a `lookup_expiry` helper method that will return the validity period of the records that were just returned. This is convenient when the consumer of srvlookup is a long running process. The records should only be considered authoritative until their expiration. --- srvlookup/__init__.py | 2 +- srvlookup/main.py | 43 +++++++++++++++++++++++++++++++++++++---- tests/test_srvlookup.py | 21 ++++++++++++++++++++ 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/srvlookup/__init__.py b/srvlookup/__init__.py index abb11da..b2aa7c7 100644 --- a/srvlookup/__init__.py +++ b/srvlookup/__init__.py @@ -2,7 +2,7 @@ Service lookup using DNS SRV records """ -from .main import lookup, SRV, SRVQueryFailure +from .main import lookup, lookup_expiry, SRV, SRVQueryFailure __all__ = [ 'lookup', 'SRV', 'SRVQueryFailure' diff --git a/srvlookup/main.py b/srvlookup/main.py index f2a2e08..41501ec 100644 --- a/srvlookup/main.py +++ b/srvlookup/main.py @@ -25,6 +25,44 @@ def __str__(self) -> str: return 'SRV query failure: %s' % self.args[0] +def lookup_expiry(name: str, protocol: str = 'TCP', + domain: typing.Optional[str] = None, + tcp_resolver: bool = False) -> tuple[typing.List[SRV], float]: + """Return a list of service records and associated data for the given + service name, protocol and optional domain. If protocol is not specified, + TCP will be used. If domain is not specified, the domain name returned by + the operating system will be used. + + This also returns the timestamp of the expiration of the records. + + Service records will be returned as a named tuple with host, port, priority + and weight attributes: + + >>> import srvlookup + >>> srvlookup.lookup_expiry('api', 'memcached') + ([SRV(host='192.169.1.100', port=11211, priority=1, weight=0, + hostname='host1.example.com'), + SRV(host='192.168.1.102', port=11211, priority=1, weight=0, + hostname='host2.example.com'), + SRV(host='192.168.1.120', port=11211, priority=1, weight=0, + hostname='host3.example.com'), + SRV(host='192.168.1.126', port=11211, priority=1, weight=0, + hostname='host4.example.com')], + 1729296172.4126928) + >>> + + :param name: The service name + :param protocol: The protocol name, defaults to TCP + :param domain: The domain name to use, defaults to local domain name + + """ + answer = _query_srv_records( + f'_{name}._{protocol}.{domain or _get_domain()}', tcp_resolver) + results = _build_result_set(answer) + records = sorted(results, key=lambda r: (r.priority, -r.weight, r.host)) + return (records, answer.expiration) + + def lookup(name: str, protocol: str = 'TCP', domain: typing.Optional[str] = None, tcp_resolver: bool = False) -> typing.List[SRV]: @@ -53,10 +91,7 @@ def lookup(name: str, protocol: str = 'TCP', :param domain: The domain name to use, defaults to local domain name """ - answer = _query_srv_records( - f'_{name}._{protocol}.{domain or _get_domain()}', tcp_resolver) - results = _build_result_set(answer) - return sorted(results, key=lambda r: (r.priority, -r.weight, r.host)) + return lookup_expiry(name, protocol, domain, tcp_resolver)[0] def _get_domain() -> str: diff --git a/tests/test_srvlookup.py b/tests/test_srvlookup.py index 519fa0c..b853376 100644 --- a/tests/test_srvlookup.py +++ b/tests/test_srvlookup.py @@ -1,4 +1,5 @@ import unittest +from datetime import datetime from unittest import mock from dns import message, name, resolver @@ -93,6 +94,26 @@ def test_should_return_name_when_addt_record_is_missing(self): srvlookup.SRV('foo3.bar.baz', 11213, 3, 0, 'foo3.bar.baz') ]) + def test_should_return_a_list_of_records_with_expiry(self): + with mock.patch('dns.resolver.resolve') as query: + query_name = name.from_text('foo.bar.baz.') + msg = self.get_message() + answer = resolver.Answer(query_name, 33, 1, msg, msg.answer[0]) + query.return_value = answer + + (records, expiry) = srvlookup.lookup_expiry('foo', 'bar', 'baz') + self.assertEqual( + records, [ + srvlookup.SRV('1.2.3.5', 11212, 1, 0, 'foo2.bar.baz'), + srvlookup.SRV('1.2.3.4', 11211, 2, 0, 'foo1.bar.baz') + ]) + # Records are sent with an expiry of 0 (now) + # Because we capture the date after the packet was parsed, the resulting + # expiry should be: now() - 1 < expiry <= now() + expected_expiry = datetime.now().timestamp() + self.assertLessEqual(expiry, expected_expiry) + self.assertGreater(expiry, expected_expiry - 1) + class WhenInvokingGetDomain(unittest.TestCase):