Skip to content

reduce boilerplate implementing comparisons for user-defined types #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
357 changes: 327 additions & 30 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// -*- mode: rust; -*-
//
// This file is part of subtle, part of the dalek cryptography project.
// Copyright (c) 2016-2018 isis lovecruft, Henry de Valence
// Copyright (c) 2016-2023 isis lovecruft, Henry de Valence
// See LICENSE for licensing information.
//
// Authors:
@@ -87,6 +87,10 @@
#[macro_use]
extern crate std;

#[cfg(test)]
extern crate rand;

use core::cmp::Ordering;
use core::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Neg, Not};
use core::option::Option;

@@ -111,6 +115,11 @@ use core::option::Option;
pub struct Choice(u8);

impl Choice {
#[inline]
pub(crate) const fn of_bool(of: bool) -> Self {
Self(of as u8)
}

/// Unwrap the `Choice` wrapper to reveal the underlying `u8`.
///
/// # Note
@@ -236,6 +245,15 @@ impl From<u8> for Choice {
}
}

/// A method to extend constant-time comparisons to abstract data types with multiple parts to
/// iterate over.
pub trait IteratedOperation {
/// Initialize any state retained across iterations.
fn initiate() -> Self;
/// Parse the state to determine whether the operation succeeded after iteration.
fn extract_result(self) -> Choice;
}

/// An `Eq`-like trait that produces a `Choice` instead of a `bool`.
///
/// # Example
@@ -257,24 +275,74 @@ pub trait ConstantTimeEq {
///
/// * `Choice(1u8)` if `self == other`;
/// * `Choice(0u8)` if `self != other`.
#[inline]
fn ct_eq(&self, other: &Self) -> Choice;
}

/// Get the conjunction of [`ConstantTimeEq::ct_eq`] over multiple possibly-heterogenous pairs
/// of elements.
///
///```
/// use subtle::{Choice, ConstantTimeEq, IteratedEq, IteratedOperation};
///
/// struct S { pub len: usize, pub live: bool };
/// impl ConstantTimeEq for S {
/// fn ct_eq(&self, other: &Self) -> Choice {
/// let mut x = IteratedEq::initiate();
/// x.apply_eq(&self.len, &other.len);
/// x.apply_eq(&(self.live as u8), &(other.live as u8));
/// x.extract_result()
/// }
/// }
///
/// let s1 = S { len: 2, live: true };
/// let s2 = S { len: 3, live: true };
/// assert_eq!(0, s1.ct_eq(&s2).unwrap_u8());
/// assert_eq!(1, s1.ct_eq(&s1).unwrap_u8());
/// assert_eq!(1, s2.ct_eq(&s2).unwrap_u8());
///```
pub struct IteratedEq {
still_equal: Choice,
}

impl IteratedOperation for IteratedEq {
fn initiate() -> Self {
Self {
still_equal: Choice::of_bool(true),
}
}
fn extract_result(self) -> Choice {
self.still_equal.into()
}
}

impl IteratedEq {
/// Unconditionally AND internal state with the result of a constant-time "equals" comparison.
///
/// [`Self::initiate()`] begins with internal state set to "true", so we can think of this
/// strategy as "assuming equal until proven wrong".
#[inline]
pub fn apply_eq<T: ConstantTimeEq + ?Sized>(&mut self, a: &T, b: &T) {
self.still_equal &= a.ct_eq(b);
}
}

impl<T: ConstantTimeEq> ConstantTimeEq for [T] {
/// Check whether two slices of `ConstantTimeEq` types are equal.
///
/// # Note
///
/// This function short-circuits if the lengths of the input slices
/// are different. Otherwise, it should execute in time independent
/// of the slice contents.
/// This function short-circuits if the lengths of the input slices are different. Otherwise,
/// it should execute in time independent of the slice contents. When the slice lengths differ,
/// this implementation applies the [shortlex] ordering scheme, which sorts shorter slices
/// before longer slices without checking the contents.
///
/// [shortlex]: https://en.wikipedia.org/wiki/Shortlex_order
///
/// Since arrays coerce to slices, this function works with fixed-size arrays:
///
/// ```
/// # use subtle::ConstantTimeEq;
/// #
/// use subtle::ConstantTimeEq;
///
/// let a: [u8; 8] = [0,1,2,3,4,5,6,7];
/// let b: [u8; 8] = [0,1,2,3,0,1,2,3];
///
@@ -291,18 +359,18 @@ impl<T: ConstantTimeEq> ConstantTimeEq for [T] {
// Short-circuit on the *lengths* of the slices, not their
// contents.
if len != _rhs.len() {
return Choice::from(0);
return Choice::of_bool(false);
}

// This loop shouldn't be shortcircuitable, since the compiler
// shouldn't be able to reason about the value of the `u8`
// unwrapped from the `ct_eq` result.
let mut x = 1u8;
let mut x = IteratedEq::initiate();
for (ai, bi) in self.iter().zip(_rhs.iter()) {
x &= ai.ct_eq(bi).unwrap_u8();
x.apply_eq(ai, bi);
}

x.into()
x.extract_result()
}
}

@@ -380,7 +448,6 @@ pub trait ConditionallySelectable: Copy {
/// assert_eq!(z, y);
/// # }
/// ```
#[inline]
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self;

/// Conditionally assign `other` to `self`, according to `choice`.
@@ -530,7 +597,6 @@ pub trait ConditionallyNegatable {
/// unchanged.
///
/// This function should execute in constant time.
#[inline]
fn conditional_negate(&mut self, choice: Choice);
}

@@ -769,14 +835,76 @@ pub trait ConstantTimeGreater {
fn ct_gt(&self, other: &Self) -> Choice;
}

/// Get the result of applying [`ConstantTimeGreater::ct_gt`] over multiple possibly-heterogenous
/// pairs of elements. The "greater than" comparison assumes that the order of these pairs
/// is lexicographic.
///
///```
/// use subtle::{
/// Choice, IteratedOperation, ConstantTimeEq, IteratedEq, ConstantTimeGreater, LexicographicIteratedGreater,
/// };
///
/// struct S { pub len: usize, pub live: bool };
/// impl ConstantTimeEq for S {
/// fn ct_eq(&self, other: &Self) -> Choice {
/// let mut x = IteratedEq::initiate();
/// x.apply_eq(&self.len, &other.len);
/// x.apply_eq(&(self.live as u8), &(other.live as u8));
/// x.extract_result()
/// }
/// }
/// impl ConstantTimeGreater for S {
/// fn ct_gt(&self, other: &Self) -> Choice {
/// let mut x = LexicographicIteratedGreater::initiate();
/// x.apply_gt(&(self.len as u64), &(other.len as u64));
/// x.apply_gt(&(self.live as u8), &(other.live as u8));
/// x.extract_result()
/// }
/// }
///
/// let s1 = S { len: 2, live: true };
/// let s2 = S { len: 3, live: false };
/// let s3 = S { len: 3, live: true };
/// assert_eq!(0, s1.ct_eq(&s2).unwrap_u8());
/// assert_eq!(1, s1.ct_eq(&s1).unwrap_u8());
/// assert_eq!(1, s2.ct_gt(&s1).unwrap_u8());
/// assert_eq!(1, s3.ct_gt(&s2).unwrap_u8());
///```
pub struct LexicographicIteratedGreater {
was_gt: Choice,
was_lt: Choice,
}

impl IteratedOperation for LexicographicIteratedGreater {
fn initiate() -> Self {
Self {
was_gt: Choice::of_bool(false),
was_lt: Choice::of_bool(false),
}
}
fn extract_result(self) -> Choice {
self.was_gt.into()
}
}

impl LexicographicIteratedGreater {
/// Unconditionally modify internal state with result of two directed "greater" comparisons.
#[inline]
pub fn apply_gt<T: ConstantTimeGreater + ?Sized>(&mut self, a: &T, b: &T) {
let Self { was_gt, was_lt } = self;
*was_gt |= (!*was_lt) & a.ct_gt(&b);
*was_lt |= b.ct_gt(&a);
}
}

macro_rules! generate_unsigned_integer_greater {
($t_u: ty, $bit_width: expr) => {
impl ConstantTimeGreater for $t_u {
/// Returns Choice::from(1) iff x > y, and Choice::from(0) iff x <= y.
///
/// # Note
///
/// This algoritm would also work for signed integers if we first
/// This algorithm would also work for signed integers if we first
/// flip the top bit, e.g. `let x: u8 = x ^ 0x80`, etc.
#[inline]
fn ct_gt(&self, other: &$t_u) -> Choice {
@@ -801,7 +929,9 @@ macro_rules! generate_unsigned_integer_greater {
Choice::from((bit & 1) as u8)
}
}
}

impl ConstantTimeLess for $t_u {}
};
}

generate_unsigned_integer_greater!(u8, 8);
@@ -811,18 +941,67 @@ generate_unsigned_integer_greater!(u64, 64);
#[cfg(feature = "i128")]
generate_unsigned_integer_greater!(u128, 128);

impl<T: ConstantTimeGreater> ConstantTimeGreater for [T] {
/// Compare whether one slice of `ConstantTimeGreater` types is greater than another.
///
/// # Note
///
/// This function short-circuits if the lengths of the input slices are different. Otherwise,
/// it should execute in time independent of the slice contents. When the slice lengths differ,
/// this implementation applies the [shortlex] ordering scheme, which sorts shorter slices
/// before longer slices without checking the contents.
///
/// [shortlex]: https://en.wikipedia.org/wiki/Shortlex_order
///
/// Since arrays coerce to slices, this function also works with fixed-size arrays:
///
/// ```
/// use subtle::ConstantTimeGreater;
///
/// let a: [u8; 8] = [0,1,2,3,4,5,6,7];
/// let b: [u8; 8] = [0,1,2,3,0,1,2,3];
///
/// let a_gt_a = a.ct_gt(&a);
/// let a_gt_b = a.ct_gt(&b);
///
/// assert_eq!(a_gt_a.unwrap_u8(), 0);
/// assert_eq!(a_gt_b.unwrap_u8(), 1);
/// ```
#[inline]
fn ct_gt(&self, _rhs: &[T]) -> Choice {
let len = self.len();

// Short-circuit on the *lengths* of the slices, not their contents. Here we apply shortlex
// ordering, sorting shorter slices before longer ones.
match len.cmp(&_rhs.len()) {
Ordering::Equal => (),
Ordering::Less => {
return Choice::of_bool(false);
}
Ordering::Greater => {
return Choice::of_bool(true);
}
}

// This loop shouldn't be shortcircuitable, since the compiler
// shouldn't be able to reason about the value of the `u8`
// unwrapped from the `ct_gt` result.
let mut x = LexicographicIteratedGreater::initiate();
for (ai, bi) in self.iter().zip(_rhs.iter()) {
x.apply_gt(ai, bi);
}

x.extract_result()
}
}

/// A type which can be compared in some manner and be determined to be less
/// than another of the same type.
pub trait ConstantTimeLess: ConstantTimeEq + ConstantTimeGreater {
pub trait ConstantTimeLess: ConstantTimeGreater {
/// Determine whether `self < other`.
///
/// The bitwise-NOT of the return value of this function should be usable to
/// determine if `self >= other`.
///
/// A default implementation is provided and implemented for the unsigned
/// integer types.
///
/// This function should execute in constant time.
/// This function should execute in constant time. The default implementation simply calls
/// [`ConstantTimeGreater::ct_gt`] with the arguments switched.
///
/// # Returns
///
@@ -852,13 +1031,131 @@ pub trait ConstantTimeLess: ConstantTimeEq + ConstantTimeGreater {
/// ```
#[inline]
fn ct_lt(&self, other: &Self) -> Choice {
!self.ct_gt(other) & !self.ct_eq(other)
other.ct_gt(self)
}
}

impl ConstantTimeLess for u8 {}
impl ConstantTimeLess for u16 {}
impl ConstantTimeLess for u32 {}
impl ConstantTimeLess for u64 {}
#[cfg(feature = "i128")]
impl ConstantTimeLess for u128 {}
/// Get the result of applying [`ConstantTimeGreater::ct_gt`] over multiple possibly-heterogenous
/// pairs of elements. The "greater than" comparison assumes that the order of these pairs
/// is lexicographic.
///
///```
/// use subtle::{
/// Choice, IteratedOperation, ConstantTimeEq, IteratedEq, ConstantTimeGreater, LexicographicIteratedGreater,
/// ConstantTimeLess, LexicographicIteratedLess,
/// };
///
/// struct S { pub len: usize, pub live: bool };
/// impl ConstantTimeEq for S {
/// fn ct_eq(&self, other: &Self) -> Choice {
/// let mut x = IteratedEq::initiate();
/// x.apply_eq(&self.len, &other.len);
/// x.apply_eq(&(self.live as u8), &(other.live as u8));
/// x.extract_result()
/// }
/// }
/// impl ConstantTimeGreater for S {
/// fn ct_gt(&self, other: &Self) -> Choice {
/// let mut x = LexicographicIteratedGreater::initiate();
/// x.apply_gt(&(self.len as u64), &(other.len as u64));
/// x.apply_gt(&(self.live as u8), &(other.live as u8));
/// x.extract_result()
/// }
/// }
/// impl ConstantTimeLess for S {
/// fn ct_lt(&self, other: &Self) -> Choice {
/// let mut x = LexicographicIteratedLess::initiate();
/// x.apply_lt(&(self.len as u64), &(other.len as u64));
/// x.apply_lt(&(self.live as u8), &(other.live as u8));
/// x.extract_result()
/// }
/// }
///
/// let s1 = S { len: 2, live: true };
/// let s2 = S { len: 3, live: false };
/// let s3 = S { len: 3, live: true };
/// assert_eq!(0, s1.ct_eq(&s2).unwrap_u8());
/// assert_eq!(1, s1.ct_eq(&s1).unwrap_u8());
/// assert_eq!(1, s2.ct_gt(&s1).unwrap_u8());
/// assert_eq!(1, s2.ct_lt(&s3).unwrap_u8());
///```
pub struct LexicographicIteratedLess {
was_lt: Choice,
was_gt: Choice,
}

impl IteratedOperation for LexicographicIteratedLess {
fn initiate() -> Self {
Self {
was_lt: Choice::of_bool(false),
was_gt: Choice::of_bool(false),
}
}
fn extract_result(self) -> Choice {
self.was_lt.into()
}
}

impl LexicographicIteratedLess {
/// Unconditionally modify internal state with result of two directed "less" comparisons.
#[inline]
pub fn apply_lt<T: ConstantTimeLess + ?Sized>(&mut self, a: &T, b: &T) {
let Self { was_lt, was_gt } = self;
*was_lt |= (!*was_gt) & a.ct_lt(&b);
*was_gt |= b.ct_lt(&a);
}
}

impl<T: ConstantTimeLess> ConstantTimeLess for [T] {
/// Compare whether one slice of `ConstantTimeLess` types is greater than another.
///
/// # Note
///
/// This function short-circuits if the lengths of the input slices are different. Otherwise,
/// it should execute in time independent of the slice contents. When the slice lengths differ,
/// this implementation applies the [shortlex] ordering scheme, which sorts shorter slices
/// before longer slices without checking the contents.
///
/// [shortlex]: https://en.wikipedia.org/wiki/Shortlex_order
///
/// Since arrays coerce to slices, this function also works with fixed-size arrays:
///
/// ```
/// use subtle::ConstantTimeLess;
///
/// let a: [u8; 8] = [0,1,2,3,0,1,2,3];
/// let b: [u8; 8] = [0,1,2,3,4,5,6,7];
///
/// let a_lt_a = a.ct_lt(&a);
/// let a_lt_b = a.ct_lt(&b);
///
/// assert_eq!(a_lt_a.unwrap_u8(), 0);
/// assert_eq!(a_lt_b.unwrap_u8(), 1);
/// ```
#[inline]
fn ct_lt(&self, _rhs: &[T]) -> Choice {
let len = self.len();

// Short-circuit on the *lengths* of the slices, not their contents. Here we apply shortlex
// ordering, sorting shorter slices before longer ones.
match len.cmp(&_rhs.len()) {
Ordering::Equal => (),
Ordering::Less => {
return Choice::of_bool(true);
}
Ordering::Greater => {
return Choice::of_bool(false);
}
}

// This loop shouldn't be shortcircuitable, since the compiler
// shouldn't be able to reason about the value of the `u8`
// unwrapped from the `ct_lt` result.
let mut x = LexicographicIteratedLess::initiate();
for (ai, bi) in self.iter().zip(_rhs.iter()) {
x.apply_lt(ai, bi);
}

x.extract_result()
}
}