View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.optim.nonlinear.vector.leastsquares;
24  
25  import org.hipparchus.analysis.MultivariateMatrixFunction;
26  import org.hipparchus.analysis.MultivariateVectorFunction;
27  import org.hipparchus.exception.LocalizedCoreFormats;
28  import org.hipparchus.exception.MathIllegalArgumentException;
29  import org.hipparchus.exception.MathIllegalStateException;
30  import org.hipparchus.geometry.euclidean.twod.Vector2D;
31  import org.hipparchus.linear.DiagonalMatrix;
32  import org.hipparchus.linear.RealMatrix;
33  import org.hipparchus.linear.RealVector;
34  import org.hipparchus.optim.ConvergenceChecker;
35  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresOptimizer.Optimum;
36  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem.Evaluation;
37  import org.hipparchus.util.FastMath;
38  import org.hipparchus.util.Incrementor;
39  import org.hipparchus.util.Precision;
40  import org.junit.jupiter.api.Test;
41  
42  import java.util.ArrayList;
43  import java.util.List;
44  
45  import static org.hamcrest.CoreMatchers.is;
46  import static org.hamcrest.MatcherAssert.assertThat;
47  import static org.junit.jupiter.api.Assertions.assertEquals;
48  import static org.junit.jupiter.api.Assertions.assertFalse;
49  import static org.junit.jupiter.api.Assertions.assertTrue;
50  
51  /**
52   * <p>Some of the unit tests are re-implementations of the MINPACK <a
53   * href="http://www.netlib.org/minpack/ex/file17">file17</a> and <a
54   * href="http://www.netlib.org/minpack/ex/file22">file22</a> test files.
55   * The redistribution policy for MINPACK is available <a
56   * href="http://www.netlib.org/minpack/disclaimer">here</a>.
57   *
58   */
59  public class LevenbergMarquardtOptimizerTest
60      extends AbstractLeastSquaresOptimizerAbstractTest{
61  
62      public LeastSquaresBuilder builder(BevingtonProblem problem){
63          return base()
64                  .model(problem.getModelFunction(), problem.getModelFunctionJacobian());
65      }
66  
67      public LeastSquaresBuilder builder(CircleProblem problem){
68          return base()
69                  .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
70                  .target(problem.target())
71                  .weight(new DiagonalMatrix(problem.weight()));
72      }
73  
74      @Override
75      public int getMaxIterations() {
76          return 25;
77      }
78  
79      @Override
80      public LeastSquaresOptimizer getOptimizer() {
81          return new LevenbergMarquardtOptimizer();
82      }
83  
84      @Override
85      @Test
86      public void testNonInvertible() {
87          /*
88           * Overrides the method from parent class, since the default singularity
89           * threshold (1e-14) does not trigger the expected exception.
90           */
91          LinearProblem problem = new LinearProblem(new double[][] {
92              {  1, 2, -3 },
93              {  2, 1,  3 },
94              { -3, 0, -9 }
95          }, new double[] { 1, 1, 1 });
96  
97          final Optimum optimum = optimizer.optimize(
98                                                     problem.getBuilder().maxIterations(20).build());
99  
100         //TODO check that it is a bad fit? Why the extra conditions?
101         assertTrue(FastMath.sqrt(problem.getTarget().length) * optimum.getRMS() > 0.6);
102 
103         try {
104             optimum.getCovariances(1.5e-14);
105             customFail(optimizer);
106         } catch (MathIllegalArgumentException e) {
107             assertEquals(LocalizedCoreFormats.SINGULAR_MATRIX, e.getSpecifier());
108         }
109 
110     }
111 
112     @Test
113     void testControlParameters() {
114         CircleVectorial circle = new CircleVectorial();
115         circle.addPoint( 30.0,  68.0);
116         circle.addPoint( 50.0,  -6.0);
117         circle.addPoint(110.0, -20.0);
118         circle.addPoint( 35.0,  15.0);
119         circle.addPoint( 45.0,  97.0);
120         checkEstimate(
121                 circle, 0.1, 10, 1.0e-14, 1.0e-16, 1.0e-10, false);
122         checkEstimate(
123                 circle, 0.1, 10, 1.0e-15, 1.0e-17, 1.0e-10, true);
124         checkEstimate(
125                 circle, 0.1,  5, 1.0e-15, 1.0e-16, 1.0e-10, true);
126         circle.addPoint(300, -300);
127         //wardev I changed true => false
128         //TODO why should this fail? It uses 15 evaluations.
129         checkEstimate(
130                 circle, 0.1, 20, 1.0e-18, 1.0e-16, 1.0e-10, false);
131     }
132 
133     private void checkEstimate(CircleVectorial circle,
134                                double initialStepBoundFactor, int maxCostEval,
135                                double costRelativeTolerance, double parRelativeTolerance,
136                                double orthoTolerance, boolean shouldFail) {
137         try {
138             final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer()
139                 .withInitialStepBoundFactor(initialStepBoundFactor)
140                 .withCostRelativeTolerance(costRelativeTolerance)
141                 .withParameterRelativeTolerance(parRelativeTolerance)
142                 .withOrthoTolerance(orthoTolerance)
143                 .withRankingThreshold(Precision.SAFE_MIN);
144 
145             final LeastSquaresProblem problem = builder(circle)
146                     .maxEvaluations(maxCostEval)
147                     .maxIterations(100)
148                     .start(new double[] { 98.680, 47.345 })
149                     .build();
150 
151             optimizer.optimize(problem);
152 
153             assertFalse(shouldFail);
154             //TODO check it got the right answer
155 
156         } catch (MathIllegalArgumentException ee) {
157             assertTrue(shouldFail);
158         } catch (MathIllegalStateException ee) {
159             assertTrue(shouldFail);
160         }
161     }
162 
163     /**
164      * Non-linear test case: fitting of decay curve (from Chapter 8 of
165      * Bevington's textbook, "Data reduction and analysis for the physical sciences").
166      * XXX The expected ("reference") values may not be accurate and the tolerance too
167      * relaxed for this test to be currently really useful (the issue is under
168      * investigation).
169      */
170     @Test
171     void testBevington() {
172         final double[][] dataPoints = {
173             // column 1 = times
174             { 15, 30, 45, 60, 75, 90, 105, 120, 135, 150,
175               165, 180, 195, 210, 225, 240, 255, 270, 285, 300,
176               315, 330, 345, 360, 375, 390, 405, 420, 435, 450,
177               465, 480, 495, 510, 525, 540, 555, 570, 585, 600,
178               615, 630, 645, 660, 675, 690, 705, 720, 735, 750,
179               765, 780, 795, 810, 825, 840, 855, 870, 885, },
180             // column 2 = measured counts
181             { 775, 479, 380, 302, 185, 157, 137, 119, 110, 89,
182               74, 61, 66, 68, 48, 54, 51, 46, 55, 29,
183               28, 37, 49, 26, 35, 29, 31, 24, 25, 35,
184               24, 30, 26, 28, 21, 18, 20, 27, 17, 17,
185               14, 17, 24, 11, 22, 17, 12, 10, 13, 16,
186               9, 9, 14, 21, 17, 13, 12, 18, 10, },
187         };
188         final double[] start = {10, 900, 80, 27, 225};
189 
190         final BevingtonProblem problem = new BevingtonProblem();
191 
192         final int len = dataPoints[0].length;
193         final double[] weights = new double[len];
194         for (int i = 0; i < len; i++) {
195             problem.addPoint(dataPoints[0][i],
196                              dataPoints[1][i]);
197 
198             weights[i] = 1 / dataPoints[1][i];
199         }
200 
201         final Optimum optimum = optimizer.optimize(
202                 builder(problem)
203                         .target(dataPoints[1])
204                         .weight(new DiagonalMatrix(weights))
205                         .start(start)
206                         .maxIterations(20)
207                         .build()
208         );
209 
210         final RealVector solution = optimum.getPoint();
211         final double[] expectedSolution = { 10.4, 958.3, 131.4, 33.9, 205.0 };
212 
213         final RealMatrix covarMatrix = optimum.getCovariances(1e-14);
214         final double[][] expectedCovarMatrix = {
215             { 3.38, -3.69, 27.98, -2.34, -49.24 },
216             { -3.69, 2492.26, 81.89, -69.21, -8.9 },
217             { 27.98, 81.89, 468.99, -44.22, -615.44 },
218             { -2.34, -69.21, -44.22, 6.39, 53.80 },
219             { -49.24, -8.9, -615.44, 53.8, 929.45 }
220         };
221 
222         final int numParams = expectedSolution.length;
223 
224         // Check that the computed solution is within the reference error range.
225         for (int i = 0; i < numParams; i++) {
226             final double error = FastMath.sqrt(expectedCovarMatrix[i][i]);
227             assertEquals(expectedSolution[i], solution.getEntry(i), error, "Parameter " + i);
228         }
229 
230         // Check that each entry of the computed covariance matrix is within 10%
231         // of the reference matrix entry.
232         for (int i = 0; i < numParams; i++) {
233             for (int j = 0; j < numParams; j++) {
234                 assertEquals(expectedCovarMatrix[i][j],
235                                     covarMatrix.getEntry(i, j),
236                                     FastMath.abs(0.1 * expectedCovarMatrix[i][j]),
237                                     "Covariance matrix [" + i + "][" + j + "]");
238             }
239         }
240 
241         // Check various measures of goodness-of-fit.
242         final double chi2 = optimum.getChiSquare();
243         final double cost = optimum.getCost();
244         final double rms = optimum.getRMS();
245         final double reducedChi2 = optimum.getReducedChiSquare(start.length);
246 
247         // XXX Values computed by the CM code: It would be better to compare
248         // with the results from another library.
249         final double expectedChi2 = 66.07852350839286;
250         final double expectedReducedChi2 = 1.2014277001525975;
251         final double expectedCost = 8.128869755900439;
252         final double expectedRms = 1.0582887010256337;
253 
254         final double tol = 1e14;
255         assertEquals(expectedChi2, chi2, tol);
256         assertEquals(expectedReducedChi2, reducedChi2, tol);
257         assertEquals(expectedCost, cost, tol);
258         assertEquals(expectedRms, rms, tol);
259     }
260 
261     @Test
262     void testCircleFitting2() {
263         final double xCenter = 123.456;
264         final double yCenter = 654.321;
265         final double xSigma = 10;
266         final double ySigma = 15;
267         final double radius = 111.111;
268         // The test is extremely sensitive to the seed.
269         final long seed = 59421061L;
270         final RandomCirclePointGenerator factory
271             = new RandomCirclePointGenerator(xCenter, yCenter, radius,
272                                              xSigma, ySigma,
273                                              seed);
274         final CircleProblem circle = new CircleProblem(xSigma, ySigma);
275 
276         final int numPoints = 10;
277         for (Vector2D p : factory.generate(numPoints)) {
278             circle.addPoint(p.getX(), p.getY());
279         }
280 
281         // First guess for the center's coordinates and radius.
282         final double[] init = { 90, 659, 115 };
283 
284         Incrementor incrementor = new Incrementor();
285         final Optimum optimum = optimizer.optimize(
286                 LeastSquaresFactory.countEvaluations(builder(circle).maxIterations(50).start(init).build(),
287                                                      incrementor));
288 
289         final double[] paramFound = optimum.getPoint().toArray();
290 
291         // Retrieve errors estimation.
292         final double[] asymptoticStandardErrorFound = optimum.getSigma(1e-14).toArray();
293 
294         // Check that the parameters are found within the assumed error bars.
295         assertEquals(xCenter, paramFound[0], 3 * asymptoticStandardErrorFound[0]);
296         assertEquals(yCenter, paramFound[1], 3 * asymptoticStandardErrorFound[1]);
297         assertEquals(radius,  paramFound[2], 3 * asymptoticStandardErrorFound[2]);
298         assertTrue(incrementor.getCount() < 40);
299     }
300 
301     @Test
302     void testParameterValidator() {
303         // Setup.
304         final double xCenter = 123.456;
305         final double yCenter = 654.321;
306         final double xSigma = 10;
307         final double ySigma = 15;
308         final double radius = 111.111;
309         final long seed = 3456789L;
310         final RandomCirclePointGenerator factory
311             = new RandomCirclePointGenerator(xCenter, yCenter, radius,
312                                              xSigma, ySigma,
313                                              seed);
314         final CircleProblem circle = new CircleProblem(xSigma, ySigma);
315 
316         final int numPoints = 10;
317         for (Vector2D p : factory.generate(numPoints)) {
318             circle.addPoint(p.getX(), p.getY());
319         }
320 
321         // First guess for the center's coordinates and radius.
322         final double[] init = { 90, 659, 115 };
323         final Optimum optimum
324             = optimizer.optimize(builder(circle).maxIterations(50).start(init).build());
325         final int numEval = optimum.getEvaluations();
326         assertTrue(numEval > 1);
327 
328         // Build a new problem with a validator that amounts to cheating.
329         final ParameterValidator cheatValidator
330             = new ParameterValidator() {
331                     public RealVector validate(RealVector params) {
332                         // Cheat: return the optimum found previously.
333                         return optimum.getPoint();
334                     }
335                 };
336 
337         final Optimum cheatOptimum
338             = optimizer.optimize(builder(circle).maxIterations(50).start(init).parameterValidator(cheatValidator).build());
339         final int cheatNumEval = cheatOptimum.getEvaluations();
340         assertTrue(cheatNumEval < numEval);
341         // System.out.println("n=" + numEval + " nc=" + cheatNumEval);
342     }
343 
344     @Test
345     void testEvaluationCount() {
346         //setup
347         LeastSquaresProblem lsp = new LinearProblem(new double[][] {{1}}, new double[] {1})
348                 .getBuilder()
349                 .checker(new ConvergenceChecker<Evaluation>() {
350                     public boolean converged(int iteration, Evaluation previous, Evaluation current) {
351                         return true;
352                     }
353                 })
354                 .build();
355 
356         //action
357         Optimum optimum = optimizer.optimize(lsp);
358 
359         //verify
360         //check iterations and evaluations are not switched.
361         assertThat(optimum.getIterations(), is(1));
362         assertThat(optimum.getEvaluations(), is(2));
363     }
364 
365     private static class BevingtonProblem {
366         private List<Double> time;
367         private List<Double> count;
368 
369         public BevingtonProblem() {
370             time = new ArrayList<Double>();
371             count = new ArrayList<Double>();
372         }
373 
374         public void addPoint(double t, double c) {
375             time.add(t);
376             count.add(c);
377         }
378 
379         public MultivariateVectorFunction getModelFunction() {
380             return new MultivariateVectorFunction() {
381                 public double[] value(double[] params) {
382                     double[] values = new double[time.size()];
383                     for (int i = 0; i < values.length; ++i) {
384                         final double t = time.get(i);
385                         values[i] = params[0] +
386                             params[1] * FastMath.exp(-t / params[3]) +
387                             params[2] * FastMath.exp(-t / params[4]);
388                     }
389                     return values;
390                 }
391             };
392         }
393 
394         public MultivariateMatrixFunction getModelFunctionJacobian() {
395             return new MultivariateMatrixFunction() {
396                 public double[][] value(double[] params) {
397                     double[][] jacobian = new double[time.size()][5];
398 
399                     for (int i = 0; i < jacobian.length; ++i) {
400                         final double t = time.get(i);
401                         jacobian[i][0] = 1;
402 
403                         final double p3 =  params[3];
404                         final double p4 =  params[4];
405                         final double tOp3 = t / p3;
406                         final double tOp4 = t / p4;
407                         jacobian[i][1] = FastMath.exp(-tOp3);
408                         jacobian[i][2] = FastMath.exp(-tOp4);
409                         jacobian[i][3] = params[1] * FastMath.exp(-tOp3) * tOp3 / p3;
410                         jacobian[i][4] = params[2] * FastMath.exp(-tOp4) * tOp4 / p4;
411                     }
412                     return jacobian;
413                 }
414             };
415         }
416     }
417 }