View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.distribution.continuous;
24  
25  import org.hipparchus.UnitTestUtils;
26  import org.hipparchus.exception.MathIllegalArgumentException;
27  import org.hipparchus.special.Gamma;
28  import org.hipparchus.util.FastMath;
29  import org.junit.jupiter.api.BeforeEach;
30  import org.junit.jupiter.api.Test;
31  
32  import java.io.BufferedReader;
33  import java.io.IOException;
34  import java.io.InputStream;
35  import java.io.InputStreamReader;
36  
37  import static org.junit.jupiter.api.Assertions.assertEquals;
38  import static org.junit.jupiter.api.Assertions.assertFalse;
39  import static org.junit.jupiter.api.Assertions.assertNotNull;
40  import static org.junit.jupiter.api.Assertions.assertTrue;
41  import static org.junit.jupiter.api.Assertions.fail;
42  
43  /**
44   * Test cases for GammaDistribution.
45   */
46  public class GammaDistributionTest extends RealDistributionAbstractTest {
47  
48      //-------------- Implementations for abstract methods -----------------------
49  
50      /** Creates the default continuous distribution instance to use in tests. */
51      @Override
52      public GammaDistribution makeDistribution() {
53          return new GammaDistribution(4d, 2d);
54      }
55  
56      /** Creates the default cumulative probability distribution test input values */
57      @Override
58      public double[] makeCumulativeTestPoints() {
59          // quantiles computed using R version 2.9.2
60          return new double[] {0.857104827257, 1.64649737269, 2.17973074725, 2.7326367935, 3.48953912565,
61                  26.1244815584, 20.0902350297, 17.5345461395, 15.5073130559, 13.3615661365};
62      }
63  
64      /** Creates the default cumulative probability density test expected values */
65      @Override
66      public double[] makeCumulativeTestValues() {
67          return new double[] {0.001, 0.01, 0.025, 0.05, 0.1, 0.999, 0.990, 0.975, 0.950, 0.900};
68      }
69  
70      /** Creates the default probability density test expected values */
71      @Override
72      public double[] makeDensityTestValues() {
73          return new double[] {0.00427280075546, 0.0204117166709, 0.0362756163658, 0.0542113174239, 0.0773195272491,
74                  0.000394468852816, 0.00366559696761, 0.00874649473311, 0.0166712508128, 0.0311798227954};
75      }
76  
77      // --------------------- Override tolerance  --------------
78      @BeforeEach
79      @Override
80      public void setUp() {
81          super.setUp();
82          setTolerance(1e-9);
83      }
84  
85      //---------------------------- Additional test cases -------------------------
86      @Test
87      void testParameterAccessors() {
88          GammaDistribution distribution = (GammaDistribution) getDistribution();
89          assertEquals(4d, distribution.getShape(), 0);
90          assertEquals(2d, distribution.getScale(), 0);
91      }
92  
93      @Test
94      void testPreconditions() {
95          try {
96              new GammaDistribution(0, 1);
97              fail("Expecting MathIllegalArgumentException for alpha = 0");
98          } catch (MathIllegalArgumentException ex) {
99              // Expected.
100         }
101         try {
102             new GammaDistribution(1, 0);
103             fail("Expecting MathIllegalArgumentException for alpha = 0");
104         } catch (MathIllegalArgumentException ex) {
105             // Expected.
106         }
107     }
108 
109     @Test
110     void testProbabilities() {
111         testProbability(-1.000, 4.0, 2.0, .0000);
112         testProbability(15.501, 4.0, 2.0, .9499);
113         testProbability(0.504, 4.0, 1.0, .0018);
114         testProbability(10.011, 1.0, 2.0, .9933);
115         testProbability(5.000, 2.0, 2.0, .7127);
116     }
117 
118     @Test
119     void testValues() {
120         testValue(15.501, 4.0, 2.0, .9499);
121         testValue(0.504, 4.0, 1.0, .0018);
122         testValue(10.011, 1.0, 2.0, .9933);
123         testValue(5.000, 2.0, 2.0, .7127);
124     }
125 
126     private void testProbability(double x, double a, double b, double expected) {
127         GammaDistribution distribution = new GammaDistribution( a, b );
128         double actual = distribution.cumulativeProbability(x);
129         assertEquals(expected, actual, 10e-4, "probability for " + x);
130     }
131 
132     private void testValue(double expected, double a, double b, double p) {
133         GammaDistribution distribution = new GammaDistribution( a, b );
134         double actual = distribution.inverseCumulativeProbability(p);
135         assertEquals(expected, actual, 10e-4, "critical value for " + p);
136     }
137 
138     @Test
139     void testDensity() {
140         double[] x = new double[]{-0.1, 1e-6, 0.5, 1, 2, 5};
141         // R2.5: print(dgamma(x, shape=1, rate=1), digits=10)
142         checkDensity(1, 1, x, new double[]{0.000000000000, 0.999999000001, 0.606530659713, 0.367879441171, 0.135335283237, 0.006737946999});
143         // R2.5: print(dgamma(x, shape=2, rate=1), digits=10)
144         checkDensity(2, 1, x, new double[]{0.000000000000, 0.000000999999, 0.303265329856, 0.367879441171, 0.270670566473, 0.033689734995});
145         // R2.5: print(dgamma(x, shape=4, rate=1), digits=10)
146         checkDensity(4, 1, x, new double[]{0.000000000e+00, 1.666665000e-19, 1.263605541e-02, 6.131324020e-02, 1.804470443e-01, 1.403738958e-01});
147         // R2.5: print(dgamma(x, shape=4, rate=10), digits=10)
148         checkDensity(4, 10, x, new double[]{0.000000000e+00, 1.666650000e-15, 1.403738958e+00, 7.566654960e-02, 2.748204830e-05, 4.018228850e-17});
149         // R2.5: print(dgamma(x, shape=.1, rate=10), digits=10)
150         checkDensity(0.1, 10, x, new double[]{0.000000000e+00, 3.323953832e+04, 1.663849010e-03, 6.007786726e-06, 1.461647647e-10, 5.996008322e-24});
151         // R2.5: print(dgamma(x, shape=.1, rate=20), digits=10)
152         checkDensity(0.1, 20, x, new double[]{0.000000000e+00, 3.562489883e+04, 1.201557345e-05, 2.923295295e-10, 3.228910843e-19, 1.239484589e-45});
153         // R2.5: print(dgamma(x, shape=.1, rate=4), digits=10)
154         checkDensity(0.1, 4, x, new double[]{0.000000000e+00, 3.032938388e+04, 3.049322494e-02, 2.211502311e-03, 2.170613371e-05, 5.846590589e-11});
155         // R2.5: print(dgamma(x, shape=.1, rate=1), digits=10)
156         checkDensity(0.1, 1, x, new double[]{0.000000000e+00, 2.640334143e+04, 1.189704437e-01, 3.866916944e-02, 7.623306235e-03, 1.663849010e-04});
157     }
158 
159     private void checkDensity(double alpha, double rate, double[] x, double[] expected) {
160         GammaDistribution d = new GammaDistribution(alpha, 1 / rate);
161         for (int i = 0; i < x.length; i++) {
162             assertEquals(expected[i], d.density(x[i]), 1e-5);
163         }
164     }
165 
166     @Test
167     void testInverseCumulativeProbabilityExtremes() {
168         setInverseCumulativeTestPoints(new double[] {0, 1});
169         setInverseCumulativeTestValues(new double[] {0, Double.POSITIVE_INFINITY});
170         verifyInverseCumulativeProbabilities();
171     }
172 
173     @Test
174     void testMoments() {
175         final double tol = 1e-9;
176         GammaDistribution dist;
177 
178         dist = new GammaDistribution(1, 2);
179         assertEquals(2, dist.getNumericalMean(), tol);
180         assertEquals(4, dist.getNumericalVariance(), tol);
181 
182         dist = new GammaDistribution(1.1, 4.2);
183         assertEquals(dist.getNumericalMean(), 1.1d * 4.2d, tol);
184         assertEquals(dist.getNumericalVariance(), 1.1d * 4.2d * 4.2d, tol);
185     }
186 
187     private static final double HALF_LOG_2_PI = 0.5 * FastMath.log(2.0 * FastMath.PI);
188 
189     public static double logGamma(double x) {
190         /*
191          * This is a copy of
192          * double Gamma.logGamma(double)
193          * prior to MATH-849
194          */
195         double ret;
196 
197         if (Double.isNaN(x) || (x <= 0.0)) {
198             ret = Double.NaN;
199         } else {
200             double sum = Gamma.lanczos(x);
201             double tmp = x + Gamma.LANCZOS_G + .5;
202             ret = ((x + .5) * FastMath.log(tmp)) - tmp +
203                 HALF_LOG_2_PI + FastMath.log(sum / x);
204         }
205 
206         return ret;
207     }
208 
209     public static double density(final double x, final double shape,
210                                  final double scale) {
211         /*
212          * This is a copy of
213          * double GammaDistribution.density(double)
214          * prior to MATH-753.
215          */
216         if (x < 0) {
217             return 0;
218         }
219         return FastMath.pow(x / scale, shape - 1) / scale *
220                FastMath.exp(-x / scale) / FastMath.exp(logGamma(shape));
221     }
222 
223     /*
224      * MATH-753: large values of x or shape parameter cause density(double) to
225      * overflow. Reference data is generated with the Maxima script
226      * gamma-distribution.mac, which can be found in
227      * src/test/resources/org.hipparchus/distribution.
228      */
229 
230     private void doTestMath753(final double shape,
231         final double meanNoOF, final double sdNoOF,
232         final double meanOF, final double sdOF,
233         final String resourceName) throws IOException {
234         final GammaDistribution distribution = new GammaDistribution(shape, 1.0);
235         final UnitTestUtils.SimpleStatistics statOld = new UnitTestUtils.SimpleStatistics();
236         final UnitTestUtils.SimpleStatistics statNewNoOF = new UnitTestUtils.SimpleStatistics();
237         final UnitTestUtils.SimpleStatistics statNewOF = new UnitTestUtils.SimpleStatistics();
238 
239         final InputStream resourceAsStream;
240         resourceAsStream = this.getClass().getResourceAsStream(resourceName);
241         assertNotNull(resourceAsStream,
242                              "Could not find resource " + resourceName);
243         final BufferedReader in;
244         in = new BufferedReader(new InputStreamReader(resourceAsStream));
245 
246         try {
247             for (String line = in.readLine(); line != null; line = in.readLine()) {
248                 if (line.startsWith("#")) {
249                     continue;
250                 }
251                 final String[] tokens = line.split(", ");
252                 assertEquals(2, tokens.length, "expected two floating-point values");
253                 final double x = Double.parseDouble(tokens[0]);
254                 final String msg = "x = " + x + ", shape = " + shape +
255                                    ", scale = 1.0";
256                 final double expected = Double.parseDouble(tokens[1]);
257                 final double ulp = FastMath.ulp(expected);
258                 final double actualOld = density(x, shape, 1.0);
259                 final double actualNew = distribution.density(x);
260                 final double errOld, errNew;
261                 errOld = FastMath.abs((actualOld - expected) / ulp);
262                 errNew = FastMath.abs((actualNew - expected) / ulp);
263 
264                 if (Double.isNaN(actualOld) || Double.isInfinite(actualOld)) {
265                     assertFalse(Double.isNaN(actualNew), msg);
266                     assertFalse(Double.isInfinite(actualNew), msg);
267                     statNewOF.addValue(errNew);
268                 } else {
269                     statOld.addValue(errOld);
270                     statNewNoOF.addValue(errNew);
271                 }
272             }
273             if (statOld.getN() != 0) {
274                 /*
275                  * If no overflow occurs, check that new implementation is
276                  * better than old one.
277                  */
278                 final StringBuilder sb = new StringBuilder("shape = ");
279                 sb.append(shape);
280                 sb.append(", scale = 1.0\n");
281                 sb.append("Old implementation\n");
282                 sb.append("------------------\n");
283                 sb.append(statOld.toString());
284                 sb.append("New implementation\n");
285                 sb.append("------------------\n");
286                 sb.append(statNewNoOF.toString());
287                 final String msg = sb.toString();
288 
289                 final double oldMin = statOld.getMin();
290                 final double newMin = statNewNoOF.getMin();
291                 assertTrue(newMin <= oldMin, msg);
292 
293                 final double oldMax = statOld.getMax();
294                 final double newMax = statNewNoOF.getMax();
295                 assertTrue(newMax <= oldMax, msg);
296 
297                 final double oldMean = statOld.getMean();
298                 final double newMean = statNewNoOF.getMean();
299                 assertTrue(newMean <= oldMean, msg);
300 
301                 final double oldSd = statOld.getStandardDeviation();
302                 final double newSd = statNewNoOF.getStandardDeviation();
303                 assertTrue(newSd <= oldSd, msg);
304 
305                 assertTrue(newMean <= meanNoOF, msg);
306                 assertTrue(newSd <= sdNoOF, msg);
307             }
308             if (statNewOF.getN() != 0) {
309                 final double newMean = statNewOF.getMean();
310                 final double newSd = statNewOF.getStandardDeviation();
311 
312                 final StringBuilder sb = new StringBuilder("shape = ");
313                 sb.append(shape);
314                 sb.append(", scale = 1.0");
315                 sb.append(", max. mean error (ulps) = ");
316                 sb.append(meanOF);
317                 sb.append(", actual mean error (ulps) = ");
318                 sb.append(newMean);
319                 sb.append(", max. sd of error (ulps) = ");
320                 sb.append(sdOF);
321                 sb.append(", actual sd of error (ulps) = ");
322                 sb.append(newSd);
323                 final String msg = sb.toString();
324 
325                 assertTrue(newMean <= meanOF, msg);
326                 assertTrue(newSd <= sdOF, msg);
327             }
328         } catch (IOException e) {
329             fail(e.getMessage());
330         } finally {
331             in.close();
332         }
333     }
334 
335 
336     @Test
337     void testMath753Shape1() throws IOException {
338         doTestMath753(1.0, 1.5, 0.5, 0.0, 0.0, "gamma-distribution-shape-1.csv");
339     }
340 
341     @Test
342     void testMath753Shape8() throws IOException {
343         doTestMath753(8.0, 1.5, 1.0, 0.0, 0.0, "gamma-distribution-shape-8.csv");
344     }
345 
346     @Test
347     void testMath753Shape10() throws IOException {
348         doTestMath753(10.0, 1.0, 1.0, 0.0, 0.0, "gamma-distribution-shape-10.csv");
349     }
350 
351     @Test
352     void testMath753Shape100() throws IOException {
353         doTestMath753(100.0, 1.5, 1.0, 0.0, 0.0, "gamma-distribution-shape-100.csv");
354     }
355 
356     @Test
357     void testMath753Shape142() throws IOException {
358         doTestMath753(142.0, 3.3, 1.6, 40.0, 40.0, "gamma-distribution-shape-142.csv");
359     }
360 
361     @Test
362     void testMath753Shape1000() throws IOException {
363         doTestMath753(1000.0, 1.0, 1.0, 160.0, 220.0, "gamma-distribution-shape-1000.csv");
364     }
365 }