Skip to content

Commit 733b3f5

Browse files
committed
Add TCP support
1 parent 97b9dae commit 733b3f5

File tree

8 files changed

+190
-119
lines changed

8 files changed

+190
-119
lines changed

src/dns.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class ResponseReader {
8383
bool is_response = (flags >> 15) & 1;
8484
auto opcode = static_cast<OpCode>((flags >> 11) & 0b1111);
8585
response.is_authoritative = (flags >> 10) & 1;
86-
bool is_truncated = (flags >> 9) & 1;
86+
response.is_truncated = (flags >> 9) & 1;
8787
response.rcode = static_cast<RCode>(flags & 0b1111);
8888
auto question_count = read_u16();
8989
auto answer_count = read_u16();
@@ -94,7 +94,6 @@ class ResponseReader {
9494
if (id != request_id) throw std::runtime_error("Wrong response ID");
9595
if (!is_response) throw std::runtime_error("Response is a query");
9696
if (opcode != OpCode::Query) throw std::runtime_error("Response has wrong opcode");
97-
if (is_truncated) throw std::runtime_error("Response is truncated");
9897
if (question_count != 1) throw std::runtime_error("Wrong question count");
9998

10099
// Validate question.

src/dns.hh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ private:
322322

323323
struct Response {
324324
bool is_authoritative;
325+
bool is_truncated;
325326
RCode rcode;
326327
std::vector<RR> answers;
327328
std::vector<RR> authority;

src/main.cc

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,21 @@ int main(int argc, char **argv) {
5959
options.custom_help("[options]");
6060
options.positional_help("<domain>");
6161

62-
options.add_options() //
63-
("domain", "Domain name to resolve", cxxopts::value<std::string>()) //
64-
("h,help", "Print usage", cxxopts::value<bool>()->default_value("false")) //
65-
("s,server", "Nameserver domain or address", cxxopts::value<std::string>()) //
66-
("p,port", "Nameserver port", cxxopts::value<uint16_t>()->default_value("53")) //
67-
("t,type", "Query type", cxxopts::value<RRType>()->default_value("A")) //
68-
("T,timeout", "Timeout in seconds", cxxopts::value<uint64_t>()->default_value("5")) //
69-
("v,verbose", "Verbose output", cxxopts::value<bool>()->default_value("false")) //
70-
("rdflag", "Set recursion desired flag", cxxopts::value<bool>()->default_value("true")) //
71-
("edns", "EDNS", cxxopts::value<FeatureState>()->default_value("on")) //
72-
("dnssec", "DNSSEC", cxxopts::value<FeatureState>()->default_value("on")) //
73-
("cookies", "Cookies", cxxopts::value<FeatureState>()->default_value("on")) //
74-
("use-root", "Use root nameservers", cxxopts::value<bool>()->default_value("true")) //
75-
("use-config", "Use nameservers from /etc/resolv.conf", cxxopts::value<bool>()->default_value("true"));
62+
options.add_options() //
63+
("domain", "Domain name to resolve", cxxopts::value<std::string>()) //
64+
("h,help", "Print usage", cxxopts::value<bool>()->default_value("false")) //
65+
("s,server", "Nameserver domain or address", cxxopts::value<std::string>()) //
66+
("p,port", "Nameserver port", cxxopts::value<uint16_t>()->default_value("53")) //
67+
("t,type", "Query type", cxxopts::value<RRType>()->default_value("A")) //
68+
("T,timeout", "Timeout in seconds", cxxopts::value<uint64_t>()->default_value("10")) //
69+
("v,verbose", "Verbose output", cxxopts::value<bool>()->default_value("false")) //
70+
("rdflag", "Set recursion desired flag", cxxopts::value<bool>()->default_value("true")) //
71+
("tcp", "on = fallback, require = TCP-only", cxxopts::value<FeatureState>()->default_value("on")) //
72+
("edns", "EDNS", cxxopts::value<FeatureState>()->default_value("on")) //
73+
("dnssec", "DNSSEC", cxxopts::value<FeatureState>()->default_value("on")) //
74+
("cookies", "Cookies", cxxopts::value<FeatureState>()->default_value("on")) //
75+
("use-root", "Use root nameservers", cxxopts::value<bool>()->default_value("true")) //
76+
("use-config", "Use nameservers from /etc/resolv.conf", cxxopts::value<bool>()->default_value("false"));
7677
options.parse_positional({"domain"});
7778

7879
auto result = options.parse(argc, argv);
@@ -99,6 +100,7 @@ int main(int argc, char **argv) {
99100
.port = result["port"].as<uint16_t>(),
100101
.verbose = result["verbose"].as<bool>(),
101102
.enable_rd = result["rdflag"].as<bool>(),
103+
.tcp = result["tcp"].as<FeatureState>(),
102104
.edns = result["edns"].as<FeatureState>(),
103105
.dnssec = result["dnssec"].as<FeatureState>(),
104106
.cookies = result["cookies"].as<FeatureState>(),

src/resolve.cc

Lines changed: 137 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "resolve.hh"
22
#include <arpa/inet.h>
33
#include <netinet/in.h>
4+
#include <sys/socket.h>
45
#include <algorithm>
56
#include <cassert>
67
#include <cstdint>
@@ -52,6 +53,20 @@ struct Zone {
5253
}
5354
};
5455

56+
class TCPSocket {
57+
public:
58+
TCPSocket() {
59+
tcp_socket = socket(AF_INET, SOCK_STREAM, 0);
60+
if (tcp_socket == -1) throw std::runtime_error("Failed to create TCP socket");
61+
}
62+
~TCPSocket() { close(tcp_socket); }
63+
64+
operator int() const { return tcp_socket; }
65+
66+
private:
67+
int tcp_socket;
68+
};
69+
5570
namespace {
5671
const constexpr uint64_t MIN_QUERY_TIMEOUT_MS = 300;
5772
const constexpr int MAX_QUERY_DEPTH = 20;
@@ -95,6 +110,24 @@ struct missing_referral_error : public std::runtime_error {
95110
missing_referral_error(std::string zone) : std::runtime_error("Missing referral"), zone(std::move(zone)) {}
96111
};
97112

113+
// List of zones to ask when there is no information to guide zone selection.
114+
class SafetyBelt {
115+
public:
116+
SafetyBelt(const std::queue<std::shared_ptr<Zone>> &zones) : zones(zones) {}
117+
118+
std::shared_ptr<Zone> next() {
119+
while (!zones.empty()) {
120+
auto zone = std::move(zones.front());
121+
zones.pop();
122+
if (zone != nullptr && !zone->is_being_resolved) return zone;
123+
}
124+
return nullptr;
125+
}
126+
127+
private:
128+
std::queue<std::shared_ptr<Zone>> zones;
129+
};
130+
98131
// Check domain length, convert it to lowercase and fully qualify.
99132
std::string fully_qualify_domain(const std::string &domain) {
100133
std::string fqd;
@@ -139,14 +172,26 @@ bool is_zone_closer(const std::string &sname, const std::string &old_zone, const
139172
bool address_equals(struct sockaddr_in a, struct sockaddr_in b) {
140173
return a.sin_port == b.sin_port && a.sin_addr.s_addr == b.sin_addr.s_addr;
141174
}
175+
176+
void set_socket_timeout(int socket, uint64_t timeout_ms) {
177+
struct timeval tv;
178+
tv.tv_sec = timeout_ms / 1000;
179+
tv.tv_usec = (timeout_ms % 1000) * 1000;
180+
181+
if (setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) != 0 || //
182+
setsockopt(socket, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) != 0) {
183+
throw std::runtime_error("Failed to set receive/send timeout");
184+
}
185+
}
142186
} // namespace
143187

144188
Resolver::Resolver(const ResolverConfig &config)
145189
: query_timeout_ms(std::max(config.timeout_ms, MIN_QUERY_TIMEOUT_MS)),
146-
udp_timeout_ms(query_timeout_ms / 3),
190+
net_timeout_ms(query_timeout_ms / 3),
147191
port(config.port),
148192
verbose(config.verbose),
149193
enable_rd(config.enable_rd),
194+
tcp(config.tcp),
150195
edns(config.edns),
151196
dnssec(config.dnssec),
152197
cookies(config.cookies),
@@ -161,8 +206,9 @@ Resolver::Resolver(const ResolverConfig &config)
161206
}
162207
if (dnssec == FeatureState::Require || cookies == FeatureState::Require) edns = FeatureState::Require;
163208

164-
fd = socket(AF_INET, SOCK_DGRAM, 0);
165-
if (fd == -1) throw std::runtime_error("Failed to create UDP socket");
209+
udp_socket = socket(AF_INET, SOCK_DGRAM, 0);
210+
if (udp_socket == -1) throw std::runtime_error("Failed to create UDP socket");
211+
set_socket_timeout(udp_socket, net_timeout_ms);
166212

167213
// While the root NS (in . zone) are signed, their addresses (in root-servers.net zone) aren't.
168214
// Load the unsigned zone of the root nameservers, otherwise query will fail (due to no RRSIG).
@@ -175,7 +221,7 @@ Resolver::Resolver(const ResolverConfig &config)
175221
zones[root_zone->domain] = root_zone;
176222
}
177223

178-
Resolver::~Resolver() { close(fd); }
224+
Resolver::~Resolver() { close(udp_socket); }
179225

180226
std::optional<std::vector<RR>> Resolver::resolve(const std::string &qname, RRType qtype) {
181227
// DNSSEC is disabled for queries of type ANY.
@@ -186,23 +232,14 @@ std::optional<std::vector<RR>> Resolver::resolve(const std::string &qname, RRTyp
186232

187233
try {
188234
query_start = std::chrono::steady_clock::now();
189-
set_socket_timeout(udp_timeout_ms);
235+
query_time_left_ms = query_timeout_ms;
190236
return resolve_rec(fully_qualify_domain(qname), qtype, 0);
191237
} catch (const std::exception &e) {
192238
if (verbose) std::println(stderr, "Failed to resolve the domain: {}.", e.what());
193239
return std::nullopt;
194240
}
195241
}
196242

197-
std::shared_ptr<Zone> Resolver::SafetyBelt::next() {
198-
while (!zones.empty()) {
199-
auto zone = std::move(zones.front());
200-
zones.pop();
201-
if (zone != nullptr && !zone->is_being_resolved) return zone;
202-
}
203-
return nullptr;
204-
}
205-
206243
std::queue<std::shared_ptr<Zone>> Resolver::init_safety_belt(const ResolverConfig &config) const {
207244
std::queue<std::shared_ptr<Zone>> zones;
208245

@@ -304,31 +341,20 @@ void Resolver::zone_disable_dnssec(Zone &zone) const {
304341
zone.enable_dnssec = false;
305342
}
306343

307-
void Resolver::set_socket_timeout(uint64_t timeout_ms) const {
308-
struct timeval tv;
309-
tv.tv_sec = timeout_ms / 1000;
310-
tv.tv_usec = (timeout_ms % 1000) * 1000;
311-
312-
if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) != 0 || //
313-
setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) != 0) {
314-
throw std::runtime_error("Failed to set receive/send timeout");
315-
}
316-
}
317-
318-
void Resolver::update_timeout() {
344+
void Resolver::update_timeout(int socket) {
319345
using namespace std::chrono;
320346

321347
auto query_duration_ms = duration_cast<duration<uint64_t, std::milli>>(steady_clock::now() - query_start).count();
322348
if (query_duration_ms >= query_timeout_ms) throw query_timeout_error();
323349

324-
auto time_left_ms = query_timeout_ms - query_duration_ms;
325-
if (time_left_ms < udp_timeout_ms) set_socket_timeout(time_left_ms);
350+
query_time_left_ms = query_timeout_ms - query_duration_ms;
351+
if (query_time_left_ms < net_timeout_ms) set_socket_timeout(socket, query_time_left_ms);
326352
}
327353

328354
void Resolver::udp_send(const std::vector<uint8_t> &buffer, struct sockaddr_in address) {
329355
auto *socket_address = reinterpret_cast<struct sockaddr *>(&address);
330-
auto result = sendto(fd, buffer.data(), buffer.size(), 0, socket_address, sizeof(address));
331-
update_timeout();
356+
auto result = sendto(udp_socket, buffer.data(), buffer.size(), 0, socket_address, sizeof(address));
357+
update_timeout(udp_socket);
332358
if (result == -1 && errno == EAGAIN) throw std::runtime_error("Request timed out");
333359
if (result != static_cast<ssize_t>(buffer.size())) throw std::runtime_error("Failed to send the request");
334360
}
@@ -341,8 +367,8 @@ void Resolver::udp_receive(std::vector<uint8_t> &buffer, struct sockaddr_in requ
341367
// Read responses until we find the one from the same address and port as in request.
342368
do {
343369
address_length = sizeof(address);
344-
result = recvfrom(fd, buffer.data(), buffer.size(), 0, socket_address, &address_length);
345-
update_timeout();
370+
result = recvfrom(udp_socket, buffer.data(), buffer.size(), 0, socket_address, &address_length);
371+
update_timeout(udp_socket);
346372
if (result == -1) {
347373
if (errno == EAGAIN) throw std::runtime_error("Response timed out");
348374
throw std::runtime_error("Failed to receive the response");
@@ -351,6 +377,44 @@ void Resolver::udp_receive(std::vector<uint8_t> &buffer, struct sockaddr_in requ
351377
buffer.resize(result);
352378
}
353379

380+
void Resolver::tcp_connect(const TCPSocket &tcp_socket, struct sockaddr_in address) {
381+
auto *socket_address = reinterpret_cast<struct sockaddr *>(&address);
382+
auto result = connect(tcp_socket, socket_address, sizeof(address));
383+
update_timeout(tcp_socket);
384+
if (result != 0) {
385+
if (errno == EAGAIN) throw std::runtime_error("Connect timed out");
386+
throw std::runtime_error("Failed to connect");
387+
}
388+
}
389+
390+
void Resolver::tcp_send(const TCPSocket &tcp_socket, const std::vector<uint8_t> &buffer) {
391+
auto message_size_net = htons(buffer.size());
392+
auto bytes_sent = send(tcp_socket, &message_size_net, sizeof(message_size_net), 0);
393+
update_timeout(tcp_socket);
394+
if (bytes_sent == -1 && errno == EAGAIN) throw std::runtime_error("Request timed out");
395+
if (bytes_sent != sizeof(message_size_net)) throw std::runtime_error("Failed to send the message size");
396+
397+
bytes_sent = send(tcp_socket, buffer.data(), buffer.size(), 0);
398+
update_timeout(tcp_socket);
399+
if (bytes_sent == -1 && errno == EAGAIN) throw std::runtime_error("Request timed out");
400+
if (bytes_sent != static_cast<ssize_t>(buffer.size())) throw std::runtime_error("Failed to send the request");
401+
}
402+
403+
void Resolver::tcp_receive(const TCPSocket &tcp_socket, std::vector<uint8_t> &buffer) {
404+
uint16_t message_size_net;
405+
auto bytes_received = recv(tcp_socket, &message_size_net, sizeof(message_size_net), MSG_WAITALL);
406+
update_timeout(tcp_socket);
407+
if (bytes_received == -1 && errno == EAGAIN) throw std::runtime_error("Response timed out");
408+
if (bytes_received != sizeof(message_size_net)) throw std::runtime_error("Failed to receive the message size");
409+
auto message_size = ntohs(message_size_net);
410+
411+
buffer.resize(message_size);
412+
bytes_received = recv(tcp_socket, buffer.data(), message_size, MSG_WAITALL);
413+
update_timeout(tcp_socket);
414+
if (bytes_received == -1 && errno == EAGAIN) throw std::runtime_error("Response timed out");
415+
if (bytes_received != message_size) throw std::runtime_error("Failed to receive the response");
416+
}
417+
354418
std::vector<RR> Resolver::get_unauthenticated_rrset(std::vector<RR> &rrset, RRType rr_type) {
355419
if (rr_type == RRType::ANY) return rrset;
356420

@@ -464,6 +528,39 @@ std::vector<RR> Resolver::get_rrset(std::vector<RR> &rrset, RRType rr_type, cons
464528
return result;
465529
}
466530

531+
Response Resolver::send_request(std::vector<uint8_t> &buffer, const std::string &qname, RRType qtype,
532+
Nameserver &nameserver, const Zone &zone, bool use_tcp) {
533+
auto payload_size
534+
= nameserver.udp_payload_size.value_or(zone.enable_edns ? EDNS_UDP_PAYLOAD_SIZE : STANDARD_UDP_PAYLOAD_SIZE);
535+
536+
// Write and send the request.
537+
buffer.reserve(payload_size);
538+
buffer.clear();
539+
auto id = write_request(buffer, payload_size, qname, qtype, enable_rd, zone.enable_edns, zone.enable_dnssec,
540+
zone.enable_cookies, nameserver.cookies);
541+
542+
struct sockaddr_in address;
543+
address.sin_family = AF_INET;
544+
address.sin_port = htons(port);
545+
address.sin_addr.s_addr = std::get<in_addr_t>(nameserver.address);
546+
547+
if (use_tcp) {
548+
TCPSocket tcp_socket;
549+
set_socket_timeout(tcp_socket, std::min(net_timeout_ms, query_time_left_ms));
550+
tcp_connect(tcp_socket, address);
551+
tcp_send(tcp_socket, buffer);
552+
tcp_receive(tcp_socket, buffer);
553+
} else {
554+
udp_send(buffer, address);
555+
556+
// Ensure buffer is big enough to receive the response.
557+
buffer.resize(payload_size);
558+
udp_receive(buffer, address);
559+
}
560+
561+
return read_response(buffer, id, qname, qtype);
562+
}
563+
467564
std::optional<std::vector<RR>> Resolver::resolve_rec(const std::string &qname, RRType qtype, int depth,
468565
std::shared_ptr<Zone> search_zone) {
469566
if (depth >= MAX_QUERY_DEPTH) throw std::runtime_error("Query is too deep");
@@ -472,10 +569,6 @@ std::optional<std::vector<RR>> Resolver::resolve_rec(const std::string &qname, R
472569
std::string sname{qname};
473570
SafetyBelt safety_belt{safety_belt_zones};
474571

475-
struct sockaddr_in address;
476-
address.sin_family = AF_INET;
477-
address.sin_port = htons(port);
478-
479572
// Choose the initial zone.
480573
std::shared_ptr<Zone> next_zone;
481574
if (search_zone != nullptr) {
@@ -536,32 +629,23 @@ std::optional<std::vector<RR>> Resolver::resolve_rec(const std::string &qname, R
536629
nameserver->address = a_rrset[0].address;
537630
for (size_t j = 1; j < a_rrset.size(); j++) zone->add_nameserver(a_rrset[j].address);
538631
}
539-
address.sin_addr.s_addr = std::get<in_addr_t>(nameserver->address);
540632

541633
if (verbose) {
542634
char ip_addr_buf[INET_ADDRSTRLEN];
543-
const auto *address_str = inet_ntop(AF_INET, &address.sin_addr, ip_addr_buf, sizeof(ip_addr_buf));
635+
auto addr = std::get<in_addr_t>(nameserver->address);
636+
const auto *address_str = inet_ntop(AF_INET, &addr, ip_addr_buf, sizeof(ip_addr_buf));
544637
if (address_str == nullptr) address_str = "invalid address";
545638
std::println("Resolving \"{}\" using {} ({})", sname, address_str, zone->domain);
546639
}
547640

548-
auto payload_size = nameserver->udp_payload_size.value_or(
549-
zone->enable_edns ? EDNS_UDP_PAYLOAD_SIZE : STANDARD_UDP_PAYLOAD_SIZE);
550-
551-
// Write and send the request.
552-
buffer.reserve(payload_size);
553-
buffer.clear();
554-
auto id = write_request(buffer, payload_size, sname, qtype, enable_rd, zone->enable_edns,
555-
zone->enable_dnssec, zone->enable_cookies, nameserver->cookies);
556-
udp_send(buffer, address);
557-
558-
// Ensure buffer is big enough to receive the response.
559-
buffer.resize(payload_size);
560-
udp_receive(buffer, address);
561-
auto response = read_response(buffer, id, sname, qtype);
562-
auto rcode = response.rcode;
641+
auto response = send_request(buffer, sname, qtype, *nameserver, *zone, tcp == FeatureState::Require);
642+
if (response.is_truncated && tcp == FeatureState::Enable) {
643+
response = send_request(buffer, sname, qtype, *nameserver, *zone, true);
644+
}
645+
if (response.is_truncated) throw std::runtime_error("Response is truncated");
563646

564647
// Handle OPT record.
648+
auto rcode = response.rcode;
565649
if (zone->enable_edns) {
566650
std::vector<RR> opt_rrset = get_unauthenticated_rrset(response.additional, RRType::OPT);
567651
if (opt_rrset.size() == 1) {

0 commit comments

Comments
 (0)