View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements. See the NOTICE file distributed with this
4    * work for additional information regarding copyright ownership. The ASF
5    * licenses this file to You under the Apache License, Version 2.0 (the
6    * "License"); you may not use this file except in compliance with the License.
7    * You may obtain a copy of the License at
8    * https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law
9    * or agreed to in writing, software distributed under the License is
10   * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
11   * KIND, either express or implied. See the License for the specific language
12   * governing permissions and limitations under the License.
13   */
14  package org.hipparchus.optim.nonlinear.vector.leastsquares;
15  
16  import org.hipparchus.UnitTestUtils;
17  import org.hipparchus.analysis.MultivariateMatrixFunction;
18  import org.hipparchus.analysis.MultivariateVectorFunction;
19  import org.hipparchus.exception.MathIllegalArgumentException;
20  import org.hipparchus.exception.MathIllegalStateException;
21  import org.hipparchus.linear.ArrayRealVector;
22  import org.hipparchus.linear.DiagonalMatrix;
23  import org.hipparchus.linear.MatrixUtils;
24  import org.hipparchus.linear.RealMatrix;
25  import org.hipparchus.linear.RealVector;
26  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem.Evaluation;
27  import org.hipparchus.util.FastMath;
28  import org.hipparchus.util.Pair;
29  import org.hipparchus.util.Precision;
30  import org.junit.jupiter.api.Test;
31  
32  import java.io.IOException;
33  import java.util.Arrays;
34  
35  import static org.junit.jupiter.api.Assertions.assertArrayEquals;
36  import static org.junit.jupiter.api.Assertions.assertEquals;
37  import static org.junit.jupiter.api.Assertions.assertNotSame;
38  import static org.junit.jupiter.api.Assertions.assertTrue;
39  import static org.junit.jupiter.api.Assertions.fail;
40  
41  /**
42   * The only features tested here are utility methods defined
43   * in {@link LeastSquaresProblem.Evaluation} that compute the
44   * chi-square and parameters standard-deviations.
45   */
46  public class EvaluationTest {
47  
48      /**
49       * Create a {@link LeastSquaresBuilder} from a {@link StatisticalReferenceDataset}.
50       *
51       * @param dataset the source data
52       * @return a builder for further customization.
53       */
54      public LeastSquaresBuilder builder(StatisticalReferenceDataset dataset) {
55          StatisticalReferenceDataset.LeastSquaresProblem problem
56                  = dataset.getLeastSquaresProblem();
57          final double[] start = dataset.getParameters();
58          final double[] observed = dataset.getData()[1];
59          final double[] weights = new double[observed.length];
60          Arrays.fill(weights, 1d);
61  
62          return new LeastSquaresBuilder()
63                  .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
64                  .target(observed)
65                  .weight(new DiagonalMatrix(weights))
66                  .start(start);
67      }
68  
69      @Test
70      void testComputeResiduals() {
71          //setup
72          RealVector point = new ArrayRealVector(2);
73          Evaluation evaluation = new LeastSquaresBuilder()
74                  .target(new ArrayRealVector(new double[]{3,-1}))
75                  .model(new MultivariateJacobianFunction() {
76                      public Pair<RealVector, RealMatrix> value(RealVector point) {
77                          return new Pair<RealVector, RealMatrix>(
78                                  new ArrayRealVector(new double[]{1, 2}),
79                                  MatrixUtils.createRealIdentityMatrix(2)
80                          );
81                      }
82                  })
83                  .weight(MatrixUtils.createRealIdentityMatrix(2))
84                  .build()
85                  .evaluate(point);
86  
87          //action + verify
88          assertArrayEquals(
89              new double[]{2, -3},
90              evaluation.getResiduals().toArray(),
91              Precision.EPSILON);
92      }
93  
94      @Test
95      void testComputeCovariance() throws IOException {
96          //setup
97          RealVector point = new ArrayRealVector(2);
98          Evaluation evaluation = new LeastSquaresBuilder()
99                  .model(new MultivariateJacobianFunction() {
100                     public Pair<RealVector, RealMatrix> value(RealVector point) {
101                         return new Pair<RealVector, RealMatrix>(
102                                 new ArrayRealVector(2),
103                                 MatrixUtils.createRealDiagonalMatrix(new double[]{1, 1e-2})
104                         );
105                     }
106                 })
107                 .weight(MatrixUtils.createRealDiagonalMatrix(new double[]{1, 1}))
108                 .target(new ArrayRealVector(2))
109                 .build()
110                 .evaluate(point);
111 
112         //action
113         UnitTestUtils.customAssertEquals(
114                 "covariance",
115                 evaluation.getCovariances(FastMath.nextAfter(1e-4, 0.0)),
116                 MatrixUtils.createRealMatrix(new double[][]{{1, 0}, {0, 1e4}}),
117                 Precision.EPSILON
118         );
119 
120         //singularity fail
121         try {
122             evaluation.getCovariances(FastMath.nextAfter(1e-4, 1.0));
123             fail("Expected Exception");
124         } catch (MathIllegalArgumentException e) {
125             //expected
126         }
127     }
128 
129     @Test
130     void testComputeValueAndJacobian() {
131         //setup
132         final RealVector point = new ArrayRealVector(new double[]{1, 2});
133         Evaluation evaluation = new LeastSquaresBuilder()
134                 .weight(new DiagonalMatrix(new double[]{16, 4}))
135                 .model(new MultivariateJacobianFunction() {
136                     public Pair<RealVector, RealMatrix> value(RealVector actualPoint) {
137                         //verify correct values passed in
138                         assertArrayEquals(
139                                 point.toArray(), actualPoint.toArray(), Precision.EPSILON);
140                         //return values
141                         return new Pair<RealVector, RealMatrix>(
142                                 new ArrayRealVector(new double[]{3, 4}),
143                                 MatrixUtils.createRealMatrix(new double[][]{{5, 6}, {7, 8}})
144                         );
145                     }
146                 })
147                 .target(new double[2])
148                 .build()
149                 .evaluate(point);
150 
151         //action
152         RealVector residuals = evaluation.getResiduals();
153         RealMatrix jacobian = evaluation.getJacobian();
154 
155         //verify
156         assertArrayEquals(evaluation.getPoint().toArray(), point.toArray(), 0);
157         assertArrayEquals(new double[]{-12, -8}, residuals.toArray(), Precision.EPSILON);
158         UnitTestUtils.customAssertEquals(
159                 "jacobian",
160                 jacobian,
161                 MatrixUtils.createRealMatrix(new double[][]{{20, 24},{14, 16}}),
162                 Precision.EPSILON);
163     }
164 
165     @Test
166     void testComputeCost() throws IOException {
167         final StatisticalReferenceDataset dataset
168             = StatisticalReferenceDatasetFactory.createKirby2();
169 
170         final LeastSquaresProblem lsp = builder(dataset).build();
171 
172         final double expected = dataset.getResidualSumOfSquares();
173         final double cost = lsp.evaluate(lsp.getStart()).getCost();
174         final double actual = cost * cost;
175         assertEquals(expected, actual, 1e-11 * expected, dataset.getName());
176     }
177 
178     @Test
179     void testComputeRMS() throws IOException {
180         final StatisticalReferenceDataset dataset
181             = StatisticalReferenceDatasetFactory.createKirby2();
182 
183         final LeastSquaresProblem lsp = builder(dataset).build();
184 
185         final double expected = FastMath.sqrt(dataset.getResidualSumOfSquares() /
186                                               dataset.getNumObservations());
187         final double actual = lsp.evaluate(lsp.getStart()).getRMS();
188         assertEquals(expected, actual, 1e-11 * expected, dataset.getName());
189     }
190 
191     @Test
192     void testComputeSigma() throws IOException {
193         final StatisticalReferenceDataset dataset
194             = StatisticalReferenceDatasetFactory.createKirby2();
195 
196         final LeastSquaresProblem lsp = builder(dataset).build();
197 
198         final double[] expected = dataset.getParametersStandardDeviations();
199 
200         final Evaluation evaluation = lsp.evaluate(lsp.getStart());
201         final double cost = evaluation.getCost();
202         final RealVector sig = evaluation.getSigma(1e-14);
203         final int dof = lsp.getObservationSize() - lsp.getParameterSize();
204         for (int i = 0; i < sig.getDimension(); i++) {
205             final double actual = FastMath.sqrt(cost * cost / dof) * sig.getEntry(i);
206             assertEquals(expected[i], actual, 1e-6 * expected[i], dataset.getName() + ", parameter #" + i);
207         }
208     }
209 
210     @Test
211     void testEvaluateCopiesPoint() throws IOException {
212         //setup
213         StatisticalReferenceDataset dataset
214                 = StatisticalReferenceDatasetFactory.createKirby2();
215         LeastSquaresProblem lsp = builder(dataset).build();
216         RealVector point = new ArrayRealVector(lsp.getParameterSize());
217 
218         //action
219         Evaluation evaluation = lsp.evaluate(point);
220 
221         //verify
222         assertNotSame(point, evaluation.getPoint());
223         point.setEntry(0, 1);
224         assertEquals(0, evaluation.getPoint().getEntry(0), 0);
225     }
226 
227     @Test
228     void testLazyEvaluation() {
229         final RealVector dummy = new ArrayRealVector(new double[] { 0 });
230 
231         final LeastSquaresProblem p
232             = LeastSquaresFactory.create(LeastSquaresFactory.model(dummyModel(), dummyJacobian()),
233                                          dummy, dummy, null, null, 0, 0, true, null);
234 
235         // Should not throw because actual evaluation is deferred.
236         final Evaluation eval = p.evaluate(dummy);
237 
238         try {
239             eval.getResiduals();
240             fail("Exception expected");
241         } catch (RuntimeException e) {
242             // Expecting exception.
243             assertEquals("dummyModel", e.getMessage());
244         }
245 
246         try {
247             eval.getJacobian();
248             fail("Exception expected");
249         } catch (RuntimeException e) {
250             // Expecting exception.
251             assertEquals("dummyJacobian", e.getMessage());
252         }
253     }
254 
255     // MATH-1151
256     @Test
257     void testLazyEvaluationPrecondition() {
258         final RealVector dummy = new ArrayRealVector(new double[] { 0 });
259 
260         // "ValueAndJacobianFunction" is required but we implement only
261         // "MultivariateJacobianFunction".
262         final MultivariateJacobianFunction m1 = new MultivariateJacobianFunction() {
263                 public Pair<RealVector, RealMatrix> value(RealVector notUsed) {
264                     return new Pair<RealVector, RealMatrix>(null, null);
265                 }
266             };
267 
268         try {
269             // Should throw.
270             LeastSquaresFactory.create(m1, dummy, dummy, null, null, 0, 0, true, null);
271             fail("Expecting MathIllegalStateException");
272         } catch (MathIllegalStateException e) {
273             // Expected.
274         }
275 
276         final MultivariateJacobianFunction m2 = new ValueAndJacobianFunction() {
277                 public Pair<RealVector, RealMatrix> value(RealVector notUsed) {
278                     return new Pair<RealVector, RealMatrix>(null, null);
279                 }
280                 public RealVector computeValue(final double[] params) {
281                     return null;
282                 }
283                 public RealMatrix computeJacobian(final double[] params) {
284                     return null;
285                 }
286             };
287 
288         // Should pass.
289         LeastSquaresFactory.create(m2, dummy, dummy, null, null, 0, 0, true, null);
290     }
291 
292     @Test
293     void testDirectEvaluation() {
294         final RealVector dummy = new ArrayRealVector(new double[] { 0 });
295 
296         final LeastSquaresProblem p
297             = LeastSquaresFactory.create(LeastSquaresFactory.model(dummyModel(), dummyJacobian()),
298                                          dummy, dummy, null, null, 0, 0, false, null);
299 
300         try {
301             // Should throw.
302             p.evaluate(dummy);
303             fail("Exception expected");
304         } catch (RuntimeException e) {
305             // Expecting exception.
306             // Whether it is model or Jacobian that caused it is not significant.
307             final String msg = e.getMessage();
308             assertTrue(msg.equals("dummyModel") ||
309                               msg.equals("dummyJacobian"));
310         }
311     }
312 
313     /** Used for testing direct vs lazy evaluation. */
314     private MultivariateVectorFunction dummyModel() {
315         return new MultivariateVectorFunction() {
316             public double[] value(double[] p) {
317                 throw new RuntimeException("dummyModel");
318             }
319         };
320     }
321 
322     /** Used for testing direct vs lazy evaluation. */
323     private MultivariateMatrixFunction dummyJacobian() {
324         return new MultivariateMatrixFunction() {
325             public double[][] value(double[] p) {
326                 throw new RuntimeException("dummyJacobian");
327             }
328         };
329     }
330 }