diff --git a/mapf/Cargo.toml b/mapf/Cargo.toml index b14771c..6c145ec 100644 --- a/mapf/Cargo.toml +++ b/mapf/Cargo.toml @@ -32,3 +32,6 @@ slotmap = "1.0" [dev-dependencies] approx = "*" + +[features] +serde = ["nalgebra/serde-serialize", "time-point/serde"] \ No newline at end of file diff --git a/mapf/src/motion/se2/mod.rs b/mapf/src/motion/se2/mod.rs index a2cc0ba..a904c34 100644 --- a/mapf/src/motion/se2/mod.rs +++ b/mapf/src/motion/se2/mod.rs @@ -37,8 +37,8 @@ impl Velocity { } pub mod timed_position; -pub use timed_position::*; +pub use timed_position::*; pub mod space; pub use space::*; @@ -47,8 +47,78 @@ pub use oriented::*; pub type LinearTrajectorySE2 = super::Trajectory; +#[cfg(feature = "serde")] +use serde::de::{Deserializer, Error, SeqAccess, Visitor}; +#[cfg(feature = "serde")] +use serde::ser::{Serialize, SerializeSeq, Serializer}; + +#[cfg(feature = "serde")] +impl serde::Serialize for LinearTrajectorySE2 { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_seq(Some(self.len()))?; + for e in self.iter() { + seq.serialize_element(&e)?; + } + seq.end() + } +} + +#[cfg(feature = "serde")] +struct SE2TrajectoryVisitor; + +#[cfg(feature = "serde")] +impl<'de> Visitor<'de> for SE2TrajectoryVisitor { + type Value = LinearTrajectorySE2; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("SE2 key value sequence.") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let start = if let Some(wp) = seq.next_element::()? { + Some(wp) + } else { + None + }; + + let second = if let Some(wp) = seq.next_element::()? { + Some(wp) + } else { + None + }; + + if start.is_some() && second.is_some() { + if let Ok(mut traj) = Trajectory::new(start.unwrap(), second.unwrap()) { + while let Some(wp) = seq.next_element::()? { + traj.insert(wp); + } + return Ok(traj); + } + } + return Err(A::Error::custom("Trajectory needs at least 2 points")); + } +} + +#[cfg(feature = "serde")] +impl<'de> serde::Deserialize<'de> for LinearTrajectorySE2 { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_seq(SE2TrajectoryVisitor) + } +} + pub mod quickest_path; pub use quickest_path::{QuickestPathHeuristic, QuickestPathPlanner}; pub mod differential_drive_line_follow; pub use differential_drive_line_follow::*; + +use super::Trajectory; diff --git a/mapf/src/motion/se2/timed_position.rs b/mapf/src/motion/se2/timed_position.rs index 9042a3a..73c0781 100644 --- a/mapf/src/motion/se2/timed_position.rs +++ b/mapf/src/motion/se2/timed_position.rs @@ -28,6 +28,7 @@ use crate::{ use arrayvec::ArrayVec; #[derive(Clone, Copy, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct WaypointSE2 { pub time: TimePoint, pub position: Position,