Skip to content

Cathode crosser PR: Bug fixes for cdist and already matched particles #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions spine/data/out/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ def p(self):
def p(self, p):
pass

def unmatch(self):
"""
Unmatch the particle from its reco or truth particle match.
"""
self.match_ids = []
self.is_matched = False


@dataclass(eq=False)
@inherit_docstring(RecoBase, ParticleBase)
Expand Down
30 changes: 19 additions & 11 deletions spine/post/reco/cathode_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from spine.data import RecoInteraction, TruthInteraction

from spine.math.distance import cdist, farthest_pair
from scipy.spatial.distance import cdist as scipy_cdist

from spine.utils.globals import TRACK_SHP
from spine.utils.geo import Geometry
Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(self, crossing_point_tolerance, offset_tolerance,
keys['points'] = True
if run_mode != 'reco':
keys[truth_point_mode] = True

keys['meta'] = True #Needed to find shift in the cathode
self.update_keys(keys)

def process(self, data):
Expand All @@ -92,6 +93,8 @@ def process(self, data):
data : dict
Dictionary of data products
"""
#Get the drift pixel resolution
dx_res = data['meta'].size[0]
# Loop over particle types
update_dict = {}
for part_key in self.particle_keys:
Expand Down Expand Up @@ -132,7 +135,7 @@ def process(self, data):
if (part.is_cathode_crosser and self.adjust_crossers and
len(tpcs) == 2):
# Adjust positions
self.adjust_positions(data, i)
self.adjust_positions(data, i,dx_res)

# If we do not want to merge broken crossers, our job here is done
if not self.merge_crossers:
Expand Down Expand Up @@ -188,7 +191,7 @@ def process(self, data):
# Check if the two particles stop at roughly the same
# position in the plane of the cathode
compat = True
dist_mat = cdist(
dist_mat = scipy_cdist(
end_points_i[:, caxes], end_points_j[:, caxes])
argmin = np.argmin(dist_mat)
pair_i, pair_j = np.unravel_index(argmin, (2, 2))
Expand All @@ -210,7 +213,7 @@ def process(self, data):
# If compatible, merge
if compat:
# Merge particle and adjust positions
self.adjust_positions(data, ci, cj, truth=pi.is_truth)
self.adjust_positions(data, ci,dx_res, cj, truth=pi.is_truth)

# Update the candidate list to remove matched particle
candidate_ids[j:-1] = candidate_ids[j+1:] - 1
Expand Down Expand Up @@ -243,7 +246,7 @@ def process(self, data):

return update_dict

def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
def adjust_positions(self, data, idx_i,dx_res, idx_j=None, truth=False):
"""Given a cathode crosser (either in one or two pieces), apply the
necessary position offsets to match it at the cathode.

Expand All @@ -253,11 +256,12 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
Dictionary of data products
idx_i : int
Index of a cathode crosser (or a cathode crosser fragment)
dx_res : float
Drift pixel resolution [cm]. Offset the drift position by this amount.
idx_j : int, optional
Index of a matched cathode crosser fragment
truth : bool, default False
If True, adjust truth object positions

Results
-------
np.ndarray
Expand All @@ -270,6 +274,10 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
points_key = 'points' if not truth else self.truth_point_key
particles = data[part_key]
if idx_j is not None:
# Unmatch the particles from their interactions
particles[idx_i].unmatch()
particles[idx_j].unmatch()

# Merge particles
int_id_i = particles[idx_i].interaction_id
int_id_j = particles[idx_j].interaction_id
Expand Down Expand Up @@ -321,18 +329,18 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
continue

# Update the sister position and the main position tensor
self.get_points(sister)[tpc_index, daxis] -= offsets[i]
data[points_key][index, daxis] -= offsets[i]
self.get_points(sister)[tpc_index, daxis] -= offsets[i] + dx_res
data[points_key][index, daxis] -= offsets[i] + dx_res

# Update the start/end points appropriately
if sister.id == idx_i:
for attr, closest_tpc in closest_tpcs.items():
if closest_tpc == t:
getattr(sister, attr)[daxis] -= offsets[i]
getattr(sister, attr)[daxis] -= offsets[i] + dx_res

else:
sister.start_point[daxis] -= offsets[i]
sister.end_point[daxis] -= offsets[i]
sister.start_point[daxis] -= offsets[i] + dx_res
sister.end_point[daxis] -= offsets[i] + dx_res

# Store crosser information
particle.is_cathode_crosser = True
Expand Down
Loading