View Javadoc
1   /*
2    * Licensed to the Hipparchus project under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.hipparchus.analysis.differentiation;
18  
19  import org.hipparchus.CalculusFieldElement;
20  import org.hipparchus.CalculusFieldElementAbstractTest;
21  import org.hipparchus.Field;
22  import org.hipparchus.UnitTestUtils;
23  import org.hipparchus.analysis.FieldUnivariateFunction;
24  import org.hipparchus.exception.LocalizedCoreFormats;
25  import org.hipparchus.exception.MathIllegalArgumentException;
26  import org.hipparchus.util.FastMath;
27  import org.hipparchus.util.FieldSinCos;
28  import org.hipparchus.util.MathArrays;
29  import org.junit.jupiter.api.Assertions;
30  import org.junit.jupiter.api.Test;
31  
32  import java.util.Arrays;
33  
34  import static org.junit.jupiter.api.Assertions.assertEquals;
35  import static org.junit.jupiter.api.Assertions.assertNotEquals;
36  import static org.junit.jupiter.api.Assertions.assertNotSame;
37  import static org.junit.jupiter.api.Assertions.assertSame;
38  import static org.junit.jupiter.api.Assertions.fail;
39  
40  /**
41   * Test for class {@link UnivariateDerivative}.
42   */
43  class GradientTest extends CalculusFieldElementAbstractTest<Gradient> {
44  
45      @Override
46      protected Gradient build(final double x) {
47          // the function is really a two variables function : f(x) = g(x, 0) with g(x, y) = x + y / 1024
48          return new Gradient(x, 1.0, FastMath.scalb(1.0, -10));
49      }
50  
51      @Test
52      void testGetGradient() {
53          Gradient g = new Gradient(-0.5, 2.5, 10.0, -1.0);
54          assertEquals(-0.5, g.getReal(), 1.0e-15);
55          assertEquals(-0.5, g.getValue(), 1.0e-15);
56          assertEquals(+2.5, g.getGradient()[0], 1.0e-15);
57          assertEquals(10.0, g.getGradient()[1], 1.0e-15);
58          assertEquals(-1.0, g.getGradient()[2], 1.0e-15);
59          assertEquals(+2.5, g.getPartialDerivative(0), 1.0e-15);
60          assertEquals(10.0, g.getPartialDerivative(1), 1.0e-15);
61          assertEquals(-1.0, g.getPartialDerivative(2), 1.0e-15);
62          assertEquals(3, g.getFreeParameters());
63          try {
64              g.getPartialDerivative(-1);
65              fail("an exception should have been thrown");
66          } catch (MathIllegalArgumentException miae) {
67              assertEquals(LocalizedCoreFormats.OUT_OF_RANGE_SIMPLE, miae.getSpecifier());
68          }
69          try {
70              g.getPartialDerivative(+3);
71              fail("an exception should have been thrown");
72          } catch (MathIllegalArgumentException miae) {
73              assertEquals(LocalizedCoreFormats.OUT_OF_RANGE_SIMPLE, miae.getSpecifier());
74          }
75      }
76  
77      @Test
78      void testConstant() {
79          Gradient g = Gradient.constant(5, -4.5);
80          assertEquals(5, g.getFreeParameters());
81          assertEquals(-4.5, g.getValue(), 1.0e-15);
82          for (int i = 0 ; i < g.getFreeParameters(); ++i) {
83              assertEquals(0.0, g.getPartialDerivative(i), 1.0e-15);
84          }
85      }
86  
87      @Test
88      void testVariable() {
89          Gradient g = Gradient.variable(5, 1, -4.5);
90          assertEquals(5, g.getFreeParameters());
91          assertEquals(-4.5, g.getValue(), 1.0e-15);
92          for (int i = 0 ; i < g.getFreeParameters(); ++i) {
93              assertEquals(i == 1 ? 1.0 : 0.0, g.getPartialDerivative(i), 1.0e-15);
94          }
95      }
96  
97      @Test
98      void testStackVariable() {
99          // GIVEN
100         final Gradient gradient = new Gradient(1, 2, 3);
101         // WHEN
102         final Gradient gradientWithMoreVariable = gradient.stackVariable();
103         // THEN
104         Assertions.assertEquals(gradient.getValue(), gradientWithMoreVariable.getValue());
105         Assertions.assertEquals(gradient.getFreeParameters() + 1, gradientWithMoreVariable.getFreeParameters());
106         Assertions.assertEquals(0., gradientWithMoreVariable.getGradient()[gradient.getFreeParameters()]);
107         Assertions.assertArrayEquals(gradient.getGradient(), Arrays.copyOfRange(gradientWithMoreVariable.getGradient(),
108                 0, gradient.getFreeParameters()));
109     }
110 
111     @Test
112     void testDoublePow() {
113         assertSame(build(3).getField().getZero(), Gradient.pow(0.0, build(1.5)));
114         Gradient g = Gradient.pow(2.0, build(1.5));
115         DSFactory factory = new DSFactory(2, 1);
116         DerivativeStructure ds = factory.constant(2.0).pow(factory.build(1.5, 1.0, FastMath.scalb(1.0, -10)));
117         assertEquals(ds.getValue(), g.getValue(), 1.0e-15);
118         final int[] indices = new int[ds.getFreeParameters()];
119         for (int i = 0; i < g.getFreeParameters(); ++i) {
120             indices[i] = 1;
121             assertEquals(ds.getPartialDerivative(indices), g.getPartialDerivative(i), 1.0e-15);
122             indices[i] = 0;
123         }
124     }
125 
126     @Test
127     void testTaylor() {
128         assertEquals(2.75, new Gradient(2, 1, 0.125).taylor(0.5, 2.0), 1.0e-15);
129     }
130 
131     @Test
132     void testOrder() {
133         assertEquals(1, new Gradient(2,  1, 0.125).getOrder());
134     }
135 
136     @Test
137     void testGetPartialDerivative() {
138         final Gradient g = new Gradient(2,  1, 0.125);
139         assertEquals(2.0,   g.getPartialDerivative(0, 0), 1.0e-15); // f(x,y)
140         assertEquals(1.0,   g.getPartialDerivative(1, 0), 1.0e-15); // ∂f/∂x
141         assertEquals(0.125, g.getPartialDerivative(0, 1), 1.0e-15); // ∂f/∂y
142     }
143 
144     @Test
145     void testGetPartialDerivativeErrors() {
146         final Gradient g = new Gradient(2,  1, 0.125);
147         try {
148             g.getPartialDerivative(0, 0, 0);
149             fail("an exception should have been thrown");
150         } catch (MathIllegalArgumentException miae) {
151             assertEquals(LocalizedCoreFormats.DIMENSIONS_MISMATCH, miae.getSpecifier());
152             assertEquals(3, ((Integer) miae.getParts()[0]).intValue());
153             assertEquals(2, ((Integer) miae.getParts()[1]).intValue());
154         }
155         try {
156             g.getPartialDerivative(0, 5);
157             fail("an exception should have been thrown");
158         } catch (MathIllegalArgumentException miae) {
159             assertEquals(LocalizedCoreFormats.DERIVATION_ORDER_NOT_ALLOWED, miae.getSpecifier());
160             assertEquals(5, ((Integer) miae.getParts()[0]).intValue());
161         }
162         try {
163             g.getPartialDerivative(1, 1);
164             fail("an exception should have been thrown");
165         } catch (MathIllegalArgumentException miae) {
166             assertEquals(LocalizedCoreFormats.DERIVATION_ORDER_NOT_ALLOWED, miae.getSpecifier());
167             assertEquals(1, ((Integer) miae.getParts()[0]).intValue());
168         }
169     }
170 
171     @Test
172     void testHashcode() {
173         assertEquals(1608501298, new Gradient(2, 1, -0.25).hashCode());
174     }
175 
176     @Test
177     void testEquals() {
178         Gradient g = new Gradient(12, -34, 56);
179         assertEquals(g, g);
180         assertNotEquals("", g);
181         assertEquals(g, new Gradient(12, -34, 56));
182         assertNotEquals(g, new Gradient(21, -34, 56));
183         assertNotEquals(g, new Gradient(12, -43, 56));
184         assertNotEquals(g, new Gradient(12, -34, 65));
185         assertNotEquals(g, new Gradient(21, -43, 65));
186     }
187 
188     @Test
189     void testRunTimeClass() {
190         Field<Gradient> field = build(0.0).getField();
191         assertEquals(Gradient.class, field.getRuntimeClass());
192     }
193 
194     @Test
195     void testConversion() {
196         Gradient gA = new Gradient(-0.5, 2.5, 4.5);
197         DerivativeStructure ds = gA.toDerivativeStructure();
198         assertEquals(2, ds.getFreeParameters());
199         assertEquals(1, ds.getOrder());
200         assertEquals(-0.5, ds.getValue(), 1.0e-15);
201         assertEquals(-0.5, ds.getPartialDerivative(0, 0), 1.0e-15);
202         assertEquals( 2.5, ds.getPartialDerivative(1, 0), 1.0e-15);
203         assertEquals( 4.5, ds.getPartialDerivative(0, 1), 1.0e-15);
204         Gradient gB = new Gradient(ds);
205         assertNotSame(gA, gB);
206         assertEquals(gA, gB);
207         try {
208             new Gradient(new DSFactory(1, 2).variable(0, 1.0));
209             fail("an exception should have been thrown");
210         } catch (MathIllegalArgumentException miae) {
211             assertEquals(LocalizedCoreFormats.DIMENSIONS_MISMATCH, miae.getSpecifier());
212         }
213     }
214 
215     @Test
216     public void testNewInstance() {
217         Gradient g = build(5.25);
218         assertEquals(5.25, g.getValue(), 1.0e-15);
219         assertEquals(1.0,  g.getPartialDerivative(0), 1.0e-15);
220         assertEquals(0.0009765625,  g.getPartialDerivative(1), 1.0e-15);
221         Gradient newInstance = g.newInstance(7.5);
222         assertEquals(7.5, newInstance.getValue(), 1.0e-15);
223         assertEquals(0.0, newInstance.getPartialDerivative(0), 1.0e-15);
224         assertEquals(0.0, newInstance.getPartialDerivative(1), 1.0e-15);
225     }
226 
227     protected void checkAgainstDS(final double x,
228                                   final FieldUnivariateFunction f) {
229         final Gradient xG = build(x);
230         final Gradient yG = f.value(xG);
231         final DerivativeStructure yDS = f.value(xG.toDerivativeStructure());
232         assertEquals(yDS.getFreeParameters(),
233                                 yG.getFreeParameters());
234 
235         if (Double.isNaN(yDS.getValue())) {
236             assertEquals(yDS.getValue(), yG.getValue());
237         } else {
238             assertEquals(yDS.getValue(), yG.getValue(),
239                                     1.0e-15 * FastMath.abs(yDS.getValue()));
240         }
241         final int[] indices = new int[yDS.getFreeParameters()];
242         for (int i = 0; i < yG.getFreeParameters(); ++i) {
243             indices[i] = 1;
244             if (Double.isNaN(yDS.getPartialDerivative(indices))) {
245                 assertEquals(yDS.getPartialDerivative(indices),
246                                         yG.getPartialDerivative(i));
247             } else {
248                 assertEquals(yDS.getPartialDerivative(indices),
249                                         yG.getPartialDerivative(i),
250                                         4.0e-14 * FastMath.abs(
251                                                         yDS.getPartialDerivative(
252                                                                         indices)));
253 
254             }
255             indices[i] = 0;
256         }
257     }
258 
259     @Test
260     void testArithmeticVsDS() {
261         for (double x = -1.25; x < 1.25; x+= 0.5) {
262             checkAgainstDS(x,
263                            new FieldUnivariateFunction() {
264                                public <S extends CalculusFieldElement<S>> S value(S x) {
265                                    final S y = x.add(3).multiply(x).subtract(5).multiply(0.5);
266                                    return y.negate().divide(4).divide(x).add(y).subtract(x).multiply(2).reciprocal();
267                                }
268                            });
269         }
270     }
271 
272     @Test
273     void testRemainderDoubleVsDS() {
274         for (double x = -1.25; x < 1.25; x+= 0.5) {
275             checkAgainstDS(x,
276                            new FieldUnivariateFunction() {
277                                public <S extends CalculusFieldElement<S>> S value(S x) {
278                                    return x.remainder(0.5);
279                                }
280                            });
281         }
282     }
283 
284     @Test
285     void testRemainderGVsDS() {
286         for (double x = -1.25; x < 1.25; x+= 0.5) {
287             checkAgainstDS(x,
288                            new FieldUnivariateFunction() {
289                               public <S extends CalculusFieldElement<S>> S value(S x) {
290                                   return x.remainder(x.divide(0.7));
291                               }
292                            });
293         }
294     }
295 
296     @Test
297     void testAbsVsDS() {
298         for (double x = -1.25; x < 1.25; x+= 0.5) {
299             checkAgainstDS(x,
300                            new FieldUnivariateFunction() {
301                                public <S extends CalculusFieldElement<S>> S value(S x) {
302                                    return x.abs();
303                                }
304                            });
305         }
306     }
307 
308     @Test
309     void testHypotVsDS() {
310         for (double x = -3.25; x < 3.25; x+= 0.5) {
311             checkAgainstDS(x,
312                            new FieldUnivariateFunction() {
313                                public <S extends CalculusFieldElement<S>> S value(S x) {
314                                    return x.cos().multiply(5).hypot(x.sin().multiply(2));
315                                }
316                            });
317         }
318     }
319 
320     @Test
321     void testAtan2VsDS() {
322         for (double x = -3.25; x < 3.25; x+= 0.5) {
323             checkAgainstDS(x,
324                            new FieldUnivariateFunction() {
325                                public <S extends CalculusFieldElement<S>> S value(S x) {
326                                    return x.cos().multiply(5).atan2(x.sin().multiply(2));
327                                }
328                            });
329         }
330     }
331 
332     @Test
333     void testPowersVsDS() {
334         for (double x = -3.25; x < 3.25; x+= 0.5) {
335             checkAgainstDS(x,
336                            new FieldUnivariateFunction() {
337                                public <S extends CalculusFieldElement<S>> S value(S x) {
338                                    final FieldSinCos<S> sc = x.sinCos();
339                                    return x.pow(3.2).add(x.pow(2)).subtract(sc.cos().abs().pow(sc.sin()));
340                                }
341                            });
342         }
343     }
344 
345     @Test
346     void testRootsVsDS() {
347         for (double x = 0.001; x < 3.25; x+= 0.5) {
348             checkAgainstDS(x,
349                            new FieldUnivariateFunction() {
350                                public <S extends CalculusFieldElement<S>> S value(S x) {
351                                    return x.rootN(5);//x.sqrt().add(x.cbrt()).subtract(x.rootN(5));
352                                }
353                            });
354         }
355     }
356 
357     @Test
358     void testExpsLogsVsDS() {
359         for (double x = 2.5; x < 3.25; x+= 0.125) {
360             checkAgainstDS(x,
361                            new FieldUnivariateFunction() {
362                                public <S extends CalculusFieldElement<S>> S value(S x) {
363                                    return x.exp().add(x.multiply(0.5).expm1()).log().log10().log1p();
364                                }
365                            });
366         }
367     }
368 
369     @Test
370     void testTrigonometryVsDS() {
371         for (double x = -3.25; x < 3.25; x+= 0.5) {
372             checkAgainstDS(x,
373                            new FieldUnivariateFunction() {
374                                public <S extends CalculusFieldElement<S>> S value(S x) {
375                                    return x.cos().multiply(x.sin()).atan().divide(12).asin().multiply(0.1).acos().tan();
376                                }
377                            });
378         }
379     }
380 
381     @Test
382     void testHyperbolicVsDS() {
383         for (double x = -1.25; x < 1.25; x+= 0.5) {
384             checkAgainstDS(x,
385                            new FieldUnivariateFunction() {
386                                public <S extends CalculusFieldElement<S>> S value(S x) {
387                                    return x.cosh().multiply(x.sinh()).multiply(12).abs().acosh().asinh().divide(7).tanh().multiply(0.1).atanh();
388                                }
389                            });
390         }
391     }
392 
393     @Test
394     void testConvertersVsDS() {
395         for (double x = -1.25; x < 1.25; x+= 0.5) {
396             checkAgainstDS(x,
397                            new FieldUnivariateFunction() {
398                                public <S extends CalculusFieldElement<S>> S value(S x) {
399                                    return x.multiply(5).toDegrees().subtract(x).toRadians();
400                                }
401                            });
402         }
403     }
404 
405     @Test
406     void testLinearCombination2D2FVsDS() {
407         for (double x = -1.25; x < 1.25; x+= 0.5) {
408             checkAgainstDS(x,
409                            new FieldUnivariateFunction() {
410                                public <S extends CalculusFieldElement<S>> S value(S x) {
411                                    return x.linearCombination(1.0, x.multiply(0.9),
412                                                               2.0, x.multiply(0.8));
413                                }
414                            });
415         }
416     }
417 
418     @Test
419     void testLinearCombination2F2FVsDS() {
420         for (double x = -1.25; x < 1.25; x+= 0.5) {
421             checkAgainstDS(x,
422                            new FieldUnivariateFunction() {
423                                public <S extends CalculusFieldElement<S>> S value(S x) {
424                                    return x.linearCombination(x.add(1), x.multiply(0.9),
425                                                               x.add(2), x.multiply(0.8));
426                                }
427                            });
428         }
429     }
430 
431     @Test
432     void testLinearCombination3D3FVsDS() {
433         for (double x = -1.25; x < 1.25; x+= 0.5) {
434             checkAgainstDS(x,
435                            new FieldUnivariateFunction() {
436                                public <S extends CalculusFieldElement<S>> S value(S x) {
437                                    return x.linearCombination(1.0, x.multiply(0.9),
438                                                               2.0, x.multiply(0.8),
439                                                               3.0, x.multiply(0.7));
440                                }
441                            });
442         }
443     }
444 
445     @Test
446     void testLinearCombination3F3FVsDS() {
447         for (double x = -1.25; x < 1.25; x+= 0.5) {
448             checkAgainstDS(x,
449                            new FieldUnivariateFunction() {
450                                public <S extends CalculusFieldElement<S>> S value(S x) {
451                                    return x.linearCombination(x.add(1), x.multiply(0.9),
452                                                               x.add(2), x.multiply(0.8),
453                                                               x.add(3), x.multiply(0.7));
454                                }
455                            });
456         }
457     }
458 
459     @Test
460     void testLinearCombination4D4FVsDS() {
461         for (double x = -1.25; x < 1.25; x+= 0.5) {
462             checkAgainstDS(x,
463                            new FieldUnivariateFunction() {
464                                public <S extends CalculusFieldElement<S>> S value(S x) {
465                                    return x.linearCombination(1.0, x.multiply(0.9),
466                                                               2.0, x.multiply(0.8),
467                                                               3.0, x.multiply(0.7),
468                                                               4.0, x.multiply(0.6));
469                                }
470                            });
471         }
472     }
473 
474     @Test
475     void testLinearCombination4F4FVsDS() {
476         for (double x = -1.25; x < 1.25; x+= 0.5) {
477             checkAgainstDS(x,
478                            new FieldUnivariateFunction() {
479                                public <S extends CalculusFieldElement<S>> S value(S x) {
480                                    return x.linearCombination(x.add(1), x.multiply(0.9),
481                                                               x.add(2), x.multiply(0.8),
482                                                               x.add(3), x.multiply(0.7),
483                                                               x.add(4), x.multiply(0.6));
484                                }
485                            });
486         }
487     }
488 
489     @Test
490     void testLinearCombinationnDnFVsDS() {
491         for (double x = -1.25; x < 1.25; x+= 0.5) {
492             checkAgainstDS(x,
493                            new FieldUnivariateFunction() {
494                                public <S extends CalculusFieldElement<S>> S value(S x) {
495                                    final S[] b = MathArrays.buildArray(x.getField(), 4);
496                                    b[0] = x.add(0.9);
497                                    b[1] = x.add(0.8);
498                                    b[2] = x.add(0.7);
499                                    b[3] = x.add(0.6);
500                                    return x.linearCombination(new double[] { 1, 2, 3, 4 }, b);
501                                }
502                            });
503         }
504     }
505 
506     @Test
507     void testLinearCombinationnFnFVsDS() {
508         for (double x = -1.25; x < 1.25; x+= 0.5) {
509             checkAgainstDS(x,
510                            new FieldUnivariateFunction() {
511                                public <S extends CalculusFieldElement<S>> S value(S x) {
512                                    final S[] a = MathArrays.buildArray(x.getField(), 4);
513                                    a[0] = x.add(1);
514                                    a[1] = x.add(2);
515                                    a[2] = x.add(3);
516                                    a[3] = x.add(4);
517                                    final S[] b = MathArrays.buildArray(x.getField(), 4);
518                                    b[0] = x.add(0.9);
519                                    b[1] = x.add(0.8);
520                                    b[2] = x.add(0.7);
521                                    b[3] = x.add(0.6);
522                                    return x.linearCombination(a, b);
523                                }
524                            });
525         }
526     }
527 
528     @Test
529     void testSerialization() {
530         Gradient a = build(1.3);
531         Gradient b = (Gradient) UnitTestUtils.serializeAndRecover(a);
532         assertEquals(a, b);
533         assertNotSame(a, b);
534     }
535 
536     @Test
537     void testZero() {
538         Gradient zero = build(17.0).getField().getZero();
539         assertEquals(0.0, zero.getValue(), 1.0e-15);
540         for (int i = 0; i < zero.getFreeParameters(); ++i) {
541             assertEquals(0.0, zero.getPartialDerivative(i), 1.0e-15);
542         }
543     }
544 
545     @Test
546     void testOne() {
547         Gradient one = build(17.0).getField().getOne();
548         assertEquals(1.0, one.getValue(), 1.0e-15);
549         for (int i = 0; i < one.getFreeParameters(); ++i) {
550             assertEquals(0.0, one.getPartialDerivative(i), 1.0e-15);
551         }
552     }
553 
554 }