5
5
from spine .data import RecoInteraction , TruthInteraction
6
6
7
7
from spine .math .distance import cdist , farthest_pair
8
+ from scipy .spatial .distance import cdist as scipy_cdist
8
9
9
10
from spine .utils .globals import TRACK_SHP
10
11
from spine .utils .geo import Geometry
@@ -81,7 +82,7 @@ def __init__(self, crossing_point_tolerance, offset_tolerance,
81
82
keys ['points' ] = True
82
83
if run_mode != 'reco' :
83
84
keys [truth_point_mode ] = True
84
-
85
+ keys [ 'meta' ] = True #Needed to find shift in the cathode
85
86
self .update_keys (keys )
86
87
87
88
def process (self , data ):
@@ -92,6 +93,8 @@ def process(self, data):
92
93
data : dict
93
94
Dictionary of data products
94
95
"""
96
+ #Get the drift pixel resolution
97
+ dx_res = data ['meta' ].size [0 ]
95
98
# Loop over particle types
96
99
update_dict = {}
97
100
for part_key in self .particle_keys :
@@ -132,7 +135,7 @@ def process(self, data):
132
135
if (part .is_cathode_crosser and self .adjust_crossers and
133
136
len (tpcs ) == 2 ):
134
137
# Adjust positions
135
- self .adjust_positions (data , i )
138
+ self .adjust_positions (data , i , dx_res )
136
139
137
140
# If we do not want to merge broken crossers, our job here is done
138
141
if not self .merge_crossers :
@@ -188,7 +191,7 @@ def process(self, data):
188
191
# Check if the two particles stop at roughly the same
189
192
# position in the plane of the cathode
190
193
compat = True
191
- dist_mat = cdist (
194
+ dist_mat = scipy_cdist (
192
195
end_points_i [:, caxes ], end_points_j [:, caxes ])
193
196
argmin = np .argmin (dist_mat )
194
197
pair_i , pair_j = np .unravel_index (argmin , (2 , 2 ))
@@ -210,7 +213,7 @@ def process(self, data):
210
213
# If compatible, merge
211
214
if compat :
212
215
# Merge particle and adjust positions
213
- self .adjust_positions (data , ci , cj , truth = pi .is_truth )
216
+ self .adjust_positions (data , ci ,dx_res , cj , truth = pi .is_truth )
214
217
215
218
# Update the candidate list to remove matched particle
216
219
candidate_ids [j :- 1 ] = candidate_ids [j + 1 :] - 1
@@ -243,7 +246,7 @@ def process(self, data):
243
246
244
247
return update_dict
245
248
246
- def adjust_positions (self , data , idx_i , idx_j = None , truth = False ):
249
+ def adjust_positions (self , data , idx_i ,dx_res , idx_j = None , truth = False ):
247
250
"""Given a cathode crosser (either in one or two pieces), apply the
248
251
necessary position offsets to match it at the cathode.
249
252
@@ -253,11 +256,12 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
253
256
Dictionary of data products
254
257
idx_i : int
255
258
Index of a cathode crosser (or a cathode crosser fragment)
259
+ dx_res : float
260
+ Drift pixel resolution [cm]. Offset the drift position by this amount.
256
261
idx_j : int, optional
257
262
Index of a matched cathode crosser fragment
258
263
truth : bool, default False
259
264
If True, adjust truth object positions
260
-
261
265
Results
262
266
-------
263
267
np.ndarray
@@ -270,6 +274,10 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
270
274
points_key = 'points' if not truth else self .truth_point_key
271
275
particles = data [part_key ]
272
276
if idx_j is not None :
277
+ # Unmatch the particles from their interactions
278
+ particles [idx_i ].unmatch ()
279
+ particles [idx_j ].unmatch ()
280
+
273
281
# Merge particles
274
282
int_id_i = particles [idx_i ].interaction_id
275
283
int_id_j = particles [idx_j ].interaction_id
@@ -321,18 +329,18 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
321
329
continue
322
330
323
331
# Update the sister position and the main position tensor
324
- self .get_points (sister )[tpc_index , daxis ] -= offsets [i ]
325
- data [points_key ][index , daxis ] -= offsets [i ]
332
+ self .get_points (sister )[tpc_index , daxis ] -= offsets [i ] + dx_res
333
+ data [points_key ][index , daxis ] -= offsets [i ] + dx_res
326
334
327
335
# Update the start/end points appropriately
328
336
if sister .id == idx_i :
329
337
for attr , closest_tpc in closest_tpcs .items ():
330
338
if closest_tpc == t :
331
- getattr (sister , attr )[daxis ] -= offsets [i ]
339
+ getattr (sister , attr )[daxis ] -= offsets [i ] + dx_res
332
340
333
341
else :
334
- sister .start_point [daxis ] -= offsets [i ]
335
- sister .end_point [daxis ] -= offsets [i ]
342
+ sister .start_point [daxis ] -= offsets [i ] + dx_res
343
+ sister .end_point [daxis ] -= offsets [i ] + dx_res
336
344
337
345
# Store crosser information
338
346
particle .is_cathode_crosser = True
0 commit comments