View Javadoc
1   /*
2    * Licensed to the Hipparchus project under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.hipparchus.filtering.kalman.unscented;
18  
19  import org.hipparchus.exception.LocalizedCoreFormats;
20  import org.hipparchus.exception.MathIllegalArgumentException;
21  import org.hipparchus.exception.MathRuntimeException;
22  import org.hipparchus.filtering.kalman.KalmanFilter;
23  import org.hipparchus.filtering.kalman.KalmanObserver;
24  import org.hipparchus.filtering.kalman.Measurement;
25  import org.hipparchus.filtering.kalman.ProcessEstimate;
26  import org.hipparchus.linear.MatrixDecomposer;
27  import org.hipparchus.linear.MatrixUtils;
28  import org.hipparchus.linear.RealMatrix;
29  import org.hipparchus.linear.RealVector;
30  import org.hipparchus.util.UnscentedTransformProvider;
31  
32  /**
33   * Unscented Kalman filter for {@link UnscentedProcess unscented process}.
34   * @param <T> the type of the measurements
35   *
36   * @see "Wan, E. A., & Van Der Merwe, R. (2000, October). The unscented Kalman filter for nonlinear estimation.
37   *       In Proceedings of the IEEE 2000 Adaptive Systems for Signal Processing, Communications, and Control Symposium
38   *       (Cat. No. 00EX373) (pp. 153-158)"
39   * @since 2.2
40   */
41  public class UnscentedKalmanFilter<T extends Measurement> implements KalmanFilter<T> {
42  
43      /** Process to be estimated. */
44      private final UnscentedProcess<T> process;
45  
46      /** Predicted state. */
47      private ProcessEstimate predicted;
48  
49      /** Corrected state. */
50      private ProcessEstimate corrected;
51  
52      /** Decompose to use for the correction phase. */
53      private final MatrixDecomposer decomposer;
54  
55      /** Number of estimated parameters. */
56      private final int n;
57  
58      /** Unscented transform provider. */
59      private final UnscentedTransformProvider utProvider;
60  
61      /** Prior corrected sigma-points. */
62      private RealVector[] priorSigmaPoints;
63  
64      /** Predicted sigma-points. */
65      private RealVector[] predictedNoNoiseSigmaPoints;
66  
67      /** Observer. */
68      private KalmanObserver observer;
69  
70      /** Simple constructor.
71       * @param decomposer decomposer to use for the correction phase
72       * @param process unscented process to estimate
73       * @param initialState initial state
74       * @param utProvider unscented transform provider
75       */
76      public UnscentedKalmanFilter(final MatrixDecomposer decomposer,
77                                   final UnscentedProcess<T> process,
78                                   final ProcessEstimate initialState,
79                                   final UnscentedTransformProvider utProvider) {
80          this.decomposer = decomposer;
81          this.process    = process;
82          this.corrected  = initialState;
83          this.n          = corrected.getState().getDimension();
84          this.utProvider = utProvider;
85          this.priorSigmaPoints = null;
86          this.predictedNoNoiseSigmaPoints = null;
87          this.observer = null;
88  
89          // Check state dimension
90          if (n == 0) {
91              // State dimension must be different from 0
92              throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_STATE_SIZE);
93          }
94      }
95  
96      /** {@inheritDoc} */
97      @Override
98      public ProcessEstimate estimationStep(final T measurement) throws MathRuntimeException {
99  
100         // Calculate sigma points
101         final RealVector[] sigmaPoints = utProvider.unscentedTransform(corrected.getState(), corrected.getCovariance());
102         priorSigmaPoints = sigmaPoints;
103 
104         // Perform the prediction and correction steps
105         return predictionAndCorrectionSteps(measurement, sigmaPoints);
106 
107     }
108 
109     /** This method perform the prediction and correction steps of the Unscented Kalman Filter.
110      * @param measurement single measurement to handle
111      * @param sigmaPoints computed sigma points
112      * @return estimated state after measurement has been considered
113      * @throws MathRuntimeException if matrix cannot be decomposed
114      */
115     private ProcessEstimate predictionAndCorrectionSteps(final T measurement, final RealVector[] sigmaPoints) throws MathRuntimeException {
116 
117         // Prediction phase
118         final UnscentedEvolution evolution = process.getEvolution(getCorrected().getTime(),
119                                                                   sigmaPoints, measurement);
120         predictedNoNoiseSigmaPoints = evolution.getCurrentStates();
121 
122         // Computation of Eq. 17, weighted mean state
123         final RealVector predictedState = utProvider.getUnscentedMeanState(evolution.getCurrentStates());
124 
125         // Calculate process noise
126         final RealMatrix processNoiseMatrix = process.getProcessNoiseMatrix(getCorrected().getTime(), predictedState,
127                                                                             measurement);
128 
129         predict(evolution.getCurrentTime(), evolution.getCurrentStates(), processNoiseMatrix);
130 
131         // Calculate sigma points from predicted state
132         final RealVector[] predictedSigmaPoints = utProvider.unscentedTransform(predicted.getState(),
133                                                                                 predicted.getCovariance());
134 
135         // Correction phase
136         final RealVector[] predictedMeasurements = process.getPredictedMeasurements(predictedSigmaPoints, measurement);
137         final RealVector   predictedMeasurement  = utProvider.getUnscentedMeanState(predictedMeasurements);
138         final RealMatrix   r                     = computeInnovationCovarianceMatrix(predictedMeasurements, predictedMeasurement, measurement.getCovariance());
139         final RealMatrix   crossCovarianceMatrix = computeCrossCovarianceMatrix(predictedSigmaPoints, predicted.getState(),
140                                                                                 predictedMeasurements, predictedMeasurement);
141         final RealVector   innovation            = (r == null) ? null : process.getInnovation(measurement, predictedMeasurement, predicted.getState(), r);
142         correct(measurement, r, crossCovarianceMatrix, innovation);
143 
144         if (observer != null) {
145             observer.updatePerformed(this);
146         }
147         return getCorrected();
148 
149     }
150 
151     /** Perform prediction step.
152      * @param time process time
153      * @param predictedStates predicted state vectors
154      * @param noise process noise covariance matrix
155      */
156     private void predict(final double time, final RealVector[] predictedStates, final RealMatrix noise) {
157 
158         // Computation of Eq. 17, weighted mean state
159         final RealVector predictedState = utProvider.getUnscentedMeanState(predictedStates);
160 
161         // Computation of Eq. 18, predicted covariance matrix
162         final RealMatrix predictedCovariance = utProvider.getUnscentedCovariance(predictedStates, predictedState).add(noise);
163 
164         predicted = new ProcessEstimate(time, predictedState, predictedCovariance);
165         corrected = null;
166 
167     }
168 
169     /** Perform correction step.
170      * @param measurement single measurement to handle
171      * @param innovationCovarianceMatrix innovation covariance matrix
172      * (may be null if measurement should be ignored)
173      * @param crossCovarianceMatrix cross covariance matrix
174      * @param innovation innovation
175      * (may be null if measurement should be ignored)
176      * @exception MathIllegalArgumentException if matrix cannot be decomposed
177      */
178     private void correct(final T measurement, final RealMatrix innovationCovarianceMatrix,
179                            final RealMatrix crossCovarianceMatrix, final RealVector innovation)
180         throws MathIllegalArgumentException {
181 
182         if (innovation == null) {
183             // measurement should be ignored
184             corrected = predicted;
185             return;
186         }
187 
188         // compute Kalman gain k
189         // the following is equivalent to k = P_cross * (R_pred)^-1
190         // we don't want to compute the inverse of a matrix,
191         // we start by post-multiplying by R_pred and get
192         // k.(R_pred) = P_cross
193         // then we transpose, knowing that R_pred is a symmetric matrix
194         // (R_pred).k^T = P_cross^T
195         // then we can use linear system solving instead of matrix inversion
196         final RealMatrix k = decomposer.
197                              decompose(innovationCovarianceMatrix).
198                              solve(crossCovarianceMatrix.transpose()).transpose();
199 
200         // correct state vector
201         final RealVector correctedState = predicted.getState().add(k.operate(innovation));
202 
203         // correct covariance matrix
204         final RealMatrix correctedCovariance = predicted.getCovariance().
205                                                subtract(k.multiply(innovationCovarianceMatrix).multiplyTransposed(k));
206 
207         corrected = new ProcessEstimate(measurement.getTime(), correctedState, correctedCovariance,
208                                         null, null, innovationCovarianceMatrix, k);
209 
210     }
211 
212     /** {@inheritDoc} */
213     @Override
214     public void setObserver(final KalmanObserver kalmanObserver) {
215         observer = kalmanObserver;
216         observer.init(this);
217     }
218 
219     /** Get the predicted state.
220      * @return predicted state
221      */
222     @Override
223     public ProcessEstimate getPredicted() {
224         return predicted;
225     }
226 
227     /** Get the corrected state.
228      * @return corrected state
229      */
230     @Override
231     public ProcessEstimate getCorrected() {
232         return corrected;
233     }
234 
235     /** {@inheritDoc} */
236     @Override
237     public RealMatrix getStateCrossCovariance() {
238         final RealVector priorState = utProvider.getUnscentedMeanState(priorSigmaPoints);
239         final RealVector predictedState = utProvider.getUnscentedMeanState(predictedNoNoiseSigmaPoints);
240 
241         return computeCrossCovarianceMatrix(priorSigmaPoints, priorState, predictedNoNoiseSigmaPoints, predictedState);
242     }
243 
244     /** Get the unscented transform provider.
245      * @return unscented transform provider
246      */
247     public UnscentedTransformProvider getUnscentedTransformProvider() {
248         return utProvider;
249     }
250 
251     /** Computes innovation covariance matrix.
252      * @param predictedMeasurements predicted measurements (one per sigma point)
253      * @param predictedMeasurement predicted measurements
254      *        (may be null if measurement should be ignored)
255      * @param r measurement covariance
256      * @return innovation covariance matrix (null if predictedMeasurement is null)
257      */
258     private RealMatrix computeInnovationCovarianceMatrix(final RealVector[] predictedMeasurements,
259                                                          final RealVector predictedMeasurement,
260                                                          final RealMatrix r) {
261         if (predictedMeasurement == null) {
262             return null;
263         }
264         // Computation of the innovation covariance matrix
265         final RealMatrix innovationCovarianceMatrix = utProvider.getUnscentedCovariance(predictedMeasurements, predictedMeasurement);
266 
267         // Add the measurement covariance
268         return innovationCovarianceMatrix.add(r);
269     }
270 
271     /**
272      * Computes cross covariance matrix.
273      * @param predictedStates predicted states
274      * @param predictedState predicted state
275      * @param predictedMeasurements current measurements
276      * @param predictedMeasurement predicted measurements
277      * @return cross covariance matrix
278      */
279     private RealMatrix computeCrossCovarianceMatrix(final RealVector[] predictedStates, final RealVector predictedState,
280                                                     final RealVector[] predictedMeasurements, final RealVector predictedMeasurement) {
281 
282         // Initialize the cross covariance matrix
283         RealMatrix crossCovarianceMatrix = MatrixUtils.createRealMatrix(predictedState.getDimension(),
284                                                                         predictedMeasurement.getDimension());
285 
286         // Covariance weights
287         final RealVector wc = utProvider.getWc();
288 
289         // Compute the cross covariance matrix
290         for (int i = 0; i <= 2 * n; i++) {
291             final RealVector stateDiff = predictedStates[i].subtract(predictedState);
292             final RealVector measDiff  = predictedMeasurements[i].subtract(predictedMeasurement);
293             crossCovarianceMatrix = crossCovarianceMatrix.add(stateDiff.outerProduct(measDiff).scalarMultiply(wc.getEntry(i)));
294         }
295 
296         // Return the cross covariance
297         return crossCovarianceMatrix;
298     }
299 
300 }