View Javadoc
1   /*
2    * Licensed to the Hipparchus project under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.hipparchus.filtering.kalman.unscented;
18  
19  import org.hipparchus.exception.LocalizedCoreFormats;
20  import org.hipparchus.exception.MathIllegalArgumentException;
21  import org.hipparchus.exception.MathRuntimeException;
22  import org.hipparchus.filtering.kalman.KalmanFilter;
23  import org.hipparchus.filtering.kalman.Measurement;
24  import org.hipparchus.filtering.kalman.ProcessEstimate;
25  import org.hipparchus.linear.MatrixDecomposer;
26  import org.hipparchus.linear.MatrixUtils;
27  import org.hipparchus.linear.RealMatrix;
28  import org.hipparchus.linear.RealVector;
29  import org.hipparchus.util.UnscentedTransformProvider;
30  
31  /**
32   * Unscented Kalman filter for {@link UnscentedProcess unscented process}.
33   * @param <T> the type of the measurements
34   *
35   * @see "Wan, E. A., & Van Der Merwe, R. (2000, October). The unscented Kalman filter for nonlinear estimation.
36   *       In Proceedings of the IEEE 2000 Adaptive Systems for Signal Processing, Communications, and Control Symposium
37   *       (Cat. No. 00EX373) (pp. 153-158)"
38   * @since 2.2
39   */
40  public class UnscentedKalmanFilter<T extends Measurement> implements KalmanFilter<T> {
41  
42      /** Process to be estimated. */
43      private UnscentedProcess<T> process;
44  
45      /** Predicted state. */
46      private ProcessEstimate predicted;
47  
48      /** Corrected state. */
49      private ProcessEstimate corrected;
50  
51      /** Decompose to use for the correction phase. */
52      private final MatrixDecomposer decomposer;
53  
54      /** Number of estimated parameters. */
55      private final int n;
56  
57      /** Unscented transform provider. */
58      private final UnscentedTransformProvider utProvider;
59  
60      /** Simple constructor.
61       * @param decomposer decomposer to use for the correction phase
62       * @param process unscented process to estimate
63       * @param initialState initial state
64       * @param utProvider unscented transform provider
65       */
66      public UnscentedKalmanFilter(final MatrixDecomposer decomposer,
67                                   final UnscentedProcess<T> process,
68                                   final ProcessEstimate initialState,
69                                   final UnscentedTransformProvider utProvider) {
70          this.decomposer = decomposer;
71          this.process    = process;
72          this.corrected  = initialState;
73          this.n          = corrected.getState().getDimension();
74          this.utProvider = utProvider;
75          // Check state dimension
76          if (n == 0) {
77              // State dimension must be different from 0
78              throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_STATE_SIZE);
79          }
80      }
81  
82      /** {@inheritDoc} */
83      @Override
84      public ProcessEstimate estimationStep(final T measurement) throws MathRuntimeException {
85  
86          // Calculate sigma points
87          final RealVector[] sigmaPoints = utProvider.unscentedTransform(corrected.getState(), corrected.getCovariance());
88  
89          // Perform the prediction and correction steps
90          return predictionAndCorrectionSteps(measurement, sigmaPoints);
91  
92      }
93  
94      /** This method perform the prediction and correction steps of the Unscented Kalman Filter.
95       * @param measurement single measurement to handle
96       * @param sigmaPoints computed sigma points
97       * @return estimated state after measurement has been considered
98       * @throws MathRuntimeException if matrix cannot be decomposed
99       */
100     public ProcessEstimate predictionAndCorrectionSteps(final T measurement, final RealVector[] sigmaPoints) throws MathRuntimeException {
101 
102         // Prediction phase
103         final UnscentedEvolution evolution = process.getEvolution(getCorrected().getTime(),
104                                                                   sigmaPoints, measurement);
105 
106         predict(evolution.getCurrentTime(), evolution.getCurrentStates(),
107                 evolution.getProcessNoiseMatrix());
108 
109         // Calculate sigma points from predicted state
110         final RealVector[] predictedSigmaPoints = utProvider.unscentedTransform(predicted.getState(),
111                                                                                 predicted.getCovariance());
112 
113         // Correction phase
114         final RealVector[] predictedMeasurements = process.getPredictedMeasurements(predictedSigmaPoints, measurement);
115         final RealVector   predictedMeasurement  = utProvider.getUnscentedMeanState(predictedMeasurements);
116         final RealMatrix   r                     = computeInnovationCovarianceMatrix(predictedMeasurements, predictedMeasurement, measurement.getCovariance());
117         final RealMatrix   crossCovarianceMatrix = computeCrossCovarianceMatrix(predictedSigmaPoints, predicted.getState(),
118                                                                                 predictedMeasurements, predictedMeasurement);
119         final RealVector   innovation            = (r == null) ? null : process.getInnovation(measurement, predictedMeasurement, predicted.getState(), r);
120         correct(measurement, r, crossCovarianceMatrix, innovation);
121         return getCorrected();
122 
123     }
124 
125     /** Perform prediction step.
126      * @param time process time
127      * @param predictedStates predicted state vectors
128      * @param noise process noise covariance matrix
129      */
130     private void predict(final double time, final RealVector[] predictedStates,  final RealMatrix noise) {
131 
132         // Computation of Eq. 17, weighted mean state
133         final RealVector predictedState = utProvider.getUnscentedMeanState(predictedStates);
134 
135         // Computation of Eq. 18, predicted covariance matrix
136         final RealMatrix predictedCovariance = utProvider.getUnscentedCovariance(predictedStates, predictedState).add(noise);
137 
138         predicted = new ProcessEstimate(time, predictedState, predictedCovariance);
139         corrected = null;
140 
141     }
142 
143     /** Perform correction step.
144      * @param measurement single measurement to handle
145      * @param innovationCovarianceMatrix innovation covariance matrix
146      * (may be null if measurement should be ignored)
147      * @param crossCovarianceMatrix cross covariance matrix
148      * @param innovation innovation
149      * (may be null if measurement should be ignored)
150      * @exception MathIllegalArgumentException if matrix cannot be decomposed
151      */
152     private void correct(final T measurement, final RealMatrix innovationCovarianceMatrix,
153                            final RealMatrix crossCovarianceMatrix, final RealVector innovation)
154         throws MathIllegalArgumentException {
155 
156         if (innovation == null) {
157             // measurement should be ignored
158             corrected = predicted;
159             return;
160         }
161 
162         // compute Kalman gain k
163         // the following is equivalent to k = P_cross * (R_pred)^-1
164         // we don't want to compute the inverse of a matrix,
165         // we start by post-multiplying by R_pred and get
166         // k.(R_pred) = P_cross
167         // then we transpose, knowing that R_pred is a symmetric matrix
168         // (R_pred).k^T = P_cross^T
169         // then we can use linear system solving instead of matrix inversion
170         final RealMatrix k = decomposer.
171                              decompose(innovationCovarianceMatrix).
172                              solve(crossCovarianceMatrix.transpose()).transpose();
173 
174         // correct state vector
175         final RealVector correctedState = predicted.getState().add(k.operate(innovation));
176 
177         // correct covariance matrix
178         final RealMatrix correctedCovariance = predicted.getCovariance().
179                                                subtract(k.multiply(innovationCovarianceMatrix).multiplyTransposed(k));
180 
181         corrected = new ProcessEstimate(measurement.getTime(), correctedState, correctedCovariance,
182                                         null, null, innovationCovarianceMatrix, k);
183 
184     }
185     /** Get the predicted state.
186      * @return predicted state
187      */
188     @Override
189     public ProcessEstimate getPredicted() {
190         return predicted;
191     }
192 
193     /** Get the corrected state.
194      * @return corrected state
195      */
196     @Override
197     public ProcessEstimate getCorrected() {
198         return corrected;
199     }
200 
201     /** Get the unscented transform provider.
202      * @return unscented transform provider
203      */
204     public UnscentedTransformProvider getUnscentedTransformProvider() {
205         return utProvider;
206     }
207 
208     /** Computes innovation covariance matrix.
209      * @param predictedMeasurements predicted measurements (one per sigma point)
210      * @param predictedMeasurement predicted measurements
211      *        (may be null if measurement should be ignored)
212      * @param r measurement covariance
213      * @return innovation covariance matrix (null if predictedMeasurement is null)
214      */
215     private RealMatrix computeInnovationCovarianceMatrix(final RealVector[] predictedMeasurements,
216                                                          final RealVector predictedMeasurement,
217                                                          final RealMatrix r) {
218         if (predictedMeasurement == null) {
219             return null;
220         }
221         // Computation of the innovation covariance matrix
222         final RealMatrix innovationCovarianceMatrix = utProvider.getUnscentedCovariance(predictedMeasurements, predictedMeasurement);
223 
224         // Add the measurement covariance
225         return innovationCovarianceMatrix.add(r);
226     }
227 
228     /**
229      * Computes cross covariance matrix.
230      * @param predictedStates predicted states
231      * @param predictedState predicted state
232      * @param predictedMeasurements current measurements
233      * @param predictedMeasurement predicted measurements
234      * @return cross covariance matrix
235      */
236     private RealMatrix computeCrossCovarianceMatrix(final RealVector[] predictedStates, final RealVector predictedState,
237                                                     final RealVector[] predictedMeasurements, final RealVector predictedMeasurement) {
238 
239         // Initialize the cross covariance matrix
240         RealMatrix crossCovarianceMatrix = MatrixUtils.createRealMatrix(predictedState.getDimension(),
241                                                                         predictedMeasurement.getDimension());
242 
243         // Covariance weights
244         final RealVector wc = utProvider.getWc();
245 
246         // Compute the cross covariance matrix
247         for (int i = 0; i <= 2 * n; i++) {
248             final RealVector stateDiff = predictedStates[i].subtract(predictedState);
249             final RealVector measDiff  = predictedMeasurements[i].subtract(predictedMeasurement);
250             crossCovarianceMatrix = crossCovarianceMatrix.add(outer(stateDiff, measDiff).scalarMultiply(wc.getEntry(i)));
251         }
252 
253         // Return the cross covariance
254         return crossCovarianceMatrix;
255     }
256 
257     /** Computes the outer product of two vectors.
258      * @param a first vector
259      * @param b second vector
260      * @return the outer product of a and b
261      */
262     private RealMatrix outer(final RealVector a, final RealVector b) {
263 
264         // Initialize matrix
265         final RealMatrix outMatrix = MatrixUtils.createRealMatrix(a.getDimension(), b.getDimension());
266 
267         // Fill matrix
268         for (int row = 0; row < outMatrix.getRowDimension(); row++) {
269             for (int col = 0; col < outMatrix.getColumnDimension(); col++) {
270                 outMatrix.setEntry(row, col, a.getEntry(row) * b.getEntry(col));
271             }
272         }
273 
274         // Return
275         return outMatrix;
276     }
277 }