View Javadoc
1   /*
2    * Licensed to the Hipparchus project under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.hipparchus.filtering.kalman;
18  
19  import org.hipparchus.exception.MathIllegalStateException;
20  import org.hipparchus.filtering.LocalizedFilterFormats;
21  import org.hipparchus.linear.MatrixDecomposer;
22  import org.hipparchus.linear.RealMatrix;
23  import org.hipparchus.linear.RealVector;
24  
25  import java.util.ArrayList;
26  import java.util.LinkedList;
27  import java.util.List;
28  
29  /**
30   * Kalman smoother for linear, extended or unscented filters.
31   * <p>
32   * This implementation is attached to a filter using the observer mechanism.  Once all measurements have been
33   * processed by the filter, the smoothing method can be called.
34   * </p>
35   * <p>
36   * For example
37   * </p>
38   * <pre>{@code
39   *     // Kalman filter
40   *     final KalmanFilter<SimpleMeasurement> filter = new LinearKalmanFilter<>(decomposer, process, initialState);
41   *
42   *     // Smoother observer
43   *     final KalmanSmoother smoother = new KalmanSmoother(decomposer);
44   *     filter.setObserver(smoother);
45   *
46   *     // Process measurements with filter (forwards pass)
47   *     measurements.forEach(filter::estimationStep);
48   *
49   *     // Smooth backwards
50   *     List<ProcessEstimate> smoothedStates = smoother.backwardsSmooth();
51   * }</pre>
52   *
53   * @see "Särkkä, S. Bayesian Filtering and Smoothing. Cambridge 2013"
54   */
55  public class KalmanSmoother implements KalmanObserver {
56  
57      /** Decomposer to use for gain calculation. */
58      private final MatrixDecomposer decomposer;
59  
60      /** Storage for smoother gain matrices. */
61      private final List<SmootherData> smootherData;
62  
63      /** Simple constructor.
64       * @param decomposer decomposer to use for the smoother gain calculations
65       */
66      public KalmanSmoother(final MatrixDecomposer decomposer) {
67          this.decomposer = decomposer;
68          this.smootherData = new ArrayList<>();
69      }
70  
71      @Override
72      public void init(KalmanEstimate estimate) {
73          // Add initial state to smoother data
74          smootherData.add(new SmootherData(
75                  estimate.getCorrected().getTime(),
76                  null,
77                  null,
78                  estimate.getCorrected().getState(),
79                  estimate.getCorrected().getCovariance(),
80                  null
81          ));
82  
83      }
84  
85      @Override
86      public void updatePerformed(KalmanEstimate estimate) {
87          // Smoother gain
88          // We want G = D * P^(-1)
89          // Calculate with G = (P^(-1) * D^T)^T
90          final RealMatrix smootherGain = decomposer
91                  .decompose(estimate.getPredicted().getCovariance())
92                  .solve(estimate.getStateCrossCovariance().transpose())
93                  .transpose();
94          smootherData.add(new SmootherData(
95                  estimate.getCorrected().getTime(),
96                  estimate.getPredicted().getState(),
97                  estimate.getPredicted().getCovariance(),
98                  estimate.getCorrected().getState(),
99                  estimate.getCorrected().getCovariance(),
100                 smootherGain
101         ));
102     }
103 
104     /** Backwards smooth.
105      * This is a backward pass over the filtered data, recursively calculating smoothed states, using the
106      * Rauch-Tung-Striebel (RTS) formulation.
107      * Note that the list result is a `LinkedList`, not an `ArrayList`.
108      * @return list of smoothed states
109      */
110     public List<ProcessEstimate> backwardsSmooth() {
111         // Check for at least one measurement
112         if (smootherData.size() < 2) {
113             throw new MathIllegalStateException(LocalizedFilterFormats.PROCESS_AT_LEAST_ONE_MEASUREMENT);
114         }
115 
116         // Initialise output
117         final LinkedList<ProcessEstimate> smootherResults = new LinkedList<>();
118 
119         // Last smoothed state is the same as the filtered state
120         final SmootherData lastUpdate = smootherData.get(smootherData.size() - 1);
121         ProcessEstimate smoothedState = new ProcessEstimate(lastUpdate.getTime(),
122                 lastUpdate.getCorrectedState(), lastUpdate.getCorrectedCovariance());
123         smootherResults.addFirst(smoothedState);
124 
125         // Backwards recursion on the smoothed state
126         for (int i = smootherData.size() - 2; i >= 0; --i) {
127 
128             // These are from equation 8.6 in Sarkka, "Bayesian Filtering and Smoothing", Cambridge, 2013.
129             final RealMatrix smootherGain = smootherData.get(i + 1).getSmootherGain();
130 
131             final RealVector smoothedMean = smootherData.get(i).getCorrectedState()
132                     .add(smootherGain.operate(smoothedState.getState()
133                             .subtract(smootherData.get(i + 1).getPredictedState())));
134 
135             final RealMatrix smoothedCovariance = smootherData.get(i).getCorrectedCovariance()
136                     .add(smootherGain.multiply(smoothedState.getCovariance()
137                                     .subtract(smootherData.get(i + 1).getPredictedCovariance()))
138                             .multiplyTransposed(smootherGain));
139 
140             // Populate smoothed state
141             smoothedState = new ProcessEstimate(smootherData.get(i).getTime(), smoothedMean, smoothedCovariance);
142             smootherResults.addFirst(smoothedState);
143         }
144 
145         return smootherResults;
146     }
147 
148     /** Container for smoother data. */
149     private static class SmootherData {
150         /** Process time (typically the time or index of a measurement). */
151         private final double time;
152 
153         /** Predicted state vector. */
154         private final RealVector predictedState;
155 
156         /** Predicted covariance. */
157         private final RealMatrix predictedCovariance;
158 
159         /** Corrected state vector. */
160         private final RealVector correctedState;
161 
162         /** Corrected covariance. */
163         private final RealMatrix correctedCovariance;
164 
165         /** Smoother gain. */
166         private final RealMatrix smootherGain;
167 
168         SmootherData(final double time,
169                      final RealVector predictedState,
170                      final RealMatrix predictedCovariance,
171                      final RealVector correctedState,
172                      final RealMatrix correctedCovariance,
173                      final RealMatrix smootherGain) {
174             this.time = time;
175             this.predictedState = predictedState;
176             this.predictedCovariance = predictedCovariance;
177             this.correctedState = correctedState;
178             this.correctedCovariance = correctedCovariance;
179             this.smootherGain = smootherGain;
180         }
181 
182         /** Get the process time.
183          * @return process time (typically the time or index of a measurement)
184          */
185         public double getTime() {
186             return time;
187         }
188 
189         /**
190          * Get predicted state
191          * @return predicted state
192          */
193         public RealVector getPredictedState() {
194             return predictedState;
195         }
196 
197         /**
198          * Get predicted covariance
199          * @return predicted covariance
200          */
201         public RealMatrix getPredictedCovariance() {
202             return predictedCovariance;
203         }
204 
205         /**
206          * Get corrected state
207          * @return corrected state
208          */
209         public RealVector getCorrectedState() {
210             return correctedState;
211         }
212 
213         /**
214          * Get corrected covariance
215          * @return corrected covariance
216          */
217         public RealMatrix getCorrectedCovariance() {
218             return correctedCovariance;
219         }
220 
221         /**
222          * Get smoother gain (for previous time-step)
223          * @return smoother gain
224          */
225         public RealMatrix getSmootherGain() {
226             return smootherGain;
227         }
228     }
229 
230 }