Add Multiple Model Framework#1257
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1257 +/- ##
==========================================
+ Coverage 91.92% 94.45% +2.53%
==========================================
Files 233 245 +12
Lines 16843 17053 +210
Branches 2372 2397 +25
==========================================
+ Hits 15483 16108 +625
+ Misses 1006 636 -370
+ Partials 354 309 -45
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
sdhiscocks
left a comment
There was a problem hiding this comment.
Thanks @jswright-dstl
Doc strings are required for new classes.
It would be useful to have a brief explanation/summary for each algorithm in the tutorial, and possibly a metric for performance comparison.
| class Augmentor(Base): | ||
| transition_probabilities: TransitionMatrix = Property(doc="TPM") | ||
| transition_models: Sequence[TransitionModel] = Property(doc="List of transition models") | ||
| histories: int = Property(doc="Depth of history to be stored") |
There was a problem hiding this comment.
Could do with complete doc strings
| from ..augmentor.base import Augmentor | ||
|
|
||
|
|
||
| class IdentityAugmentor(Augmentor): |
| try: | ||
| (isinstance(states[0][0], WeightedGaussianState)) | ||
| except TypeError or IndexError: | ||
| print(len(states)) | ||
|
|
There was a problem hiding this comment.
Legacy debug?
| try: | |
| (isinstance(states[0][0], WeightedGaussianState)) | |
| except TypeError or IndexError: | |
| print(len(states)) |
| if isinstance(states[0], ExpandedModelAugmentedWeightedGaussianState): | ||
| pass | ||
| else: | ||
| Likelihood_j = [] | ||
| for i in range(len(states)): | ||
| Likelihood_j.append(mvn.pdf( | ||
| states[i].hypothesis.measurement.state_vector.T, | ||
| states[i].hypothesis.measurement_prediction.mean.ravel(), | ||
| states[i].hypothesis.measurement_prediction.covar)) | ||
| c_j = self.transition_probabilities[states[0]].T @ states.weights | ||
| weights = Likelihood_j * c_j.ravel() | ||
| weights = weights / np.sum(weights) | ||
| for i in range(len(states)): | ||
| states[i].weight = weights[i] |
There was a problem hiding this comment.
| if isinstance(states[0], ExpandedModelAugmentedWeightedGaussianState): | |
| pass | |
| else: | |
| Likelihood_j = [] | |
| for i in range(len(states)): | |
| Likelihood_j.append(mvn.pdf( | |
| states[i].hypothesis.measurement.state_vector.T, | |
| states[i].hypothesis.measurement_prediction.mean.ravel(), | |
| states[i].hypothesis.measurement_prediction.covar)) | |
| c_j = self.transition_probabilities[states[0]].T @ states.weights | |
| weights = Likelihood_j * c_j.ravel() | |
| weights = weights / np.sum(weights) | |
| for i in range(len(states)): | |
| states[i].weight = weights[i] | |
| if not isinstance(states[0], ExpandedModelAugmentedWeightedGaussianState): | |
| Likelihood_j = [] | |
| for i in range(len(states)): | |
| Likelihood_j.append(mvn.pdf( | |
| states[i].hypothesis.measurement.state_vector.T, | |
| states[i].hypothesis.measurement_prediction.mean.ravel(), | |
| states[i].hypothesis.measurement_prediction.covar)) | |
| c_j = self.transition_probabilities[states[0]].T @ states.weights | |
| weights = Likelihood_j * c_j.ravel() | |
| weights = weights / np.sum(weights) | |
| for i in range(len(states)): | |
| states[i].weight = weights[i] |
| else: | ||
| pass | ||
|
|
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
|
|
There was a problem hiding this comment.
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) |
| elif len(self.transition_probabilities[states[0]].ravel())**2 == len(states.weights): | ||
| m_ij_weights = ( | ||
| self.transition_probabilities.transition_matrices[1].ravel() * | ||
| (states.weights/np.sum(states.weights)))*np.random.rand(len(states.weights)) |
There was a problem hiding this comment.
Would be good to be able to seed random for this, like models, etc..
| self.transition_probabilities[states[0]] @ | ||
| (states.weights/np.sum(states.weights))) |
There was a problem hiding this comment.
| self.transition_probabilities[states[0]] @ | |
| (states.weights/np.sum(states.weights))) | |
| self.transition_probabilities[states[0]] | |
| @ (states.weights/np.sum(states.weights))) |
| from ..base import Base, Property | ||
|
|
||
|
|
||
| class TransitionMatrix(Base): |
| from stonesoup.models.transition.linear import (CombinedLinearGaussianTransitionModel as CLGTM, | ||
| ConstantVelocity as CV, | ||
| KnownTurnRate as CT) |
There was a problem hiding this comment.
These are re-imported a couple lines down
| segment_lengths = [150, 200, 40, 165, 45] | ||
| segment_lengths = [int(x) for x in segment_lengths] |
There was a problem hiding this comment.
| segment_lengths = [150, 200, 40, 165, 45] | |
| segment_lengths = [int(x) for x in segment_lengths] | |
| segment_lengths = [150, 200, 40, 165, 45] |
The second line will return the same list right?
| print(len(states)) | ||
|
|
||
| if isinstance(states[0], list): | ||
| if isinstance(states[0][0], WeightedGaussianState): |
There was a problem hiding this comment.
what happens if this if is false? is it correct that states will remain as the original list
| for i in range(len(states)): | ||
| Likelihood_j.append(mvn.pdf( | ||
| states[i].hypothesis.measurement.state_vector.T, | ||
| states[i].hypothesis.measurement_prediction.mean.ravel(), | ||
| states[i].hypothesis.measurement_prediction.covar)) |
There was a problem hiding this comment.
| for i in range(len(states)): | |
| Likelihood_j.append(mvn.pdf( | |
| states[i].hypothesis.measurement.state_vector.T, | |
| states[i].hypothesis.measurement_prediction.mean.ravel(), | |
| states[i].hypothesis.measurement_prediction.covar)) | |
| for state in states: | |
| Likelihood_j.append(mvn.pdf( | |
| state.hypothesis.measurement.state_vector.T, | |
| state.hypothesis.measurement_prediction.mean.ravel(), | |
| state.hypothesis.measurement_prediction.covar)) |
| if str(history_length) not in [x for x in self.transition_matrices.keys()]: | ||
| history_length = max([x for x in self.transition_matrices.keys()]) + 1 |
There was a problem hiding this comment.
| if str(history_length) not in [x for x in self.transition_matrices.keys()]: | |
| history_length = max([x for x in self.transition_matrices.keys()]) + 1 | |
| if history_length not in self.transition_matrices.keys(): | |
| history_length = max(self.transition_matrices.keys()) + 1 |
| if self.model_histories is None: | ||
| self.model_histories = [] | ||
| if self.existence is None: | ||
| self.existence = Probability(1) |
There was a problem hiding this comment.
Can move this into the "default factory" way
| gtplotter = Plotter() | ||
| ms = [m for _, m in all_measurements] | ||
| gtplotter.plot_ground_truths(truths, [0, 2], linewidth=2) | ||
| gtplotter.fig |
There was a problem hiding this comment.
| gtplotter = Plotter() | |
| ms = [m for _, m in all_measurements] | |
| gtplotter.plot_ground_truths(truths, [0, 2], linewidth=2) | |
| gtplotter.fig | |
| plotter = Plotter() | |
| ms = [m for _, m in all_measurements] | |
| plotter.plot_ground_truths(truths, [0, 2], linewidth=2) | |
| plotter.fig |
Makes sense to have the truth on the same plot as the tracks
| measurement_reducer, priors, all_measurements) | ||
|
|
||
| # %% | ||
| plotter = Plotter() |
There was a problem hiding this comment.
| plotter = Plotter() |
This PR adds the components for the Multiple model Framework and provides a tutorial introducing the KF, GPB1, GPB2 and IMM implemented through the Multiple Model Framework.