1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.distribution.discrete;
24
25 import org.hipparchus.exception.LocalizedCoreFormats;
26 import org.hipparchus.exception.MathIllegalArgumentException;
27 import org.hipparchus.util.FastMath;
28
29
30
31
32
33
34
35 public class HypergeometricDistribution extends AbstractIntegerDistribution {
36
37 private static final long serialVersionUID = 20160320L;
38
39 private final int numberOfSuccesses;
40
41 private final int populationSize;
42
43 private final int sampleSize;
44
45 private final double numericalVariance;
46
47
48
49
50
51
52
53
54
55
56
57
58
59 public HypergeometricDistribution(int populationSize, int numberOfSuccesses, int sampleSize)
60 throws MathIllegalArgumentException {
61 if (populationSize <= 0) {
62 throw new MathIllegalArgumentException(LocalizedCoreFormats.POPULATION_SIZE,
63 populationSize);
64 }
65 if (numberOfSuccesses < 0) {
66 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_OF_SUCCESSES,
67 numberOfSuccesses);
68 }
69 if (sampleSize < 0) {
70 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_OF_SAMPLES,
71 sampleSize);
72 }
73
74 if (numberOfSuccesses > populationSize) {
75 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_OF_SUCCESS_LARGER_THAN_POPULATION_SIZE,
76 numberOfSuccesses, populationSize, true);
77 }
78 if (sampleSize > populationSize) {
79 throw new MathIllegalArgumentException(LocalizedCoreFormats.SAMPLE_SIZE_LARGER_THAN_POPULATION_SIZE,
80 sampleSize, populationSize, true);
81 }
82
83 this.numberOfSuccesses = numberOfSuccesses;
84 this.populationSize = populationSize;
85 this.sampleSize = sampleSize;
86 this.numericalVariance = calculateNumericalVariance();
87 }
88
89
90 @Override
91 public double cumulativeProbability(int x) {
92 double ret;
93
94 int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
95 if (x < domain[0]) {
96 ret = 0.0;
97 } else if (x >= domain[1]) {
98 ret = 1.0;
99 } else {
100 ret = innerCumulativeProbability(domain[0], x, 1);
101 }
102
103 return ret;
104 }
105
106
107
108
109
110
111
112
113
114
115 private int[] getDomain(int n, int m, int k) {
116 return new int[] { getLowerDomain(n, m, k), getUpperDomain(m, k) };
117 }
118
119
120
121
122
123
124
125
126
127
128 private int getLowerDomain(int n, int m, int k) {
129 return FastMath.max(0, m - (n - k));
130 }
131
132
133
134
135
136
137 public int getNumberOfSuccesses() {
138 return numberOfSuccesses;
139 }
140
141
142
143
144
145
146 public int getPopulationSize() {
147 return populationSize;
148 }
149
150
151
152
153
154
155 public int getSampleSize() {
156 return sampleSize;
157 }
158
159
160
161
162
163
164
165
166
167 private int getUpperDomain(int m, int k) {
168 return FastMath.min(k, m);
169 }
170
171
172 @Override
173 public double probability(int x) {
174 final double logProbability = logProbability(x);
175 return logProbability == Double.NEGATIVE_INFINITY ? 0 : FastMath.exp(logProbability);
176 }
177
178
179 @Override
180 public double logProbability(int x) {
181 double ret;
182
183 int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
184 if (x < domain[0] || x > domain[1]) {
185 ret = Double.NEGATIVE_INFINITY;
186 } else {
187 double p = ((double) sampleSize) / populationSize;
188 double q = ((double) (populationSize - sampleSize)) / populationSize;
189 double p1 = SaddlePointExpansion.logBinomialProbability(x, numberOfSuccesses, p, q);
190 double p2 = SaddlePointExpansion.logBinomialProbability(sampleSize - x, populationSize - numberOfSuccesses, p, q);
191 double p3 = SaddlePointExpansion.logBinomialProbability(sampleSize, populationSize, p, q);
192 ret = p1 + p2 - p3;
193 }
194
195 return ret;
196 }
197
198
199
200
201
202
203
204 public double upperCumulativeProbability(int x) {
205 double ret;
206
207 final int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
208 if (x <= domain[0]) {
209 ret = 1.0;
210 } else if (x > domain[1]) {
211 ret = 0.0;
212 } else {
213 ret = innerCumulativeProbability(domain[1], x, -1);
214 }
215
216 return ret;
217 }
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232 private double innerCumulativeProbability(int x0, int x1, int dx) {
233 double ret = probability(x0);
234 while (x0 != x1) {
235 x0 += dx;
236 ret += probability(x0);
237 }
238 return ret;
239 }
240
241
242
243
244
245
246
247 @Override
248 public double getNumericalMean() {
249 return getSampleSize() * (getNumberOfSuccesses() / (double) getPopulationSize());
250 }
251
252
253
254
255
256
257
258
259 @Override
260 public double getNumericalVariance() {
261 return numericalVariance;
262 }
263
264
265
266
267
268
269 private double calculateNumericalVariance() {
270 final double N = getPopulationSize();
271 final double m = getNumberOfSuccesses();
272 final double n = getSampleSize();
273 return (n * m * (N - n) * (N - m)) / (N * N * (N - 1));
274 }
275
276
277
278
279
280
281
282
283
284
285 @Override
286 public int getSupportLowerBound() {
287 return FastMath.max(0,
288 getSampleSize() + getNumberOfSuccesses() - getPopulationSize());
289 }
290
291
292
293
294
295
296
297
298
299 @Override
300 public int getSupportUpperBound() {
301 return FastMath.min(getNumberOfSuccesses(), getSampleSize());
302 }
303
304
305
306
307
308
309
310
311 @Override
312 public boolean isSupportConnected() {
313 return true;
314 }
315 }