Skip to content

Commit 770a708

Browse files
committed
Refactor
1 parent 7f02a7e commit 770a708

File tree

9 files changed

+253
-251
lines changed

9 files changed

+253
-251
lines changed

src/dns.cc

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ template <std::integral T>
1919
T random_int() {
2020
T result;
2121
if (RAND_bytes(reinterpret_cast<unsigned char *>(&result), sizeof(result)) != 1) {
22-
throw std::runtime_error{"Failed to generate random bytes"};
22+
throw std::runtime_error("Failed to generate random bytes");
2323
}
2424
return result;
2525
}
@@ -79,40 +79,38 @@ class ResponseReader {
7979

8080
// Read response header.
8181
auto id = read_u16();
82-
if (id != request_id) throw std::runtime_error("Wrong response ID");
83-
8482
auto flags = read_u16();
8583
bool is_response = (flags >> 15) & 1;
84+
auto opcode = static_cast<OpCode>((flags >> 11) & 0b1111);
8685
response.is_authoritative = (flags >> 10) & 1;
8786
bool is_truncated = (flags >> 9) & 1;
88-
auto opcode = static_cast<OpCode>((flags >> 11) & 0b1111);
8987
response.rcode = static_cast<RCode>(flags & 0b1111);
90-
91-
if (!is_response) throw std::runtime_error("Response is a query");
92-
if (is_truncated) throw std::runtime_error("Response is truncated");
93-
if (opcode != OpCode::Query) throw std::runtime_error("Response has wrong opcode");
94-
9588
auto question_count = read_u16();
96-
if (question_count != 1) throw std::runtime_error("Wrong question count");
9789
auto answer_count = read_u16();
9890
auto authority_count = read_u16();
9991
auto additional_count = read_u16();
10092

93+
// Validate response header.
94+
if (id != request_id) throw std::runtime_error("Wrong response ID");
95+
if (!is_response) throw std::runtime_error("Response is a query");
96+
if (opcode != OpCode::Query) throw std::runtime_error("Response has wrong opcode");
97+
if (is_truncated) throw std::runtime_error("Response is truncated");
98+
if (question_count != 1) throw std::runtime_error("Wrong question count");
99+
101100
// Validate question.
102101
if (read_domain() != request_domain) throw std::runtime_error("Wrong question domain");
103102
if (read_u16<RRType>() != request_rr_type) throw std::runtime_error("Wrong question type");
104103
if (read_u16<DNSClass>() != DNSClass::Internet) throw std::runtime_error("Unknown DNS class");
105104

105+
// Read the answer.
106106
response.answers.reserve(answer_count);
107107
for (uint16_t i = 0; i < answer_count; i++) response.answers.push_back(read_rr());
108-
109108
response.authority.reserve(authority_count);
110109
for (uint16_t i = 0; i < authority_count; i++) response.authority.push_back(read_rr());
111-
112110
response.additional.reserve(additional_count);
113111
for (uint16_t i = 0; i < additional_count; i++) response.additional.push_back(read_rr());
114112

115-
if (offset != buffer.size()) throw std::runtime_error("Response is too long");
113+
if (offset != buffer.size()) throw std::runtime_error("Failed to parse the response");
116114
return response;
117115
}
118116

@@ -193,7 +191,7 @@ class ResponseReader {
193191
std::string domain;
194192
read_domain_rec(allow_compression, domain);
195193
// Handle root domain.
196-
if (domain.empty()) domain.push_back('.');
194+
if (domain.empty()) return ".";
197195
if (domain.size() > MAX_DOMAIN_LENGTH) throw std::runtime_error("Domain is too long");
198196
return domain;
199197
}
@@ -280,9 +278,9 @@ class ResponseReader {
280278
ds.key_tag = read_u16();
281279
ds.signing_algorithm = read_u8<SigningAlgorithm>();
282280
ds.digest_algorithm = read_u8<DigestAlgorithm>();
283-
read(get_ds_digest_size(ds.digest_algorithm), ds.digest);
281+
read(dnssec::get_ds_digest_size(ds.digest_algorithm), ds.digest);
284282

285-
// Save data to verify the RRSIG later.
283+
// Save data to authenticate the RRSIG later.
286284
ResponseReader data_reader{buffer, data_offset};
287285
data_reader.read(data_length, ds.data);
288286

@@ -306,7 +304,7 @@ class ResponseReader {
306304
auto signature_length = data_length - (offset - data_offset);
307305
read(signature_length, rrsig.signature);
308306

309-
// Save data without signature to verify it later.
307+
// Save data without signature to authenticate it later.
310308
ResponseReader data_reader{buffer, data_offset};
311309
data_reader.read(data_length - signature_length, rrsig.data);
312310

@@ -345,7 +343,7 @@ class ResponseReader {
345343
auto domain_size = offset - data_offset;
346344
nsec.types = read_rr_type_bitmap(data_length - domain_size);
347345

348-
// Save data to verify the RRSIG later.
346+
// Save data to authenticate the RRSIG later.
349347
ResponseReader data_reader{buffer, data_offset};
350348
data_reader.read(data_length, nsec.data);
351349

@@ -369,10 +367,10 @@ class ResponseReader {
369367
auto key_size = data_length - 4;
370368
read(key_size, dnskey.key);
371369

372-
// Save data to verify the RRSIG later.
370+
// Save data to authenticate the RRSIG later.
373371
ResponseReader data_reader{buffer, data_offset};
374372
data_reader.read(data_length, dnskey.data);
375-
dnskey.key_tag = compute_key_tag(dnskey.data);
373+
dnskey.key_tag = dnssec::compute_key_tag(dnskey.data);
376374

377375
return dnskey;
378376
}
@@ -394,7 +392,7 @@ class ResponseReader {
394392
auto nsec3_data_length = offset - data_offset;
395393
nsec3.types = read_rr_type_bitmap(data_length - nsec3_data_length);
396394

397-
// Save data to verify the RRSIG later.
395+
// Save data to authenticate the RRSIG later.
398396
ResponseReader data_reader{buffer, data_offset};
399397
data_reader.read(data_length, nsec3.data);
400398

src/dns.hh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ struct RRSIG {
162162
uint16_t key_tag;
163163
std::string signer_name;
164164
std::vector<uint8_t> signature;
165-
// Data does not include the signature since it is only used to verify it.
165+
// Data does not include the signature since it is only used to authenticate it.
166166
std::vector<uint8_t> data;
167167
};
168168

@@ -270,7 +270,7 @@ public:
270270
case RRType::DS: {
271271
const auto &ds = std::get<DS>(rr.data);
272272
std::format_to(out, "{} {} {} {}", ds.key_tag, std::to_underlying(ds.signing_algorithm),
273-
std::to_underlying(ds.digest_algorithm), hex_string_encode(ds.digest));
273+
std::to_underlying(ds.digest_algorithm), hex_encode(ds.digest));
274274
} break;
275275
case RRType::RRSIG: {
276276
const auto &rrsig = std::get<RRSIG>(rr.data);
@@ -293,8 +293,8 @@ public:
293293
case RRType::NSEC3: {
294294
const auto &nsec3 = std::get<NSEC3>(rr.data);
295295
std::format_to(out, "{} {} {} {} {} (", std::to_underlying(nsec3.algorithm), nsec3.flags,
296-
nsec3.iterations, nsec3.salt.empty() ? "-" : hex_string_encode(nsec3.salt),
297-
base32_encode(nsec3.next_domain_hash));
296+
nsec3.iterations, nsec3.salt.empty() ? "-" : hex_encode(nsec3.salt),
297+
base32hex_encode(nsec3.next_domain_hash));
298298
print_types(out, nsec3.types);
299299
std::format_to(out, ")");
300300
} break;

src/dnssec.cc

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,6 @@
2121
#include "write.hh"
2222

2323
namespace {
24-
struct RRWithData {
25-
std::reference_wrapper<const RR> rr;
26-
std::vector<std::string_view> labels;
27-
std::vector<uint8_t> data;
28-
29-
RRWithData(const RR &rr, std::vector<std::string_view> &&labels) : rr(rr), labels(labels) {}
30-
};
31-
3224
using EVP_PKEY_unique_ptr = std::unique_ptr<EVP_PKEY, decltype([](auto *pkey) { EVP_PKEY_free(pkey); })>;
3325
using BIGNUM_unique_ptr = std::unique_ptr<BIGNUM, decltype([](auto *bn) { BN_free(bn); })>;
3426
using OSSL_PARAM_BLD_unique_ptr
@@ -105,7 +97,7 @@ EVP_PKEY_unique_ptr load_ecdsa_key(const std::vector<uint8_t> &dnskey, const std
10597
}
10698

10799
EVP_PKEY_unique_ptr load_eddsa_key(const std::vector<uint8_t> &dnskey, int type) {
108-
auto *pkey = EVP_PKEY_new_raw_public_key(type, NULL, dnskey.data(), dnskey.size());
100+
auto *pkey = EVP_PKEY_new_raw_public_key(type, nullptr, dnskey.data(), dnskey.size());
109101
if (pkey == nullptr) throw std::runtime_error("Failed to load EdDSA key");
110102
return EVP_PKEY_unique_ptr{pkey};
111103
}
@@ -215,6 +207,14 @@ std::vector<std::string_view> domain_to_labels(const std::string_view &domain) {
215207
return labels;
216208
}
217209

210+
struct RRWithData {
211+
std::reference_wrapper<const RR> rr;
212+
std::vector<std::string_view> labels;
213+
std::vector<uint8_t> data;
214+
215+
RRWithData(const RR &rr, std::vector<std::string_view> &&labels) : rr(rr), labels(labels) {}
216+
};
217+
218218
std::vector<RRWithData> add_data_to_rrset(const std::vector<RR> &rrset) {
219219
std::vector<RRWithData> result;
220220
result.reserve(rrset.size());
@@ -359,10 +359,10 @@ int compare_domains(const std::vector<std::string_view> &a, const std::vector<st
359359
return 0;
360360
}
361361

362-
bool is_domain_between(const std::string_view &domain, const std::string_view &before, const std::string_view &after) {
362+
bool is_domain_between(const std::string &domain, const std::string &before, const std::string &after) {
363+
auto domain_labels = domain_to_labels(domain);
363364
auto before_labels = domain_to_labels(before);
364365
auto after_labels = domain_to_labels(after);
365-
auto domain_labels = domain_to_labels(domain);
366366
return compare_domains(before_labels, domain_labels) < 0 && compare_domains(domain_labels, after_labels) < 0;
367367
}
368368

@@ -394,7 +394,7 @@ std::string get_nsec3_domain(const NSEC3 &nsec3, const std::string_view &domain,
394394
}
395395
}
396396

397-
return base32_encode(digest) + "." + zone_domain;
397+
return base32hex_encode(digest) + "." + zone_domain;
398398
}
399399

400400
std::optional<NSEC3> find_covering_nsec3(const std::vector<RR> &nsec3_rrset, const std::string_view &domain,
@@ -406,7 +406,7 @@ std::optional<NSEC3> find_covering_nsec3(const std::vector<RR> &nsec3_rrset, con
406406
auto covered_domain = get_nsec3_domain(nsec3, domain, zone_domain);
407407
for (const auto &nsec3_rr : nsec3_rrset) {
408408
const auto &nsec3 = std::get<NSEC3>(nsec3_rr.data);
409-
auto next_domain = base32_encode(nsec3.next_domain_hash) + "." + zone_domain;
409+
auto next_domain = base32hex_encode(nsec3.next_domain_hash) + "." + zone_domain;
410410
if (is_domain_between(covered_domain, nsec3_rr.domain, next_domain)) return std::get<NSEC3>(nsec3_rr.data);
411411
}
412412
} catch (...) {
@@ -463,6 +463,7 @@ std::optional<EncloserProof> verify_closest_encloser_proof(const std::vector<RR>
463463
}
464464
} // namespace
465465

466+
namespace dnssec {
466467
int get_ds_digest_size(DigestAlgorithm algorithm) {
467468
auto digest_size = EVP_MD_get_size(get_ds_digest_algorithm(algorithm));
468469
if (digest_size <= 0) throw std::runtime_error("Failed to get digest size");
@@ -550,10 +551,10 @@ bool authenticate_rrset(const std::vector<RR> &rrset, const std::vector<RRSIG> &
550551

551552
bool authenticate_delegation(const std::vector<RR> &dnskey_rrset, const std::vector<DS> &dss,
552553
const std::vector<RRSIG> &rrsigs, const std::string &zone_domain) {
553-
if (dnskey_rrset.empty() || dss.empty()) return {};
554+
if (dnskey_rrset.empty() || dss.empty()) return false;
554555

555556
EVP_MD_CTX_unique_ptr ctx{EVP_MD_CTX_new()};
556-
if (ctx == nullptr) return {};
557+
if (ctx == nullptr) return false;
557558

558559
std::vector<uint8_t> canonical_domain;
559560
std::vector<uint8_t> digest;
@@ -565,7 +566,9 @@ bool authenticate_delegation(const std::vector<RR> &dnskey_rrset, const std::vec
565566
if (digest_size <= 0) continue;
566567

567568
for (const auto &dnskey_rr : dnskey_rrset) {
569+
if (dnskey_rr.type != RRType::DNSKEY) return false;
568570
const auto &dnskey = std::get<DNSKEY>(dnskey_rr.data);
571+
569572
if (dnskey.key_tag != ds.key_tag) continue;
570573

571574
canonical_domain.clear();
@@ -602,9 +605,10 @@ bool authenticate_name_error(const std::string &domain, const std::vector<RR> &n
602605
return find_covering_nsec3(nsec3_rrset, wildcard_domain, zone_domain).has_value();
603606
}
604607

605-
for (const auto &rr : nsec_rrset) {
606-
const auto &nsec = std::get<NSEC>(rr.data);
607-
if (is_domain_between(domain, rr.domain, nsec.next_domain)) return true;
608+
for (const auto &nsec_rr : nsec_rrset) {
609+
if (nsec_rr.type != RRType::NSEC) return false;
610+
const auto &nsec = std::get<NSEC>(nsec_rr.data);
611+
if (is_domain_between(domain, nsec_rr.domain, nsec.next_domain)) return true;
608612
}
609613

610614
return false;
@@ -625,9 +629,9 @@ bool authenticate_no_ds(const std::string &domain, const std::vector<RR> &nsec3_
625629
return encloser_proof.has_value() && encloser_proof->next_closer_opt_out;
626630
}
627631

628-
if (!nsec_rr.has_value()) return false;
629-
632+
if (!nsec_rr.has_value() || nsec_rr->type != RRType::NSEC) return false;
630633
const auto &nsec = std::get<NSEC>(nsec_rr->data);
634+
631635
if (nsec.types.contains(RRType::DS) || nsec.types.contains(RRType::CNAME)) return false;
632636

633637
return true;
@@ -642,9 +646,9 @@ bool authenticate_no_rrset(RRType rr_type, const std::string &domain, const std:
642646
return true;
643647
}
644648

645-
if (!nsec_rr.has_value()) return false;
646-
649+
if (!nsec_rr.has_value() || nsec_rr->type != RRType::NSEC) return false;
647650
const auto &nsec = std::get<NSEC>(nsec_rr->data);
648-
if (nsec.types.contains(rr_type) || nsec.types.contains(RRType::CNAME)) return false;
649-
return true;
651+
652+
return !nsec.types.contains(rr_type) && !nsec.types.contains(RRType::CNAME);
650653
}
654+
}; // namespace dnssec

src/dnssec.hh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <vector>
66
#include "dns.hh"
77

8+
namespace dnssec {
89
int get_ds_digest_size(DigestAlgorithm algorithm);
910
uint16_t compute_key_tag(const std::vector<uint8_t> &data);
1011

@@ -18,3 +19,4 @@ bool authenticate_no_ds(const std::string &domain, const std::vector<RR> &nsec3_
1819
const std::string &zone_domain);
1920
bool authenticate_no_rrset(RRType rr_type, const std::string &domain, const std::vector<RR> &nsec3_rrset,
2021
const std::optional<RR> &nsec_rr, const std::string &zone_domain);
22+
}; // namespace dnssec

src/encode.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ std::string base64_encode(const std::vector<uint8_t> &src) {
3737
return output;
3838
}
3939

40-
// Base 32 Encoding with Extended Hex Alphabet.
41-
std::string base32_encode(const std::vector<uint8_t> &src) {
40+
std::string base32hex_encode(const std::vector<uint8_t> &src) {
4241
std::string output;
4342
output.reserve(((src.size() + 4) / 5) * 8);
4443

@@ -98,7 +97,7 @@ std::string base32_encode(const std::vector<uint8_t> &src) {
9897
return output;
9998
}
10099

101-
std::string hex_string_encode(const std::vector<uint8_t> &src) {
100+
std::string hex_encode(const std::vector<uint8_t> &src) {
102101
std::string output;
103102
output.reserve(src.size() * 2);
104103

src/encode.hh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include <vector>
66

77
std::string base64_encode(const std::vector<uint8_t> &src);
8-
std::string base32_encode(const std::vector<uint8_t> &src);
9-
std::string hex_string_encode(const std::vector<uint8_t> &src);
8+
std::string base32hex_encode(const std::vector<uint8_t> &src);
9+
std::string hex_encode(const std::vector<uint8_t> &src);

src/main.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ int main(int argc, char **argv) {
7070
("edns", "EDNS", cxxopts::value<FeatureState>()->default_value("on")) //
7171
("dnssec", "DNSSEC", cxxopts::value<FeatureState>()->default_value("on")) //
7272
("cookies", "Cookies", cxxopts::value<FeatureState>()->default_value("on")) //
73-
("use-root", "Load root nameservers", cxxopts::value<bool>()->default_value("true")) //
74-
("use-resolv-conf", "Load /etc/resolv.conf", cxxopts::value<bool>()->default_value("true"));
73+
("use-root", "Use root nameservers", cxxopts::value<bool>()->default_value("true")) //
74+
("use-config", "Use nameservers from /etc/resolv.conf", cxxopts::value<bool>()->default_value("true"));
7575
options.parse_positional({"domain"});
7676

7777
auto result = options.parse(argc, argv);
@@ -84,7 +84,7 @@ int main(int argc, char **argv) {
8484
.timeout_ms = result["timeout"].as<uint64_t>() * 1000,
8585
.nameserver = result["server"].as_optional<std::string>(),
8686
.use_root_nameservers = result["use-root"].as<bool>(),
87-
.use_resolve_config = result["use-resolv-conf"].as<bool>(),
87+
.use_resolve_config = result["use-config"].as<bool>(),
8888
.port = result["port"].as<uint16_t>(),
8989
.verbose = result["verbose"].as<bool>(),
9090
.enable_rd = result["rdflag"].as<bool>(),

0 commit comments

Comments
 (0)