View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.fitting;
23  
24  import org.hipparchus.exception.MathIllegalArgumentException;
25  import org.hipparchus.exception.MathIllegalStateException;
26  import org.junit.jupiter.api.Test;
27  
28  import static org.junit.jupiter.api.Assertions.assertEquals;
29  import static org.junit.jupiter.api.Assertions.assertThrows;
30  
31  /**
32   * Tests {@link GaussianCurveFitter}.
33   *
34   */
35  class GaussianCurveFitterTest {
36      /** Good data. */
37      protected static final double[][] DATASET1 = new double[][] {
38          {4.0254623,  531026.0},
39          {4.02804905, 664002.0},
40          {4.02934242, 787079.0},
41          {4.03128248, 984167.0},
42          {4.03386923, 1294546.0},
43          {4.03580929, 1560230.0},
44          {4.03839603, 1887233.0},
45          {4.0396894,  2113240.0},
46          {4.04162946, 2375211.0},
47          {4.04421621, 2687152.0},
48          {4.04550958, 2862644.0},
49          {4.04744964, 3078898.0},
50          {4.05003639, 3327238.0},
51          {4.05132976, 3461228.0},
52          {4.05326982, 3580526.0},
53          {4.05585657, 3576946.0},
54          {4.05779662, 3439750.0},
55          {4.06038337, 3220296.0},
56          {4.06167674, 3070073.0},
57          {4.0636168,  2877648.0},
58          {4.06620355, 2595848.0},
59          {4.06749692, 2390157.0},
60          {4.06943698, 2175960.0},
61          {4.07202373, 1895104.0},
62          {4.0733171,  1687576.0},
63          {4.07525716, 1447024.0},
64          {4.0778439,  1130879.0},
65          {4.07978396, 904900.0},
66          {4.08237071, 717104.0},
67          {4.08366408, 620014.0}
68      };
69      /** Poor data: right of peak not symmetric with left of peak. */
70      protected static final double[][] DATASET2 = new double[][] {
71          {-20.15,   1523.0},
72          {-19.65,   1566.0},
73          {-19.15,   1592.0},
74          {-18.65,   1927.0},
75          {-18.15,   3089.0},
76          {-17.65,   6068.0},
77          {-17.15,  14239.0},
78          {-16.65,  34124.0},
79          {-16.15,  64097.0},
80          {-15.65, 110352.0},
81          {-15.15, 164742.0},
82          {-14.65, 209499.0},
83          {-14.15, 267274.0},
84          {-13.65, 283290.0},
85          {-13.15, 275363.0},
86          {-12.65, 258014.0},
87          {-12.15, 225000.0},
88          {-11.65, 200000.0},
89          {-11.15, 190000.0},
90          {-10.65, 185000.0},
91          {-10.15, 180000.0},
92          { -9.65, 179000.0},
93          { -9.15, 178000.0},
94          { -8.65, 177000.0},
95          { -8.15, 176000.0},
96          { -7.65, 175000.0},
97          { -7.15, 174000.0},
98          { -6.65, 173000.0},
99          { -6.15, 172000.0},
100         { -5.65, 171000.0},
101         { -5.15, 170000.0}
102     };
103     /** Poor data: long tails. */
104     protected static final double[][] DATASET3 = new double[][] {
105         {-90.15,   1513.0},
106         {-80.15,   1514.0},
107         {-70.15,   1513.0},
108         {-60.15,   1514.0},
109         {-50.15,   1513.0},
110         {-40.15,   1514.0},
111         {-30.15,   1513.0},
112         {-20.15,   1523.0},
113         {-19.65,   1566.0},
114         {-19.15,   1592.0},
115         {-18.65,   1927.0},
116         {-18.15,   3089.0},
117         {-17.65,   6068.0},
118         {-17.15,  14239.0},
119         {-16.65,  34124.0},
120         {-16.15,  64097.0},
121         {-15.65, 110352.0},
122         {-15.15, 164742.0},
123         {-14.65, 209499.0},
124         {-14.15, 267274.0},
125         {-13.65, 283290.0},
126         {-13.15, 275363.0},
127         {-12.65, 258014.0},
128         {-12.15, 214073.0},
129         {-11.65, 182244.0},
130         {-11.15, 136419.0},
131         {-10.65,  97823.0},
132         {-10.15,  58930.0},
133         { -9.65,  35404.0},
134         { -9.15,  16120.0},
135         { -8.65,   9823.0},
136         { -8.15,   5064.0},
137         { -7.65,   2575.0},
138         { -7.15,   1642.0},
139         { -6.65,   1101.0},
140         { -6.15,    812.0},
141         { -5.65,    690.0},
142         { -5.15,    565.0},
143         {  5.15,    564.0},
144         { 15.15,    565.0},
145         { 25.15,    564.0},
146         { 35.15,    565.0},
147         { 45.15,    564.0},
148         { 55.15,    565.0},
149         { 65.15,    564.0},
150         { 75.15,    565.0}
151     };
152     /** Poor data: right of peak is missing. */
153     protected static final double[][] DATASET4 = new double[][] {
154         {-20.15,   1523.0},
155         {-19.65,   1566.0},
156         {-19.15,   1592.0},
157         {-18.65,   1927.0},
158         {-18.15,   3089.0},
159         {-17.65,   6068.0},
160         {-17.15,  14239.0},
161         {-16.65,  34124.0},
162         {-16.15,  64097.0},
163         {-15.65, 110352.0},
164         {-15.15, 164742.0},
165         {-14.65, 209499.0},
166         {-14.15, 267274.0},
167         {-13.65, 283290.0}
168     };
169     /** Good data, but few points. */
170     protected static final double[][] DATASET5 = new double[][] {
171         {4.0254623,  531026.0},
172         {4.03128248, 984167.0},
173         {4.03839603, 1887233.0},
174         {4.04421621, 2687152.0},
175         {4.05132976, 3461228.0},
176         {4.05326982, 3580526.0},
177         {4.05779662, 3439750.0},
178         {4.0636168,  2877648.0},
179         {4.06943698, 2175960.0},
180         {4.07525716, 1447024.0},
181         {4.08237071, 717104.0},
182         {4.08366408, 620014.0}
183     };
184 
185     /**
186      * Basic.
187      */
188     @Test
189     void testFit01() {
190         GaussianCurveFitter fitter = GaussianCurveFitter.create();
191         double[] parameters = fitter.fit(createDataset(DATASET1).toList());
192 
193         assertEquals(3496978.1837704973, parameters[0], 1e-4);
194         assertEquals(4.054933085999146, parameters[1], 1e-4);
195         assertEquals(0.015039355620304326, parameters[2], 1e-4);
196     }
197 
198     @Test
199     void testWithMaxIterations1() {
200         final int maxIter = 20;
201         final double[] init = { 3.5e6, 4.2, 0.1 };
202 
203         GaussianCurveFitter fitter = GaussianCurveFitter.create();
204         double[] parameters = fitter
205             .withMaxIterations(maxIter)
206             .withStartPoint(init)
207             .fit(createDataset(DATASET1).toList());
208 
209         assertEquals(3496978.1837704973, parameters[0], 1e-2);
210         assertEquals(4.054933085999146, parameters[1], 1e-4);
211         assertEquals(0.015039355620304326, parameters[2], 1e-4);
212     }
213 
214     @Test
215     void testWithMaxIterations2() {
216         assertThrows(MathIllegalStateException.class, () -> {
217             final int maxIter = 1; // Too few iterations.
218             final double[] init = {3.5e6, 4.2, 0.1};
219 
220             GaussianCurveFitter fitter = GaussianCurveFitter.create();
221             fitter.withMaxIterations(maxIter)
222                 .withStartPoint(init)
223                 .fit(createDataset(DATASET1).toList());
224         });
225     }
226 
227     @Test
228     void testWithStartPoint() {
229         final double[] init = { 3.5e6, 4.2, 0.1 };
230 
231         GaussianCurveFitter fitter = GaussianCurveFitter.create();
232         double[] parameters = fitter
233             .withStartPoint(init)
234             .fit(createDataset(DATASET1).toList());
235 
236         assertEquals(3496978.1837704973, parameters[0], 1e-2);
237         assertEquals(4.054933085999146, parameters[1], 1e-4);
238         assertEquals(0.015039355620304326, parameters[2], 1e-4);
239     }
240 
241     /**
242      * Zero points is not enough observed points.
243      */
244     @Test
245     void testFit02() {
246         assertThrows(MathIllegalArgumentException.class, () -> {
247             GaussianCurveFitter.create().fit(new WeightedObservedPoints().toList());
248         });
249     }
250 
251     /**
252      * Two points is not enough observed points.
253      */
254     @Test
255     void testFit03() {
256         assertThrows(MathIllegalArgumentException.class, () -> {
257             GaussianCurveFitter fitter = GaussianCurveFitter.create();
258             fitter.fit(createDataset(new double[][]{
259                 {4.0254623, 531026.0},
260                 {4.02804905, 664002.0}
261             }).toList());
262         });
263     }
264 
265     /**
266      * Poor data: right of peak not symmetric with left of peak.
267      */
268     @Test
269     void testFit04() {
270         GaussianCurveFitter fitter = GaussianCurveFitter.create();
271         double[] parameters = fitter.fit(createDataset(DATASET2).toList());
272 
273         assertEquals(233003.2967252038, parameters[0], 1e-4);
274         assertEquals(-10.654887521095983, parameters[1], 1e-4);
275         assertEquals(4.335937353196641, parameters[2], 1e-4);
276     }
277 
278     /**
279      * Poor data: long tails.
280      */
281     @Test
282     void testFit05() {
283         GaussianCurveFitter fitter = GaussianCurveFitter.create();
284         double[] parameters = fitter.fit(createDataset(DATASET3).toList());
285 
286         assertEquals(283863.81929180305, parameters[0], 1e-4);
287         assertEquals(-13.29641995105174, parameters[1], 1e-4);
288         assertEquals(1.7297330293549908, parameters[2], 1e-4);
289     }
290 
291     /**
292      * Poor data: right of peak is missing.
293      */
294     @Test
295     void testFit06() {
296         GaussianCurveFitter fitter = GaussianCurveFitter.create();
297         double[] parameters = fitter.fit(createDataset(DATASET4).toList());
298 
299         assertEquals(285250.66754309234, parameters[0], 1e-4);
300         assertEquals(-13.528375695228455, parameters[1], 1e-4);
301         assertEquals(1.5204344894331614, parameters[2], 1e-4);
302     }
303 
304     /**
305      * Basic with smaller dataset.
306      */
307     @Test
308     void testFit07() {
309         GaussianCurveFitter fitter = GaussianCurveFitter.create();
310         double[] parameters = fitter.fit(createDataset(DATASET5).toList());
311 
312         assertEquals(3514384.729342235, parameters[0], 1e-4);
313         assertEquals(4.054970307455625, parameters[1], 1e-4);
314         assertEquals(0.015029412832160017, parameters[2], 1e-4);
315     }
316 
317     @Test
318     void testMath519() {
319         // The optimizer will try negative sigma values but "GaussianCurveFitter"
320         // will catch the raised exceptions and return NaN values instead.
321 
322         final double[] data = {
323             1.1143831578403364E-29,
324             4.95281403484594E-28,
325             1.1171347211930288E-26,
326             1.7044813962636277E-25,
327             1.9784716574832164E-24,
328             1.8630236407866774E-23,
329             1.4820532905097742E-22,
330             1.0241963854632831E-21,
331             6.275077366673128E-21,
332             3.461808994532493E-20,
333             1.7407124684715706E-19,
334             8.056687953553974E-19,
335             3.460193945992071E-18,
336             1.3883326374011525E-17,
337             5.233894983671116E-17,
338             1.8630791465263745E-16,
339             6.288759227922111E-16,
340             2.0204433920597856E-15,
341             6.198768938576155E-15,
342             1.821419346860626E-14,
343             5.139176445538471E-14,
344             1.3956427429045787E-13,
345             3.655705706448139E-13,
346             9.253753324779779E-13,
347             2.267636001476696E-12,
348             5.3880460095836855E-12,
349             1.2431632654852931E-11
350         };
351 
352         final WeightedObservedPoints obs = new WeightedObservedPoints();
353         for (int i = 0; i < data.length; i++) {
354             obs.add(i, data[i]);
355         }
356         final double[] p = GaussianCurveFitter.create().fit(obs.toList());
357 
358         assertEquals(53.1572792, p[1], 1e-7);
359         assertEquals(5.75214622, p[2], 1e-8);
360     }
361 
362     @Test
363     void testMath798() {
364         // When the data points are not commented out below, the fit stalls.
365         // This is expected however, since the whole dataset hardly looks like
366         // a Gaussian.
367         // When commented out, the fit proceeds fine.
368 
369         final WeightedObservedPoints obs = new WeightedObservedPoints();
370 
371         obs.add(0.23, 395.0);
372         //obs.add(0.68, 0.0);
373         obs.add(1.14, 376.0);
374         //obs.add(1.59, 0.0);
375         obs.add(2.05, 163.0);
376         //obs.add(2.50, 0.0);
377         obs.add(2.95, 49.0);
378         //obs.add(3.41, 0.0);
379         obs.add(3.86, 16.0);
380         //obs.add(4.32, 0.0);
381         obs.add(4.77, 1.0);
382 
383         final double[] p = GaussianCurveFitter.create().fit(obs.toList());
384 
385         // Values are copied from a previous run of this test.
386         assertEquals(420.8397296167364, p[0], 1e-12);
387         assertEquals(0.603770729862231, p[1], 1e-15);
388         assertEquals(1.0786447936766612, p[2], 1e-14);
389     }
390 
391     /**
392      * Adds the specified points to specified <code>GaussianCurveFitter</code>
393      * instance.
394      *
395      * @param points Data points where first dimension is a point index and
396      *        second dimension is an array of length two representing the point
397      *        with the first value corresponding to X and the second value
398      *        corresponding to Y.
399      * @return the collection of observed points.
400      */
401     private static WeightedObservedPoints createDataset(double[][] points) {
402         final WeightedObservedPoints obs = new WeightedObservedPoints();
403         for (int i = 0; i < points.length; i++) {
404             obs.add(points[i][0], points[i][1]);
405         }
406         return obs;
407     }
408 }