View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.analysis;
24  
25  import org.hipparchus.analysis.differentiation.DSFactory;
26  import org.hipparchus.analysis.differentiation.Derivative;
27  import org.hipparchus.analysis.differentiation.DerivativeStructure;
28  import org.hipparchus.analysis.differentiation.MultivariateDifferentiableFunction;
29  import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
30  import org.hipparchus.analysis.function.Identity;
31  import org.hipparchus.exception.LocalizedCoreFormats;
32  import org.hipparchus.exception.MathIllegalArgumentException;
33  import org.hipparchus.util.MathArrays;
34  import org.hipparchus.util.MathUtils;
35  
36  /**
37   * Utilities for manipulating function objects.
38   *
39   */
40  public class FunctionUtils {
41      /**
42       * Class only contains static methods.
43       */
44      private FunctionUtils() {}
45  
46      /**
47       * Composes functions.
48       * <p>
49       * The functions in the argument list are composed sequentially, in the
50       * given order.  For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).</p>
51       *
52       * @param f List of functions.
53       * @return the composite function.
54       */
55      public static UnivariateFunction compose(final UnivariateFunction ... f) {
56          return new UnivariateFunction() {
57              /** {@inheritDoc} */
58              @Override
59              public double value(double x) {
60                  double r = x;
61                  for (int i = f.length - 1; i >= 0; i--) {
62                      r = f[i].value(r);
63                  }
64                  return r;
65              }
66          };
67      }
68  
69      /**
70       * Composes functions.
71       * <p>
72       * The functions in the argument list are composed sequentially, in the
73       * given order.  For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).</p>
74       *
75       * @param f List of functions.
76       * @return the composite function.
77       */
78      public static UnivariateDifferentiableFunction compose(final UnivariateDifferentiableFunction ... f) {
79          return new UnivariateDifferentiableFunction() {
80  
81              /** {@inheritDoc} */
82              @Override
83              public double value(final double t) {
84                  double r = t;
85                  for (int i = f.length - 1; i >= 0; i--) {
86                      r = f[i].value(r);
87                  }
88                  return r;
89              }
90  
91              /** {@inheritDoc} */
92              @Override
93              public <T extends Derivative<T>> T value(final T t) {
94                  T r = t;
95                  for (int i = f.length - 1; i >= 0; i--) {
96                      r = f[i].value(r);
97                  }
98                  return r;
99              }
100 
101         };
102     }
103 
104     /**
105      * Adds functions.
106      *
107      * @param f List of functions.
108      * @return a function that computes the sum of the functions.
109      */
110     public static UnivariateFunction add(final UnivariateFunction ... f) {
111         return new UnivariateFunction() {
112             /** {@inheritDoc} */
113             @Override
114             public double value(double x) {
115                 double r = f[0].value(x);
116                 for (int i = 1; i < f.length; i++) {
117                     r += f[i].value(x);
118                 }
119                 return r;
120             }
121         };
122     }
123 
124     /**
125      * Adds functions.
126      *
127      * @param f List of functions.
128      * @return a function that computes the sum of the functions.
129      */
130     public static UnivariateDifferentiableFunction add(final UnivariateDifferentiableFunction ... f) {
131         return new UnivariateDifferentiableFunction() {
132 
133             /** {@inheritDoc} */
134             @Override
135             public double value(final double t) {
136                 double r = f[0].value(t);
137                 for (int i = 1; i < f.length; i++) {
138                     r += f[i].value(t);
139                 }
140                 return r;
141             }
142 
143             /** {@inheritDoc}
144              * @throws MathIllegalArgumentException if functions are not consistent with each other
145              */
146             @Override
147             public <T extends Derivative<T>> T value(final T t)
148                 throws MathIllegalArgumentException {
149                 T r = f[0].value(t);
150                 for (int i = 1; i < f.length; i++) {
151                     r = r.add(f[i].value(t));
152                 }
153                 return r;
154             }
155 
156         };
157     }
158 
159     /**
160      * Multiplies functions.
161      *
162      * @param f List of functions.
163      * @return a function that computes the product of the functions.
164      */
165     public static UnivariateFunction multiply(final UnivariateFunction ... f) {
166         return new UnivariateFunction() {
167             /** {@inheritDoc} */
168             @Override
169             public double value(double x) {
170                 double r = f[0].value(x);
171                 for (int i = 1; i < f.length; i++) {
172                     r *= f[i].value(x);
173                 }
174                 return r;
175             }
176         };
177     }
178 
179     /**
180      * Multiplies functions.
181      *
182      * @param f List of functions.
183      * @return a function that computes the product of the functions.
184      */
185     public static UnivariateDifferentiableFunction multiply(final UnivariateDifferentiableFunction ... f) {
186         return new UnivariateDifferentiableFunction() {
187 
188             /** {@inheritDoc} */
189             @Override
190             public double value(final double t) {
191                 double r = f[0].value(t);
192                 for (int i = 1; i < f.length; i++) {
193                     r  *= f[i].value(t);
194                 }
195                 return r;
196             }
197 
198             /** {@inheritDoc} */
199             @Override
200             public <T extends Derivative<T>> T value(final T t) {
201                 T r = f[0].value(t);
202                 for (int i = 1; i < f.length; i++) {
203                     r = r.multiply(f[i].value(t));
204                 }
205                 return r;
206             }
207 
208         };
209     }
210 
211     /**
212      * Returns the univariate function
213      * {@code h(x) = combiner(f(x), g(x)).}
214      *
215      * @param combiner Combiner function.
216      * @param f Function.
217      * @param g Function.
218      * @return the composite function.
219      */
220     public static UnivariateFunction combine(final BivariateFunction combiner,
221                                              final UnivariateFunction f,
222                                              final UnivariateFunction g) {
223         return new UnivariateFunction() {
224             /** {@inheritDoc} */
225             @Override
226             public double value(double x) {
227                 return combiner.value(f.value(x), g.value(x));
228             }
229         };
230     }
231 
232     /**
233      * Returns a MultivariateFunction h(x[]) defined by <pre> <code>
234      * h(x[]) = combiner(...combiner(combiner(initialValue,f(x[0])),f(x[1]))...),f(x[x.length-1]))
235      * </code></pre>
236      *
237      * @param combiner Combiner function.
238      * @param f Function.
239      * @param initialValue Initial value.
240      * @return a collector function.
241      */
242     public static MultivariateFunction collector(final BivariateFunction combiner,
243                                                  final UnivariateFunction f,
244                                                  final double initialValue) {
245         return new MultivariateFunction() {
246             /** {@inheritDoc} */
247             @Override
248             public double value(double[] point) {
249                 double result = combiner.value(initialValue, f.value(point[0]));
250                 for (int i = 1; i < point.length; i++) {
251                     result = combiner.value(result, f.value(point[i]));
252                 }
253                 return result;
254             }
255         };
256     }
257 
258     /**
259      * Returns a MultivariateFunction h(x[]) defined by <pre> <code>
260      * h(x[]) = combiner(...combiner(combiner(initialValue,x[0]),x[1])...),x[x.length-1])
261      * </code></pre>
262      *
263      * @param combiner Combiner function.
264      * @param initialValue Initial value.
265      * @return a collector function.
266      */
267     public static MultivariateFunction collector(final BivariateFunction combiner,
268                                                  final double initialValue) {
269         return collector(combiner, new Identity(), initialValue);
270     }
271 
272     /**
273      * Creates a unary function by fixing the first argument of a binary function.
274      *
275      * @param f Binary function.
276      * @param fixed value to which the first argument of {@code f} is set.
277      * @return the unary function h(x) = f(fixed, x)
278      */
279     public static UnivariateFunction fix1stArgument(final BivariateFunction f,
280                                                     final double fixed) {
281         return new UnivariateFunction() {
282             /** {@inheritDoc} */
283             @Override
284             public double value(double x) {
285                 return f.value(fixed, x);
286             }
287         };
288     }
289     /**
290      * Creates a unary function by fixing the second argument of a binary function.
291      *
292      * @param f Binary function.
293      * @param fixed value to which the second argument of {@code f} is set.
294      * @return the unary function h(x) = f(x, fixed)
295      */
296     public static UnivariateFunction fix2ndArgument(final BivariateFunction f,
297                                                     final double fixed) {
298         return new UnivariateFunction() {
299             /** {@inheritDoc} */
300             @Override
301             public double value(double x) {
302                 return f.value(x, fixed);
303             }
304         };
305     }
306 
307     /**
308      * Samples the specified univariate real function on the specified interval.
309      * <p>
310      * The interval is divided equally into {@code n} sections and sample points
311      * are taken from {@code min} to {@code max - (max - min) / n}; therefore
312      * {@code f} is not sampled at the upper bound {@code max}.</p>
313      *
314      * @param f Function to be sampled
315      * @param min Lower bound of the interval (included).
316      * @param max Upper bound of the interval (excluded).
317      * @param n Number of sample points.
318      * @return the array of samples.
319      * @throws MathIllegalArgumentException if the lower bound {@code min} is
320      * greater than, or equal to the upper bound {@code max}.
321      * @throws MathIllegalArgumentException if the number of sample points
322      * {@code n} is negative.
323      */
324     public static double[] sample(UnivariateFunction f, double min, double max, int n)
325        throws MathIllegalArgumentException {
326 
327         if (n <= 0) {
328             throw new MathIllegalArgumentException(
329                     LocalizedCoreFormats.NOT_POSITIVE_NUMBER_OF_SAMPLES,
330                     Integer.valueOf(n));
331         }
332         if (min >= max) {
333             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE_BOUND_EXCLUDED,
334                                                    min, max);
335         }
336 
337         final double[] s = new double[n];
338         final double h = (max - min) / n;
339         for (int i = 0; i < n; i++) {
340             s[i] = f.value(min + i * h);
341         }
342         return s;
343     }
344 
345     /** Convert regular functions to {@link UnivariateDifferentiableFunction}.
346      * <p>
347      * This method handle the case with one free parameter and several derivatives.
348      * For the case with several free parameters and only first order derivatives,
349      * see {@link #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)}.
350      * There are no direct support for intermediate cases, with several free parameters
351      * and order 2 or more derivatives, as is would be difficult to specify all the
352      * cross derivatives.
353      * </p>
354      * <p>
355      * Note that the derivatives are expected to be computed only with respect to the
356      * raw parameter x of the base function, i.e. they are df/dx, df<sup>2</sup>/dx<sup>2</sup>, ...
357      * Even if the built function is later used in a composition like f(sin(t)), the provided
358      * derivatives should <em>not</em> apply the composition with sine and its derivatives by
359      * themselves. The composition will be done automatically here and the result will properly
360      * contain f(sin(t)), df(sin(t))/dt, df<sup>2</sup>(sin(t))/dt<sup>2</sup> despite the
361      * provided derivatives functions know nothing about the sine function.
362      * </p>
363      * @param f base function f(x)
364      * @param derivatives derivatives of the base function, in increasing differentiation order
365      * @return a differentiable function with value and all specified derivatives
366      * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
367      * @see #derivative(UnivariateDifferentiableFunction, int)
368      */
369     public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
370                                                                     final UnivariateFunction ... derivatives) {
371 
372         return new UnivariateDifferentiableFunction() {
373 
374             /** {@inheritDoc} */
375             @Override
376             public double value(final double x) {
377                 return f.value(x);
378             }
379 
380             /** {@inheritDoc} */
381             @Override
382             public <T extends Derivative<T>> T value(final T x) {
383                 if (x.getOrder() > derivatives.length) {
384                     throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE,
385                                                            x.getOrder(), derivatives.length);
386                 }
387                 final double[] packed = new double[x.getOrder() + 1];
388                 packed[0] = f.value(x.getValue());
389                 for (int i = 0; i < x.getOrder(); ++i) {
390                     packed[i + 1] = derivatives[i].value(x.getValue());
391                 }
392                 return x.compose(packed);
393             }
394 
395         };
396 
397     }
398 
399     /** Convert regular functions to {@link MultivariateDifferentiableFunction}.
400      * <p>
401      * This method handle the case with several free parameters and only first order derivatives.
402      * For the case with one free parameter and several derivatives,
403      * see {@link #toDifferentiable(UnivariateFunction, UnivariateFunction...)}.
404      * There are no direct support for intermediate cases, with several free parameters
405      * and order 2 or more derivatives, as is would be difficult to specify all the
406      * cross derivatives.
407      * </p>
408      * <p>
409      * Note that the gradient is expected to be computed only with respect to the
410      * raw parameter x of the base function, i.e. it is df/dx<sub>1</sub>, df/dx<sub>2</sub>, ...
411      * Even if the built function is later used in a composition like f(sin(t), cos(t)), the provided
412      * gradient should <em>not</em> apply the composition with sine or cosine and their derivative by
413      * itself. The composition will be done automatically here and the result will properly
414      * contain f(sin(t), cos(t)), df(sin(t), cos(t))/dt despite the provided derivatives functions
415      * know nothing about the sine or cosine functions.
416      * </p>
417      * @param f base function f(x)
418      * @param gradient gradient of the base function
419      * @return a differentiable function with value and gradient
420      * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
421      * @see #derivative(MultivariateDifferentiableFunction, int[])
422      */
423     public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
424                                                                       final MultivariateVectorFunction gradient) {
425 
426         return new MultivariateDifferentiableFunction() {
427 
428             /** {@inheritDoc} */
429             @Override
430             public double value(final double[] point) {
431                 return f.value(point);
432             }
433 
434             /** {@inheritDoc} */
435             @Override
436             public DerivativeStructure value(final DerivativeStructure[] point) {
437 
438                 // set up the input parameters
439                 final double[] dPoint = new double[point.length];
440                 for (int i = 0; i < point.length; ++i) {
441                     dPoint[i] = point[i].getValue();
442                     if (point[i].getOrder() > 1) {
443                         throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE,
444                                                                point[i].getOrder(), 1);
445                     }
446                 }
447 
448                 // evaluate regular functions
449                 final double    v = f.value(dPoint);
450                 final double[] dv = gradient.value(dPoint);
451                 MathUtils.checkDimension(dv.length, point.length);
452 
453                 // build the combined derivative
454                 final int parameters = point[0].getFreeParameters();
455                 final double[] partials = new double[point.length];
456                 final double[] packed = new double[parameters + 1];
457                 packed[0] = v;
458                 final int[] orders = new int[parameters];
459                 for (int i = 0; i < parameters; ++i) {
460 
461                     // we differentiate once with respect to parameter i
462                     orders[i] = 1;
463                     for (int j = 0; j < point.length; ++j) {
464                         partials[j] = point[j].getPartialDerivative(orders);
465                     }
466                     orders[i] = 0;
467 
468                     // compose partial derivatives
469                     packed[i + 1] = MathArrays.linearCombination(dv, partials);
470 
471                 }
472 
473                 return point[0].getFactory().build(packed);
474 
475             }
476 
477         };
478 
479     }
480 
481     /** Convert an {@link UnivariateDifferentiableFunction} to an
482      * {@link UnivariateFunction} computing n<sup>th</sup> order derivative.
483      * <p>
484      * This converter is only a convenience method. Beware computing only one derivative does
485      * not save any computation as the original function will really be called under the hood.
486      * The derivative will be extracted from the full {@link DerivativeStructure} result.
487      * </p>
488      * @param f original function, with value and all its derivatives
489      * @param order of the derivative to extract
490      * @return function computing the derivative at required order
491      * @see #derivative(MultivariateDifferentiableFunction, int[])
492      * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
493      */
494     public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {
495 
496         final DSFactory factory = new DSFactory(1, order);
497 
498         return new UnivariateFunction() {
499 
500             /** {@inheritDoc} */
501             @Override
502             public double value(final double x) {
503                 final DerivativeStructure dsX = factory.variable(0, x);
504                 return f.value(dsX).getPartialDerivative(order);
505             }
506 
507         };
508     }
509 
510     /** Convert an {@link MultivariateDifferentiableFunction} to an
511      * {@link MultivariateFunction} computing n<sup>th</sup> order derivative.
512      * <p>
513      * This converter is only a convenience method. Beware computing only one derivative does
514      * not save any computation as the original function will really be called under the hood.
515      * The derivative will be extracted from the full {@link DerivativeStructure} result.
516      * </p>
517      * @param f original function, with value and all its derivatives
518      * @param orders of the derivative to extract, for each free parameters
519      * @return function computing the derivative at required order
520      * @see #derivative(UnivariateDifferentiableFunction, int)
521      * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
522      */
523     public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {
524 
525         // the maximum differentiation order is the sum of all orders
526         int sum = 0;
527         for (final int order : orders) {
528             sum += order;
529         }
530         final int sumOrders = sum;
531 
532         return new MultivariateFunction() {
533 
534             /** Factory used for building derivatives. */
535             private DSFactory factory;
536 
537             /** {@inheritDoc} */
538             @Override
539             public double value(final double[] point) {
540 
541                 if (factory == null || point.length != factory.getCompiler().getFreeParameters()) {
542                     // rebuild the factory in case of mismatch
543                     factory = new DSFactory(point.length, sumOrders);
544                 }
545 
546                 // set up the input parameters
547                 final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
548                 for (int i = 0; i < point.length; ++i) {
549                     dsPoint[i] = factory.variable(i, point[i]);
550                 }
551 
552                 return f.value(dsPoint).getPartialDerivative(orders);
553 
554             }
555 
556         };
557     }
558 
559 }