From 4e7758ec974807ea9502acad45713ba0fa248cb9 Mon Sep 17 00:00:00 2001 From: Marcin Rataj Date: Wed, 3 Jun 2026 17:17:46 +0200 Subject: [PATCH] feat: add LookupTXTWithTTL to expose TXT record TTL LookupTXT discarded the TTL that doRequestTXT already computes from the DNS answer, so callers had no way to learn how long a TXT record set is valid. This blocks consumers (e.g. a gateway resolving DNSLink) from setting Cache-Control max-age based on the real DNS TTL. Add LookupTXTWithTTL which returns the records plus their TTL. LookupTXT delegates to it so behavior is unchanged. On a cache hit the remaining lifetime is returned so the value does not over-report as the entry ages, and it is capped by the resolver max cache TTL. A TTL of 0 means unknown. --- resolver.go | 31 ++++++++++++++++++++------- resolver_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/resolver.go b/resolver.go index a9bfe36..f7d67cc 100644 --- a/resolver.go +++ b/resolver.go @@ -124,19 +124,29 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) (result []ne } func (r *Resolver) LookupTXT(ctx context.Context, domain string) ([]string, error) { - result, ok := r.getCachedTXT(domain) - if ok { - return result, nil + result, _, err := r.LookupTXTWithTTL(ctx, domain) + return result, err +} + +// LookupTXTWithTTL is like [Resolver.LookupTXT] but also returns how long the +// TXT records may be cached. The TTL is the smallest Ttl across the answer's +// TXT resource records, capped by the resolver's max cache TTL. On a cache hit +// it is the remaining lifetime of the cached entry, so the value shrinks as the +// entry ages. A TTL of 0 means the TTL is unknown, for example when the +// upstream resolver does not provide one. +func (r *Resolver) LookupTXTWithTTL(ctx context.Context, domain string) ([]string, time.Duration, error) { + if result, ttl, ok := r.getCachedTXTWithTTL(domain); ok { + return result, ttl, nil } result, ttl, err := doRequestTXT(ctx, r.url, domain) if err != nil { - return nil, err + return nil, 0, err } cacheTTL := minTTL(time.Duration(ttl)*time.Second, r.maxCacheTTL) r.cacheTXT(domain, result, cacheTTL) - return result, nil + return result, cacheTTL, nil } func (r *Resolver) getCachedIPAddr(domain string) ([]net.IPAddr, bool) { @@ -170,21 +180,26 @@ func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl time.Duratio } func (r *Resolver) getCachedTXT(domain string) ([]string, bool) { + txt, _, ok := r.getCachedTXTWithTTL(domain) + return txt, ok +} + +func (r *Resolver) getCachedTXTWithTTL(domain string) ([]string, time.Duration, bool) { r.mx.Lock() defer r.mx.Unlock() fqdn := dns.Fqdn(domain) entry, ok := r.txtCache[fqdn] if !ok { - return nil, false + return nil, 0, false } if time.Now().After(entry.expire) { delete(r.txtCache, fqdn) - return nil, false + return nil, 0, false } - return entry.txt, true + return entry.txt, time.Until(entry.expire), true } func (r *Resolver) cacheTXT(domain string, txt []string, ttl time.Duration) { diff --git a/resolver_test.go b/resolver_test.go index a58b1ff..4f8c188 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -147,6 +147,62 @@ func TestLookupTXT(t *testing.T) { } } +func TestLookupTXTWithTTL(t *testing.T) { + domain := "example.com" + resolver := mockDoHResolver(t, map[uint16]*dns.Msg{ + dns.TypeTXT: mockDNSAnswerTXT(dns.Fqdn(domain), []string{"dnslink=/ipns/example.com"}), + }) + defer resolver.Close() + + r, err := NewResolver(resolver.URL) + if err != nil { + t.Fatal("resolver cannot be initialised") + } + + // cold lookup returns the record TTL from the answer (300s in the mock) + txt, ttl, err := r.LookupTXTWithTTL(context.Background(), domain) + if err != nil { + t.Fatal(err) + } + if len(txt) == 0 { + t.Fatal("got no TXT entries") + } + if ttl != 300*time.Second { + t.Fatalf("expected ttl 300s, got %s", ttl) + } + + // warm lookup (cache hit) returns the remaining TTL, never more than the record TTL + _, ttl2, err := r.LookupTXTWithTTL(context.Background(), domain) + if err != nil { + t.Fatal(err) + } + if ttl2 <= 0 || ttl2 > 300*time.Second { + t.Fatalf("expected remaining ttl in (0s, 300s], got %s", ttl2) + } +} + +func TestLookupTXTWithTTLCappedByMaxCacheTTL(t *testing.T) { + domain := "example.com" + resolver := mockDoHResolver(t, map[uint16]*dns.Msg{ + dns.TypeTXT: mockDNSAnswerTXT(dns.Fqdn(domain), []string{"dnslink=/ipns/example.com"}), + }) + defer resolver.Close() + + // record TTL (300s) is larger than the max cache TTL, so the returned TTL is capped + r, err := NewResolver(resolver.URL, WithMaxCacheTTL(10*time.Second)) + if err != nil { + t.Fatal("resolver cannot be initialised") + } + + _, ttl, err := r.LookupTXTWithTTL(context.Background(), domain) + if err != nil { + t.Fatal(err) + } + if ttl != 10*time.Second { + t.Fatalf("expected ttl capped to 10s, got %s", ttl) + } +} + func TestLookupCache(t *testing.T) { domain := "example.com" resolver := mockDoHResolver(t, map[uint16]*dns.Msg{