Skip to content

Commit ed47ac6

Browse files
authored
Add support for CFRG Hybrid KEMs (#455)
* Add support for ML-KEM * clang-format * Add testing of ML-KEM * Specuatively enable OQS for OpenSSL 3 and BoringSSL * Update build configuration to work for OpenSSL 3 and BoringSSL * Add implementations using OpenSSL3 and BoringSSL * Remove BoringSSL support because of missing SHAKE256 * clang-format * CI fixes * Test against the HPKE PQ test vectors * Pass test vectors * clang-format * Change config order to avoid imposing flags on libOQS * Use liboqs from environment instead of vendored version * Use vcpkg for libOQS * Revert hack changes * Build interop tests with OpenSSL 3 * More hybrid KEM implementation * Add SHA3_256 * Factor out SHAKE256 * More hybrid KEM implementation * Add support for hybrid KEMs * Add PQ test vectors * clang-format * Get rid of strlen * Add a flag to disable PQ and use it consistently * clang-format * Add missing comma * Fix CI issue * More integer conversions
1 parent 6a0b31e commit ed47ac6

File tree

19 files changed

+1958
-51
lines changed

19 files changed

+1958
-51
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ option(SANITIZERS "Enable sanitizers" OFF)
1111
option(MLS_NAMESPACE_SUFFIX "Namespace Suffix for CXX and CMake Export")
1212
option(DISABLE_GREASE "Disables the inclusion of MLS protocol recommended GREASE values" OFF)
1313
option(REQUIRE_BORINGSSL "Require BoringSSL instead of OpenSSL" OFF)
14+
option(DISABLE_PQ "Disables support for PQ algorithms even when they would otherwise be enabled" OFF)
1415

1516
if(MLS_NAMESPACE_SUFFIX)
1617
set(MLS_CXX_NAMESPACE "mls_${MLS_NAMESPACE_SUFFIX}" CACHE STRING "Top-level Namespace for CXX")

Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ devB: ${TOOLCHAIN_FILE}
4646
-DVCPKG_MANIFEST_DIR=${BORINGSSL_MANIFEST} \
4747
-DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE}
4848

49+
# Like `dev`, but using OpenSSL 3 with PQ disabled
50+
dev-no-pq:
51+
cmake -B${BUILD_DIR} -DTESTING=ON -DCMAKE_BUILD_TYPE=Debug \
52+
-DDISABLE_PQ=ON \
53+
-DVCPKG_MANIFEST_DIR=${OPENSSL3_MANIFEST} \
54+
-DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE}
55+
4956
test: ${BUILD_DIR} test/*
5057
cmake --build ${BUILD_DIR} --target mlspp_test
5158

lib/hpke/CMakeLists.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,19 @@ if ( OPENSSL_FOUND )
4242

4343
elseif (REQUIRE_BORINGSSL)
4444
message(FATAL_ERROR "BoringSSL required but not found")
45+
elseif (${OPENSSL_VERSION} VERSION_GREATER_EQUAL 3.5)
46+
target_compile_definitions(${CURRENT_LIB_NAME} PUBLIC WITH_OPENSSL3)
47+
48+
if(NOT DISABLE_PQ)
49+
target_compile_definitions(${CURRENT_LIB_NAME} PUBLIC WITH_PQ)
50+
endif()
4551
elseif (${OPENSSL_VERSION} VERSION_GREATER_EQUAL 3)
4652
target_compile_definitions(${CURRENT_LIB_NAME} PUBLIC WITH_OPENSSL3)
4753
elseif (${OPENSSL_VERSION} VERSION_GREATER_EQUAL 1.1.1)
48-
set(USING_LIBOQS ON)
54+
if(NOT DISABLE_PQ)
55+
set(USING_LIBOQS ON)
56+
target_compile_definitions(${CURRENT_LIB_NAME} PUBLIC WITH_PQ)
57+
endif()
4958
else()
5059
message(FATAL_ERROR "OpenSSL 1.1.1 or greater is required")
5160
endif()

lib/hpke/include/hpke/digest.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <memory>
44

55
#include <bytes/bytes.h>
6+
#include <hpke/hpke.h>
67
#include <namespace.h>
78

89
using namespace MLS_NAMESPACE::bytes_ns;
@@ -16,6 +17,9 @@ struct Digest
1617
SHA256,
1718
SHA384,
1819
SHA512,
20+
#if !defined(WITH_BORINGSSL)
21+
SHA3_256,
22+
#endif
1923
};
2024

2125
template<ID id>
@@ -35,4 +39,16 @@ struct Digest
3539
friend struct HKDF;
3640
};
3741

42+
#if !defined(WITH_BORINGSSL)
43+
struct SHAKE256
44+
{
45+
static bytes derive(const bytes& ikm, size_t length);
46+
static bytes labeled_derive(KEM::ID kem_id,
47+
const bytes& ikm,
48+
const std::string& label,
49+
const bytes& context,
50+
size_t length);
51+
};
52+
#endif
53+
3854
} // namespace MLS_NAMESPACE::hpke

lib/hpke/include/hpke/hpke.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,14 @@ struct KEM
1919
DHKEM_X25519_SHA256 = 0x0020,
2020
#if !defined(WITH_BORINGSSL)
2121
DHKEM_X448_SHA512 = 0x0021,
22+
#endif
23+
#if defined(WITH_PQ)
2224
MLKEM512 = 0x0040,
2325
MLKEM768 = 0x0041,
2426
MLKEM1024 = 0x0042,
27+
MLKEM768_P256 = 0x0050,
28+
MLKEM1024_P384 = 0x0051,
29+
MLKEM768_X25519 = 0x647a,
2530
#endif
2631
};
2732

@@ -42,6 +47,7 @@ struct KEM
4247
};
4348

4449
const ID id;
50+
const size_t seed_size;
4551
const size_t secret_size;
4652
const size_t enc_size;
4753
const size_t pk_size;
@@ -71,6 +77,7 @@ struct KEM
7177

7278
protected:
7379
KEM(ID id_in,
80+
size_t seed_size_in,
7481
size_t secret_size_in,
7582
size_t enc_size_in,
7683
size_t pk_size_in,

lib/hpke/scripts/test-vectors-pq.json

Lines changed: 1380 additions & 0 deletions
Large diffs are not rendered by default.

lib/hpke/src/dhkem.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ DHKEM::get<KEM::ID::DHKEM_X448_SHA512>()
7676

7777
DHKEM::DHKEM(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in)
7878
: KEM(kem_id_in,
79+
group_in.seed_size,
7980
kdf_in.hash_size,
8081
group_in.pk_size,
8182
group_in.pk_size,

lib/hpke/src/digest.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <openssl/core_names.h>
88
#endif
99

10+
#include "common.h"
1011
#include "openssl_common.h"
1112

1213
namespace MLS_NAMESPACE::hpke {
@@ -24,6 +25,11 @@ openssl_digest_type(Digest::ID digest)
2425
case Digest::ID::SHA512:
2526
return EVP_sha512();
2627

28+
#if !defined(WITH_BORINGSSL)
29+
case Digest::ID::SHA3_256:
30+
return EVP_sha3_256();
31+
#endif
32+
2733
default:
2834
throw std::runtime_error("Unsupported ciphersuite");
2935
}
@@ -43,6 +49,9 @@ openssl_digest_name(Digest::ID digest)
4349
case Digest::ID::SHA512:
4450
return OSSL_DIGEST_NAME_SHA2_512;
4551

52+
case Digest::ID::SHA3_256:
53+
return OSSL_DIGEST_NAME_SHA3_256;
54+
4655
default:
4756
throw std::runtime_error("Unsupported digest algorithm");
4857
}
@@ -73,6 +82,16 @@ Digest::get<Digest::ID::SHA512>()
7382
return instance;
7483
}
7584

85+
#if !defined(WITH_BORINGSSL)
86+
template<>
87+
const Digest&
88+
Digest::get<Digest::ID::SHA3_256>()
89+
{
90+
static const Digest instance(Digest::ID::SHA3_256);
91+
return instance;
92+
}
93+
#endif
94+
7695
Digest::Digest(Digest::ID id_in)
7796
: id(id_in)
7897
, hash_size(EVP_MD_size(openssl_digest_type(id_in)))
@@ -185,4 +204,50 @@ Digest::hmac_for_hkdf_extract(const bytes& key, const bytes& data) const
185204
return md;
186205
}
187206

207+
#if !defined(WITH_BORINGSSL)
208+
bytes
209+
SHAKE256::derive(const bytes& ikm, size_t length)
210+
{
211+
auto ctx = make_typed_unique(EVP_MD_CTX_new());
212+
if (!ctx) {
213+
throw openssl_error();
214+
}
215+
216+
if (EVP_DigestInit_ex(ctx.get(), EVP_shake256(), nullptr) != 1) {
217+
throw openssl_error();
218+
}
219+
220+
if (EVP_DigestUpdate(ctx.get(), ikm.data(), ikm.size()) != 1) {
221+
throw openssl_error();
222+
}
223+
224+
auto out = bytes(length);
225+
if (EVP_DigestFinalXOF(ctx.get(), out.data(), out.size()) != 1) {
226+
throw openssl_error();
227+
}
228+
229+
return out;
230+
}
231+
232+
bytes
233+
SHAKE256::labeled_derive(KEM::ID kem_id,
234+
const bytes& ikm,
235+
const std::string& label,
236+
const bytes& context,
237+
size_t length)
238+
{
239+
const auto hpke_version = from_ascii("HPKE-v1");
240+
const auto label_kem = from_ascii("KEM");
241+
const auto suite_id = label_kem + i2osp(uint16_t(kem_id), 2);
242+
const auto label_bytes = from_ascii(label);
243+
const auto label_len = i2osp(uint16_t(label_bytes.size()), 2);
244+
const auto length_bytes = i2osp(uint16_t(length), 2);
245+
246+
return derive(ikm + hpke_version + suite_id + label_len + label_bytes +
247+
length_bytes + context,
248+
length);
249+
}
250+
251+
#endif // !defined(WITH_BORINGSSL)
252+
188253
} // namespace MLS_NAMESPACE::hpke

lib/hpke/src/group.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,57 @@ struct ECKeyGroup : public EVPGroup
350350
}
351351
}
352352

353+
#if defined(WITH_OPENSSL3)
354+
auto key = keypair_evp_key(sk);
355+
return std::make_unique<EVPGroup::PrivateKey>(key.release());
356+
#else
357+
auto pt = make_typed_unique(EC_POINT_new(group));
358+
EC_POINT_mul(group, pt.get(), sk.get(), nullptr, nullptr, nullptr);
359+
360+
EC_KEY_set_private_key(eckey.get(), sk.get());
361+
EC_KEY_set_public_key(eckey.get(), pt.get());
362+
363+
auto pkey = to_pkey(eckey.release());
364+
return std::make_unique<PrivateKey>(pkey.release());
365+
#endif
366+
}
367+
368+
std::unique_ptr<Group::PrivateKey> random_scalar(
369+
const bytes& seed) const override
370+
{
371+
#if defined(WITH_OPENSSL3)
372+
auto* group = EC_GROUP_new_by_curve_name_ex(nullptr, nullptr, curve_nid);
373+
auto group_ptr = make_typed_unique(group);
374+
#else
375+
auto eckey = new_ec_key();
376+
const auto* group = EC_KEY_get0_group(eckey.get());
377+
#endif
378+
379+
auto order = make_typed_unique(BN_new());
380+
if (1 != EC_GROUP_get_order(group, order.get(), nullptr)) {
381+
throw openssl_error();
382+
}
383+
384+
auto sk = make_typed_unique(BN_new());
385+
BN_zero(sk.get());
386+
387+
auto start = size_t(0);
388+
auto end = sk_size;
389+
auto candidate = seed.slice(start, end);
390+
auto candidate_size = static_cast<int>(candidate.size());
391+
sk.reset(BN_bin2bn(candidate.data(), candidate_size, nullptr));
392+
393+
while (BN_is_zero(sk.get()) != 0 || BN_cmp(sk.get(), order.get()) != -1) {
394+
start = end;
395+
end = end + sk_size;
396+
if (end > seed.size()) {
397+
throw std::runtime_error("Rejection sampling failed");
398+
}
399+
400+
candidate = seed.slice(start, end);
401+
sk.reset(BN_bin2bn(candidate.data(), candidate_size, nullptr));
402+
}
403+
353404
#if defined(WITH_OPENSSL3)
354405
auto key = keypair_evp_key(sk);
355406
return std::make_unique<EVPGroup::PrivateKey>(key.release());
@@ -785,6 +836,16 @@ struct RawKeyGroup : public EVPGroup
785836
return deserialize_private(skm);
786837
}
787838

839+
std::unique_ptr<Group::PrivateKey> random_scalar(
840+
const bytes& seed) const override
841+
{
842+
if (seed.size() != sk_size) {
843+
throw std::runtime_error("Invalid seed");
844+
}
845+
846+
return deserialize_private(seed);
847+
}
848+
788849
bytes serialize(const Group::PublicKey& pk) const override
789850
{
790851
const auto& rpk = dynamic_cast<const PublicKey&>(pk);
@@ -952,6 +1013,32 @@ Group::get<Group::ID::Ed448>()
9521013
return instance;
9531014
}
9541015

1016+
static inline size_t
1017+
group_seed_size(Group::ID group_id)
1018+
{
1019+
switch (group_id) {
1020+
case Group::ID::P256:
1021+
return 128;
1022+
case Group::ID::P384:
1023+
return 48;
1024+
case Group::ID::P521:
1025+
// XXX(RLB): This may be wrong, but we're never going to use it
1026+
return 66;
1027+
case Group::ID::X25519:
1028+
return 32;
1029+
case Group::ID::X448:
1030+
return 56;
1031+
1032+
// Non-DH groups
1033+
case Group::ID::Ed25519:
1034+
case Group::ID::Ed448:
1035+
return 0;
1036+
1037+
default:
1038+
throw std::runtime_error("Unknown group");
1039+
}
1040+
}
1041+
9551042
static inline size_t
9561043
group_dh_size(Group::ID group_id)
9571044
{
@@ -1066,6 +1153,7 @@ group_jwk_key_type(Group::ID group_id)
10661153

10671154
Group::Group(ID group_id_in, const KDF& kdf_in)
10681155
: id(group_id_in)
1156+
, seed_size(group_seed_size(group_id_in))
10691157
, dh_size(group_dh_size(group_id_in))
10701158
, pk_size(group_pk_size(group_id_in))
10711159
, sk_size(group_sk_size(group_id_in))

lib/hpke/src/group.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ struct Group
4141
virtual ~Group() = default;
4242

4343
const ID id;
44+
const size_t seed_size;
4445
const size_t dh_size;
4546
const size_t pk_size;
4647
const size_t sk_size;
@@ -51,6 +52,8 @@ struct Group
5152
virtual std::unique_ptr<PrivateKey> derive_key_pair(
5253
const bytes& suite_id,
5354
const bytes& ikm) const = 0;
55+
virtual std::unique_ptr<PrivateKey> random_scalar(
56+
const bytes& seed) const = 0;
5457

5558
virtual bytes serialize(const PublicKey& pk) const = 0;
5659
virtual std::unique_ptr<PublicKey> deserialize(const bytes& enc) const = 0;

0 commit comments

Comments
 (0)