1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22 package org.hipparchus.fitting;
23
24 import java.util.ArrayList;
25 import java.util.Collection;
26 import java.util.Comparator;
27 import java.util.List;
28
29 import org.hipparchus.analysis.function.Gaussian;
30 import org.hipparchus.exception.LocalizedCoreFormats;
31 import org.hipparchus.exception.MathIllegalArgumentException;
32 import org.hipparchus.linear.DiagonalMatrix;
33 import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresBuilder;
34 import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem;
35 import org.hipparchus.util.FastMath;
36 import org.hipparchus.util.MathUtils;
37
38 /**
39 * Fits points to a {@link
40 * org.hipparchus.analysis.function.Gaussian.Parametric Gaussian}
41 * function.
42 * <br>
43 * The {@link #withStartPoint(double[]) initial guess values} must be passed
44 * in the following order:
45 * <ul>
46 * <li>Normalization</li>
47 * <li>Mean</li>
48 * <li>Sigma</li>
49 * </ul>
50 * The optimal values will be returned in the same order.
51 *
52 * <p>
53 * Usage example:
54 * <pre>
55 * WeightedObservedPoints obs = new WeightedObservedPoints();
56 * obs.add(4.0254623, 531026.0);
57 * obs.add(4.03128248, 984167.0);
58 * obs.add(4.03839603, 1887233.0);
59 * obs.add(4.04421621, 2687152.0);
60 * obs.add(4.05132976, 3461228.0);
61 * obs.add(4.05326982, 3580526.0);
62 * obs.add(4.05779662, 3439750.0);
63 * obs.add(4.0636168, 2877648.0);
64 * obs.add(4.06943698, 2175960.0);
65 * obs.add(4.07525716, 1447024.0);
66 * obs.add(4.08237071, 717104.0);
67 * obs.add(4.08366408, 620014.0);
68 * double[] parameters = GaussianCurveFitter.create().fit(obs.toList());
69 * </pre>
70 *
71 */
72 public class GaussianCurveFitter extends AbstractCurveFitter {
73 /** Parametric function to be fitted. */
74 private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
75 /** {@inheritDoc} */
76 @Override
77 public double value(double x, double ... p) {
78 double v = Double.POSITIVE_INFINITY;
79 try {
80 v = super.value(x, p);
81 } catch (MathIllegalArgumentException e) { // NOPMD
82 // Do nothing.
83 }
84 return v;
85 }
86
87 /** {@inheritDoc} */
88 @Override
89 public double[] gradient(double x, double ... p) {
90 double[] v = { Double.POSITIVE_INFINITY,
91 Double.POSITIVE_INFINITY,
92 Double.POSITIVE_INFINITY };
93 try {
94 v = super.gradient(x, p);
95 } catch (MathIllegalArgumentException e) { // NOPMD
96 // Do nothing.
97 }
98 return v;
99 }
100 };
101 /** Initial guess. */
102 private final double[] initialGuess;
103 /** Maximum number of iterations of the optimization algorithm. */
104 private final int maxIter;
105
106 /**
107 * Constructor used by the factory methods.
108 *
109 * @param initialGuess Initial guess. If set to {@code null}, the initial guess
110 * will be estimated using the {@link ParameterGuesser}.
111 * @param maxIter Maximum number of iterations of the optimization algorithm.
112 */
113 private GaussianCurveFitter(double[] initialGuess, int maxIter) {
114 this.initialGuess = initialGuess == null ? null : initialGuess.clone();
115 this.maxIter = maxIter;
116 }
117
118 /**
119 * Creates a default curve fitter.
120 * The initial guess for the parameters will be {@link ParameterGuesser}
121 * computed automatically, and the maximum number of iterations of the
122 * optimization algorithm is set to {@link Integer#MAX_VALUE}.
123 *
124 * @return a curve fitter.
125 *
126 * @see #withStartPoint(double[])
127 * @see #withMaxIterations(int)
128 */
129 public static GaussianCurveFitter create() {
130 return new GaussianCurveFitter(null, Integer.MAX_VALUE);
131 }
132
133 /**
134 * Configure the start point (initial guess).
135 * @param newStart new start point (initial guess)
136 * @return a new instance.
137 */
138 public GaussianCurveFitter withStartPoint(double[] newStart) {
139 return new GaussianCurveFitter(newStart.clone(),
140 maxIter);
141 }
142
143 /**
144 * Configure the maximum number of iterations.
145 * @param newMaxIter maximum number of iterations
146 * @return a new instance.
147 */
148 public GaussianCurveFitter withMaxIterations(int newMaxIter) {
149 return new GaussianCurveFitter(initialGuess,
150 newMaxIter);
151 }
152
153 /** {@inheritDoc} */
154 @Override
155 protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
156
157 // Prepare least-squares problem.
158 final int len = observations.size();
159 final double[] target = new double[len];
160 final double[] weights = new double[len];
161
162 int i = 0;
163 for (WeightedObservedPoint obs : observations) {
164 target[i] = obs.getY();
165 weights[i] = obs.getWeight();
166 ++i;
167 }
168
169 final AbstractCurveFitter.TheoreticalValuesFunction model =
170 new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);
171
172 final double[] startPoint = initialGuess != null ?
173 initialGuess :
174 // Compute estimation.
175 new ParameterGuesser(observations).guess();
176
177 // Return a new least squares problem set up to fit a Gaussian curve to the
178 // observed points.
179 return new LeastSquaresBuilder().
180 maxEvaluations(Integer.MAX_VALUE).
181 maxIterations(maxIter).
182 start(startPoint).
183 target(target).
184 weight(new DiagonalMatrix(weights)).
185 model(model.getModelFunction(), model.getModelFunctionJacobian()).
186 build();
187
188 }
189
190 /**
191 * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
192 * of a {@link org.hipparchus.analysis.function.Gaussian.Parametric}
193 * based on the specified observed points.
194 */
195 public static class ParameterGuesser {
196 /** Normalization factor. */
197 private final double norm;
198 /** Mean. */
199 private final double mean;
200 /** Standard deviation. */
201 private final double sigma;
202
203 /**
204 * Constructs instance with the specified observed points.
205 *
206 * @param observations Observed points from which to guess the
207 * parameters of the Gaussian.
208 * @throws org.hipparchus.exception.NullArgumentException if {@code observations} is
209 * {@code null}.
210 * @throws MathIllegalArgumentException if there are less than 3
211 * observations.
212 */
213 public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
214 MathUtils.checkNotNull(observations);
215 if (observations.size() < 3) {
216 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
217 observations.size(), 3);
218 }
219
220 final List<WeightedObservedPoint> sorted = sortObservations(observations);
221 final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
222
223 norm = params[0];
224 mean = params[1];
225 sigma = params[2];
226 }
227
228 /**
229 * Gets an estimation of the parameters.
230 *
231 * @return the guessed parameters, in the following order:
232 * <ul>
233 * <li>Normalization factor</li>
234 * <li>Mean</li>
235 * <li>Standard deviation</li>
236 * </ul>
237 */
238 public double[] guess() {
239 return new double[] { norm, mean, sigma };
240 }
241
242 /**
243 * Sort the observations.
244 *
245 * @param unsorted Input observations.
246 * @return the input observations, sorted.
247 */
248 private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
249 final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
250
251 final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() {
252 /** {@inheritDoc} */
253 @Override
254 public int compare(WeightedObservedPoint p1,
255 WeightedObservedPoint p2) {
256 if (p1 == null && p2 == null) {
257 return 0;
258 }
259 if (p1 == null) {
260 return -1;
261 }
262 if (p2 == null) {
263 return 1;
264 }
265 int comp = Double.compare(p1.getX(), p2.getX());
266 if (comp != 0) {
267 return comp;
268 }
269 comp = Double.compare(p1.getY(), p2.getY());
270 if (comp != 0) {
271 return comp;
272 }
273 comp = Double.compare(p1.getWeight(), p2.getWeight());
274 if (comp != 0) {
275 return comp;
276 }
277 return 0;
278 }
279 };
280
281 observations.sort(cmp);
282 return observations;
283 }
284
285 /**
286 * Guesses the parameters based on the specified observed points.
287 *
288 * @param points Observed points, sorted.
289 * @return the guessed parameters (normalization factor, mean and
290 * sigma).
291 */
292 private double[] basicGuess(WeightedObservedPoint[] points) {
293 final int maxYIdx = findMaxY(points);
294 final double n = points[maxYIdx].getY();
295 final double m = points[maxYIdx].getX();
296
297 double fwhmApprox;
298 try {
299 final double halfY = n + ((m - n) / 2);
300 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
301 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
302 fwhmApprox = fwhmX2 - fwhmX1;
303 } catch (MathIllegalArgumentException e) {
304 // TODO: Exceptions should not be used for flow control.
305 fwhmApprox = points[points.length - 1].getX() - points[0].getX();
306 }
307 final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));
308
309 return new double[] { n, m, s };
310 }
311
312 /**
313 * Finds index of point in specified points with the largest Y.
314 *
315 * @param points Points to search.
316 * @return the index in specified points array.
317 */
318 private int findMaxY(WeightedObservedPoint[] points) {
319 int maxYIdx = 0;
320 for (int i = 1; i < points.length; i++) {
321 if (points[i].getY() > points[maxYIdx].getY()) {
322 maxYIdx = i;
323 }
324 }
325 return maxYIdx;
326 }
327
328 /**
329 * Interpolates using the specified points to determine X at the
330 * specified Y.
331 *
332 * @param points Points to use for interpolation.
333 * @param startIdx Index within points from which to start the search for
334 * interpolation bounds points.
335 * @param idxStep Index step for searching interpolation bounds points.
336 * @param y Y value for which X should be determined.
337 * @return the value of X for the specified Y.
338 * @throws MathIllegalArgumentException if {@code idxStep} is 0.
339 * @throws MathIllegalArgumentException if specified {@code y} is not within the
340 * range of the specified {@code points}.
341 */
342 private double interpolateXAtY(WeightedObservedPoint[] points,
343 int startIdx,
344 int idxStep,
345 double y)
346 throws MathIllegalArgumentException {
347 if (idxStep == 0) {
348 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_NOT_ALLOWED);
349 }
350 final WeightedObservedPoint[] twoPoints
351 = getInterpolationPointsForY(points, startIdx, idxStep, y);
352 final WeightedObservedPoint p1 = twoPoints[0];
353 final WeightedObservedPoint p2 = twoPoints[1];
354 if (p1.getY() == y) {
355 return p1.getX();
356 }
357 if (p2.getY() == y) {
358 return p2.getX();
359 }
360 return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
361 (p2.getY() - p1.getY()));
362 }
363
364 /**
365 * Gets the two bounding interpolation points from the specified points
366 * suitable for determining X at the specified Y.
367 *
368 * @param points Points to use for interpolation.
369 * @param startIdx Index within points from which to start search for
370 * interpolation bounds points.
371 * @param idxStep Index step for search for interpolation bounds points.
372 * @param y Y value for which X should be determined.
373 * @return the array containing two points suitable for determining X at
374 * the specified Y.
375 * @throws MathIllegalArgumentException if {@code idxStep} is 0.
376 * @throws MathIllegalArgumentException if specified {@code y} is not within the
377 * range of the specified {@code points}.
378 */
379 private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
380 int startIdx,
381 int idxStep,
382 double y)
383 throws MathIllegalArgumentException {
384 if (idxStep == 0) {
385 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_NOT_ALLOWED);
386 }
387 for (int i = startIdx;
388 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
389 i += idxStep) {
390 final WeightedObservedPoint p1 = points[i];
391 final WeightedObservedPoint p2 = points[i + idxStep];
392 if (isBetween(y, p1.getY(), p2.getY())) {
393 if (idxStep < 0) {
394 return new WeightedObservedPoint[] { p2, p1 };
395 } else {
396 return new WeightedObservedPoint[] { p1, p2 };
397 }
398 }
399 }
400
401 // Boundaries are replaced by dummy values because the raised
402 // exception is caught and the message never displayed.
403 // TODO: Exceptions should not be used for flow control.
404 throw new MathIllegalArgumentException(LocalizedCoreFormats.OUT_OF_RANGE_SIMPLE,
405 y, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
406 }
407
408 /**
409 * Determines whether a value is between two other values.
410 *
411 * @param value Value to test whether it is between {@code boundary1}
412 * and {@code boundary2}.
413 * @param boundary1 One end of the range.
414 * @param boundary2 Other end of the range.
415 * @return {@code true} if {@code value} is between {@code boundary1} and
416 * {@code boundary2} (inclusive), {@code false} otherwise.
417 */
418 private boolean isBetween(double value,
419 double boundary1,
420 double boundary2) {
421 return (value >= boundary1 && value <= boundary2) ||
422 (value >= boundary2 && value <= boundary1);
423 }
424 }
425 }