diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index 341d9ae9..d84491f9 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -303,6 +303,8 @@ struct Deserializer<'de> { config: DecoderConfig, recursion_depth: crate::utils::RecursionDepth, primitive_vec_fast_path: Option, + #[cfg(feature = "bignum")] + bignum_vec_fast_path: Option, } impl<'de> Deserializer<'de> { @@ -322,6 +324,8 @@ impl<'de> Deserializer<'de> { config: config.clone(), recursion_depth: crate::utils::RecursionDepth::new(), primitive_vec_fast_path: None, + #[cfg(feature = "bignum")] + bignum_vec_fast_path: None, }) } fn dump_state(&self) -> String { @@ -442,10 +446,37 @@ impl<'de> Deserializer<'de> { where V: Visitor<'de>, { - self.unroll_type()?; - assert!(*self.expect_type == TypeInner::Int); - let mut bytes = vec![0u8]; + if self.bignum_vec_fast_path.is_none() { + self.unroll_type()?; + assert!(*self.expect_type == TypeInner::Int); + } let pos = self.input.position(); + if !self.is_untyped { + match self.wire_type.as_ref() { + TypeInner::Int => match leb128::read::signed(&mut self.input) { + Ok(value) => { + self.add_cost((self.input.position() - pos) as usize)?; + return visitor.visit_i64(value); + } + Err(leb128::read::Error::Overflow) => { + self.input.set_position(pos); + } + Err(e) => return Err(Error::msg(e)), + }, + TypeInner::Nat => match leb128::read::unsigned(&mut self.input) { + Ok(value) => { + self.add_cost((self.input.position() - pos) as usize)?; + return visitor.visit_u64(value); + } + Err(leb128::read::Error::Overflow) => { + self.input.set_position(pos); + } + Err(e) => return Err(Error::msg(e)), + }, + t => return Err(Error::subtype(format!("{t} cannot be deserialized to int"))), + } + } + let mut bytes = vec![0u8]; let int = match self.wire_type.as_ref() { TypeInner::Int => Int::decode(&mut self.input).map_err(Error::msg)?, TypeInner::Nat => Int(Nat::decode(&mut self.input).map_err(Error::msg)?.0.into()), @@ -461,13 +492,27 @@ impl<'de> Deserializer<'de> { where V: Visitor<'de>, { - self.unroll_type()?; - check!( - *self.expect_type == TypeInner::Nat && *self.wire_type == TypeInner::Nat, - "nat" - ); - let mut bytes = vec![1u8]; + if self.bignum_vec_fast_path.is_none() { + self.unroll_type()?; + check!( + *self.expect_type == TypeInner::Nat && *self.wire_type == TypeInner::Nat, + "nat" + ); + } let pos = self.input.position(); + if !self.is_untyped { + match leb128::read::unsigned(&mut self.input) { + Ok(value) => { + self.add_cost((self.input.position() - pos) as usize)?; + return visitor.visit_u64(value); + } + Err(leb128::read::Error::Overflow) => { + self.input.set_position(pos); + } + Err(e) => return Err(Error::msg(e)), + } + } + let mut bytes = vec![1u8]; let nat = Nat::decode(&mut self.input).map_err(Error::msg)?; self.add_cost((self.input.position() - pos) as usize)?; bytes.extend_from_slice(&nat.0.to_bytes_le()); @@ -642,6 +687,14 @@ fn primitive_byte_cost(p: PrimitiveType) -> usize { } } +#[cfg(feature = "bignum")] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +enum BigNumFastPath { + Nat, + Int, + NatAsInt, +} + fn exact_primitive_type(expect: &Type, wire: &Type) -> Option { match (expect.as_ref(), wire.as_ref()) { (TypeInner::Bool, TypeInner::Bool) => Some(PrimitiveType::Bool), @@ -687,6 +740,13 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { if self.field_name.is_some() { return self.deserialize_identifier(visitor); } + #[cfg(feature = "bignum")] + if let Some(fast) = self.bignum_vec_fast_path { + return match fast { + BigNumFastPath::Nat => self.deserialize_nat(visitor), + BigNumFastPath::Int | BigNumFastPath::NatAsInt => self.deserialize_int(visitor), + }; + } self.unroll_type()?; match self.expect_type.as_ref() { #[cfg(feature = "bignum")] @@ -870,16 +930,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { where V: Visitor<'de>, { - self.unroll_type()?; - check!( - *self.expect_type == TypeInner::Text && *self.wire_type == TypeInner::Text, - "text" - ); - let len = Len::read(&mut self.input)?.0; - self.add_cost(len.saturating_add(1))?; - let bytes = self.borrow_bytes(len)?.to_owned(); - let value = String::from_utf8(bytes).map_err(Error::msg)?; - visitor.visit_string(value) + self.deserialize_str(visitor) } fn deserialize_str(self, visitor: V) -> Result where @@ -957,7 +1008,61 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { len.checked_mul(per_element_cost) .ok_or_else(|| Error::msg("Vec length overflow"))?, )?; - self.primitive_vec_fast_path = exact_primitive; + + #[cfg(target_endian = "little")] + { + let byte_size = primitive_byte_cost(prim); + let total_bytes = len + .checked_mul(byte_size) + .ok_or_else(|| Error::msg("Vec byte length overflow"))?; + let pos = self.input.position() as usize; + let slice = self.input.get_ref(); + if pos + total_bytes > slice.len() { + return Err(Error::msg(format!( + "Not enough bytes for primitive vec: need {total_bytes}, have {}", + slice.len() - pos + ))); + } + let data = &slice[pos..pos + total_bytes]; + let mut access = PrimitiveVecAccess { + data, + offset: 0, + remaining: len, + element_size: byte_size, + prim, + }; + let result = visitor.visit_seq(&mut access); + // Advance by bytes actually consumed, not total_bytes, so + // the cursor is correct if the visitor short-circuits. + self.input.set_position((pos + access.offset) as u64); + return result; + } + + #[cfg(not(target_endian = "little"))] + { + self.primitive_vec_fast_path = exact_primitive; + } + } + #[cfg(feature = "bignum")] + let bignum_fast = if exact_primitive.is_none() { + match (expect.as_ref(), wire.as_ref()) { + (TypeInner::Nat, TypeInner::Nat) => Some(BigNumFastPath::Nat), + (TypeInner::Int, TypeInner::Int) => Some(BigNumFastPath::Int), + (TypeInner::Int, TypeInner::Nat) => Some(BigNumFastPath::NatAsInt), + _ => None, + } + } else { + None + }; + #[cfg(feature = "bignum")] + if let Some(fast) = bignum_fast { + self.add_cost( + len.checked_mul(3) + .ok_or_else(|| Error::msg("Vec length overflow"))?, + )?; + self.bignum_vec_fast_path = Some(fast); + self.expect_type = expect.clone(); + self.wire_type = wire.clone(); } let result = visitor.visit_seq(Compound::new( self, @@ -968,9 +1073,6 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { exact_primitive, }, )); - if exact_primitive.is_some() { - self.primitive_vec_fast_path = None; - } result } (TypeInner::Record(_), TypeInner::Record(_)) => { @@ -1192,6 +1294,81 @@ enum Style { }, } +#[cfg(target_endian = "little")] +struct PrimitiveVecAccess<'de> { + data: &'de [u8], + offset: usize, + remaining: usize, + element_size: usize, + prim: PrimitiveType, +} + +#[cfg(target_endian = "little")] +impl<'de> de::SeqAccess<'de> for PrimitiveVecAccess<'de> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: de::DeserializeSeed<'de>, + { + use serde::de::IntoDeserializer; + if self.remaining == 0 { + return Ok(None); + } + self.remaining -= 1; + let bytes = &self.data[self.offset..self.offset + self.element_size]; + self.offset += self.element_size; + + match self.prim { + PrimitiveType::Bool => match bytes[0] { + 0 => seed.deserialize(false.into_deserializer()).map(Some), + 1 => seed.deserialize(true.into_deserializer()).map(Some), + _ => Err(Error::msg("Expect 00 or 01")), + }, + PrimitiveType::Nat8 => seed.deserialize(bytes[0].into_deserializer()).map(Some), + PrimitiveType::Int8 => seed + .deserialize((bytes[0] as i8).into_deserializer()) + .map(Some), + PrimitiveType::Nat16 => { + let v = u16::from_le_bytes(bytes.try_into().unwrap()); + seed.deserialize(v.into_deserializer()).map(Some) + } + PrimitiveType::Int16 => { + let v = i16::from_le_bytes(bytes.try_into().unwrap()); + seed.deserialize(v.into_deserializer()).map(Some) + } + PrimitiveType::Nat32 => { + let v = u32::from_le_bytes(bytes.try_into().unwrap()); + seed.deserialize(v.into_deserializer()).map(Some) + } + PrimitiveType::Int32 => { + let v = i32::from_le_bytes(bytes.try_into().unwrap()); + seed.deserialize(v.into_deserializer()).map(Some) + } + PrimitiveType::Float32 => { + let v = f32::from_le_bytes(bytes.try_into().unwrap()); + seed.deserialize(v.into_deserializer()).map(Some) + } + PrimitiveType::Nat64 => { + let v = u64::from_le_bytes(bytes.try_into().unwrap()); + seed.deserialize(v.into_deserializer()).map(Some) + } + PrimitiveType::Int64 => { + let v = i64::from_le_bytes(bytes.try_into().unwrap()); + seed.deserialize(v.into_deserializer()).map(Some) + } + PrimitiveType::Float64 => { + let v = f64::from_le_bytes(bytes.try_into().unwrap()); + seed.deserialize(v.into_deserializer()).map(Some) + } + } + } + + fn size_hint(&self) -> Option { + Some(self.remaining) + } +} + struct Compound<'a, 'de> { de: &'a mut Deserializer<'de>, style: Style, @@ -1243,7 +1420,11 @@ impl<'de> de::SeqAccess<'de> for Compound<'_, 'de> { *len -= 1; self.de.expect_type = expect.clone(); self.de.wire_type = wire.clone(); - if exact_primitive.is_none() { + #[cfg(feature = "bignum")] + let is_fast = exact_primitive.is_some() || self.de.bignum_vec_fast_path.is_some(); + #[cfg(not(feature = "bignum"))] + let is_fast = exact_primitive.is_some(); + if !is_fast { self.de.add_cost(3)?; } seed.deserialize(&mut *self.de).map(Some) @@ -1299,6 +1480,10 @@ impl Drop for Compound<'_, '_> { // Reset fast-path state so it cannot leak if this Compound is dropped // before all elements are consumed (e.g., on an error path). self.de.primitive_vec_fast_path = None; + #[cfg(feature = "bignum")] + { + self.de.bignum_vec_fast_path = None; + } } }