1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.hipparchus.filtering.kalman.unscented;
18
19 import org.hipparchus.exception.LocalizedCoreFormats;
20 import org.hipparchus.exception.MathIllegalArgumentException;
21 import org.hipparchus.exception.MathRuntimeException;
22 import org.hipparchus.filtering.kalman.KalmanFilter;
23 import org.hipparchus.filtering.kalman.KalmanObserver;
24 import org.hipparchus.filtering.kalman.Measurement;
25 import org.hipparchus.filtering.kalman.ProcessEstimate;
26 import org.hipparchus.linear.MatrixDecomposer;
27 import org.hipparchus.linear.MatrixUtils;
28 import org.hipparchus.linear.RealMatrix;
29 import org.hipparchus.linear.RealVector;
30 import org.hipparchus.util.UnscentedTransformProvider;
31
32
33
34
35
36
37
38
39
40
41 public class UnscentedKalmanFilter<T extends Measurement> implements KalmanFilter<T> {
42
43
44 private final UnscentedProcess<T> process;
45
46
47 private ProcessEstimate predicted;
48
49
50 private ProcessEstimate corrected;
51
52
53 private final MatrixDecomposer decomposer;
54
55
56 private final int n;
57
58
59 private final UnscentedTransformProvider utProvider;
60
61
62 private RealVector[] priorSigmaPoints;
63
64
65 private RealVector[] predictedNoNoiseSigmaPoints;
66
67
68 private KalmanObserver observer;
69
70
71
72
73
74
75
76 public UnscentedKalmanFilter(final MatrixDecomposer decomposer,
77 final UnscentedProcess<T> process,
78 final ProcessEstimate initialState,
79 final UnscentedTransformProvider utProvider) {
80 this.decomposer = decomposer;
81 this.process = process;
82 this.corrected = initialState;
83 this.n = corrected.getState().getDimension();
84 this.utProvider = utProvider;
85 this.priorSigmaPoints = null;
86 this.predictedNoNoiseSigmaPoints = null;
87 this.observer = null;
88
89
90 if (n == 0) {
91
92 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_STATE_SIZE);
93 }
94 }
95
96
97 @Override
98 public ProcessEstimate estimationStep(final T measurement) throws MathRuntimeException {
99
100
101 final RealVector[] sigmaPoints = utProvider.unscentedTransform(corrected.getState(), corrected.getCovariance());
102 priorSigmaPoints = sigmaPoints;
103
104
105 return predictionAndCorrectionSteps(measurement, sigmaPoints);
106
107 }
108
109
110
111
112
113
114
115 private ProcessEstimate predictionAndCorrectionSteps(final T measurement, final RealVector[] sigmaPoints) throws MathRuntimeException {
116
117
118 final UnscentedEvolution evolution = process.getEvolution(getCorrected().getTime(),
119 sigmaPoints, measurement);
120 predictedNoNoiseSigmaPoints = evolution.getCurrentStates();
121
122
123 final RealVector predictedState = utProvider.getUnscentedMeanState(evolution.getCurrentStates());
124
125
126 final RealMatrix processNoiseMatrix = process.getProcessNoiseMatrix(getCorrected().getTime(), predictedState,
127 measurement);
128
129 predict(evolution.getCurrentTime(), evolution.getCurrentStates(), processNoiseMatrix);
130
131
132 final RealVector[] predictedSigmaPoints = utProvider.unscentedTransform(predicted.getState(),
133 predicted.getCovariance());
134
135
136 final RealVector[] predictedMeasurements = process.getPredictedMeasurements(predictedSigmaPoints, measurement);
137 final RealVector predictedMeasurement = utProvider.getUnscentedMeanState(predictedMeasurements);
138 final RealMatrix r = computeInnovationCovarianceMatrix(predictedMeasurements, predictedMeasurement, measurement.getCovariance());
139 final RealMatrix crossCovarianceMatrix = computeCrossCovarianceMatrix(predictedSigmaPoints, predicted.getState(),
140 predictedMeasurements, predictedMeasurement);
141 final RealVector innovation = (r == null) ? null : process.getInnovation(measurement, predictedMeasurement, predicted.getState(), r);
142 correct(measurement, r, crossCovarianceMatrix, innovation);
143
144 if (observer != null) {
145 observer.updatePerformed(this);
146 }
147 return getCorrected();
148
149 }
150
151
152
153
154
155
156 private void predict(final double time, final RealVector[] predictedStates, final RealMatrix noise) {
157
158
159 final RealVector predictedState = utProvider.getUnscentedMeanState(predictedStates);
160
161
162 final RealMatrix predictedCovariance = utProvider.getUnscentedCovariance(predictedStates, predictedState).add(noise);
163
164 predicted = new ProcessEstimate(time, predictedState, predictedCovariance);
165 corrected = null;
166
167 }
168
169
170
171
172
173
174
175
176
177
178 private void correct(final T measurement, final RealMatrix innovationCovarianceMatrix,
179 final RealMatrix crossCovarianceMatrix, final RealVector innovation)
180 throws MathIllegalArgumentException {
181
182 if (innovation == null) {
183
184 corrected = predicted;
185 return;
186 }
187
188
189
190
191
192
193
194
195
196 final RealMatrix k = decomposer.
197 decompose(innovationCovarianceMatrix).
198 solve(crossCovarianceMatrix.transpose()).transpose();
199
200
201 final RealVector correctedState = predicted.getState().add(k.operate(innovation));
202
203
204 final RealMatrix correctedCovariance = predicted.getCovariance().
205 subtract(k.multiply(innovationCovarianceMatrix).multiplyTransposed(k));
206
207 corrected = new ProcessEstimate(measurement.getTime(), correctedState, correctedCovariance,
208 null, null, innovationCovarianceMatrix, k);
209
210 }
211
212
213 @Override
214 public void setObserver(final KalmanObserver kalmanObserver) {
215 observer = kalmanObserver;
216 observer.init(this);
217 }
218
219
220
221
222 @Override
223 public ProcessEstimate getPredicted() {
224 return predicted;
225 }
226
227
228
229
230 @Override
231 public ProcessEstimate getCorrected() {
232 return corrected;
233 }
234
235
236 @Override
237 public RealMatrix getStateCrossCovariance() {
238 final RealVector priorState = utProvider.getUnscentedMeanState(priorSigmaPoints);
239 final RealVector predictedState = utProvider.getUnscentedMeanState(predictedNoNoiseSigmaPoints);
240
241 return computeCrossCovarianceMatrix(priorSigmaPoints, priorState, predictedNoNoiseSigmaPoints, predictedState);
242 }
243
244
245
246
247 public UnscentedTransformProvider getUnscentedTransformProvider() {
248 return utProvider;
249 }
250
251
252
253
254
255
256
257
258 private RealMatrix computeInnovationCovarianceMatrix(final RealVector[] predictedMeasurements,
259 final RealVector predictedMeasurement,
260 final RealMatrix r) {
261 if (predictedMeasurement == null) {
262 return null;
263 }
264
265 final RealMatrix innovationCovarianceMatrix = utProvider.getUnscentedCovariance(predictedMeasurements, predictedMeasurement);
266
267
268 return innovationCovarianceMatrix.add(r);
269 }
270
271
272
273
274
275
276
277
278
279 private RealMatrix computeCrossCovarianceMatrix(final RealVector[] predictedStates, final RealVector predictedState,
280 final RealVector[] predictedMeasurements, final RealVector predictedMeasurement) {
281
282
283 RealMatrix crossCovarianceMatrix = MatrixUtils.createRealMatrix(predictedState.getDimension(),
284 predictedMeasurement.getDimension());
285
286
287 final RealVector wc = utProvider.getWc();
288
289
290 for (int i = 0; i <= 2 * n; i++) {
291 final RealVector stateDiff = predictedStates[i].subtract(predictedState);
292 final RealVector measDiff = predictedMeasurements[i].subtract(predictedMeasurement);
293 crossCovarianceMatrix = crossCovarianceMatrix.add(stateDiff.outerProduct(measDiff).scalarMultiply(wc.getEntry(i)));
294 }
295
296
297 return crossCovarianceMatrix;
298 }
299
300 }