1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.analysis;
24
25 import org.hipparchus.analysis.differentiation.DSFactory;
26 import org.hipparchus.analysis.differentiation.Derivative;
27 import org.hipparchus.analysis.differentiation.DerivativeStructure;
28 import org.hipparchus.analysis.differentiation.MultivariateDifferentiableFunction;
29 import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
30 import org.hipparchus.analysis.function.Add;
31 import org.hipparchus.analysis.function.Constant;
32 import org.hipparchus.analysis.function.Cos;
33 import org.hipparchus.analysis.function.Cosh;
34 import org.hipparchus.analysis.function.Divide;
35 import org.hipparchus.analysis.function.Identity;
36 import org.hipparchus.analysis.function.Inverse;
37 import org.hipparchus.analysis.function.Log;
38 import org.hipparchus.analysis.function.Max;
39 import org.hipparchus.analysis.function.Min;
40 import org.hipparchus.analysis.function.Minus;
41 import org.hipparchus.analysis.function.Multiply;
42 import org.hipparchus.analysis.function.Pow;
43 import org.hipparchus.analysis.function.Power;
44 import org.hipparchus.analysis.function.Sin;
45 import org.hipparchus.analysis.function.Sinc;
46 import org.hipparchus.analysis.function.Subtract;
47 import org.hipparchus.exception.LocalizedCoreFormats;
48 import org.hipparchus.exception.MathIllegalArgumentException;
49 import org.hipparchus.util.FastMath;
50 import org.junit.jupiter.api.Test;
51
52 import static org.junit.jupiter.api.Assertions.assertEquals;
53 import static org.junit.jupiter.api.Assertions.assertThrows;
54 import static org.junit.jupiter.api.Assertions.fail;
55
56
57
58
59 class FunctionUtilsTest {
60 private final double EPS = FastMath.ulp(1d);
61
62 @Test
63 void testCompose() {
64 UnivariateFunction id = new Identity();
65 assertEquals(3, FunctionUtils.compose(id, id, id).value(3), EPS);
66
67 UnivariateFunction c = new Constant(4);
68 assertEquals(4, FunctionUtils.compose(id, c).value(3), EPS);
69 assertEquals(4, FunctionUtils.compose(c, id).value(3), EPS);
70
71 UnivariateFunction m = new Minus();
72 assertEquals(-3, FunctionUtils.compose(m).value(3), EPS);
73 assertEquals(3, FunctionUtils.compose(m, m).value(3), EPS);
74
75 UnivariateFunction inv = new Inverse();
76 assertEquals(-0.25, FunctionUtils.compose(inv, m, c, id).value(3), EPS);
77
78 UnivariateFunction pow = new Power(2);
79 assertEquals(81, FunctionUtils.compose(pow, pow).value(3), EPS);
80 }
81
82 @Test
83 void testComposeDifferentiable() {
84 DSFactory factory = new DSFactory(1, 1);
85 UnivariateDifferentiableFunction id = new Identity();
86 assertEquals(1, FunctionUtils.compose(id, id, id).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
87 assertEquals(1.5, FunctionUtils.compose(id, id, id).value(1.5), EPS);
88
89 UnivariateDifferentiableFunction c = new Constant(4);
90 assertEquals(0, FunctionUtils.compose(id, c).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
91 assertEquals(0, FunctionUtils.compose(c, id).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
92
93 UnivariateDifferentiableFunction m = new Minus();
94 assertEquals(-1, FunctionUtils.compose(m).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
95 assertEquals(1, FunctionUtils.compose(m, m).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
96
97 UnivariateDifferentiableFunction inv = new Inverse();
98 assertEquals(0.25, FunctionUtils.compose(inv, m, id).value(factory.variable(0, 2)).getPartialDerivative(1), EPS);
99
100 UnivariateDifferentiableFunction pow = new Power(2);
101 assertEquals(108, FunctionUtils.compose(pow, pow).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
102
103 UnivariateDifferentiableFunction log = new Log();
104 double a = 9876.54321;
105 assertEquals(pow.value(factory.variable(0, a)).getPartialDerivative(1) / pow.value(a),
106 FunctionUtils.compose(log, pow).value(factory.variable(0, a)).getPartialDerivative(1), EPS);
107 }
108
109 @Test
110 void testAdd() {
111 UnivariateFunction id = new Identity();
112 UnivariateFunction c = new Constant(4);
113 UnivariateFunction m = new Minus();
114 UnivariateFunction inv = new Inverse();
115
116 assertEquals(4.5, FunctionUtils.add(inv, m, c, id).value(2), EPS);
117 assertEquals(4 + 2, FunctionUtils.add(c, id).value(2), EPS);
118 assertEquals(4 - 2, FunctionUtils.add(c, FunctionUtils.compose(m, id)).value(2), EPS);
119 }
120
121 @Test
122 void testAddDifferentiable() {
123 UnivariateDifferentiableFunction sin = new Sin();
124 UnivariateDifferentiableFunction c = new Constant(4);
125 UnivariateDifferentiableFunction m = new Minus();
126 UnivariateDifferentiableFunction inv = new Inverse();
127
128 final double a = 123.456;
129 DSFactory factory = new DSFactory(1, 1);
130 assertEquals(- 1 / (a * a) -1 + FastMath.cos(a),
131 FunctionUtils.add(inv, m, c, sin).value(factory.variable(0, a)).getPartialDerivative(1),
132 EPS);
133 assertEquals(4 + FastMath.sin(1.2), FunctionUtils.add(sin, c).value(1.2), EPS);
134 }
135
136 @Test
137 void testMultiply() {
138 UnivariateFunction c = new Constant(4);
139 assertEquals(16, FunctionUtils.multiply(c, c).value(12345), EPS);
140
141 UnivariateFunction inv = new Inverse();
142 UnivariateFunction pow = new Power(2);
143 assertEquals(1, FunctionUtils.multiply(FunctionUtils.compose(inv, pow), pow).value(3.5), EPS);
144 }
145
146 @Test
147 void testMultiplyDifferentiable() {
148 UnivariateDifferentiableFunction c = new Constant(4);
149 UnivariateDifferentiableFunction id = new Identity();
150 DSFactory factory = new DSFactory(1, 1);
151 final double a = 1.2345678;
152 assertEquals(8 * a, FunctionUtils.multiply(c, id, id).value(factory.variable(0, a)).getPartialDerivative(1), EPS);
153
154 UnivariateDifferentiableFunction inv = new Inverse();
155 UnivariateDifferentiableFunction pow = new Power(2.5);
156 UnivariateDifferentiableFunction cos = new Cos();
157 assertEquals(1.5 * FastMath.sqrt(a) * FastMath.cos(a) - FastMath.pow(a, 1.5) * FastMath.sin(a),
158 FunctionUtils.multiply(inv, pow, cos).value(factory.variable(0, a)).getPartialDerivative(1), EPS);
159
160 UnivariateDifferentiableFunction cosh = new Cosh();
161 assertEquals(1.5 * FastMath.sqrt(a) * FastMath.cosh(a) + FastMath.pow(a, 1.5) * FastMath.sinh(a),
162 FunctionUtils.multiply(inv, pow, cosh).value(factory.variable(0, a)).getPartialDerivative(1), 8 * EPS);
163 assertEquals(16, FunctionUtils.multiply(c, c).value(FastMath.PI), EPS);
164 }
165
166 @Test
167 void testCombine() {
168 BivariateFunction bi = new Subtract();
169 UnivariateFunction id = new Identity();
170 UnivariateFunction m = new Minus();
171 UnivariateFunction c = FunctionUtils.combine(bi, id, m);
172 assertEquals(4.6912, c.value(2.3456), EPS);
173
174 bi = new Multiply();
175 UnivariateFunction inv = new Inverse();
176 c = FunctionUtils.combine(bi, id, inv);
177 assertEquals(1, c.value(2.3456), EPS);
178 }
179
180 @Test
181 void testCollector() {
182 BivariateFunction bi = new Add();
183 MultivariateFunction coll = FunctionUtils.collector(bi, 0);
184 assertEquals(10, coll.value(new double[] {1, 2, 3, 4}), EPS);
185
186 bi = new Multiply();
187 coll = FunctionUtils.collector(bi, 1);
188 assertEquals(24, coll.value(new double[] {1, 2, 3, 4}), EPS);
189
190 bi = new Max();
191 coll = FunctionUtils.collector(bi, Double.NEGATIVE_INFINITY);
192 assertEquals(10, coll.value(new double[] {1, -2, 7.5, 10, -24, 9.99}), 0);
193
194 bi = new Min();
195 coll = FunctionUtils.collector(bi, Double.POSITIVE_INFINITY);
196 assertEquals(-24, coll.value(new double[] {1, -2, 7.5, 10, -24, 9.99}), 0);
197 }
198
199 @Test
200 void testSinc() {
201 BivariateFunction div = new Divide();
202 UnivariateFunction sin = new Sin();
203 UnivariateFunction id = new Identity();
204 UnivariateFunction sinc1 = FunctionUtils.combine(div, sin, id);
205 UnivariateFunction sinc2 = new Sinc();
206
207 for (int i = 0; i < 10; i++) {
208 double x = FastMath.random();
209 assertEquals(sinc1.value(x), sinc2.value(x), EPS);
210 }
211 }
212
213 @Test
214 void testFixingArguments() {
215 UnivariateFunction scaler = FunctionUtils.fix1stArgument(new Multiply(), 10);
216 assertEquals(1.23456, scaler.value(0.123456), EPS);
217
218 UnivariateFunction pow1 = new Power(2);
219 UnivariateFunction pow2 = FunctionUtils.fix2ndArgument(new Pow(), 2);
220
221 for (int i = 0; i < 10; i++) {
222 double x = FastMath.random() * 10;
223 assertEquals(pow1.value(x), pow2.value(x), 0);
224 }
225 }
226
227 @Test
228 void testSampleWrongBounds(){
229 assertThrows(MathIllegalArgumentException.class, () -> {
230 FunctionUtils.sample(new Sin(), FastMath.PI, 0.0, 10);
231 });
232 }
233
234 @Test
235 void testSampleNegativeNumberOfPoints(){
236 assertThrows(MathIllegalArgumentException.class, () -> {
237 FunctionUtils.sample(new Sin(), 0.0, FastMath.PI, -1);
238 });
239 }
240
241 @Test
242 void testSampleNullNumberOfPoints(){
243 assertThrows(MathIllegalArgumentException.class, () -> {
244 FunctionUtils.sample(new Sin(), 0.0, FastMath.PI, 0);
245 });
246 }
247
248 @Test
249 void testSample() {
250 final int n = 11;
251 final double min = 0.0;
252 final double max = FastMath.PI;
253 final double[] actual = FunctionUtils.sample(new Sin(), min, max, n);
254 for (int i = 0; i < n; i++) {
255 final double x = min + (max - min) / n * i;
256 assertEquals(FastMath.sin(x), actual[i], 0.0, "x = " + x);
257 }
258 }
259
260 @Test
261 void testToDifferentiableUnivariate() {
262
263 final UnivariateFunction f0 = new UnivariateFunction() {
264 @Override
265 public double value(final double x) {
266 return x * x;
267 }
268 };
269 final UnivariateFunction f1 = new UnivariateFunction() {
270 @Override
271 public double value(final double x) {
272 return 2 * x;
273 }
274 };
275 final UnivariateFunction f2 = new UnivariateFunction() {
276 @Override
277 public double value(final double x) {
278 return 2;
279 }
280 };
281 final UnivariateDifferentiableFunction f = FunctionUtils.toDifferentiable(f0, f1, f2);
282
283 DSFactory factory = new DSFactory(1, 2);
284 for (double t = -1.0; t < 1; t += 0.01) {
285
286 DerivativeStructure dsT = factory.variable(0, t);
287 DerivativeStructure y = f.value(dsT.sin());
288 assertEquals(FastMath.sin(t) * FastMath.sin(t), f.value(FastMath.sin(t)), 1.0e-15);
289 assertEquals(FastMath.sin(t) * FastMath.sin(t), y.getValue(), 1.0e-15);
290 assertEquals(2 * FastMath.cos(t) * FastMath.sin(t), y.getPartialDerivative(1), 1.0e-15);
291 assertEquals(2 * (1 - 2 * FastMath.sin(t) * FastMath.sin(t)), y.getPartialDerivative(2), 1.0e-15);
292 }
293
294 try {
295 f.value(new DSFactory(1, 3).constant(0.0));
296 fail("an exception should have been thrown");
297 } catch (MathIllegalArgumentException e) {
298 assertEquals(LocalizedCoreFormats.NUMBER_TOO_LARGE, e.getSpecifier());
299 assertEquals(2, ((Integer) e.getParts()[1]).intValue());
300 assertEquals(3, ((Integer) e.getParts()[0]).intValue());
301 }
302 }
303
304 @Test
305 void testToDifferentiableMultivariate() {
306
307 final double a = 1.5;
308 final double b = 0.5;
309 final MultivariateFunction f = new MultivariateFunction() {
310 @Override
311 public double value(final double[] point) {
312 return a * point[0] + b * point[1];
313 }
314 };
315 final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
316 @Override
317 public double[] value(final double[] point) {
318 return new double[] { a, b };
319 }
320 };
321 final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
322
323 DSFactory factory11 = new DSFactory(1, 1);
324 for (double t = -1.0; t < 1; t += 0.01) {
325
326 DerivativeStructure dsT = factory11.variable(0, t);
327 DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
328 assertEquals(a * FastMath.sin(t) + b * FastMath.cos(t), y.getValue(), 1.0e-15);
329 assertEquals(a * FastMath.cos(t) - b * FastMath.sin(t), y.getPartialDerivative(1), 1.0e-15);
330 }
331
332 DSFactory factory21 = new DSFactory(2, 1);
333 for (double u = -1.0; u < 1; u += 0.01) {
334 DerivativeStructure dsU = factory21.variable(0, u);
335 for (double v = -1.0; v < 1; v += 0.01) {
336 DerivativeStructure dsV = factory21.variable(1, v);
337 DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsU, dsV });
338 assertEquals(a * u + b * v, mdf.value(new double[] { u, v }), 1.0e-15);
339 assertEquals(a * u + b * v, y.getValue(), 1.0e-15);
340 assertEquals(a, y.getPartialDerivative(1, 0), 1.0e-15);
341 assertEquals(b, y.getPartialDerivative(0, 1), 1.0e-15);
342 }
343 }
344
345 DSFactory factory13 = new DSFactory(1, 3);
346 try {
347 mdf.value(new DerivativeStructure[] { factory13.constant(0.0), factory13.constant(0.0) });
348 fail("an exception should have been thrown");
349 } catch (MathIllegalArgumentException e) {
350 assertEquals(LocalizedCoreFormats.NUMBER_TOO_LARGE, e.getSpecifier());
351 assertEquals(1, ((Integer) e.getParts()[1]).intValue());
352 assertEquals(3, ((Integer) e.getParts()[0]).intValue());
353 }
354 }
355
356 @Test
357 void testToDifferentiableMultivariateInconsistentGradient() {
358
359 final double a = 1.5;
360 final double b = 0.5;
361 final MultivariateFunction f = new MultivariateFunction() {
362 @Override
363 public double value(final double[] point) {
364 return a * point[0] + b * point[1];
365 }
366 };
367 final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
368 @Override
369 public double[] value(final double[] point) {
370 return new double[] { a, b, 0.0 };
371 }
372 };
373 final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
374
375 DSFactory factory = new DSFactory(1, 1);
376 try {
377 DerivativeStructure dsT = factory.variable(0, 0.0);
378 mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
379 fail("an exception should have been thrown");
380 } catch (MathIllegalArgumentException e) {
381 assertEquals(3, ((Integer) e.getParts()[0]).intValue());
382 assertEquals(2, ((Integer) e.getParts()[1]).intValue());
383 }
384 }
385
386 @Test
387 void testDerivativeUnivariate() {
388
389 final UnivariateDifferentiableFunction f = new UnivariateDifferentiableFunction() {
390
391 @Override
392 public double value(double x) {
393 return x * x;
394 }
395
396 @Override
397 public <T extends Derivative<T>> T value(T x) {
398 return x.square();
399 }
400
401 };
402
403 final UnivariateFunction f0 = FunctionUtils.derivative(f, 0);
404 final UnivariateFunction f1 = FunctionUtils.derivative(f, 1);
405 final UnivariateFunction f2 = FunctionUtils.derivative(f, 2);
406
407 for (double t = -1.0; t < 1; t += 0.01) {
408 assertEquals(t * t, f0.value(t), 1.0e-15);
409 assertEquals(2 * t, f1.value(t), 1.0e-15);
410 assertEquals(2, f2.value(t), 1.0e-15);
411 }
412
413 }
414
415 @Test
416 void testDerivativeMultivariate() {
417
418 final double a = 1.5;
419 final double b = 0.5;
420 final double c = 0.25;
421 final MultivariateDifferentiableFunction mdf = new MultivariateDifferentiableFunction() {
422
423 @Override
424 public double value(double[] point) {
425 return a * point[0] * point[0] + b * point[1] * point[1] + c * point[0] * point[1];
426 }
427
428 @Override
429 public DerivativeStructure value(DerivativeStructure[] point) {
430 DerivativeStructure x = point[0];
431 DerivativeStructure y = point[1];
432 DerivativeStructure x2 = x.square();
433 DerivativeStructure y2 = y.square();
434 DerivativeStructure xy = x.multiply(y);
435 return x2.multiply(a).add(y2.multiply(b)).add(xy.multiply(c));
436 }
437
438 };
439
440 final MultivariateFunction f = FunctionUtils.derivative(mdf, new int[] { 0, 0 });
441 final MultivariateFunction dfdx = FunctionUtils.derivative(mdf, new int[] { 1, 0 });
442 final MultivariateFunction dfdy = FunctionUtils.derivative(mdf, new int[] { 0, 1 });
443 final MultivariateFunction d2fdx2 = FunctionUtils.derivative(mdf, new int[] { 2, 0 });
444 final MultivariateFunction d2fdy2 = FunctionUtils.derivative(mdf, new int[] { 0, 2 });
445 final MultivariateFunction d2fdxdy = FunctionUtils.derivative(mdf, new int[] { 1, 1 });
446
447 for (double x = -1.0; x < 1; x += 0.01) {
448 for (double y = -1.0; y < 1; y += 0.01) {
449 assertEquals(a * x * x + b * y * y + c * x * y, f.value(new double[] { x, y }), 1.0e-15);
450 assertEquals(2 * a * x + c * y, dfdx.value(new double[] { x, y }), 1.0e-15);
451 assertEquals(2 * b * y + c * x, dfdy.value(new double[] { x, y }), 1.0e-15);
452 assertEquals(2 * a, d2fdx2.value(new double[] { x, y }), 1.0e-15);
453 assertEquals(2 * b, d2fdy2.value(new double[] { x, y }), 1.0e-15);
454 assertEquals(c, d2fdxdy.value(new double[] { x, y }), 1.0e-15);
455 }
456 }
457
458 }
459
460 }