11#include " resolve.hh"
22#include < arpa/inet.h>
33#include < netinet/in.h>
4+ #include < sys/socket.h>
45#include < algorithm>
56#include < cassert>
67#include < cstdint>
@@ -52,6 +53,20 @@ struct Zone {
5253 }
5354};
5455
56+ class TCPSocket {
57+ public:
58+ TCPSocket () {
59+ tcp_socket = socket (AF_INET, SOCK_STREAM, 0 );
60+ if (tcp_socket == -1 ) throw std::runtime_error (" Failed to create TCP socket" );
61+ }
62+ ~TCPSocket () { close (tcp_socket); }
63+
64+ operator int () const { return tcp_socket; }
65+
66+ private:
67+ int tcp_socket;
68+ };
69+
5570namespace {
5671const constexpr uint64_t MIN_QUERY_TIMEOUT_MS = 300 ;
5772const constexpr int MAX_QUERY_DEPTH = 20 ;
@@ -95,6 +110,24 @@ struct missing_referral_error : public std::runtime_error {
95110 missing_referral_error (std::string zone) : std::runtime_error(" Missing referral" ), zone(std::move(zone)) {}
96111};
97112
113+ // List of zones to ask when there is no information to guide zone selection.
114+ class SafetyBelt {
115+ public:
116+ SafetyBelt (const std::queue<std::shared_ptr<Zone>> &zones) : zones(zones) {}
117+
118+ std::shared_ptr<Zone> next () {
119+ while (!zones.empty ()) {
120+ auto zone = std::move (zones.front ());
121+ zones.pop ();
122+ if (zone != nullptr && !zone->is_being_resolved ) return zone;
123+ }
124+ return nullptr ;
125+ }
126+
127+ private:
128+ std::queue<std::shared_ptr<Zone>> zones;
129+ };
130+
98131// Check domain length, convert it to lowercase and fully qualify.
99132std::string fully_qualify_domain (const std::string &domain) {
100133 std::string fqd;
@@ -139,14 +172,26 @@ bool is_zone_closer(const std::string &sname, const std::string &old_zone, const
139172bool address_equals (struct sockaddr_in a, struct sockaddr_in b) {
140173 return a.sin_port == b.sin_port && a.sin_addr .s_addr == b.sin_addr .s_addr ;
141174}
175+
176+ void set_socket_timeout (int socket, uint64_t timeout_ms) {
177+ struct timeval tv;
178+ tv.tv_sec = timeout_ms / 1000 ;
179+ tv.tv_usec = (timeout_ms % 1000 ) * 1000 ;
180+
181+ if (setsockopt (socket, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof (tv)) != 0 || //
182+ setsockopt (socket, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof (tv)) != 0 ) {
183+ throw std::runtime_error (" Failed to set receive/send timeout" );
184+ }
185+ }
142186} // namespace
143187
144188Resolver::Resolver (const ResolverConfig &config)
145189 : query_timeout_ms(std::max(config.timeout_ms, MIN_QUERY_TIMEOUT_MS)),
146- udp_timeout_ms (query_timeout_ms / 3 ),
190+ net_timeout_ms (query_timeout_ms / 3 ),
147191 port(config.port),
148192 verbose(config.verbose),
149193 enable_rd(config.enable_rd),
194+ tcp(config.tcp),
150195 edns(config.edns),
151196 dnssec(config.dnssec),
152197 cookies(config.cookies),
@@ -161,8 +206,9 @@ Resolver::Resolver(const ResolverConfig &config)
161206 }
162207 if (dnssec == FeatureState::Require || cookies == FeatureState::Require) edns = FeatureState::Require;
163208
164- fd = socket (AF_INET, SOCK_DGRAM, 0 );
165- if (fd == -1 ) throw std::runtime_error (" Failed to create UDP socket" );
209+ udp_socket = socket (AF_INET, SOCK_DGRAM, 0 );
210+ if (udp_socket == -1 ) throw std::runtime_error (" Failed to create UDP socket" );
211+ set_socket_timeout (udp_socket, net_timeout_ms);
166212
167213 // While the root NS (in . zone) are signed, their addresses (in root-servers.net zone) aren't.
168214 // Load the unsigned zone of the root nameservers, otherwise query will fail (due to no RRSIG).
@@ -175,7 +221,7 @@ Resolver::Resolver(const ResolverConfig &config)
175221 zones[root_zone->domain ] = root_zone;
176222}
177223
178- Resolver::~Resolver () { close (fd ); }
224+ Resolver::~Resolver () { close (udp_socket ); }
179225
180226std::optional<std::vector<RR>> Resolver::resolve (const std::string &qname, RRType qtype) {
181227 // DNSSEC is disabled for queries of type ANY.
@@ -186,23 +232,14 @@ std::optional<std::vector<RR>> Resolver::resolve(const std::string &qname, RRTyp
186232
187233 try {
188234 query_start = std::chrono::steady_clock::now ();
189- set_socket_timeout (udp_timeout_ms) ;
235+ query_time_left_ms = query_timeout_ms ;
190236 return resolve_rec (fully_qualify_domain (qname), qtype, 0 );
191237 } catch (const std::exception &e) {
192238 if (verbose) std::println (stderr, " Failed to resolve the domain: {}." , e.what ());
193239 return std::nullopt ;
194240 }
195241}
196242
197- std::shared_ptr<Zone> Resolver::SafetyBelt::next () {
198- while (!zones.empty ()) {
199- auto zone = std::move (zones.front ());
200- zones.pop ();
201- if (zone != nullptr && !zone->is_being_resolved ) return zone;
202- }
203- return nullptr ;
204- }
205-
206243std::queue<std::shared_ptr<Zone>> Resolver::init_safety_belt (const ResolverConfig &config) const {
207244 std::queue<std::shared_ptr<Zone>> zones;
208245
@@ -304,31 +341,20 @@ void Resolver::zone_disable_dnssec(Zone &zone) const {
304341 zone.enable_dnssec = false ;
305342}
306343
307- void Resolver::set_socket_timeout (uint64_t timeout_ms) const {
308- struct timeval tv;
309- tv.tv_sec = timeout_ms / 1000 ;
310- tv.tv_usec = (timeout_ms % 1000 ) * 1000 ;
311-
312- if (setsockopt (fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof (tv)) != 0 || //
313- setsockopt (fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof (tv)) != 0 ) {
314- throw std::runtime_error (" Failed to set receive/send timeout" );
315- }
316- }
317-
318- void Resolver::update_timeout () {
344+ void Resolver::update_timeout (int socket) {
319345 using namespace std ::chrono;
320346
321347 auto query_duration_ms = duration_cast<duration<uint64_t , std::milli>>(steady_clock::now () - query_start).count ();
322348 if (query_duration_ms >= query_timeout_ms) throw query_timeout_error ();
323349
324- auto time_left_ms = query_timeout_ms - query_duration_ms;
325- if (time_left_ms < udp_timeout_ms ) set_socket_timeout (time_left_ms );
350+ query_time_left_ms = query_timeout_ms - query_duration_ms;
351+ if (query_time_left_ms < net_timeout_ms ) set_socket_timeout (socket, query_time_left_ms );
326352}
327353
328354void Resolver::udp_send (const std::vector<uint8_t > &buffer, struct sockaddr_in address) {
329355 auto *socket_address = reinterpret_cast <struct sockaddr *>(&address);
330- auto result = sendto (fd , buffer.data (), buffer.size (), 0 , socket_address, sizeof (address));
331- update_timeout ();
356+ auto result = sendto (udp_socket , buffer.data (), buffer.size (), 0 , socket_address, sizeof (address));
357+ update_timeout (udp_socket );
332358 if (result == -1 && errno == EAGAIN) throw std::runtime_error (" Request timed out" );
333359 if (result != static_cast <ssize_t >(buffer.size ())) throw std::runtime_error (" Failed to send the request" );
334360}
@@ -341,8 +367,8 @@ void Resolver::udp_receive(std::vector<uint8_t> &buffer, struct sockaddr_in requ
341367 // Read responses until we find the one from the same address and port as in request.
342368 do {
343369 address_length = sizeof (address);
344- result = recvfrom (fd , buffer.data (), buffer.size (), 0 , socket_address, &address_length);
345- update_timeout ();
370+ result = recvfrom (udp_socket , buffer.data (), buffer.size (), 0 , socket_address, &address_length);
371+ update_timeout (udp_socket );
346372 if (result == -1 ) {
347373 if (errno == EAGAIN) throw std::runtime_error (" Response timed out" );
348374 throw std::runtime_error (" Failed to receive the response" );
@@ -351,6 +377,44 @@ void Resolver::udp_receive(std::vector<uint8_t> &buffer, struct sockaddr_in requ
351377 buffer.resize (result);
352378}
353379
380+ void Resolver::tcp_connect (const TCPSocket &tcp_socket, struct sockaddr_in address) {
381+ auto *socket_address = reinterpret_cast <struct sockaddr *>(&address);
382+ auto result = connect (tcp_socket, socket_address, sizeof (address));
383+ update_timeout (tcp_socket);
384+ if (result != 0 ) {
385+ if (errno == EAGAIN) throw std::runtime_error (" Connect timed out" );
386+ throw std::runtime_error (" Failed to connect" );
387+ }
388+ }
389+
390+ void Resolver::tcp_send (const TCPSocket &tcp_socket, const std::vector<uint8_t > &buffer) {
391+ auto message_size_net = htons (buffer.size ());
392+ auto bytes_sent = send (tcp_socket, &message_size_net, sizeof (message_size_net), 0 );
393+ update_timeout (tcp_socket);
394+ if (bytes_sent == -1 && errno == EAGAIN) throw std::runtime_error (" Request timed out" );
395+ if (bytes_sent != sizeof (message_size_net)) throw std::runtime_error (" Failed to send the message size" );
396+
397+ bytes_sent = send (tcp_socket, buffer.data (), buffer.size (), 0 );
398+ update_timeout (tcp_socket);
399+ if (bytes_sent == -1 && errno == EAGAIN) throw std::runtime_error (" Request timed out" );
400+ if (bytes_sent != static_cast <ssize_t >(buffer.size ())) throw std::runtime_error (" Failed to send the request" );
401+ }
402+
403+ void Resolver::tcp_receive (const TCPSocket &tcp_socket, std::vector<uint8_t > &buffer) {
404+ uint16_t message_size_net;
405+ auto bytes_received = recv (tcp_socket, &message_size_net, sizeof (message_size_net), MSG_WAITALL);
406+ update_timeout (tcp_socket);
407+ if (bytes_received == -1 && errno == EAGAIN) throw std::runtime_error (" Response timed out" );
408+ if (bytes_received != sizeof (message_size_net)) throw std::runtime_error (" Failed to receive the message size" );
409+ auto message_size = ntohs (message_size_net);
410+
411+ buffer.resize (message_size);
412+ bytes_received = recv (tcp_socket, buffer.data (), message_size, MSG_WAITALL);
413+ update_timeout (tcp_socket);
414+ if (bytes_received == -1 && errno == EAGAIN) throw std::runtime_error (" Response timed out" );
415+ if (bytes_received != message_size) throw std::runtime_error (" Failed to receive the response" );
416+ }
417+
354418std::vector<RR> Resolver::get_unauthenticated_rrset (std::vector<RR> &rrset, RRType rr_type) {
355419 if (rr_type == RRType::ANY) return rrset;
356420
@@ -464,6 +528,39 @@ std::vector<RR> Resolver::get_rrset(std::vector<RR> &rrset, RRType rr_type, cons
464528 return result;
465529}
466530
531+ Response Resolver::send_request (std::vector<uint8_t > &buffer, const std::string &qname, RRType qtype,
532+ Nameserver &nameserver, const Zone &zone, bool use_tcp) {
533+ auto payload_size
534+ = nameserver.udp_payload_size .value_or (zone.enable_edns ? EDNS_UDP_PAYLOAD_SIZE : STANDARD_UDP_PAYLOAD_SIZE);
535+
536+ // Write and send the request.
537+ buffer.reserve (payload_size);
538+ buffer.clear ();
539+ auto id = write_request (buffer, payload_size, qname, qtype, enable_rd, zone.enable_edns , zone.enable_dnssec ,
540+ zone.enable_cookies , nameserver.cookies );
541+
542+ struct sockaddr_in address;
543+ address.sin_family = AF_INET;
544+ address.sin_port = htons (port);
545+ address.sin_addr .s_addr = std::get<in_addr_t >(nameserver.address );
546+
547+ if (use_tcp) {
548+ TCPSocket tcp_socket;
549+ set_socket_timeout (tcp_socket, std::min (net_timeout_ms, query_time_left_ms));
550+ tcp_connect (tcp_socket, address);
551+ tcp_send (tcp_socket, buffer);
552+ tcp_receive (tcp_socket, buffer);
553+ } else {
554+ udp_send (buffer, address);
555+
556+ // Ensure buffer is big enough to receive the response.
557+ buffer.resize (payload_size);
558+ udp_receive (buffer, address);
559+ }
560+
561+ return read_response (buffer, id, qname, qtype);
562+ }
563+
467564std::optional<std::vector<RR>> Resolver::resolve_rec (const std::string &qname, RRType qtype, int depth,
468565 std::shared_ptr<Zone> search_zone) {
469566 if (depth >= MAX_QUERY_DEPTH) throw std::runtime_error (" Query is too deep" );
@@ -472,10 +569,6 @@ std::optional<std::vector<RR>> Resolver::resolve_rec(const std::string &qname, R
472569 std::string sname{qname};
473570 SafetyBelt safety_belt{safety_belt_zones};
474571
475- struct sockaddr_in address;
476- address.sin_family = AF_INET;
477- address.sin_port = htons (port);
478-
479572 // Choose the initial zone.
480573 std::shared_ptr<Zone> next_zone;
481574 if (search_zone != nullptr ) {
@@ -536,32 +629,23 @@ std::optional<std::vector<RR>> Resolver::resolve_rec(const std::string &qname, R
536629 nameserver->address = a_rrset[0 ].address ;
537630 for (size_t j = 1 ; j < a_rrset.size (); j++) zone->add_nameserver (a_rrset[j].address );
538631 }
539- address.sin_addr .s_addr = std::get<in_addr_t >(nameserver->address );
540632
541633 if (verbose) {
542634 char ip_addr_buf[INET_ADDRSTRLEN];
543- const auto *address_str = inet_ntop (AF_INET, &address.sin_addr , ip_addr_buf, sizeof (ip_addr_buf));
635+ auto addr = std::get<in_addr_t >(nameserver->address );
636+ const auto *address_str = inet_ntop (AF_INET, &addr, ip_addr_buf, sizeof (ip_addr_buf));
544637 if (address_str == nullptr ) address_str = " invalid address" ;
545638 std::println (" Resolving \" {}\" using {} ({})" , sname, address_str, zone->domain );
546639 }
547640
548- auto payload_size = nameserver->udp_payload_size .value_or (
549- zone->enable_edns ? EDNS_UDP_PAYLOAD_SIZE : STANDARD_UDP_PAYLOAD_SIZE);
550-
551- // Write and send the request.
552- buffer.reserve (payload_size);
553- buffer.clear ();
554- auto id = write_request (buffer, payload_size, sname, qtype, enable_rd, zone->enable_edns ,
555- zone->enable_dnssec , zone->enable_cookies , nameserver->cookies );
556- udp_send (buffer, address);
557-
558- // Ensure buffer is big enough to receive the response.
559- buffer.resize (payload_size);
560- udp_receive (buffer, address);
561- auto response = read_response (buffer, id, sname, qtype);
562- auto rcode = response.rcode ;
641+ auto response = send_request (buffer, sname, qtype, *nameserver, *zone, tcp == FeatureState::Require);
642+ if (response.is_truncated && tcp == FeatureState::Enable) {
643+ response = send_request (buffer, sname, qtype, *nameserver, *zone, true );
644+ }
645+ if (response.is_truncated ) throw std::runtime_error (" Response is truncated" );
563646
564647 // Handle OPT record.
648+ auto rcode = response.rcode ;
565649 if (zone->enable_edns ) {
566650 std::vector<RR> opt_rrset = get_unauthenticated_rrset (response.additional , RRType::OPT);
567651 if (opt_rrset.size () == 1 ) {
0 commit comments