Skip to content
233 changes: 209 additions & 24 deletions rust/candid/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ struct Deserializer<'de> {
config: DecoderConfig,
recursion_depth: crate::utils::RecursionDepth,
primitive_vec_fast_path: Option<PrimitiveType>,
#[cfg(feature = "bignum")]
bignum_vec_fast_path: Option<BigNumFastPath>,
}

impl<'de> Deserializer<'de> {
Expand All @@ -322,6 +324,8 @@ impl<'de> Deserializer<'de> {
config: config.clone(),
recursion_depth: crate::utils::RecursionDepth::new(),
primitive_vec_fast_path: None,
#[cfg(feature = "bignum")]
bignum_vec_fast_path: None,
})
}
fn dump_state(&self) -> String {
Expand Down Expand Up @@ -442,10 +446,37 @@ impl<'de> Deserializer<'de> {
where
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(*self.expect_type == TypeInner::Int);
let mut bytes = vec![0u8];
if self.bignum_vec_fast_path.is_none() {
self.unroll_type()?;
assert!(*self.expect_type == TypeInner::Int);
}
let pos = self.input.position();
if !self.is_untyped {
match self.wire_type.as_ref() {
TypeInner::Int => match leb128::read::signed(&mut self.input) {
Ok(value) => {
self.add_cost((self.input.position() - pos) as usize)?;
return visitor.visit_i64(value);
}
Err(leb128::read::Error::Overflow) => {
self.input.set_position(pos);
}
Err(e) => return Err(Error::msg(e)),
},
TypeInner::Nat => match leb128::read::unsigned(&mut self.input) {
Ok(value) => {
self.add_cost((self.input.position() - pos) as usize)?;
return visitor.visit_u64(value);
}
Err(leb128::read::Error::Overflow) => {
self.input.set_position(pos);
}
Err(e) => return Err(Error::msg(e)),
},
t => return Err(Error::subtype(format!("{t} cannot be deserialized to int"))),
}
}
let mut bytes = vec![0u8];
let int = match self.wire_type.as_ref() {
TypeInner::Int => Int::decode(&mut self.input).map_err(Error::msg)?,
TypeInner::Nat => Int(Nat::decode(&mut self.input).map_err(Error::msg)?.0.into()),
Expand All @@ -461,13 +492,27 @@ impl<'de> Deserializer<'de> {
where
V: Visitor<'de>,
{
self.unroll_type()?;
check!(
*self.expect_type == TypeInner::Nat && *self.wire_type == TypeInner::Nat,
"nat"
);
let mut bytes = vec![1u8];
if self.bignum_vec_fast_path.is_none() {
self.unroll_type()?;
check!(
*self.expect_type == TypeInner::Nat && *self.wire_type == TypeInner::Nat,
"nat"
);
}
let pos = self.input.position();
if !self.is_untyped {
match leb128::read::unsigned(&mut self.input) {
Ok(value) => {
self.add_cost((self.input.position() - pos) as usize)?;
return visitor.visit_u64(value);
}
Err(leb128::read::Error::Overflow) => {
self.input.set_position(pos);
}
Err(e) => return Err(Error::msg(e)),
}
}
let mut bytes = vec![1u8];
let nat = Nat::decode(&mut self.input).map_err(Error::msg)?;
self.add_cost((self.input.position() - pos) as usize)?;
bytes.extend_from_slice(&nat.0.to_bytes_le());
Expand Down Expand Up @@ -642,6 +687,14 @@ fn primitive_byte_cost(p: PrimitiveType) -> usize {
}
}

#[cfg(feature = "bignum")]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum BigNumFastPath {
Nat,
Int,
NatAsInt,
}

fn exact_primitive_type(expect: &Type, wire: &Type) -> Option<PrimitiveType> {
match (expect.as_ref(), wire.as_ref()) {
(TypeInner::Bool, TypeInner::Bool) => Some(PrimitiveType::Bool),
Expand Down Expand Up @@ -687,6 +740,13 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
if self.field_name.is_some() {
return self.deserialize_identifier(visitor);
}
#[cfg(feature = "bignum")]
if let Some(fast) = self.bignum_vec_fast_path {
return match fast {
BigNumFastPath::Nat => self.deserialize_nat(visitor),
BigNumFastPath::Int | BigNumFastPath::NatAsInt => self.deserialize_int(visitor),
};
}
self.unroll_type()?;
match self.expect_type.as_ref() {
#[cfg(feature = "bignum")]
Expand Down Expand Up @@ -870,16 +930,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
where
V: Visitor<'de>,
{
self.unroll_type()?;
check!(
*self.expect_type == TypeInner::Text && *self.wire_type == TypeInner::Text,
"text"
);
let len = Len::read(&mut self.input)?.0;
self.add_cost(len.saturating_add(1))?;
let bytes = self.borrow_bytes(len)?.to_owned();
let value = String::from_utf8(bytes).map_err(Error::msg)?;
visitor.visit_string(value)
self.deserialize_str(visitor)
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where
Expand Down Expand Up @@ -957,7 +1008,61 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
len.checked_mul(per_element_cost)
.ok_or_else(|| Error::msg("Vec length overflow"))?,
)?;
self.primitive_vec_fast_path = exact_primitive;

#[cfg(target_endian = "little")]
{
let byte_size = primitive_byte_cost(prim);
let total_bytes = len
.checked_mul(byte_size)
.ok_or_else(|| Error::msg("Vec byte length overflow"))?;
let pos = self.input.position() as usize;
let slice = self.input.get_ref();
if pos + total_bytes > slice.len() {
return Err(Error::msg(format!(
"Not enough bytes for primitive vec: need {total_bytes}, have {}",
slice.len() - pos
)));
}
let data = &slice[pos..pos + total_bytes];
let mut access = PrimitiveVecAccess {
data,
offset: 0,
remaining: len,
element_size: byte_size,
prim,
};
let result = visitor.visit_seq(&mut access);
// Advance by bytes actually consumed, not total_bytes, so
// the cursor is correct if the visitor short-circuits.
self.input.set_position((pos + access.offset) as u64);
return result;
}

#[cfg(not(target_endian = "little"))]
{
self.primitive_vec_fast_path = exact_primitive;
}
}
#[cfg(feature = "bignum")]
let bignum_fast = if exact_primitive.is_none() {
match (expect.as_ref(), wire.as_ref()) {
(TypeInner::Nat, TypeInner::Nat) => Some(BigNumFastPath::Nat),
(TypeInner::Int, TypeInner::Int) => Some(BigNumFastPath::Int),
(TypeInner::Int, TypeInner::Nat) => Some(BigNumFastPath::NatAsInt),
_ => None,
}
} else {
None
};
#[cfg(feature = "bignum")]
if let Some(fast) = bignum_fast {
self.add_cost(
len.checked_mul(3)
.ok_or_else(|| Error::msg("Vec length overflow"))?,
)?;
self.bignum_vec_fast_path = Some(fast);
self.expect_type = expect.clone();
self.wire_type = wire.clone();
}
let result = visitor.visit_seq(Compound::new(
self,
Expand All @@ -968,9 +1073,6 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
exact_primitive,
},
));
if exact_primitive.is_some() {
self.primitive_vec_fast_path = None;
}
result
}
(TypeInner::Record(_), TypeInner::Record(_)) => {
Expand Down Expand Up @@ -1192,6 +1294,81 @@ enum Style {
},
}

#[cfg(target_endian = "little")]
struct PrimitiveVecAccess<'de> {
data: &'de [u8],
offset: usize,
remaining: usize,
element_size: usize,
prim: PrimitiveType,
}

#[cfg(target_endian = "little")]
impl<'de> de::SeqAccess<'de> for PrimitiveVecAccess<'de> {
type Error = Error;

fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: de::DeserializeSeed<'de>,
{
use serde::de::IntoDeserializer;
if self.remaining == 0 {
return Ok(None);
}
self.remaining -= 1;
let bytes = &self.data[self.offset..self.offset + self.element_size];
self.offset += self.element_size;

match self.prim {
PrimitiveType::Bool => match bytes[0] {
0 => seed.deserialize(false.into_deserializer()).map(Some),
1 => seed.deserialize(true.into_deserializer()).map(Some),
_ => Err(Error::msg("Expect 00 or 01")),
},
PrimitiveType::Nat8 => seed.deserialize(bytes[0].into_deserializer()).map(Some),
PrimitiveType::Int8 => seed
.deserialize((bytes[0] as i8).into_deserializer())
.map(Some),
PrimitiveType::Nat16 => {
let v = u16::from_le_bytes(bytes.try_into().unwrap());
seed.deserialize(v.into_deserializer()).map(Some)
}
PrimitiveType::Int16 => {
let v = i16::from_le_bytes(bytes.try_into().unwrap());
seed.deserialize(v.into_deserializer()).map(Some)
}
PrimitiveType::Nat32 => {
let v = u32::from_le_bytes(bytes.try_into().unwrap());
seed.deserialize(v.into_deserializer()).map(Some)
}
PrimitiveType::Int32 => {
let v = i32::from_le_bytes(bytes.try_into().unwrap());
seed.deserialize(v.into_deserializer()).map(Some)
}
PrimitiveType::Float32 => {
let v = f32::from_le_bytes(bytes.try_into().unwrap());
seed.deserialize(v.into_deserializer()).map(Some)
}
PrimitiveType::Nat64 => {
let v = u64::from_le_bytes(bytes.try_into().unwrap());
seed.deserialize(v.into_deserializer()).map(Some)
}
PrimitiveType::Int64 => {
let v = i64::from_le_bytes(bytes.try_into().unwrap());
seed.deserialize(v.into_deserializer()).map(Some)
}
PrimitiveType::Float64 => {
let v = f64::from_le_bytes(bytes.try_into().unwrap());
seed.deserialize(v.into_deserializer()).map(Some)
}
}
}

fn size_hint(&self) -> Option<usize> {
Some(self.remaining)
}
}

struct Compound<'a, 'de> {
de: &'a mut Deserializer<'de>,
style: Style,
Expand Down Expand Up @@ -1243,7 +1420,11 @@ impl<'de> de::SeqAccess<'de> for Compound<'_, 'de> {
*len -= 1;
self.de.expect_type = expect.clone();
self.de.wire_type = wire.clone();
if exact_primitive.is_none() {
#[cfg(feature = "bignum")]
let is_fast = exact_primitive.is_some() || self.de.bignum_vec_fast_path.is_some();
#[cfg(not(feature = "bignum"))]
let is_fast = exact_primitive.is_some();
if !is_fast {
self.de.add_cost(3)?;
}
seed.deserialize(&mut *self.de).map(Some)
Expand Down Expand Up @@ -1299,6 +1480,10 @@ impl Drop for Compound<'_, '_> {
// Reset fast-path state so it cannot leak if this Compound is dropped
// before all elements are consumed (e.g., on an error path).
self.de.primitive_vec_fast_path = None;
#[cfg(feature = "bignum")]
{
self.de.bignum_vec_fast_path = None;
}
}
}

Expand Down
Loading