diff --git a/spine/data/out/particle.py b/spine/data/out/particle.py index 4169697d..b81fd786 100644 --- a/spine/data/out/particle.py +++ b/spine/data/out/particle.py @@ -196,6 +196,13 @@ def p(self): def p(self, p): pass + def unmatch(self): + """ + Unmatch the particle from its reco or truth particle match. + """ + self.match_ids = [] + self.is_matched = False + @dataclass(eq=False) @inherit_docstring(RecoBase, ParticleBase) diff --git a/spine/post/reco/cathode_cross.py b/spine/post/reco/cathode_cross.py index c0d3c2e3..4536ec89 100644 --- a/spine/post/reco/cathode_cross.py +++ b/spine/post/reco/cathode_cross.py @@ -5,6 +5,7 @@ from spine.data import RecoInteraction, TruthInteraction from spine.math.distance import cdist, farthest_pair +from scipy.spatial.distance import cdist as scipy_cdist from spine.utils.globals import TRACK_SHP from spine.utils.geo import Geometry @@ -81,7 +82,7 @@ def __init__(self, crossing_point_tolerance, offset_tolerance, keys['points'] = True if run_mode != 'reco': keys[truth_point_mode] = True - + keys['meta'] = True #Needed to find shift in the cathode self.update_keys(keys) def process(self, data): @@ -92,6 +93,8 @@ def process(self, data): data : dict Dictionary of data products """ + #Get the drift pixel resolution + dx_res = data['meta'].size[0] # Loop over particle types update_dict = {} for part_key in self.particle_keys: @@ -132,7 +135,7 @@ def process(self, data): if (part.is_cathode_crosser and self.adjust_crossers and len(tpcs) == 2): # Adjust positions - self.adjust_positions(data, i) + self.adjust_positions(data, i,dx_res) # If we do not want to merge broken crossers, our job here is done if not self.merge_crossers: @@ -188,7 +191,7 @@ def process(self, data): # Check if the two particles stop at roughly the same # position in the plane of the cathode compat = True - dist_mat = cdist( + dist_mat = scipy_cdist( end_points_i[:, caxes], end_points_j[:, caxes]) argmin = np.argmin(dist_mat) pair_i, pair_j = np.unravel_index(argmin, (2, 2)) @@ -210,7 +213,7 @@ def process(self, data): # If compatible, merge if compat: # Merge particle and adjust positions - self.adjust_positions(data, ci, cj, truth=pi.is_truth) + self.adjust_positions(data, ci,dx_res, cj, truth=pi.is_truth) # Update the candidate list to remove matched particle candidate_ids[j:-1] = candidate_ids[j+1:] - 1 @@ -243,7 +246,7 @@ def process(self, data): return update_dict - def adjust_positions(self, data, idx_i, idx_j=None, truth=False): + def adjust_positions(self, data, idx_i,dx_res, idx_j=None, truth=False): """Given a cathode crosser (either in one or two pieces), apply the necessary position offsets to match it at the cathode. @@ -253,11 +256,12 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False): Dictionary of data products idx_i : int Index of a cathode crosser (or a cathode crosser fragment) + dx_res : float + Drift pixel resolution [cm]. Offset the drift position by this amount. idx_j : int, optional Index of a matched cathode crosser fragment truth : bool, default False If True, adjust truth object positions - Results ------- np.ndarray @@ -270,6 +274,10 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False): points_key = 'points' if not truth else self.truth_point_key particles = data[part_key] if idx_j is not None: + # Unmatch the particles from their interactions + particles[idx_i].unmatch() + particles[idx_j].unmatch() + # Merge particles int_id_i = particles[idx_i].interaction_id int_id_j = particles[idx_j].interaction_id @@ -321,18 +329,18 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False): continue # Update the sister position and the main position tensor - self.get_points(sister)[tpc_index, daxis] -= offsets[i] - data[points_key][index, daxis] -= offsets[i] + self.get_points(sister)[tpc_index, daxis] -= offsets[i] + dx_res + data[points_key][index, daxis] -= offsets[i] + dx_res # Update the start/end points appropriately if sister.id == idx_i: for attr, closest_tpc in closest_tpcs.items(): if closest_tpc == t: - getattr(sister, attr)[daxis] -= offsets[i] + getattr(sister, attr)[daxis] -= offsets[i] + dx_res else: - sister.start_point[daxis] -= offsets[i] - sister.end_point[daxis] -= offsets[i] + sister.start_point[daxis] -= offsets[i] + dx_res + sister.end_point[daxis] -= offsets[i] + dx_res # Store crosser information particle.is_cathode_crosser = True