View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.analysis.differentiation;
24  
25  import org.hipparchus.UnitTestUtils;
26  import org.hipparchus.analysis.QuinticFunction;
27  import org.hipparchus.analysis.UnivariateFunction;
28  import org.hipparchus.analysis.UnivariateMatrixFunction;
29  import org.hipparchus.analysis.UnivariateVectorFunction;
30  import org.hipparchus.analysis.function.Gaussian;
31  import org.hipparchus.analysis.function.Sin;
32  import org.hipparchus.exception.LocalizedCoreFormats;
33  import org.hipparchus.exception.MathIllegalArgumentException;
34  import org.hipparchus.exception.MathRuntimeException;
35  import org.hipparchus.util.FastMath;
36  import org.junit.jupiter.api.Test;
37  
38  import static org.junit.jupiter.api.Assertions.assertEquals;
39  import static org.junit.jupiter.api.Assertions.assertThrows;
40  import static org.junit.jupiter.api.Assertions.fail;
41  
42  /**
43   * Test for class {@link FiniteDifferencesDifferentiator}.
44   */
45  class FiniteDifferencesDifferentiatorTest {
46  
47      @Test
48      void testWrongNumberOfPoints() {
49          assertThrows(MathIllegalArgumentException.class, () -> {
50              new FiniteDifferencesDifferentiator(1, 1.0);
51          });
52      }
53  
54      @Test
55      void testWrongStepSize() {
56          assertThrows(MathIllegalArgumentException.class, () -> {
57              new FiniteDifferencesDifferentiator(3, 0.0);
58          });
59      }
60  
61      @Test
62      void testSerialization() {
63          FiniteDifferencesDifferentiator differentiator =
64                  new FiniteDifferencesDifferentiator(3, 1.0e-3);
65          FiniteDifferencesDifferentiator recovered =
66                  (FiniteDifferencesDifferentiator) UnitTestUtils.serializeAndRecover(differentiator);
67          assertEquals(differentiator.getNbPoints(), recovered.getNbPoints());
68          assertEquals(differentiator.getStepSize(), recovered.getStepSize(), 1.0e-15);
69      }
70  
71      @Test
72      void testConstant() {
73          FiniteDifferencesDifferentiator differentiator =
74                  new FiniteDifferencesDifferentiator(5, 0.01);
75          UnivariateDifferentiableFunction f =
76                  differentiator.differentiate(new UnivariateFunction() {
77                      @Override
78                      public double value(double x) {
79                          return 42.0;
80                      }
81                  });
82          DSFactory factory = new DSFactory(1, 2);
83          for (double x = -10; x < 10; x += 0.1) {
84              DerivativeStructure y = f.value(factory.variable(0, x));
85              assertEquals(42.0, y.getValue(), 1.0e-15);
86              assertEquals( 0.0, y.getPartialDerivative(1), 1.0e-15);
87              assertEquals( 0.0, y.getPartialDerivative(2), 1.0e-15);
88          }
89      }
90  
91      @Test
92      void testLinear() {
93          FiniteDifferencesDifferentiator differentiator =
94                  new FiniteDifferencesDifferentiator(5, 0.01);
95          UnivariateDifferentiableFunction f =
96                  differentiator.differentiate(new UnivariateFunction() {
97                      @Override
98                      public double value(double x) {
99                          return 2 - 3 * x;
100                     }
101                 });
102         DSFactory factory = new DSFactory(1, 2);
103         for (double x = -10; x < 10; x += 0.1) {
104             DerivativeStructure y = f.value(factory.variable(0, x));
105             assertEquals(2 - 3 * x, y.getValue(), 2.0e-15, "" + (2 - 3 * x - y.getValue()));
106             assertEquals(-3.0, y.getPartialDerivative(1), 4.0e-13);
107             assertEquals( 0.0, y.getPartialDerivative(2), 9.0e-11);
108         }
109     }
110 
111     @Test
112     void testGaussian() {
113         FiniteDifferencesDifferentiator differentiator =
114                 new FiniteDifferencesDifferentiator(9, 0.02);
115         UnivariateDifferentiableFunction gaussian = new Gaussian(1.0, 2.0);
116         UnivariateDifferentiableFunction f =
117                 differentiator.differentiate(gaussian);
118         double[] expectedError = new double[] {
119             6.939e-18, 1.284e-15, 2.477e-13, 1.168e-11, 2.840e-9, 7.971e-8
120         };
121         double[] maxError = new double[expectedError.length];
122         DSFactory factory = new DSFactory(1, maxError.length - 1);
123         for (double x = -10; x < 10; x += 0.1) {
124             DerivativeStructure dsX  = factory.variable(0, x);
125             DerivativeStructure yRef = gaussian.value(dsX);
126             DerivativeStructure y    = f.value(dsX);
127             assertEquals(f.value(dsX.getValue()), f.value(dsX).getValue(), 1.0e-15);
128             for (int order = 0; order <= yRef.getOrder(); ++order) {
129                 maxError[order] = FastMath.max(maxError[order],
130                                         FastMath.abs(yRef.getPartialDerivative(order) -
131                                                      y.getPartialDerivative(order)));
132             }
133         }
134         for (int i = 0; i < maxError.length; ++i) {
135             assertEquals(expectedError[i], maxError[i], 0.01 * expectedError[i]);
136         }
137     }
138 
139     @Test
140     void testStepSizeUnstability() {
141         UnivariateDifferentiableFunction quintic = new QuinticFunction();
142         UnivariateDifferentiableFunction goodStep =
143                 new FiniteDifferencesDifferentiator(7, 0.25).differentiate(quintic);
144         UnivariateDifferentiableFunction badStep =
145                 new FiniteDifferencesDifferentiator(7, 1.0e-6).differentiate(quintic);
146         double[] maxErrorGood = new double[7];
147         double[] maxErrorBad  = new double[7];
148         DSFactory factory = new DSFactory(1, maxErrorGood.length - 1);
149         for (double x = -10; x < 10; x += 0.1) {
150             DerivativeStructure dsX  = factory.variable(0, x);
151             DerivativeStructure yRef  = quintic.value(dsX);
152             DerivativeStructure yGood = goodStep.value(dsX);
153             DerivativeStructure yBad  = badStep.value(dsX);
154             for (int order = 0; order <= 6; ++order) {
155                 maxErrorGood[order] = FastMath.max(maxErrorGood[order],
156                                                    FastMath.abs(yRef.getPartialDerivative(order) -
157                                                                 yGood.getPartialDerivative(order)));
158                 maxErrorBad[order]  = FastMath.max(maxErrorBad[order],
159                                                    FastMath.abs(yRef.getPartialDerivative(order) -
160                                                                 yBad.getPartialDerivative(order)));
161             }
162         }
163 
164         // the 0.25 step size is good for finite differences in the quintic on this abscissa range for 7 points
165         // the errors are fair
166         final double[] expectedGood = new double[] {
167             7.276e-12, 7.276e-11, 9.968e-10, 3.092e-9, 5.432e-8, 8.196e-8, 1.818e-6
168         };
169 
170         // the 1.0e-6 step size is far too small for finite differences in the quintic on this abscissa range for 7 points
171         // the errors are huge!
172         final double[] expectedBad = new double[] {
173             2.910e-11, 2.087e-5, 147.7, 3.820e7, 6.354e14, 6.548e19, 1.543e27
174         };
175 
176         for (int i = 0; i < maxErrorGood.length; ++i) {
177             assertEquals(expectedGood[i], maxErrorGood[i], 0.01 * expectedGood[i]);
178             assertEquals(expectedBad[i],  maxErrorBad[i],  0.01 * expectedBad[i]);
179         }
180 
181     }
182 
183     @Test
184     void testWrongOrder() {
185         assertThrows(MathIllegalArgumentException.class, () -> {
186             UnivariateDifferentiableFunction f =
187                 new FiniteDifferencesDifferentiator(3, 0.01).differentiate(new UnivariateFunction() {
188                     @Override
189                     public double value(double x) {
190                         // this exception should not be thrown because wrong order
191                         // should be detected before function call
192                         throw MathRuntimeException.createInternalError();
193                     }
194                 });
195             f.value(new DSFactory(1, 3).variable(0, 1.0));
196         });
197     }
198 
199     @Test
200     void testWrongOrderVector() {
201         assertThrows(MathIllegalArgumentException.class, () -> {
202             UnivariateDifferentiableVectorFunction f =
203                 new FiniteDifferencesDifferentiator(3, 0.01).differentiate(new UnivariateVectorFunction() {
204                     @Override
205                     public double[] value(double x) {
206                         // this exception should not be thrown because wrong order
207                         // should be detected before function call
208                         throw MathRuntimeException.createInternalError();
209                     }
210                 });
211             f.value(new DSFactory(1, 3).variable(0, 1.0));
212         });
213     }
214 
215     @Test
216     void testWrongOrderMatrix() {
217         assertThrows(MathIllegalArgumentException.class, () -> {
218             UnivariateDifferentiableMatrixFunction f =
219                 new FiniteDifferencesDifferentiator(3, 0.01).differentiate(new UnivariateMatrixFunction() {
220                     @Override
221                     public double[][] value(double x) {
222                         // this exception should not be thrown because wrong order
223                         // should be detected before function call
224                         throw MathRuntimeException.createInternalError();
225                     }
226                 });
227             f.value(new DSFactory(1, 3).variable(0, 1.0));
228         });
229     }
230 
231     @Test
232     void testTooLargeStep() {
233         assertThrows(MathIllegalArgumentException.class, () -> {
234             new FiniteDifferencesDifferentiator(3, 2.5, 0.0, 1.0);
235         });
236     }
237 
238     @Test
239     void testBounds() {
240 
241         final double slope = 2.5;
242         UnivariateFunction f = new UnivariateFunction() {
243             @Override
244             public double value(double x) {
245                 if (x < 0) {
246                     throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
247                                                            x, 0);
248                 } else if (x > 1) {
249                     throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE,
250                                                            x, 1);
251                 } else {
252                     return slope * x;
253                 }
254             }
255         };
256 
257         UnivariateDifferentiableFunction missingBounds =
258                 new FiniteDifferencesDifferentiator(3, 0.1).differentiate(f);
259         UnivariateDifferentiableFunction properlyBounded =
260                 new FiniteDifferencesDifferentiator(3, 0.1, 0.0, 1.0).differentiate(f);
261         DSFactory factory = new DSFactory(1, 1);
262         DerivativeStructure tLow  = factory.variable(0, 0.05);
263         DerivativeStructure tHigh = factory.variable(0, 0.95);
264 
265         try {
266             // here, we did not set the bounds, so the differences are evaluated out of domain
267             // using f(-0.05), f(0.05), f(0.15)
268             missingBounds.value(tLow);
269             fail("an exception should have been thrown");
270         } catch (MathIllegalArgumentException nse) {
271             assertEquals(LocalizedCoreFormats.NUMBER_TOO_SMALL, nse.getSpecifier());
272             assertEquals(-0.05, ((Double) nse.getParts()[0]).doubleValue(), 1.0e-10);
273         } catch (Exception e) {
274             fail("wrong exception caught: " + e.getClass().getName());
275         }
276 
277         try {
278             // here, we did not set the bounds, so the differences are evaluated out of domain
279             // using f(0.85), f(0.95), f(1.05)
280             missingBounds.value(tHigh);
281             fail("an exception should have been thrown");
282         } catch (MathIllegalArgumentException nle) {
283             assertEquals(LocalizedCoreFormats.NUMBER_TOO_LARGE, nle.getSpecifier());
284             assertEquals(1.05, ((Double) nle.getParts()[0]).doubleValue(), 1.0e-10);
285         } catch (Exception e) {
286             fail("wrong exception caught: " + e.getClass().getName());
287         }
288 
289         // here, we did set the bounds, so evaluations are done within domain
290         // using f(0.0), f(0.1), f(0.2)
291         assertEquals(slope, properlyBounded.value(tLow).getPartialDerivative(1), 1.0e-10);
292 
293         // here, we did set the bounds, so evaluations are done within domain
294         // using f(0.8), f(0.9), f(1.0)
295         assertEquals(slope, properlyBounded.value(tHigh).getPartialDerivative(1), 1.0e-10);
296 
297     }
298 
299     @Test
300     void testBoundedSqrt() {
301 
302         UnivariateFunctionDifferentiator differentiator =
303                 new FiniteDifferencesDifferentiator(9, 1.0 / 32, 0.0, Double.POSITIVE_INFINITY);
304         UnivariateDifferentiableFunction sqrt = differentiator.differentiate(new UnivariateFunction() {
305             @Override
306             public double value(double x) {
307                 return FastMath.sqrt(x);
308             }
309         });
310 
311         // we are able to compute derivative near 0, but the accuracy is much poorer there
312         DSFactory factory = new DSFactory(1, 1);
313         DerivativeStructure t001 = factory.variable(0, 0.01);
314         assertEquals(0.5 / FastMath.sqrt(t001.getValue()), sqrt.value(t001).getPartialDerivative(1), 1.6);
315         DerivativeStructure t01 = factory.variable(0, 0.1);
316         assertEquals(0.5 / FastMath.sqrt(t01.getValue()), sqrt.value(t01).getPartialDerivative(1), 7.0e-3);
317         DerivativeStructure t03 = factory.variable(0, 0.3);
318         assertEquals(0.5 / FastMath.sqrt(t03.getValue()), sqrt.value(t03).getPartialDerivative(1), 2.1e-7);
319 
320     }
321 
322     @Test
323     void testVectorFunction() {
324 
325         FiniteDifferencesDifferentiator differentiator =
326                 new FiniteDifferencesDifferentiator(7, 0.01);
327         UnivariateDifferentiableVectorFunction f =
328                 differentiator.differentiate(new UnivariateVectorFunction() {
329 
330             @Override
331             public double[] value(double x) {
332                 return new double[] { FastMath.cos(x), FastMath.sin(x) };
333             }
334 
335         });
336 
337         DSFactory factory = new DSFactory(1, 2);
338         for (double x = -10; x < 10; x += 0.1) {
339             DerivativeStructure dsX = factory.variable(0, x);
340             DerivativeStructure[] y = f.value(dsX);
341             double cos = FastMath.cos(x);
342             double sin = FastMath.sin(x);
343             double[] f1 = f.value(dsX.getValue());
344             DerivativeStructure[] f2 = f.value(dsX);
345             assertEquals(f1.length, f2.length);
346             for (int i = 0; i < f1.length; ++i) {
347                 assertEquals(f1[i], f2[i].getValue(), 1.0e-15);
348             }
349             assertEquals( cos, y[0].getValue(), 7.0e-16);
350             assertEquals( sin, y[1].getValue(), 7.0e-16);
351             assertEquals(-sin, y[0].getPartialDerivative(1), 6.0e-14);
352             assertEquals( cos, y[1].getPartialDerivative(1), 6.0e-14);
353             assertEquals(-cos, y[0].getPartialDerivative(2), 2.0e-11);
354             assertEquals(-sin, y[1].getPartialDerivative(2), 2.0e-11);
355         }
356 
357     }
358 
359     @Test
360     void testMatrixFunction() {
361 
362         FiniteDifferencesDifferentiator differentiator =
363                 new FiniteDifferencesDifferentiator(7, 0.01);
364         UnivariateDifferentiableMatrixFunction f =
365                 differentiator.differentiate(new UnivariateMatrixFunction() {
366 
367             @Override
368             public double[][] value(double x) {
369                 return new double[][] {
370                     { FastMath.cos(x),  FastMath.sin(x)  },
371                     { FastMath.cosh(x), FastMath.sinh(x) }
372                 };
373             }
374 
375         });
376 
377         DSFactory factory = new DSFactory(1, 2);
378         for (double x = -1; x < 1; x += 0.02) {
379             DerivativeStructure dsX = factory.variable(0, x);
380             DerivativeStructure[][] y = f.value(dsX);
381             double cos = FastMath.cos(x);
382             double sin = FastMath.sin(x);
383             double cosh = FastMath.cosh(x);
384             double sinh = FastMath.sinh(x);
385             double[][] f1 = f.value(dsX.getValue());
386             DerivativeStructure[][] f2 = f.value(dsX);
387             assertEquals(f1.length, f2.length);
388             for (int i = 0; i < f1.length; ++i) {
389                 assertEquals(f1[i].length, f2[i].length);
390                 for (int j = 0; j < f1[i].length; ++j) {
391                     assertEquals(f1[i][j], f2[i][j].getValue(), 1.0e-15);
392                 }
393             }
394             assertEquals(cos,   y[0][0].getValue(), 7.0e-18);
395             assertEquals(sin,   y[0][1].getValue(), 6.0e-17);
396             assertEquals(cosh,  y[1][0].getValue(), 3.0e-16);
397             assertEquals(sinh,  y[1][1].getValue(), 3.0e-16);
398             assertEquals(-sin,  y[0][0].getPartialDerivative(1), 2.0e-14);
399             assertEquals( cos,  y[0][1].getPartialDerivative(1), 2.0e-14);
400             assertEquals( sinh, y[1][0].getPartialDerivative(1), 3.0e-14);
401             assertEquals( cosh, y[1][1].getPartialDerivative(1), 3.0e-14);
402             assertEquals(-cos,  y[0][0].getPartialDerivative(2), 3.0e-12);
403             assertEquals(-sin,  y[0][1].getPartialDerivative(2), 3.0e-12);
404             assertEquals( cosh, y[1][0].getPartialDerivative(2), 6.0e-12);
405             assertEquals( sinh, y[1][1].getPartialDerivative(2), 6.0e-12);
406         }
407 
408     }
409 
410     @Test
411     void testSeveralFreeParameters() {
412         FiniteDifferencesDifferentiator differentiator =
413                 new FiniteDifferencesDifferentiator(5, 0.001);
414         UnivariateDifferentiableFunction sine = new Sin();
415         UnivariateDifferentiableFunction f =
416                 differentiator.differentiate(sine);
417         double[] expectedError = new double[] {
418             6.696e-16, 1.371e-12, 2.007e-8, 1.754e-5
419         };
420         double[] maxError = new double[expectedError.length];
421         DSFactory factory = new DSFactory(2, maxError.length - 1);
422         for (double x = -2; x < 2; x += 0.1) {
423            for (double y = -2; y < 2; y += 0.1) {
424                DerivativeStructure dsX  = factory.variable(0, x);
425                DerivativeStructure dsY  = factory.variable(1, y);
426                DerivativeStructure dsT  = dsX.multiply(3).subtract(dsY.multiply(2));
427                DerivativeStructure sRef = sine.value(dsT);
428                DerivativeStructure s    = f.value(dsT);
429                for (int xOrder = 0; xOrder <= sRef.getOrder(); ++xOrder) {
430                    for (int yOrder = 0; yOrder <= sRef.getOrder(); ++yOrder) {
431                        if (xOrder + yOrder <= sRef.getOrder()) {
432                            maxError[xOrder +yOrder] = FastMath.max(maxError[xOrder + yOrder],
433                                                                     FastMath.abs(sRef.getPartialDerivative(xOrder, yOrder) -
434                                                                                  s.getPartialDerivative(xOrder, yOrder)));
435                        }
436                    }
437                }
438            }
439        }
440        for (int i = 0; i < maxError.length; ++i) {
441            assertEquals(expectedError[i], maxError[i], 0.01 * expectedError[i]);
442        }
443     }
444 
445 }