View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.fitting;
23  
24  import java.util.Collection;
25  
26  import org.hipparchus.analysis.MultivariateMatrixFunction;
27  import org.hipparchus.analysis.MultivariateVectorFunction;
28  import org.hipparchus.analysis.ParametricUnivariateFunction;
29  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresOptimizer;
30  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem;
31  import org.hipparchus.optim.nonlinear.vector.leastsquares.LevenbergMarquardtOptimizer;
32  
33  /**
34   * Base class that contains common code for fitting parametric univariate
35   * real functions <code>y = f(p<sub>i</sub>;x)</code>, where {@code x} is
36   * the independent variable and the <code>p<sub>i</sub></code> are the
37   * <em>parameters</em>.
38   * <br>
39   * A fitter will find the optimal values of the parameters by
40   * <em>fitting</em> the curve so it remains very close to a set of
41   * {@code N} observed points <code>(x<sub>k</sub>, y<sub>k</sub>)</code>,
42   * {@code 0 <= k < N}.
43   * <br>
44   * An algorithm usually performs the fit by finding the parameter
45   * values that minimizes the objective function
46   * <pre><code>
47   *  &sum;y<sub>k</sub> - f(x<sub>k</sub>)<sup>2</sup>,
48   * </code></pre>
49   * which is actually a least-squares problem.
50   * This class contains boilerplate code for calling the
51   * {@link #fit(Collection)} method for obtaining the parameters.
52   * The problem setup, such as the choice of optimization algorithm
53   * for fitting a specific function is delegated to subclasses.
54   *
55   */
56  public abstract class AbstractCurveFitter {
57  
58      /** Empty constructor.
59       * <p>
60       * This constructor is not strictly necessary, but it prevents spurious
61       * javadoc warnings with JDK 18 and later.
62       * </p>
63       * @since 3.0
64       */
65      public AbstractCurveFitter() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
66          // nothing to do
67      }
68  
69      /**
70       * Fits a curve.
71       * This method computes the coefficients of the curve that best
72       * fit the sample of observed points.
73       *
74       * @param points Observations.
75       * @return the fitted parameters.
76       */
77      public double[] fit(Collection<WeightedObservedPoint> points) {
78          // Perform the fit.
79          return getOptimizer().optimize(getProblem(points)).getPoint().toArray();
80      }
81  
82      /**
83       * Creates an optimizer set up to fit the appropriate curve.
84       * <p>
85       * The default implementation uses a {@link LevenbergMarquardtOptimizer
86       * Levenberg-Marquardt} optimizer.
87       * </p>
88       * @return the optimizer to use for fitting the curve to the
89       * given {@code points}.
90       */
91      protected LeastSquaresOptimizer getOptimizer() {
92          return new LevenbergMarquardtOptimizer();
93      }
94  
95      /**
96       * Creates a least squares problem corresponding to the appropriate curve.
97       *
98       * @param points Sample points.
99       * @return the least squares problem to use for fitting the curve to the
100      * given {@code points}.
101      */
102     protected abstract LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points);
103 
104     /**
105      * Vector function for computing function theoretical values.
106      */
107     protected static class TheoreticalValuesFunction {
108         /** Function to fit. */
109         private final ParametricUnivariateFunction f;
110         /** Observations. */
111         private final double[] points;
112 
113         /** Simple constructor.
114          * @param f function to fit.
115          * @param observations Observations.
116          */
117         public TheoreticalValuesFunction(final ParametricUnivariateFunction f,
118                                          final Collection<WeightedObservedPoint> observations) {
119             this.f = f;
120 
121             final int len = observations.size();
122             this.points = new double[len];
123             int i = 0;
124             for (WeightedObservedPoint obs : observations) {
125                 this.points[i++] = obs.getX();
126             }
127         }
128 
129         /** Get model function value.
130          * @return the model function value
131          */
132         public MultivariateVectorFunction getModelFunction() {
133             return new MultivariateVectorFunction() {
134                 /** {@inheritDoc} */
135                 @Override
136                 public double[] value(double[] p) {
137                     final int len = points.length;
138                     final double[] values = new double[len];
139                     for (int i = 0; i < len; i++) {
140                         values[i] = f.value(points[i], p);
141                     }
142 
143                     return values;
144                 }
145             };
146         }
147 
148         /** Get model function Jacobian.
149          * @return the model function Jacobian
150          */
151         public MultivariateMatrixFunction getModelFunctionJacobian() {
152             return new MultivariateMatrixFunction() {
153                 /** {@inheritDoc} */
154                 @Override
155                 public double[][] value(double[] p) {
156                     final int len = points.length;
157                     final double[][] jacobian = new double[len][];
158                     for (int i = 0; i < len; i++) {
159                         jacobian[i] = f.gradient(points[i], p);
160                     }
161                     return jacobian;
162                 }
163             };
164         }
165     }
166 }