diff --git a/crates/bevy_ecs/src/relationship/related_methods.rs b/crates/bevy_ecs/src/relationship/related_methods.rs index fc6d1f1183cea..1983b6b37c11b 100644 --- a/crates/bevy_ecs/src/relationship/related_methods.rs +++ b/crates/bevy_ecs/src/relationship/related_methods.rs @@ -139,22 +139,26 @@ impl<'w> EntityWorldMut<'w> { return self; } - let Some(mut existing_relations) = self.get_mut::() else { + let Some(existing_relations) = self.get_mut::() else { return self.add_related::(related); }; - // We take the collection here so we can modify it without taking the component itself (this would create archetype move). + // We replace the component here with a dummy value so we can modify it without taking it (this would create archetype move). // SAFETY: We eventually return the correctly initialized collection into the target. - let mut existing_relations = mem::replace( - existing_relations.collection_mut_risky(), - Collection::::with_capacity(0), + let mut relations = mem::replace( + existing_relations.into_inner(), + ::RelationshipTarget::from_collection_risky( + Collection::::with_capacity(0), + ), ); + let collection = relations.collection_mut_risky(); + let mut potential_relations = EntityHashSet::from_iter(related.iter().copied()); let id = self.id(); self.world_scope(|world| { - for related in existing_relations.iter() { + for related in collection.iter() { if !potential_relations.remove(related) { world.entity_mut(related).remove::(); } @@ -169,11 +173,9 @@ impl<'w> EntityWorldMut<'w> { }); // SAFETY: The entities we're inserting will be the entities that were either already there or entities that we've just inserted. - existing_relations.clear(); - existing_relations.extend_from_iter(related.iter().copied()); - self.insert(R::RelationshipTarget::from_collection_risky( - existing_relations, - )); + collection.clear(); + collection.extend_from_iter(related.iter().copied()); + self.insert(relations); self } @@ -239,11 +241,20 @@ impl<'w> EntityWorldMut<'w> { assert_eq!(newly_related_entities, entities_to_relate, "`entities_to_relate` ({entities_to_relate:?}) didn't contain all entities that would end up related"); }; - if !self.contains::() { - self.add_related::(entities_to_relate); + match self.get_mut::() { + None => { + self.add_related::(entities_to_relate); - return self; - }; + return self; + } + Some(mut target) => { + // SAFETY: The invariants expected by this function mean we'll only be inserting entities that are already related. + let collection = target.collection_mut_risky(); + collection.clear(); + + collection.extend_from_iter(entities_to_relate.iter().copied()); + } + } let this = self.id(); self.world_scope(|world| { @@ -252,32 +263,13 @@ impl<'w> EntityWorldMut<'w> { } for new_relation in newly_related_entities { - // We're changing the target collection manually so don't run the insert hook + // We changed the target collection manually so don't run the insert hook world .entity_mut(*new_relation) .insert_with_relationship_hook_mode(R::from(this), RelationshipHookMode::Skip); } }); - if !entities_to_relate.is_empty() { - if let Some(mut target) = self.get_mut::() { - // SAFETY: The invariants expected by this function mean we'll only be inserting entities that are already related. - let collection = target.collection_mut_risky(); - collection.clear(); - - collection.extend_from_iter(entities_to_relate.iter().copied()); - } else { - let mut empty = - ::Collection::with_capacity( - entities_to_relate.len(), - ); - empty.extend_from_iter(entities_to_relate.iter().copied()); - - // SAFETY: We've just initialized this collection and we know there's no `RelationshipTarget` on `self` - self.insert(R::RelationshipTarget::from_collection_risky(empty)); - } - } - self } @@ -668,4 +660,61 @@ mod tests { assert_eq!(world.entity(b).get::(), None); assert_eq!(world.entity(c).get::(), None); } + + #[test] + fn replace_related_works() { + let mut world = World::new(); + let child1 = world.spawn_empty().id(); + let child2 = world.spawn_empty().id(); + let child3 = world.spawn_empty().id(); + + let mut parent = world.spawn_empty(); + parent.add_children(&[child1, child2]); + let child_value = ChildOf(parent.id()); + let some_child = Some(&child_value); + + parent.replace_children(&[child2, child3]); + let children = parent.get::().unwrap().collection(); + assert_eq!(children, &[child2, child3]); + assert_eq!(parent.world().get::(child1), None); + assert_eq!(parent.world().get::(child2), some_child); + assert_eq!(parent.world().get::(child3), some_child); + + parent.replace_children_with_difference(&[child3], &[child1, child2], &[child1]); + let children = parent.get::().unwrap().collection(); + assert_eq!(children, &[child1, child2]); + assert_eq!(parent.world().get::(child1), some_child); + assert_eq!(parent.world().get::(child2), some_child); + assert_eq!(parent.world().get::(child3), None); + } + + #[test] + fn replace_related_keeps_data() { + #[derive(Component)] + #[relationship(relationship_target = Parent)] + pub struct Child(Entity); + + #[derive(Component)] + #[relationship_target(relationship = Child)] + pub struct Parent { + #[relationship] + children: Vec, + pub data: u8, + } + + let mut world = World::new(); + let child1 = world.spawn_empty().id(); + let child2 = world.spawn_empty().id(); + let mut parent = world.spawn_empty(); + parent.add_related::(&[child1]); + parent.get_mut::().unwrap().data = 42; + + parent.replace_related_with_difference::(&[child1], &[child2], &[child2]); + let data = parent.get::().unwrap().data; + assert_eq!(data, 42); + + parent.replace_related::(&[child1]); + let data = parent.get::().unwrap().data; + assert_eq!(data, 42); + } }