1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.distribution.continuous;
24
25 import org.hipparchus.UnitTestUtils;
26 import org.hipparchus.exception.MathIllegalArgumentException;
27 import org.hipparchus.special.Gamma;
28 import org.hipparchus.util.FastMath;
29 import org.junit.jupiter.api.BeforeEach;
30 import org.junit.jupiter.api.Test;
31
32 import java.io.BufferedReader;
33 import java.io.IOException;
34 import java.io.InputStream;
35 import java.io.InputStreamReader;
36
37 import static org.junit.jupiter.api.Assertions.assertEquals;
38 import static org.junit.jupiter.api.Assertions.assertFalse;
39 import static org.junit.jupiter.api.Assertions.assertNotNull;
40 import static org.junit.jupiter.api.Assertions.assertTrue;
41 import static org.junit.jupiter.api.Assertions.fail;
42
43
44
45
46 public class GammaDistributionTest extends RealDistributionAbstractTest {
47
48
49
50
51 @Override
52 public GammaDistribution makeDistribution() {
53 return new GammaDistribution(4d, 2d);
54 }
55
56
57 @Override
58 public double[] makeCumulativeTestPoints() {
59
60 return new double[] {0.857104827257, 1.64649737269, 2.17973074725, 2.7326367935, 3.48953912565,
61 26.1244815584, 20.0902350297, 17.5345461395, 15.5073130559, 13.3615661365};
62 }
63
64
65 @Override
66 public double[] makeCumulativeTestValues() {
67 return new double[] {0.001, 0.01, 0.025, 0.05, 0.1, 0.999, 0.990, 0.975, 0.950, 0.900};
68 }
69
70
71 @Override
72 public double[] makeDensityTestValues() {
73 return new double[] {0.00427280075546, 0.0204117166709, 0.0362756163658, 0.0542113174239, 0.0773195272491,
74 0.000394468852816, 0.00366559696761, 0.00874649473311, 0.0166712508128, 0.0311798227954};
75 }
76
77
78 @BeforeEach
79 @Override
80 public void setUp() {
81 super.setUp();
82 setTolerance(1e-9);
83 }
84
85
86 @Test
87 void testParameterAccessors() {
88 GammaDistribution distribution = (GammaDistribution) getDistribution();
89 assertEquals(4d, distribution.getShape(), 0);
90 assertEquals(2d, distribution.getScale(), 0);
91 }
92
93 @Test
94 void testPreconditions() {
95 try {
96 new GammaDistribution(0, 1);
97 fail("Expecting MathIllegalArgumentException for alpha = 0");
98 } catch (MathIllegalArgumentException ex) {
99
100 }
101 try {
102 new GammaDistribution(1, 0);
103 fail("Expecting MathIllegalArgumentException for alpha = 0");
104 } catch (MathIllegalArgumentException ex) {
105
106 }
107 }
108
109 @Test
110 void testProbabilities() {
111 testProbability(-1.000, 4.0, 2.0, .0000);
112 testProbability(15.501, 4.0, 2.0, .9499);
113 testProbability(0.504, 4.0, 1.0, .0018);
114 testProbability(10.011, 1.0, 2.0, .9933);
115 testProbability(5.000, 2.0, 2.0, .7127);
116 }
117
118 @Test
119 void testValues() {
120 testValue(15.501, 4.0, 2.0, .9499);
121 testValue(0.504, 4.0, 1.0, .0018);
122 testValue(10.011, 1.0, 2.0, .9933);
123 testValue(5.000, 2.0, 2.0, .7127);
124 }
125
126 private void testProbability(double x, double a, double b, double expected) {
127 GammaDistribution distribution = new GammaDistribution( a, b );
128 double actual = distribution.cumulativeProbability(x);
129 assertEquals(expected, actual, 10e-4, "probability for " + x);
130 }
131
132 private void testValue(double expected, double a, double b, double p) {
133 GammaDistribution distribution = new GammaDistribution( a, b );
134 double actual = distribution.inverseCumulativeProbability(p);
135 assertEquals(expected, actual, 10e-4, "critical value for " + p);
136 }
137
138 @Test
139 void testDensity() {
140 double[] x = new double[]{-0.1, 1e-6, 0.5, 1, 2, 5};
141
142 checkDensity(1, 1, x, new double[]{0.000000000000, 0.999999000001, 0.606530659713, 0.367879441171, 0.135335283237, 0.006737946999});
143
144 checkDensity(2, 1, x, new double[]{0.000000000000, 0.000000999999, 0.303265329856, 0.367879441171, 0.270670566473, 0.033689734995});
145
146 checkDensity(4, 1, x, new double[]{0.000000000e+00, 1.666665000e-19, 1.263605541e-02, 6.131324020e-02, 1.804470443e-01, 1.403738958e-01});
147
148 checkDensity(4, 10, x, new double[]{0.000000000e+00, 1.666650000e-15, 1.403738958e+00, 7.566654960e-02, 2.748204830e-05, 4.018228850e-17});
149
150 checkDensity(0.1, 10, x, new double[]{0.000000000e+00, 3.323953832e+04, 1.663849010e-03, 6.007786726e-06, 1.461647647e-10, 5.996008322e-24});
151
152 checkDensity(0.1, 20, x, new double[]{0.000000000e+00, 3.562489883e+04, 1.201557345e-05, 2.923295295e-10, 3.228910843e-19, 1.239484589e-45});
153
154 checkDensity(0.1, 4, x, new double[]{0.000000000e+00, 3.032938388e+04, 3.049322494e-02, 2.211502311e-03, 2.170613371e-05, 5.846590589e-11});
155
156 checkDensity(0.1, 1, x, new double[]{0.000000000e+00, 2.640334143e+04, 1.189704437e-01, 3.866916944e-02, 7.623306235e-03, 1.663849010e-04});
157 }
158
159 private void checkDensity(double alpha, double rate, double[] x, double[] expected) {
160 GammaDistribution d = new GammaDistribution(alpha, 1 / rate);
161 for (int i = 0; i < x.length; i++) {
162 assertEquals(expected[i], d.density(x[i]), 1e-5);
163 }
164 }
165
166 @Test
167 void testInverseCumulativeProbabilityExtremes() {
168 setInverseCumulativeTestPoints(new double[] {0, 1});
169 setInverseCumulativeTestValues(new double[] {0, Double.POSITIVE_INFINITY});
170 verifyInverseCumulativeProbabilities();
171 }
172
173 @Test
174 void testMoments() {
175 final double tol = 1e-9;
176 GammaDistribution dist;
177
178 dist = new GammaDistribution(1, 2);
179 assertEquals(2, dist.getNumericalMean(), tol);
180 assertEquals(4, dist.getNumericalVariance(), tol);
181
182 dist = new GammaDistribution(1.1, 4.2);
183 assertEquals(dist.getNumericalMean(), 1.1d * 4.2d, tol);
184 assertEquals(dist.getNumericalVariance(), 1.1d * 4.2d * 4.2d, tol);
185 }
186
187 private static final double HALF_LOG_2_PI = 0.5 * FastMath.log(2.0 * FastMath.PI);
188
189 public static double logGamma(double x) {
190
191
192
193
194
195 double ret;
196
197 if (Double.isNaN(x) || (x <= 0.0)) {
198 ret = Double.NaN;
199 } else {
200 double sum = Gamma.lanczos(x);
201 double tmp = x + Gamma.LANCZOS_G + .5;
202 ret = ((x + .5) * FastMath.log(tmp)) - tmp +
203 HALF_LOG_2_PI + FastMath.log(sum / x);
204 }
205
206 return ret;
207 }
208
209 public static double density(final double x, final double shape,
210 final double scale) {
211
212
213
214
215
216 if (x < 0) {
217 return 0;
218 }
219 return FastMath.pow(x / scale, shape - 1) / scale *
220 FastMath.exp(-x / scale) / FastMath.exp(logGamma(shape));
221 }
222
223
224
225
226
227
228
229
230 private void doTestMath753(final double shape,
231 final double meanNoOF, final double sdNoOF,
232 final double meanOF, final double sdOF,
233 final String resourceName) throws IOException {
234 final GammaDistribution distribution = new GammaDistribution(shape, 1.0);
235 final UnitTestUtils.SimpleStatistics statOld = new UnitTestUtils.SimpleStatistics();
236 final UnitTestUtils.SimpleStatistics statNewNoOF = new UnitTestUtils.SimpleStatistics();
237 final UnitTestUtils.SimpleStatistics statNewOF = new UnitTestUtils.SimpleStatistics();
238
239 final InputStream resourceAsStream;
240 resourceAsStream = this.getClass().getResourceAsStream(resourceName);
241 assertNotNull(resourceAsStream,
242 "Could not find resource " + resourceName);
243 final BufferedReader in;
244 in = new BufferedReader(new InputStreamReader(resourceAsStream));
245
246 try {
247 for (String line = in.readLine(); line != null; line = in.readLine()) {
248 if (line.startsWith("#")) {
249 continue;
250 }
251 final String[] tokens = line.split(", ");
252 assertEquals(2, tokens.length, "expected two floating-point values");
253 final double x = Double.parseDouble(tokens[0]);
254 final String msg = "x = " + x + ", shape = " + shape +
255 ", scale = 1.0";
256 final double expected = Double.parseDouble(tokens[1]);
257 final double ulp = FastMath.ulp(expected);
258 final double actualOld = density(x, shape, 1.0);
259 final double actualNew = distribution.density(x);
260 final double errOld, errNew;
261 errOld = FastMath.abs((actualOld - expected) / ulp);
262 errNew = FastMath.abs((actualNew - expected) / ulp);
263
264 if (Double.isNaN(actualOld) || Double.isInfinite(actualOld)) {
265 assertFalse(Double.isNaN(actualNew), msg);
266 assertFalse(Double.isInfinite(actualNew), msg);
267 statNewOF.addValue(errNew);
268 } else {
269 statOld.addValue(errOld);
270 statNewNoOF.addValue(errNew);
271 }
272 }
273 if (statOld.getN() != 0) {
274
275
276
277
278 final StringBuilder sb = new StringBuilder("shape = ");
279 sb.append(shape);
280 sb.append(", scale = 1.0\n");
281 sb.append("Old implementation\n");
282 sb.append("------------------\n");
283 sb.append(statOld.toString());
284 sb.append("New implementation\n");
285 sb.append("------------------\n");
286 sb.append(statNewNoOF.toString());
287 final String msg = sb.toString();
288
289 final double oldMin = statOld.getMin();
290 final double newMin = statNewNoOF.getMin();
291 assertTrue(newMin <= oldMin, msg);
292
293 final double oldMax = statOld.getMax();
294 final double newMax = statNewNoOF.getMax();
295 assertTrue(newMax <= oldMax, msg);
296
297 final double oldMean = statOld.getMean();
298 final double newMean = statNewNoOF.getMean();
299 assertTrue(newMean <= oldMean, msg);
300
301 final double oldSd = statOld.getStandardDeviation();
302 final double newSd = statNewNoOF.getStandardDeviation();
303 assertTrue(newSd <= oldSd, msg);
304
305 assertTrue(newMean <= meanNoOF, msg);
306 assertTrue(newSd <= sdNoOF, msg);
307 }
308 if (statNewOF.getN() != 0) {
309 final double newMean = statNewOF.getMean();
310 final double newSd = statNewOF.getStandardDeviation();
311
312 final StringBuilder sb = new StringBuilder("shape = ");
313 sb.append(shape);
314 sb.append(", scale = 1.0");
315 sb.append(", max. mean error (ulps) = ");
316 sb.append(meanOF);
317 sb.append(", actual mean error (ulps) = ");
318 sb.append(newMean);
319 sb.append(", max. sd of error (ulps) = ");
320 sb.append(sdOF);
321 sb.append(", actual sd of error (ulps) = ");
322 sb.append(newSd);
323 final String msg = sb.toString();
324
325 assertTrue(newMean <= meanOF, msg);
326 assertTrue(newSd <= sdOF, msg);
327 }
328 } catch (IOException e) {
329 fail(e.getMessage());
330 } finally {
331 in.close();
332 }
333 }
334
335
336 @Test
337 void testMath753Shape1() throws IOException {
338 doTestMath753(1.0, 1.5, 0.5, 0.0, 0.0, "gamma-distribution-shape-1.csv");
339 }
340
341 @Test
342 void testMath753Shape8() throws IOException {
343 doTestMath753(8.0, 1.5, 1.0, 0.0, 0.0, "gamma-distribution-shape-8.csv");
344 }
345
346 @Test
347 void testMath753Shape10() throws IOException {
348 doTestMath753(10.0, 1.0, 1.0, 0.0, 0.0, "gamma-distribution-shape-10.csv");
349 }
350
351 @Test
352 void testMath753Shape100() throws IOException {
353 doTestMath753(100.0, 1.5, 1.0, 0.0, 0.0, "gamma-distribution-shape-100.csv");
354 }
355
356 @Test
357 void testMath753Shape142() throws IOException {
358 doTestMath753(142.0, 3.3, 1.6, 40.0, 40.0, "gamma-distribution-shape-142.csv");
359 }
360
361 @Test
362 void testMath753Shape1000() throws IOException {
363 doTestMath753(1000.0, 1.0, 1.0, 160.0, 220.0, "gamma-distribution-shape-1000.csv");
364 }
365 }