View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements. See the NOTICE file distributed with this
4    * work for additional information regarding copyright ownership. The ASF
5    * licenses this file to You under the Apache License, Version 2.0 (the
6    * "License"); you may not use this file except in compliance with the License.
7    * You may obtain a copy of the License at
8    * https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law
9    * or agreed to in writing, software distributed under the License is
10   * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
11   * KIND, either express or implied. See the License for the specific language
12   * governing permissions and limitations under the License.
13   */
14  package org.hipparchus.optim.nonlinear.vector.leastsquares;
15  
16  import org.hipparchus.UnitTestUtils;
17  import org.hipparchus.linear.ArrayRealVector;
18  import org.hipparchus.linear.DiagonalMatrix;
19  import org.hipparchus.linear.RealVector;
20  import org.hipparchus.util.FastMath;
21  import org.junit.jupiter.api.Disabled;
22  import org.junit.jupiter.api.Test;
23  
24  import java.awt.geom.Point2D;
25  import java.util.ArrayList;
26  import java.util.List;
27  
28  import static org.junit.jupiter.api.Assertions.assertEquals;
29  
30  /**
31   * This class demonstrates the main functionality of the
32   * {@link LeastSquaresProblem.Evaluation}, common to the
33   * optimizer implementations in package
34   * {@link org.hipparchus.optim.nonlinear.vector.leastsquares}.
35   * <br>
36   * Not enabled by default, as the class name does not end with "Test".
37   * <br>
38   * Invoke by running
39   * <pre><code>
40   *  mvn test -Dtest=EvaluationTestValidation
41   * </code></pre>
42   * or by running
43   * <pre><code>
44   *  mvn test -Dtest=EvaluationTestValidation -DargLine="-DmcRuns=1234 -server"
45   * </code></pre>
46   */
47  class EvaluationTestValidation {
48      /** Number of runs. */
49      private static final int MONTE_CARLO_RUNS = Integer.parseInt(System.getProperty("mcRuns",
50                                                                                      "100"));
51  
52      /**
53       * Using a Monte-Carlo procedure, this test checks the error estimations
54       * as provided by the square-root of the diagonal elements of the
55       * covariance matrix.
56       * <br>
57       * The test generates sets of observations, each sampled from
58       * a Gaussian distribution.
59       * <br>
60       * The optimization problem solved is defined in class
61       * {@link StraightLineProblem}.
62       * <br>
63       * The output (on stdout) will be a table summarizing the distribution
64       * of parameters generated by the Monte-Carlo process and by the direct
65       * estimation provided by the diagonal elements of the covariance matrix.
66       */
67      @Disabled
68      @Test
69      void testParametersErrorMonteCarloObservations() {
70          // Error on the observations.
71          final double yError = 15;
72  
73          // True values of the parameters.
74          final double slope = 123.456;
75          final double offset = -98.765;
76  
77          // Samples generator.
78          final RandomStraightLinePointGenerator lineGenerator
79              = new RandomStraightLinePointGenerator(slope, offset,
80                                                     yError,
81                                                     -1e3, 1e4,
82                                                     138577L);
83  
84          // Number of observations.
85          final int numObs = 100; // XXX Should be a command-line option.
86          // number of parameters.
87          final int numParams = 2;
88  
89          // Parameters found for each of Monte-Carlo run.
90          final UnitTestUtils.SimpleStatistics[] paramsFoundByDirectSolution = new UnitTestUtils.SimpleStatistics[numParams];
91          // Sigma estimations (square-root of the diagonal elements of the
92          // covariance matrix), for each Monte-Carlo run.
93          final UnitTestUtils.SimpleStatistics[] sigmaEstimate = new UnitTestUtils.SimpleStatistics[numParams];
94  
95          // Initialize statistics accumulators.
96          for (int i = 0; i < numParams; i++) {
97              paramsFoundByDirectSolution[i] = new UnitTestUtils.SimpleStatistics();
98              sigmaEstimate[i] = new UnitTestUtils.SimpleStatistics();
99          }
100 
101         final RealVector init = new ArrayRealVector(new double[]{ slope, offset }, false);
102 
103         // Monte-Carlo (generates many sets of observations).
104         final int mcRepeat = MONTE_CARLO_RUNS;
105         int mcCount = 0;
106         while (mcCount < mcRepeat) {
107             // Observations.
108             final Point2D.Double[] obs = lineGenerator.generate(numObs);
109 
110             final StraightLineProblem problem = new StraightLineProblem(yError);
111             for (int i = 0; i < numObs; i++) {
112                 final Point2D.Double p = obs[i];
113                 problem.addPoint(p.x, p.y);
114             }
115 
116             // Direct solution (using simple regression).
117             final double[] regress = problem.solve();
118 
119             // Estimation of the standard deviation (diagonal elements of the
120             // covariance matrix).
121             final LeastSquaresProblem lsp = builder(problem).build();
122 
123             final RealVector sigma = lsp.evaluate(init).getSigma(1e-14);
124 
125             // Accumulate statistics.
126             for (int i = 0; i < numParams; i++) {
127                 paramsFoundByDirectSolution[i].addValue(regress[i]);
128                 sigmaEstimate[i].addValue(sigma.getEntry(i));
129             }
130 
131             // Next Monte-Carlo.
132             ++mcCount;
133         }
134 
135         // Print statistics.
136         final String line = "--------------------------------------------------------------";
137         System.out.println("                 True value       Mean        Std deviation");
138         for (int i = 0; i < numParams; i++) {
139             System.out.println(line);
140             System.out.println("Parameter #" + i);
141 
142             System.out.printf("              %+.6e   %+.6e   %+.6e\n",
143                               init.getEntry(i),
144                               paramsFoundByDirectSolution[i].getMean(),
145                               paramsFoundByDirectSolution[i].getStandardDeviation());
146 
147             System.out.printf("sigma: %+.6e (%+.6e)\n",
148                               sigmaEstimate[i].getMean(),
149                               sigmaEstimate[i].getStandardDeviation());
150         }
151         System.out.println(line);
152 
153         // Check the error estimation.
154         for (int i = 0; i < numParams; i++) {
155             assertEquals(paramsFoundByDirectSolution[i].getStandardDeviation(),
156                                 sigmaEstimate[i].getMean(),
157                                 8e-2);
158         }
159     }
160 
161     /**
162      * In this test, the set of observations is fixed.
163      * Using a Monte-Carlo procedure, it generates sets of parameters,
164      * and determine the parameter change that will result in the
165      * normalized chi-square becoming larger by one than the value from
166      * the best fit solution.
167      * <br>
168      * The optimization problem solved is defined in class
169      * {@link StraightLineProblem}.
170      * <br>
171      * The output (on stdout) will be a list of lines containing:
172      * <ul>
173      *  <li>slope of the straight line,</li>
174      *  <li>intercept of the straight line,</li>
175      *  <li>chi-square of the solution defined by the above two values.</li>
176      * </ul>
177      * The output is separated into two blocks (with a blank line between
178      * them); the first block will contain all parameter sets for which
179      * {@code chi2 < chi2_b + 1}
180      * and the second block, all sets for which
181      * {@code chi2 >= chi2_b + 1}
182      * where {@code chi2_b} is the lowest chi-square (corresponding to the
183      * best solution).
184      */
185     @Disabled
186     @Test
187     void testParametersErrorMonteCarloParameters() {
188         // Error on the observations.
189         final double yError = 15;
190 
191         // True values of the parameters.
192         final double slope = 123.456;
193         final double offset = -98.765;
194 
195         // Samples generator.
196         final RandomStraightLinePointGenerator lineGenerator
197             = new RandomStraightLinePointGenerator(slope, offset,
198                                                    yError,
199                                                    -1e3, 1e4,
200                                                    13839013L);
201 
202         // Number of observations.
203         final int numObs = 10;
204         // number of parameters.
205 
206         // Create a single set of observations.
207         final Point2D.Double[] obs = lineGenerator.generate(numObs);
208 
209         final StraightLineProblem problem = new StraightLineProblem(yError);
210         for (int i = 0; i < numObs; i++) {
211             final Point2D.Double p = obs[i];
212             problem.addPoint(p.x, p.y);
213         }
214 
215         // Direct solution (using simple regression).
216         final RealVector regress = new ArrayRealVector(problem.solve(), false);
217 
218         // Dummy optimizer (to compute the chi-square).
219         final LeastSquaresProblem lsp = builder(problem).build();
220 
221         // Get chi-square of the best parameters set for the given set of
222         // observations.
223         final double bestChi2N = getChi2N(lsp, regress);
224         final RealVector sigma = lsp.evaluate(regress).getSigma(1e-14);
225 
226         // Monte-Carlo (generates a grid of parameters).
227         final int mcRepeat = MONTE_CARLO_RUNS;
228         final int gridSize = (int) FastMath.sqrt(mcRepeat);
229 
230         // Parameters found for each of Monte-Carlo run.
231         // Index 0 = slope
232         // Index 1 = offset
233         // Index 2 = normalized chi2
234         final List<double[]> paramsAndChi2 = new ArrayList<double[]>(gridSize * gridSize);
235 
236         final double slopeRange = 10 * sigma.getEntry(0);
237         final double offsetRange = 10 * sigma.getEntry(1);
238         final double minSlope = slope - 0.5 * slopeRange;
239         final double minOffset = offset - 0.5 * offsetRange;
240         final double deltaSlope =  slopeRange/ gridSize;
241         final double deltaOffset = offsetRange / gridSize;
242         for (int i = 0; i < gridSize; i++) {
243             final double s = minSlope + i * deltaSlope;
244             for (int j = 0; j < gridSize; j++) {
245                 final double o = minOffset + j * deltaOffset;
246                 final double chi2N = getChi2N(lsp,
247                         new ArrayRealVector(new double[] {s, o}, false));
248 
249                 paramsAndChi2.add(new double[] {s, o, chi2N});
250             }
251         }
252 
253         // Output (for use with "gnuplot").
254 
255         // Some info.
256 
257         // For plotting separately sets of parameters that have a large chi2.
258         final double chi2NPlusOne = bestChi2N + 1;
259         int numLarger = 0;
260 
261         final String lineFmt = "%+.10e %+.10e   %.8e\n";
262 
263         // Point with smallest chi-square.
264         System.out.printf(lineFmt, regress.getEntry(0), regress.getEntry(1), bestChi2N);
265         System.out.println(); // Empty line.
266 
267         // Points within the confidence interval.
268         for (double[] d : paramsAndChi2) {
269             if (d[2] <= chi2NPlusOne) {
270                 System.out.printf(lineFmt, d[0], d[1], d[2]);
271             }
272         }
273         System.out.println(); // Empty line.
274 
275         // Points outside the confidence interval.
276         for (double[] d : paramsAndChi2) {
277             if (d[2] > chi2NPlusOne) {
278                 ++numLarger;
279                 System.out.printf(lineFmt, d[0], d[1], d[2]);
280             }
281         }
282         System.out.println(); // Empty line.
283 
284         System.out.println("# sigma=" + sigma.toString());
285         System.out.println("# " + numLarger + " sets filtered out");
286     }
287 
288     LeastSquaresBuilder builder(StraightLineProblem problem){
289         return new LeastSquaresBuilder()
290                 .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
291                 .target(problem.target())
292                 .weight(new DiagonalMatrix(problem.weight()))
293                 //unused start point to avoid NPE
294                 .start(new double[2]);
295     }
296     /**
297      * @return the normalized chi-square.
298      */
299     private double getChi2N(LeastSquaresProblem lsp,
300                             RealVector params) {
301         final double cost = lsp.evaluate(params).getCost();
302         return cost * cost / (lsp.getObservationSize() - params.getDimension());
303     }
304 }
305