Skip to content

adds a method to return expiration of the records #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion srvlookup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Service lookup using DNS SRV records

"""
from .main import lookup, SRV, SRVQueryFailure
from .main import lookup, lookup_expiry, SRV, SRVQueryFailure

__all__ = [
'lookup', 'SRV', 'SRVQueryFailure'
Expand Down
43 changes: 39 additions & 4 deletions srvlookup/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,44 @@ def __str__(self) -> str:
return 'SRV query failure: %s' % self.args[0]


def lookup_expiry(name: str, protocol: str = 'TCP',
domain: typing.Optional[str] = None,
tcp_resolver: bool = False) -> tuple[typing.List[SRV], float]:
"""Return a list of service records and associated data for the given
service name, protocol and optional domain. If protocol is not specified,
TCP will be used. If domain is not specified, the domain name returned by
the operating system will be used.

This also returns the timestamp of the expiration of the records.

Service records will be returned as a named tuple with host, port, priority
and weight attributes:

>>> import srvlookup
>>> srvlookup.lookup_expiry('api', 'memcached')
([SRV(host='192.169.1.100', port=11211, priority=1, weight=0,
hostname='host1.example.com'),
SRV(host='192.168.1.102', port=11211, priority=1, weight=0,
hostname='host2.example.com'),
SRV(host='192.168.1.120', port=11211, priority=1, weight=0,
hostname='host3.example.com'),
SRV(host='192.168.1.126', port=11211, priority=1, weight=0,
hostname='host4.example.com')],
1729296172.4126928)
>>>

:param name: The service name
:param protocol: The protocol name, defaults to TCP
:param domain: The domain name to use, defaults to local domain name

"""
answer = _query_srv_records(
f'_{name}._{protocol}.{domain or _get_domain()}', tcp_resolver)
results = _build_result_set(answer)
records = sorted(results, key=lambda r: (r.priority, -r.weight, r.host))
return (records, answer.expiration)


def lookup(name: str, protocol: str = 'TCP',
domain: typing.Optional[str] = None,
tcp_resolver: bool = False) -> typing.List[SRV]:
Expand Down Expand Up @@ -53,10 +91,7 @@ def lookup(name: str, protocol: str = 'TCP',
:param domain: The domain name to use, defaults to local domain name

"""
answer = _query_srv_records(
f'_{name}._{protocol}.{domain or _get_domain()}', tcp_resolver)
results = _build_result_set(answer)
return sorted(results, key=lambda r: (r.priority, -r.weight, r.host))
return lookup_expiry(name, protocol, domain, tcp_resolver)[0]


def _get_domain() -> str:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_srvlookup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from datetime import datetime
from unittest import mock

from dns import message, name, resolver
Expand Down Expand Up @@ -93,6 +94,26 @@ def test_should_return_name_when_addt_record_is_missing(self):
srvlookup.SRV('foo3.bar.baz', 11213, 3, 0, 'foo3.bar.baz')
])

def test_should_return_a_list_of_records_with_expiry(self):
with mock.patch('dns.resolver.resolve') as query:
query_name = name.from_text('foo.bar.baz.')
msg = self.get_message()
answer = resolver.Answer(query_name, 33, 1, msg, msg.answer[0])
query.return_value = answer

(records, expiry) = srvlookup.lookup_expiry('foo', 'bar', 'baz')
self.assertEqual(
records, [
srvlookup.SRV('1.2.3.5', 11212, 1, 0, 'foo2.bar.baz'),
srvlookup.SRV('1.2.3.4', 11211, 2, 0, 'foo1.bar.baz')
])
# Records are sent with an expiry of 0 (now)
# Because we capture the date after the packet was parsed, the resulting
# expiry should be: now() - 1 < expiry <= now()
expected_expiry = datetime.now().timestamp()
self.assertLessEqual(expiry, expected_expiry)
self.assertGreater(expiry, expected_expiry - 1)


class WhenInvokingGetDomain(unittest.TestCase):

Expand Down