1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.fitting;
23
24 import java.util.ArrayList;
25 import java.util.Collection;
26 import java.util.Comparator;
27 import java.util.List;
28
29 import org.hipparchus.analysis.function.Gaussian;
30 import org.hipparchus.exception.LocalizedCoreFormats;
31 import org.hipparchus.exception.MathIllegalArgumentException;
32 import org.hipparchus.linear.DiagonalMatrix;
33 import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresBuilder;
34 import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem;
35 import org.hipparchus.util.FastMath;
36 import org.hipparchus.util.MathUtils;
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72 public class GaussianCurveFitter extends AbstractCurveFitter {
73
74 private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
75
76 @Override
77 public double value(double x, double ... p) {
78 double v = Double.POSITIVE_INFINITY;
79 try {
80 v = super.value(x, p);
81 } catch (MathIllegalArgumentException e) {
82
83 }
84 return v;
85 }
86
87
88 @Override
89 public double[] gradient(double x, double ... p) {
90 double[] v = { Double.POSITIVE_INFINITY,
91 Double.POSITIVE_INFINITY,
92 Double.POSITIVE_INFINITY };
93 try {
94 v = super.gradient(x, p);
95 } catch (MathIllegalArgumentException e) {
96
97 }
98 return v;
99 }
100 };
101
102 private final double[] initialGuess;
103
104 private final int maxIter;
105
106
107
108
109
110
111
112
113 private GaussianCurveFitter(double[] initialGuess, int maxIter) {
114 this.initialGuess = initialGuess == null ? null : initialGuess.clone();
115 this.maxIter = maxIter;
116 }
117
118
119
120
121
122
123
124
125
126
127
128
129 public static GaussianCurveFitter create() {
130 return new GaussianCurveFitter(null, Integer.MAX_VALUE);
131 }
132
133
134
135
136
137
138 public GaussianCurveFitter withStartPoint(double[] newStart) {
139 return new GaussianCurveFitter(newStart.clone(),
140 maxIter);
141 }
142
143
144
145
146
147
148 public GaussianCurveFitter withMaxIterations(int newMaxIter) {
149 return new GaussianCurveFitter(initialGuess,
150 newMaxIter);
151 }
152
153
154 @Override
155 protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
156
157
158 final int len = observations.size();
159 final double[] target = new double[len];
160 final double[] weights = new double[len];
161
162 int i = 0;
163 for (WeightedObservedPoint obs : observations) {
164 target[i] = obs.getY();
165 weights[i] = obs.getWeight();
166 ++i;
167 }
168
169 final AbstractCurveFitter.TheoreticalValuesFunction model =
170 new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);
171
172 final double[] startPoint = initialGuess != null ?
173 initialGuess :
174
175 new ParameterGuesser(observations).guess();
176
177
178
179 return new LeastSquaresBuilder().
180 maxEvaluations(Integer.MAX_VALUE).
181 maxIterations(maxIter).
182 start(startPoint).
183 target(target).
184 weight(new DiagonalMatrix(weights)).
185 model(model.getModelFunction(), model.getModelFunctionJacobian()).
186 build();
187
188 }
189
190
191
192
193
194
195 public static class ParameterGuesser {
196
197 private final double norm;
198
199 private final double mean;
200
201 private final double sigma;
202
203
204
205
206
207
208
209
210
211
212
213 public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
214 MathUtils.checkNotNull(observations);
215 if (observations.size() < 3) {
216 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
217 observations.size(), 3);
218 }
219
220 final List<WeightedObservedPoint> sorted = sortObservations(observations);
221 final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
222
223 norm = params[0];
224 mean = params[1];
225 sigma = params[2];
226 }
227
228
229
230
231
232
233
234
235
236
237
238 public double[] guess() {
239 return new double[] { norm, mean, sigma };
240 }
241
242
243
244
245
246
247
248 private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
249 final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
250
251 final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() {
252
253 @Override
254 public int compare(WeightedObservedPoint p1,
255 WeightedObservedPoint p2) {
256 if (p1 == null && p2 == null) {
257 return 0;
258 }
259 if (p1 == null) {
260 return -1;
261 }
262 if (p2 == null) {
263 return 1;
264 }
265 int comp = Double.compare(p1.getX(), p2.getX());
266 if (comp != 0) {
267 return comp;
268 }
269 comp = Double.compare(p1.getY(), p2.getY());
270 if (comp != 0) {
271 return comp;
272 }
273 comp = Double.compare(p1.getWeight(), p2.getWeight());
274 if (comp != 0) {
275 return comp;
276 }
277 return 0;
278 }
279 };
280
281 observations.sort(cmp);
282 return observations;
283 }
284
285
286
287
288
289
290
291
292 private double[] basicGuess(WeightedObservedPoint[] points) {
293 final int maxYIdx = findMaxY(points);
294 final double n = points[maxYIdx].getY();
295 final double m = points[maxYIdx].getX();
296
297 double fwhmApprox;
298 try {
299 final double halfY = n + ((m - n) / 2);
300 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
301 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
302 fwhmApprox = fwhmX2 - fwhmX1;
303 } catch (MathIllegalArgumentException e) {
304
305 fwhmApprox = points[points.length - 1].getX() - points[0].getX();
306 }
307 final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));
308
309 return new double[] { n, m, s };
310 }
311
312
313
314
315
316
317
318 private int findMaxY(WeightedObservedPoint[] points) {
319 int maxYIdx = 0;
320 for (int i = 1; i < points.length; i++) {
321 if (points[i].getY() > points[maxYIdx].getY()) {
322 maxYIdx = i;
323 }
324 }
325 return maxYIdx;
326 }
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342 private double interpolateXAtY(WeightedObservedPoint[] points,
343 int startIdx,
344 int idxStep,
345 double y)
346 throws MathIllegalArgumentException {
347 if (idxStep == 0) {
348 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_NOT_ALLOWED);
349 }
350 final WeightedObservedPoint[] twoPoints
351 = getInterpolationPointsForY(points, startIdx, idxStep, y);
352 final WeightedObservedPoint p1 = twoPoints[0];
353 final WeightedObservedPoint p2 = twoPoints[1];
354 if (p1.getY() == y) {
355 return p1.getX();
356 }
357 if (p2.getY() == y) {
358 return p2.getX();
359 }
360 return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
361 (p2.getY() - p1.getY()));
362 }
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379 private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
380 int startIdx,
381 int idxStep,
382 double y)
383 throws MathIllegalArgumentException {
384 if (idxStep == 0) {
385 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_NOT_ALLOWED);
386 }
387 for (int i = startIdx;
388 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
389 i += idxStep) {
390 final WeightedObservedPoint p1 = points[i];
391 final WeightedObservedPoint p2 = points[i + idxStep];
392 if (isBetween(y, p1.getY(), p2.getY())) {
393 if (idxStep < 0) {
394 return new WeightedObservedPoint[] { p2, p1 };
395 } else {
396 return new WeightedObservedPoint[] { p1, p2 };
397 }
398 }
399 }
400
401
402
403
404 throw new MathIllegalArgumentException(LocalizedCoreFormats.OUT_OF_RANGE_SIMPLE,
405 y, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
406 }
407
408
409
410
411
412
413
414
415
416
417
418 private boolean isBetween(double value,
419 double boundary1,
420 double boundary2) {
421 return (value >= boundary1 && value <= boundary2) ||
422 (value >= boundary2 && value <= boundary1);
423 }
424 }
425 }