Skip to content

Commit e495c00

Browse files
Merge pull request #18 from bear-is-asleep/feature/cathode_crosser
Cathode crosser PR: Bug fixes for cdist and already matched particles
2 parents 682c8c5 + 6e23ba1 commit e495c00

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

spine/data/out/particle.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ def p(self):
196196
def p(self, p):
197197
pass
198198

199+
def unmatch(self):
200+
"""
201+
Unmatch the particle from its reco or truth particle match.
202+
"""
203+
self.match_ids = []
204+
self.is_matched = False
205+
199206

200207
@dataclass(eq=False)
201208
@inherit_docstring(RecoBase, ParticleBase)

spine/post/reco/cathode_cross.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from spine.data import RecoInteraction, TruthInteraction
66

77
from spine.math.distance import cdist, farthest_pair
8+
from scipy.spatial.distance import cdist as scipy_cdist
89

910
from spine.utils.globals import TRACK_SHP
1011
from spine.utils.geo import Geometry
@@ -81,7 +82,7 @@ def __init__(self, crossing_point_tolerance, offset_tolerance,
8182
keys['points'] = True
8283
if run_mode != 'reco':
8384
keys[truth_point_mode] = True
84-
85+
keys['meta'] = True #Needed to find shift in the cathode
8586
self.update_keys(keys)
8687

8788
def process(self, data):
@@ -92,6 +93,8 @@ def process(self, data):
9293
data : dict
9394
Dictionary of data products
9495
"""
96+
#Get the drift pixel resolution
97+
dx_res = data['meta'].size[0]
9598
# Loop over particle types
9699
update_dict = {}
97100
for part_key in self.particle_keys:
@@ -132,7 +135,7 @@ def process(self, data):
132135
if (part.is_cathode_crosser and self.adjust_crossers and
133136
len(tpcs) == 2):
134137
# Adjust positions
135-
self.adjust_positions(data, i)
138+
self.adjust_positions(data, i,dx_res)
136139

137140
# If we do not want to merge broken crossers, our job here is done
138141
if not self.merge_crossers:
@@ -188,7 +191,7 @@ def process(self, data):
188191
# Check if the two particles stop at roughly the same
189192
# position in the plane of the cathode
190193
compat = True
191-
dist_mat = cdist(
194+
dist_mat = scipy_cdist(
192195
end_points_i[:, caxes], end_points_j[:, caxes])
193196
argmin = np.argmin(dist_mat)
194197
pair_i, pair_j = np.unravel_index(argmin, (2, 2))
@@ -210,7 +213,7 @@ def process(self, data):
210213
# If compatible, merge
211214
if compat:
212215
# Merge particle and adjust positions
213-
self.adjust_positions(data, ci, cj, truth=pi.is_truth)
216+
self.adjust_positions(data, ci,dx_res, cj, truth=pi.is_truth)
214217

215218
# Update the candidate list to remove matched particle
216219
candidate_ids[j:-1] = candidate_ids[j+1:] - 1
@@ -243,7 +246,7 @@ def process(self, data):
243246

244247
return update_dict
245248

246-
def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
249+
def adjust_positions(self, data, idx_i,dx_res, idx_j=None, truth=False):
247250
"""Given a cathode crosser (either in one or two pieces), apply the
248251
necessary position offsets to match it at the cathode.
249252
@@ -253,11 +256,12 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
253256
Dictionary of data products
254257
idx_i : int
255258
Index of a cathode crosser (or a cathode crosser fragment)
259+
dx_res : float
260+
Drift pixel resolution [cm]. Offset the drift position by this amount.
256261
idx_j : int, optional
257262
Index of a matched cathode crosser fragment
258263
truth : bool, default False
259264
If True, adjust truth object positions
260-
261265
Results
262266
-------
263267
np.ndarray
@@ -270,6 +274,10 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
270274
points_key = 'points' if not truth else self.truth_point_key
271275
particles = data[part_key]
272276
if idx_j is not None:
277+
# Unmatch the particles from their interactions
278+
particles[idx_i].unmatch()
279+
particles[idx_j].unmatch()
280+
273281
# Merge particles
274282
int_id_i = particles[idx_i].interaction_id
275283
int_id_j = particles[idx_j].interaction_id
@@ -321,18 +329,18 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
321329
continue
322330

323331
# Update the sister position and the main position tensor
324-
self.get_points(sister)[tpc_index, daxis] -= offsets[i]
325-
data[points_key][index, daxis] -= offsets[i]
332+
self.get_points(sister)[tpc_index, daxis] -= offsets[i] + dx_res
333+
data[points_key][index, daxis] -= offsets[i] + dx_res
326334

327335
# Update the start/end points appropriately
328336
if sister.id == idx_i:
329337
for attr, closest_tpc in closest_tpcs.items():
330338
if closest_tpc == t:
331-
getattr(sister, attr)[daxis] -= offsets[i]
339+
getattr(sister, attr)[daxis] -= offsets[i] + dx_res
332340

333341
else:
334-
sister.start_point[daxis] -= offsets[i]
335-
sister.end_point[daxis] -= offsets[i]
342+
sister.start_point[daxis] -= offsets[i] + dx_res
343+
sister.end_point[daxis] -= offsets[i] + dx_res
336344

337345
# Store crosser information
338346
particle.is_cathode_crosser = True

0 commit comments

Comments
 (0)