View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.fitting;
23  
24  import java.util.Collection;
25  
26  import org.hipparchus.analysis.ParametricUnivariateFunction;
27  import org.hipparchus.linear.DiagonalMatrix;
28  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresBuilder;
29  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem;
30  
31  /**
32   * Fits points to a user-defined {@link ParametricUnivariateFunction function}.
33   *
34   */
35  public class SimpleCurveFitter extends AbstractCurveFitter {
36      /** Function to fit. */
37      private final ParametricUnivariateFunction function;
38      /** Initial guess for the parameters. */
39      private final double[] initialGuess;
40      /** Maximum number of iterations of the optimization algorithm. */
41      private final int maxIter;
42  
43      /**
44       * Constructor used by the factory methods.
45       *
46       * @param function Function to fit.
47       * @param initialGuess Initial guess. Cannot be {@code null}. Its length must
48       * be consistent with the number of parameters of the {@code function} to fit.
49       * @param maxIter Maximum number of iterations of the optimization algorithm.
50       */
51      private SimpleCurveFitter(ParametricUnivariateFunction function, double[] initialGuess, int maxIter) {
52          this.function = function;
53          this.initialGuess = initialGuess.clone();
54          this.maxIter = maxIter;
55      }
56  
57      /**
58       * Creates a curve fitter.
59       * The maximum number of iterations of the optimization algorithm is set
60       * to {@link Integer#MAX_VALUE}.
61       *
62       * @param f Function to fit.
63       * @param start Initial guess for the parameters.  Cannot be {@code null}.
64       * Its length must be consistent with the number of parameters of the
65       * function to fit.
66       * @return a curve fitter.
67       *
68       * @see #withStartPoint(double[])
69       * @see #withMaxIterations(int)
70       */
71      public static SimpleCurveFitter create(ParametricUnivariateFunction f,
72                                             double[] start) {
73          return new SimpleCurveFitter(f, start, Integer.MAX_VALUE);
74      }
75  
76      /**
77       * Configure the start point (initial guess).
78       * @param newStart new start point (initial guess)
79       * @return a new instance.
80       */
81      public SimpleCurveFitter withStartPoint(double[] newStart) {
82          return new SimpleCurveFitter(function,
83                                       newStart.clone(),
84                                       maxIter);
85      }
86  
87      /**
88       * Configure the maximum number of iterations.
89       * @param newMaxIter maximum number of iterations
90       * @return a new instance.
91       */
92      public SimpleCurveFitter withMaxIterations(int newMaxIter) {
93          return new SimpleCurveFitter(function,
94                                       initialGuess,
95                                       newMaxIter);
96      }
97  
98      /** {@inheritDoc} */
99      @Override
100     protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
101         // Prepare least-squares problem.
102         final int len = observations.size();
103         final double[] target  = new double[len];
104         final double[] weights = new double[len];
105 
106         int count = 0;
107         for (WeightedObservedPoint obs : observations) {
108             target[count]  = obs.getY();
109             weights[count] = obs.getWeight();
110             ++count;
111         }
112 
113         final AbstractCurveFitter.TheoreticalValuesFunction model
114             = new AbstractCurveFitter.TheoreticalValuesFunction(function,
115                                                                 observations);
116 
117         // Create an optimizer for fitting the curve to the observed points.
118         return new LeastSquaresBuilder().
119                 maxEvaluations(Integer.MAX_VALUE).
120                 maxIterations(maxIter).
121                 start(initialGuess).
122                 target(target).
123                 weight(new DiagonalMatrix(weights)).
124                 model(model.getModelFunction(), model.getModelFunctionJacobian()).
125                 build();
126     }
127 }