View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.analysis;
24  
25  import org.hipparchus.analysis.differentiation.DSFactory;
26  import org.hipparchus.analysis.differentiation.Derivative;
27  import org.hipparchus.analysis.differentiation.DerivativeStructure;
28  import org.hipparchus.analysis.differentiation.MultivariateDifferentiableFunction;
29  import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
30  import org.hipparchus.analysis.function.Add;
31  import org.hipparchus.analysis.function.Constant;
32  import org.hipparchus.analysis.function.Cos;
33  import org.hipparchus.analysis.function.Cosh;
34  import org.hipparchus.analysis.function.Divide;
35  import org.hipparchus.analysis.function.Identity;
36  import org.hipparchus.analysis.function.Inverse;
37  import org.hipparchus.analysis.function.Log;
38  import org.hipparchus.analysis.function.Max;
39  import org.hipparchus.analysis.function.Min;
40  import org.hipparchus.analysis.function.Minus;
41  import org.hipparchus.analysis.function.Multiply;
42  import org.hipparchus.analysis.function.Pow;
43  import org.hipparchus.analysis.function.Power;
44  import org.hipparchus.analysis.function.Sin;
45  import org.hipparchus.analysis.function.Sinc;
46  import org.hipparchus.analysis.function.Subtract;
47  import org.hipparchus.exception.LocalizedCoreFormats;
48  import org.hipparchus.exception.MathIllegalArgumentException;
49  import org.hipparchus.util.FastMath;
50  import org.junit.jupiter.api.Test;
51  
52  import static org.junit.jupiter.api.Assertions.assertEquals;
53  import static org.junit.jupiter.api.Assertions.assertThrows;
54  import static org.junit.jupiter.api.Assertions.fail;
55  
56  /**
57   * Test for {@link FunctionUtils}.
58   */
59  class FunctionUtilsTest {
60      private final double EPS = FastMath.ulp(1d);
61  
62      @Test
63      void testCompose() {
64          UnivariateFunction id = new Identity();
65          assertEquals(3, FunctionUtils.compose(id, id, id).value(3), EPS);
66  
67          UnivariateFunction c = new Constant(4);
68          assertEquals(4, FunctionUtils.compose(id, c).value(3), EPS);
69          assertEquals(4, FunctionUtils.compose(c, id).value(3), EPS);
70  
71          UnivariateFunction m = new Minus();
72          assertEquals(-3, FunctionUtils.compose(m).value(3), EPS);
73          assertEquals(3, FunctionUtils.compose(m, m).value(3), EPS);
74  
75          UnivariateFunction inv = new Inverse();
76          assertEquals(-0.25, FunctionUtils.compose(inv, m, c, id).value(3), EPS);
77  
78          UnivariateFunction pow = new Power(2);
79          assertEquals(81, FunctionUtils.compose(pow, pow).value(3), EPS);
80      }
81  
82      @Test
83      void testComposeDifferentiable() {
84          DSFactory factory = new DSFactory(1, 1);
85          UnivariateDifferentiableFunction id = new Identity();
86          assertEquals(1, FunctionUtils.compose(id, id, id).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
87          assertEquals(1.5, FunctionUtils.compose(id, id, id).value(1.5), EPS);
88  
89          UnivariateDifferentiableFunction c = new Constant(4);
90          assertEquals(0, FunctionUtils.compose(id, c).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
91          assertEquals(0, FunctionUtils.compose(c, id).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
92  
93          UnivariateDifferentiableFunction m = new Minus();
94          assertEquals(-1, FunctionUtils.compose(m).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
95          assertEquals(1, FunctionUtils.compose(m, m).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
96  
97          UnivariateDifferentiableFunction inv = new Inverse();
98          assertEquals(0.25, FunctionUtils.compose(inv, m, id).value(factory.variable(0, 2)).getPartialDerivative(1), EPS);
99  
100         UnivariateDifferentiableFunction pow = new Power(2);
101         assertEquals(108, FunctionUtils.compose(pow, pow).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
102 
103         UnivariateDifferentiableFunction log = new Log();
104         double a = 9876.54321;
105         assertEquals(pow.value(factory.variable(0, a)).getPartialDerivative(1) / pow.value(a),
106                             FunctionUtils.compose(log, pow).value(factory.variable(0, a)).getPartialDerivative(1), EPS);
107     }
108 
109     @Test
110     void testAdd() {
111         UnivariateFunction id = new Identity();
112         UnivariateFunction c = new Constant(4);
113         UnivariateFunction m = new Minus();
114         UnivariateFunction inv = new Inverse();
115 
116         assertEquals(4.5, FunctionUtils.add(inv, m, c, id).value(2), EPS);
117         assertEquals(4 + 2, FunctionUtils.add(c, id).value(2), EPS);
118         assertEquals(4 - 2, FunctionUtils.add(c, FunctionUtils.compose(m, id)).value(2), EPS);
119     }
120 
121     @Test
122     void testAddDifferentiable() {
123         UnivariateDifferentiableFunction sin = new Sin();
124         UnivariateDifferentiableFunction c = new Constant(4);
125         UnivariateDifferentiableFunction m = new Minus();
126         UnivariateDifferentiableFunction inv = new Inverse();
127 
128         final double a = 123.456;
129         DSFactory factory = new DSFactory(1, 1);
130         assertEquals(- 1 / (a * a) -1 + FastMath.cos(a),
131                             FunctionUtils.add(inv, m, c, sin).value(factory.variable(0, a)).getPartialDerivative(1),
132                             EPS);
133         assertEquals(4 + FastMath.sin(1.2), FunctionUtils.add(sin, c).value(1.2), EPS);
134     }
135 
136     @Test
137     void testMultiply() {
138         UnivariateFunction c = new Constant(4);
139         assertEquals(16, FunctionUtils.multiply(c, c).value(12345), EPS);
140 
141         UnivariateFunction inv = new Inverse();
142         UnivariateFunction pow = new Power(2);
143         assertEquals(1, FunctionUtils.multiply(FunctionUtils.compose(inv, pow), pow).value(3.5), EPS);
144     }
145 
146     @Test
147     void testMultiplyDifferentiable() {
148         UnivariateDifferentiableFunction c = new Constant(4);
149         UnivariateDifferentiableFunction id = new Identity();
150         DSFactory factory = new DSFactory(1, 1);
151         final double a = 1.2345678;
152         assertEquals(8 * a, FunctionUtils.multiply(c, id, id).value(factory.variable(0, a)).getPartialDerivative(1), EPS);
153 
154         UnivariateDifferentiableFunction inv = new Inverse();
155         UnivariateDifferentiableFunction pow = new Power(2.5);
156         UnivariateDifferentiableFunction cos = new Cos();
157         assertEquals(1.5 * FastMath.sqrt(a) * FastMath.cos(a) - FastMath.pow(a, 1.5) * FastMath.sin(a),
158                             FunctionUtils.multiply(inv, pow, cos).value(factory.variable(0, a)).getPartialDerivative(1), EPS);
159 
160         UnivariateDifferentiableFunction cosh = new Cosh();
161         assertEquals(1.5 * FastMath.sqrt(a) * FastMath.cosh(a) + FastMath.pow(a, 1.5) * FastMath.sinh(a),
162                             FunctionUtils.multiply(inv, pow, cosh).value(factory.variable(0, a)).getPartialDerivative(1), 8 * EPS);
163         assertEquals(16, FunctionUtils.multiply(c, c).value(FastMath.PI), EPS);
164     }
165 
166     @Test
167     void testCombine() {
168         BivariateFunction bi = new Subtract();
169         UnivariateFunction id = new Identity();
170         UnivariateFunction m = new Minus();
171         UnivariateFunction c = FunctionUtils.combine(bi, id, m);
172         assertEquals(4.6912, c.value(2.3456), EPS);
173 
174         bi = new Multiply();
175         UnivariateFunction inv = new Inverse();
176         c = FunctionUtils.combine(bi, id, inv);
177         assertEquals(1, c.value(2.3456), EPS);
178     }
179 
180     @Test
181     void testCollector() {
182         BivariateFunction bi = new Add();
183         MultivariateFunction coll = FunctionUtils.collector(bi, 0);
184         assertEquals(10, coll.value(new double[] {1, 2, 3, 4}), EPS);
185 
186         bi = new Multiply();
187         coll = FunctionUtils.collector(bi, 1);
188         assertEquals(24, coll.value(new double[] {1, 2, 3, 4}), EPS);
189 
190         bi = new Max();
191         coll = FunctionUtils.collector(bi, Double.NEGATIVE_INFINITY);
192         assertEquals(10, coll.value(new double[] {1, -2, 7.5, 10, -24, 9.99}), 0);
193 
194         bi = new Min();
195         coll = FunctionUtils.collector(bi, Double.POSITIVE_INFINITY);
196         assertEquals(-24, coll.value(new double[] {1, -2, 7.5, 10, -24, 9.99}), 0);
197     }
198 
199     @Test
200     void testSinc() {
201         BivariateFunction div = new Divide();
202         UnivariateFunction sin = new Sin();
203         UnivariateFunction id = new Identity();
204         UnivariateFunction sinc1 = FunctionUtils.combine(div, sin, id);
205         UnivariateFunction sinc2 = new Sinc();
206 
207         for (int i = 0; i < 10; i++) {
208             double x = FastMath.random();
209             assertEquals(sinc1.value(x), sinc2.value(x), EPS);
210         }
211     }
212 
213     @Test
214     void testFixingArguments() {
215         UnivariateFunction scaler = FunctionUtils.fix1stArgument(new Multiply(), 10);
216         assertEquals(1.23456, scaler.value(0.123456), EPS);
217 
218         UnivariateFunction pow1 = new Power(2);
219         UnivariateFunction pow2 = FunctionUtils.fix2ndArgument(new Pow(), 2);
220 
221         for (int i = 0; i < 10; i++) {
222             double x = FastMath.random() * 10;
223             assertEquals(pow1.value(x), pow2.value(x), 0);
224         }
225     }
226 
227     @Test
228     void testSampleWrongBounds(){
229         assertThrows(MathIllegalArgumentException.class, () -> {
230             FunctionUtils.sample(new Sin(), FastMath.PI, 0.0, 10);
231         });
232     }
233 
234     @Test
235     void testSampleNegativeNumberOfPoints(){
236         assertThrows(MathIllegalArgumentException.class, () -> {
237             FunctionUtils.sample(new Sin(), 0.0, FastMath.PI, -1);
238         });
239     }
240 
241     @Test
242     void testSampleNullNumberOfPoints(){
243         assertThrows(MathIllegalArgumentException.class, () -> {
244             FunctionUtils.sample(new Sin(), 0.0, FastMath.PI, 0);
245         });
246     }
247 
248     @Test
249     void testSample() {
250         final int n = 11;
251         final double min = 0.0;
252         final double max = FastMath.PI;
253         final double[] actual = FunctionUtils.sample(new Sin(), min, max, n);
254         for (int i = 0; i < n; i++) {
255             final double x = min + (max - min) / n * i;
256             assertEquals(FastMath.sin(x), actual[i], 0.0, "x = " + x);
257         }
258     }
259 
260     @Test
261     void testToDifferentiableUnivariate() {
262 
263         final UnivariateFunction f0 = new UnivariateFunction() {
264             @Override
265             public double value(final double x) {
266                 return x * x;
267             }
268         };
269         final UnivariateFunction f1 = new UnivariateFunction() {
270             @Override
271             public double value(final double x) {
272                 return 2 * x;
273             }
274         };
275         final UnivariateFunction f2 = new UnivariateFunction() {
276             @Override
277             public double value(final double x) {
278                 return 2;
279             }
280         };
281         final UnivariateDifferentiableFunction f = FunctionUtils.toDifferentiable(f0, f1, f2);
282 
283         DSFactory factory = new DSFactory(1, 2);
284         for (double t = -1.0; t < 1; t += 0.01) {
285             // x = sin(t)
286             DerivativeStructure dsT = factory.variable(0, t);
287             DerivativeStructure y = f.value(dsT.sin());
288             assertEquals(FastMath.sin(t) * FastMath.sin(t),               f.value(FastMath.sin(t)),  1.0e-15);
289             assertEquals(FastMath.sin(t) * FastMath.sin(t),               y.getValue(),              1.0e-15);
290             assertEquals(2 * FastMath.cos(t) * FastMath.sin(t),           y.getPartialDerivative(1), 1.0e-15);
291             assertEquals(2 * (1 - 2 * FastMath.sin(t) * FastMath.sin(t)), y.getPartialDerivative(2), 1.0e-15);
292         }
293 
294         try {
295             f.value(new DSFactory(1, 3).constant(0.0));
296             fail("an exception should have been thrown");
297         } catch (MathIllegalArgumentException e) {
298             assertEquals(LocalizedCoreFormats.NUMBER_TOO_LARGE, e.getSpecifier());
299             assertEquals(2, ((Integer) e.getParts()[1]).intValue());
300             assertEquals(3, ((Integer) e.getParts()[0]).intValue());
301         }
302     }
303 
304     @Test
305     void testToDifferentiableMultivariate() {
306 
307         final double a = 1.5;
308         final double b = 0.5;
309         final MultivariateFunction f = new MultivariateFunction() {
310             @Override
311             public double value(final double[] point) {
312                 return a * point[0] + b * point[1];
313             }
314         };
315         final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
316             @Override
317             public double[] value(final double[] point) {
318                 return new double[] { a, b };
319             }
320         };
321         final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
322 
323         DSFactory factory11 = new DSFactory(1, 1);
324         for (double t = -1.0; t < 1; t += 0.01) {
325             // x = sin(t), y = cos(t), hence the method really becomes univariate
326             DerivativeStructure dsT = factory11.variable(0, t);
327             DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
328             assertEquals(a * FastMath.sin(t) + b * FastMath.cos(t), y.getValue(),              1.0e-15);
329             assertEquals(a * FastMath.cos(t) - b * FastMath.sin(t), y.getPartialDerivative(1), 1.0e-15);
330         }
331 
332         DSFactory factory21 = new DSFactory(2, 1);
333         for (double u = -1.0; u < 1; u += 0.01) {
334             DerivativeStructure dsU = factory21.variable(0, u);
335             for (double v = -1.0; v < 1; v += 0.01) {
336                 DerivativeStructure dsV = factory21.variable(1, v);
337                 DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsU, dsV });
338                 assertEquals(a * u + b * v, mdf.value(new double[] { u, v }), 1.0e-15);
339                 assertEquals(a * u + b * v, y.getValue(),                     1.0e-15);
340                 assertEquals(a,             y.getPartialDerivative(1, 0),     1.0e-15);
341                 assertEquals(b,             y.getPartialDerivative(0, 1),     1.0e-15);
342             }
343         }
344 
345         DSFactory factory13 = new DSFactory(1, 3);
346         try {
347             mdf.value(new DerivativeStructure[] { factory13.constant(0.0), factory13.constant(0.0) });
348             fail("an exception should have been thrown");
349         } catch (MathIllegalArgumentException e) {
350             assertEquals(LocalizedCoreFormats.NUMBER_TOO_LARGE, e.getSpecifier());
351             assertEquals(1, ((Integer) e.getParts()[1]).intValue());
352             assertEquals(3, ((Integer) e.getParts()[0]).intValue());
353         }
354     }
355 
356     @Test
357     void testToDifferentiableMultivariateInconsistentGradient() {
358 
359         final double a = 1.5;
360         final double b = 0.5;
361         final MultivariateFunction f = new MultivariateFunction() {
362             @Override
363             public double value(final double[] point) {
364                 return a * point[0] + b * point[1];
365             }
366         };
367         final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
368             @Override
369             public double[] value(final double[] point) {
370                 return new double[] { a, b, 0.0 };
371             }
372         };
373         final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
374 
375         DSFactory factory = new DSFactory(1, 1);
376         try {
377             DerivativeStructure dsT = factory.variable(0, 0.0);
378             mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
379             fail("an exception should have been thrown");
380         } catch (MathIllegalArgumentException e) {
381             assertEquals(3, ((Integer) e.getParts()[0]).intValue());
382             assertEquals(2, ((Integer) e.getParts()[1]).intValue());
383         }
384     }
385 
386     @Test
387     void testDerivativeUnivariate() {
388 
389         final UnivariateDifferentiableFunction f = new UnivariateDifferentiableFunction() {
390 
391             @Override
392             public double value(double x) {
393                 return x * x;
394             }
395 
396             @Override
397             public <T extends Derivative<T>> T value(T x) {
398                 return x.square();
399             }
400 
401         };
402 
403         final UnivariateFunction f0 = FunctionUtils.derivative(f, 0);
404         final UnivariateFunction f1 = FunctionUtils.derivative(f, 1);
405         final UnivariateFunction f2 = FunctionUtils.derivative(f, 2);
406 
407         for (double t = -1.0; t < 1; t += 0.01) {
408             assertEquals(t * t, f0.value(t), 1.0e-15);
409             assertEquals(2 * t, f1.value(t), 1.0e-15);
410             assertEquals(2,     f2.value(t), 1.0e-15);
411         }
412 
413     }
414 
415     @Test
416     void testDerivativeMultivariate() {
417 
418         final double a = 1.5;
419         final double b = 0.5;
420         final double c = 0.25;
421         final MultivariateDifferentiableFunction mdf = new MultivariateDifferentiableFunction() {
422 
423             @Override
424             public double value(double[] point) {
425                 return a * point[0] * point[0] + b * point[1] * point[1] + c * point[0] * point[1];
426             }
427 
428             @Override
429             public DerivativeStructure value(DerivativeStructure[] point) {
430                 DerivativeStructure x  = point[0];
431                 DerivativeStructure y  = point[1];
432                 DerivativeStructure x2 = x.square();
433                 DerivativeStructure y2 = y.square();
434                 DerivativeStructure xy = x.multiply(y);
435                 return x2.multiply(a).add(y2.multiply(b)).add(xy.multiply(c));
436             }
437 
438         };
439 
440         final MultivariateFunction f       = FunctionUtils.derivative(mdf, new int[] { 0, 0 });
441         final MultivariateFunction dfdx    = FunctionUtils.derivative(mdf, new int[] { 1, 0 });
442         final MultivariateFunction dfdy    = FunctionUtils.derivative(mdf, new int[] { 0, 1 });
443         final MultivariateFunction d2fdx2  = FunctionUtils.derivative(mdf, new int[] { 2, 0 });
444         final MultivariateFunction d2fdy2  = FunctionUtils.derivative(mdf, new int[] { 0, 2 });
445         final MultivariateFunction d2fdxdy = FunctionUtils.derivative(mdf, new int[] { 1, 1 });
446 
447         for (double x = -1.0; x < 1; x += 0.01) {
448             for (double y = -1.0; y < 1; y += 0.01) {
449                 assertEquals(a * x * x + b * y * y + c * x * y, f.value(new double[]       { x, y }), 1.0e-15);
450                 assertEquals(2 * a * x + c * y,                 dfdx.value(new double[]    { x, y }), 1.0e-15);
451                 assertEquals(2 * b * y + c * x,                 dfdy.value(new double[]    { x, y }), 1.0e-15);
452                 assertEquals(2 * a,                             d2fdx2.value(new double[]  { x, y }), 1.0e-15);
453                 assertEquals(2 * b,                             d2fdy2.value(new double[]  { x, y }), 1.0e-15);
454                 assertEquals(c,                                 d2fdxdy.value(new double[] { x, y }), 1.0e-15);
455             }
456         }
457 
458     }
459 
460 }