View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.analysis.differentiation;
23  
24  import org.hipparchus.analysis.MultivariateVectorFunction;
25  
26  /** Class representing the gradient of a multivariate function.
27   * <p>
28   * The vectorial components of the function represent the derivatives
29   * with respect to each function parameters.
30   * </p>
31   */
32  public class GradientFunction implements MultivariateVectorFunction {
33  
34      /** Underlying real-valued function. */
35      private final MultivariateDifferentiableFunction f;
36  
37      /** Simple constructor.
38       * @param f underlying real-valued function
39       */
40      public GradientFunction(final MultivariateDifferentiableFunction f) {
41          this.f = f;
42      }
43  
44      /** {@inheritDoc} */
45      @Override
46      public double[] value(double[] point) {
47  
48          // set up parameters
49          final DSFactory factory = new DSFactory(point.length, 1);
50          final DerivativeStructure[] dsX = new DerivativeStructure[point.length];
51          for (int i = 0; i < point.length; ++i) {
52              dsX[i] = factory.variable(i, point[i]);
53          }
54  
55          // compute the derivatives
56          final DerivativeStructure dsY = f.value(dsX);
57  
58          // extract the gradient
59          final double[] y = new double[point.length];
60          final int[] orders = new int[point.length];
61          for (int i = 0; i < point.length; ++i) {
62              orders[i] = 1;
63              y[i] = dsY.getPartialDerivative(orders);
64              orders[i] = 0;
65          }
66  
67          return y;
68  
69      }
70  
71  }