1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22 package org.hipparchus.fitting;
23
24 import java.util.Collection;
25
26 import org.hipparchus.analysis.ParametricUnivariateFunction;
27 import org.hipparchus.linear.DiagonalMatrix;
28 import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresBuilder;
29 import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem;
30
31 /**
32 * Fits points to a user-defined {@link ParametricUnivariateFunction function}.
33 *
34 */
35 public class SimpleCurveFitter extends AbstractCurveFitter {
36 /** Function to fit. */
37 private final ParametricUnivariateFunction function;
38 /** Initial guess for the parameters. */
39 private final double[] initialGuess;
40 /** Maximum number of iterations of the optimization algorithm. */
41 private final int maxIter;
42
43 /**
44 * Constructor used by the factory methods.
45 *
46 * @param function Function to fit.
47 * @param initialGuess Initial guess. Cannot be {@code null}. Its length must
48 * be consistent with the number of parameters of the {@code function} to fit.
49 * @param maxIter Maximum number of iterations of the optimization algorithm.
50 */
51 private SimpleCurveFitter(ParametricUnivariateFunction function, double[] initialGuess, int maxIter) {
52 this.function = function;
53 this.initialGuess = initialGuess.clone();
54 this.maxIter = maxIter;
55 }
56
57 /**
58 * Creates a curve fitter.
59 * The maximum number of iterations of the optimization algorithm is set
60 * to {@link Integer#MAX_VALUE}.
61 *
62 * @param f Function to fit.
63 * @param start Initial guess for the parameters. Cannot be {@code null}.
64 * Its length must be consistent with the number of parameters of the
65 * function to fit.
66 * @return a curve fitter.
67 *
68 * @see #withStartPoint(double[])
69 * @see #withMaxIterations(int)
70 */
71 public static SimpleCurveFitter create(ParametricUnivariateFunction f,
72 double[] start) {
73 return new SimpleCurveFitter(f, start, Integer.MAX_VALUE);
74 }
75
76 /**
77 * Configure the start point (initial guess).
78 * @param newStart new start point (initial guess)
79 * @return a new instance.
80 */
81 public SimpleCurveFitter withStartPoint(double[] newStart) {
82 return new SimpleCurveFitter(function,
83 newStart.clone(),
84 maxIter);
85 }
86
87 /**
88 * Configure the maximum number of iterations.
89 * @param newMaxIter maximum number of iterations
90 * @return a new instance.
91 */
92 public SimpleCurveFitter withMaxIterations(int newMaxIter) {
93 return new SimpleCurveFitter(function,
94 initialGuess,
95 newMaxIter);
96 }
97
98 /** {@inheritDoc} */
99 @Override
100 protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
101 // Prepare least-squares problem.
102 final int len = observations.size();
103 final double[] target = new double[len];
104 final double[] weights = new double[len];
105
106 int count = 0;
107 for (WeightedObservedPoint obs : observations) {
108 target[count] = obs.getY();
109 weights[count] = obs.getWeight();
110 ++count;
111 }
112
113 final AbstractCurveFitter.TheoreticalValuesFunction model
114 = new AbstractCurveFitter.TheoreticalValuesFunction(function,
115 observations);
116
117 // Create an optimizer for fitting the curve to the observed points.
118 return new LeastSquaresBuilder().
119 maxEvaluations(Integer.MAX_VALUE).
120 maxIterations(maxIter).
121 start(initialGuess).
122 target(target).
123 weight(new DiagonalMatrix(weights)).
124 model(model.getModelFunction(), model.getModelFunctionJacobian()).
125 build();
126 }
127 }