Skip to content

Commit ff02db1

Browse files
authored
Merge pull request #434 from lsst/tickets/DM-51266
DM-51266: Adapt C++ glint finding code to python
2 parents 4a8f988 + 2e50151 commit ff02db1

File tree

2 files changed

+518
-0
lines changed

2 files changed

+518
-0
lines changed
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
# This file is part of meas_algorithms.
2+
#
3+
# Developed for the LSST Data Management System.
4+
# This product includes software developed by the LSST Project
5+
# (https://www.lsst.org).
6+
# See the COPYRIGHT file at the top-level directory of this distribution
7+
# for details of code ownership.
8+
#
9+
# This program is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 3 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# This program is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
21+
22+
__all__ = ["FindGlintTrailsConfig", "FindGlintTrailsTask", "GlintTrailParameters"]
23+
24+
import collections
25+
import dataclasses
26+
import math
27+
28+
import numpy as np
29+
import scipy.spatial
30+
import sklearn.linear_model
31+
32+
import lsst.afw.table
33+
import lsst.pex.config
34+
import lsst.pipe.base
35+
36+
37+
class FindGlintTrailsConfig(lsst.pex.config.Config):
38+
radius = lsst.pex.config.Field(
39+
doc="Radius to search for glint trail candidates from each source (pixels).",
40+
dtype=float,
41+
default=500,
42+
)
43+
min_points = lsst.pex.config.Field(
44+
doc="Minimum number of points to be considered a possible glint trail.",
45+
dtype=int,
46+
default=5,
47+
check=lambda x: x >= 3,
48+
)
49+
threshold = lsst.pex.config.Field(
50+
doc="Maximum root mean squared deviation from a straight line (pixels).",
51+
dtype=float,
52+
default=15.0,
53+
)
54+
seed = lsst.pex.config.Field(
55+
doc="Random seed for RANSAC fitter, to ensure stable fitting.",
56+
dtype=int,
57+
default=42,
58+
)
59+
bad_flags = lsst.pex.config.ListField[str](
60+
doc="Do not fit sources that have these flags set.",
61+
default=["ip_diffim_DipoleFit_classification",
62+
"is_negative",
63+
],
64+
)
65+
66+
67+
@dataclasses.dataclass(frozen=True, kw_only=True)
68+
class GlintTrailParameters:
69+
"""Holds values from the line fit to a single glint trail."""
70+
slope: float
71+
intercept: float
72+
stderr: float
73+
length: float # pixels
74+
angle: float # radians, from +X axis
75+
76+
77+
class FindGlintTrailsTask(lsst.pipe.base.Task):
78+
"""Find glint trails in a catalog by searching for sources that lie in a
79+
line.
80+
81+
Notes
82+
-----
83+
For each source ("anchor") in the input catalog that was not included in
84+
an an earlier iteration as part of a trail:
85+
* Find all sources within a given radius.
86+
* For each pair of anchor and match, identify the other sources that
87+
could lie on the same line(s).
88+
* Take the longest set of such pairs as a candidate trail.
89+
* Fit a line to the identified pairs with the RANSAC algorithm.
90+
* Find all sources in the catalog that could lie on that line.
91+
* Refit a line to all of the matched sources.
92+
* If the error is below the threshold and the number of sources on the
93+
line is greater than the minimum, return the sources that were
94+
considered inliers during the fit, and the fit parameters.
95+
"""
96+
97+
ConfigClass = FindGlintTrailsConfig
98+
_DefaultName = "findGlintTrails"
99+
100+
def run(self, catalog):
101+
"""Find glint trails in a catalog.
102+
103+
Parameters
104+
----------
105+
catalog : `lsst.afw.table.SourceCatalog`
106+
Catalog to search for glint trails.
107+
108+
Returns
109+
-------
110+
result : `lsst.pipe.base.Struct`
111+
Results as a struct with attributes:
112+
113+
``trails``
114+
Catalog subsets containing sources in each trail that was found.
115+
(`list` [`lsst.afw.table.SourceCatalog`])
116+
``trailed_ids``
117+
Ids of all the sources that were included in any fit trail.
118+
(`set` [`int`])
119+
``parameters``
120+
Parameters of all the trails that were found.
121+
(`list` [`GlintTrailParameters`])
122+
"""
123+
good_catalog = self._select_good_sources(catalog)
124+
125+
matches = lsst.afw.table.matchXy(good_catalog, self.config.radius)
126+
per_id = collections.defaultdict(list)
127+
for match in matches:
128+
per_id[match.first["id"]].append(match)
129+
counts = {id: len(value) for id, value in per_id.items()}
130+
131+
trails = []
132+
parameters = []
133+
trailed_ids = set()
134+
# Search starting with the source with the largest number of matches.
135+
for id in dict(sorted(counts.items(), key=lambda item: item[1], reverse=True)):
136+
# Don't search this point if it was already included in a trail.
137+
if counts[id] < self.config.min_points or id in trailed_ids:
138+
continue
139+
140+
self.log.debug("id=%d at %.1f,%.1f has %d matches within %d pixels.",
141+
id,
142+
per_id[id][0].first.getX(),
143+
per_id[id][0].first.getY(),
144+
counts[id],
145+
self.config.radius)
146+
if (trail := self._search_one(per_id[id], good_catalog)) is not None:
147+
trail, result = trail
148+
# Check that we didn't already find this trail.
149+
n_new = len(set(trail["id"]).difference(trailed_ids))
150+
if n_new > 0:
151+
self.log.info("Found %.1f pixel length trail with %d points, "
152+
"%d not in any other trail (slope=%.4f, intercept=%.2f)",
153+
result.length, len(trail), n_new, result.slope, result.intercept)
154+
trails.append(trail)
155+
trailed_ids.update(trail["id"])
156+
parameters.append(result)
157+
158+
self.log.info("Found %d glint trails containing %d total sources.",
159+
len(trails), len(trailed_ids))
160+
return lsst.pipe.base.Struct(trails=trails,
161+
trailed_ids=trailed_ids,
162+
parameters=parameters)
163+
164+
def _select_good_sources(self, catalog):
165+
"""Return sources that could possibly be in a glint trail, i.e. ones
166+
that do not have bad flags set.
167+
168+
Parameters
169+
----------
170+
catalog : `lsst.afw.table.SourceCatalog`
171+
Original catalog to be selected from.
172+
173+
Returns
174+
-------
175+
good_catalog : `lsst.afw.table.SourceCatalog`
176+
Catalog that has had bad sources removed.
177+
"""
178+
bad = np.zeros(len(catalog), dtype=bool)
179+
for flag in self.config.bad_flags:
180+
bad |= catalog[flag]
181+
return catalog[~bad]
182+
183+
def _search_one(self, matches, catalog):
184+
"""Search one set of matches for a possible trail.
185+
186+
Parameters
187+
----------
188+
matches : `list` [`lsst.afw.table.Match`]
189+
Matches for one anchor source to search for lines.
190+
catalog : `lsst.afw.SourceCatalog`
191+
Catalog of all sources, to refit lines to.
192+
193+
Returns
194+
-------
195+
trail, result : `tuple` or None
196+
If the no trails matching the criteria are found, return None,
197+
otherwise return a tuple of the sources in the trail and the
198+
trail parameters.
199+
"""
200+
components = collections.defaultdict(list)
201+
# Normalized distances from the first record to all the others.
202+
xy_deltas = {pair.second["id"]: (pair.second.getX() - pair.first.getX(),
203+
pair.second.getY() - pair.first.getY()) for pair in matches}
204+
205+
# Find all sets of pairs from this anchor that could lie on a line.
206+
for i, (id1, pair1) in enumerate(xy_deltas.items()):
207+
distance = math.sqrt(pair1[0]**2 + pair1[1]**2)
208+
for j, (id2, pair2) in enumerate(xy_deltas.items()):
209+
if i == j:
210+
continue
211+
delta = abs(pair1[0] * pair2[1] - pair1[1] * pair2[0])
212+
# 2x threshold to search more broadly; will be refined later.
213+
if delta / distance < 2 * self.config.threshold:
214+
components[i].append(j)
215+
216+
# There are no lines with at least 3 components.
217+
if len(components) == 0:
218+
return None
219+
220+
longest, value = max(components.items(), key=lambda x: len(x[1]))
221+
n_points = len(value)
222+
n_points += 2 # to account for the base source and the first pair
223+
if n_points < self.config.min_points:
224+
return None
225+
226+
candidate = [longest] + components[longest]
227+
trail, result = self._other_points(n_points, candidate, matches, catalog)
228+
229+
if trail is None or len(trail) < self.config.min_points:
230+
return None
231+
if result.stderr > self.config.threshold:
232+
self.log.info("Candidate trail with %d sources rejected with stderr %.6f > %.3f",
233+
len(trail), result.stderr, self.config.threshold)
234+
return None
235+
else:
236+
return trail, result
237+
238+
def _other_points(self, n_points, indexes, matches, catalog):
239+
"""Find all catalog records that could lie on this line.
240+
241+
Parameters
242+
----------
243+
n_points : `int`
244+
Number of sources in this candidate trail.
245+
indexes : `list` [`int`]
246+
Indexes into matches on this candidate trail.
247+
matches : `list` [`lsst.afw.table.Match`]
248+
Matches for one anchor sources to search for lines.
249+
catalog : `lsst.afw.SourceCatalog`
250+
Catalog of all sources, to refit lines to.
251+
252+
Returns
253+
-------
254+
trail : `lsst.afw.table.SourceCatalog`
255+
Sources that are in the fitted trail.
256+
result : `GlintTrailParameters`
257+
Parameters of the fitted trail.
258+
"""
259+
260+
def extract(fitter, x, y, prefix=""):
261+
"""Extract values from the fit and log and return them."""
262+
x = x[fitter.inlier_mask_]
263+
y = y[fitter.inlier_mask_]
264+
predicted = fitter.predict(x).flatten()
265+
stderr = math.sqrt(((predicted - y.flatten())**2).sum())
266+
m, b = fitter.estimator_.coef_[0][0], fitter.estimator_.intercept_[0]
267+
self.log.debug("%s fit: score=%.6f, stderr=%.6f, inliers/total=%d/%d",
268+
prefix, fitter.score(x, y), stderr, sum(fitter.inlier_mask_), len(x))
269+
# Simple O(N^2) search for longest distance; there will never be
270+
# enough points in a trail a for "faster" approach to be worth it.
271+
length = max(scipy.spatial.distance.pdist(np.hstack((x, y))))
272+
angle = math.atan(m)
273+
return GlintTrailParameters(slope=m, intercept=b, stderr=stderr, length=length, angle=angle)
274+
275+
# min_samples=2 is necessary here for some sets of only 5 matches,
276+
# otherwise we sometimes get "UndefinedMetricWarning: R^2 score is not
277+
# well-defined with less than two samples" from RANSAC.
278+
fitter = sklearn.linear_model.RANSACRegressor(residual_threshold=self.config.threshold,
279+
loss="squared_error",
280+
random_state=self.config.seed,
281+
min_samples=2)
282+
283+
# The (-1,1) shape is to keep sklearn happy.
284+
x = np.empty(n_points).reshape(-1, 1)
285+
x[0] = matches[0].first.getX()
286+
x[1:, 0] = [matches[i].second.getX() for i in indexes]
287+
y = np.empty(n_points).reshape(-1, 1)
288+
y[0] = matches[0].first.getY()
289+
y[1:, 0] = [matches[i].second.getY() for i in indexes]
290+
291+
fitter.fit(x, y)
292+
result = extract(fitter, x, y, prefix="preliminary")
293+
# Reject trails that have too many outliers after the first fit.
294+
if (n_inliers := sum(fitter.inlier_mask_)) < self.config.min_points:
295+
self.log.debug("Candidate trail rejected with %d < %d points.", n_inliers, self.config.min_points)
296+
return None, None
297+
298+
# Find all points that are close to this line and refit with them.
299+
x = catalog["slot_Centroid_x"]
300+
y = catalog["slot_Centroid_y"]
301+
dist = abs(result.intercept + result.slope * x - y) / math.sqrt(1 + result.slope**2)
302+
# 2x threshold to search more broadly: outlier rejection may change
303+
# the line parameters some and we want to grab all candidates here.
304+
candidates = (dist < 2 * self.config.threshold).flatten()
305+
# min_samples>2 should make the fit more stable.
306+
fitter = sklearn.linear_model.RANSACRegressor(residual_threshold=self.config.threshold,
307+
loss="squared_error",
308+
random_state=self.config.seed,
309+
min_samples=3)
310+
# The (-1,1) shape is to keep sklearn happy.
311+
x = x[candidates].reshape(-1, 1)
312+
y = y[candidates].reshape(-1, 1)
313+
fitter.fit(x, y)
314+
result = extract(fitter, x, y, prefix="final")
315+
316+
return catalog[candidates][fitter.inlier_mask_], result

0 commit comments

Comments
 (0)