1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.distribution.multivariate;
23
24 import org.hipparchus.exception.MathIllegalArgumentException;
25 import org.hipparchus.exception.MathRuntimeException;
26 import org.hipparchus.util.Pair;
27 import org.junit.jupiter.api.Test;
28
29 import java.util.ArrayList;
30 import java.util.List;
31
32 import static org.junit.jupiter.api.Assertions.assertEquals;
33 import static org.junit.jupiter.api.Assertions.assertThrows;
34
35
36
37
38
39
40 class MultivariateNormalMixtureModelDistributionTest {
41
42 @Test
43 void testNonUnitWeightSum() {
44 final double[] weights = { 1, 2 };
45 final double[][] means = { { -1.5, 2.0 },
46 { 4.0, 8.2 } };
47 final double[][][] covariances = { { { 2.0, -1.1 },
48 { -1.1, 2.0 } },
49 { { 3.5, 1.5 },
50 { 1.5, 3.5 } } };
51 final MultivariateNormalMixtureModelDistribution d
52 = create(weights, means, covariances);
53
54 final List<Pair<Double, MultivariateNormalDistribution>> comp = d.getComponents();
55
56 assertEquals(1d / 3, comp.get(0).getFirst().doubleValue(), Math.ulp(1d));
57 assertEquals(2d / 3, comp.get(1).getFirst().doubleValue(), Math.ulp(1d));
58 }
59
60 @Test
61 void testWeightSumOverFlow() {
62 assertThrows(MathRuntimeException.class, () -> {
63 final double[] weights = {0.5 * Double.MAX_VALUE, 0.51 * Double.MAX_VALUE};
64 final double[][] means = {{-1.5, 2.0},
65 {4.0, 8.2}};
66 final double[][][] covariances = {{{2.0, -1.1},
67 {-1.1, 2.0}},
68 {{3.5, 1.5},
69 {1.5, 3.5}}};
70 create(weights, means, covariances);
71 });
72 }
73
74 @Test
75 void testPreconditionPositiveWeights() {
76 assertThrows(MathIllegalArgumentException.class, () -> {
77 final double[] negativeWeights = {-0.5, 1.5};
78 final double[][] means = {{-1.5, 2.0},
79 {4.0, 8.2}};
80 final double[][][] covariances = {{{2.0, -1.1},
81 {-1.1, 2.0}},
82 {{3.5, 1.5},
83 {1.5, 3.5}}};
84 create(negativeWeights, means, covariances);
85 });
86 }
87
88
89
90
91 @Test
92 void testDensities() {
93 final double[] weights = { 0.3, 0.7 };
94 final double[][] means = { { -1.5, 2.0 },
95 { 4.0, 8.2 } };
96 final double[][][] covariances = { { { 2.0, -1.1 },
97 { -1.1, 2.0 } },
98 { { 3.5, 1.5 },
99 { 1.5, 3.5 } } };
100 final MultivariateNormalMixtureModelDistribution d
101 = create(weights, means, covariances);
102
103
104 final double[][] testValues = { { -1.5, 2 },
105 { 4, 8.2 },
106 { 1.5, -2 },
107 { 0, 0 } };
108
109
110
111
112
113
114 final double[] correctDensities = { 0.02862037278930575,
115 0.03523044847314091,
116 0.000416241365629767,
117 0.009932042831700297 };
118
119 for (int i = 0; i < testValues.length; i++) {
120 assertEquals(correctDensities[i], d.density(testValues[i]), Math.ulp(1d));
121 }
122 }
123
124
125
126
127 @Test
128 void testSampling() {
129 final double[] weights = { 0.3, 0.7 };
130 final double[][] means = { { -1.5, 2.0 },
131 { 4.0, 8.2 } };
132 final double[][][] covariances = { { { 2.0, -1.1 },
133 { -1.1, 2.0 } },
134 { { 3.5, 1.5 },
135 { 1.5, 3.5 } } };
136 final MultivariateNormalMixtureModelDistribution d
137 = create(weights, means, covariances);
138 d.reseedRandomGenerator(50);
139
140 final double[][] correctSamples = getCorrectSamples();
141 final int n = correctSamples.length;
142 final double[][] samples = d.sample(n);
143
144 for (int i = 0; i < n; i++) {
145 for (int j = 0; j < samples[i].length; j++) {
146 assertEquals(correctSamples[i][j], samples[i][j], 1e-16);
147 }
148 }
149 }
150
151
152
153
154
155
156
157
158
159 private MultivariateNormalMixtureModelDistribution create(double[] weights,
160 double[][] means,
161 double[][][] covariances) {
162 final List<Pair<Double, MultivariateNormalDistribution>> mvns
163 = new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
164
165 for (int i = 0; i < weights.length; i++) {
166 final MultivariateNormalDistribution dist
167 = new MultivariateNormalDistribution(means[i], covariances[i]);
168 mvns.add(new Pair<Double, MultivariateNormalDistribution>(weights[i], dist));
169 }
170
171 return new MultivariateNormalMixtureModelDistribution(mvns);
172 }
173
174
175
176
177 private double[][] getCorrectSamples() {
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195 return new double[][] {
196 { 6.259990922080121, 11.972954175355897 },
197 { -2.5296544304801847, 1.0031292519854365 },
198 { 0.49037886081440396, 0.9758251727325711 },
199 { 5.022970993312015, 9.289348879616787 },
200 { -1.686183146603914, 2.007244382745706 },
201 { -1.4729253946002685, 2.762166644212484 },
202 { 4.329788143963888, 11.514016497132253 },
203 { 3.008674596114442, 4.960246550446107 },
204 { 3.342379304090846, 5.937630105198625 },
205 { 2.6993068328674754, 7.42190871572571 },
206 { -2.446569340219571, 1.9687117791378763 },
207 { 1.922417883170056, 4.917616702617099 },
208 { -1.1969741543898518, 2.4576126277884387 },
209 { 2.4216948702967196, 8.227710158117134 },
210 { 6.701424725804463, 9.098666475042428 },
211 { 2.9890253545698964, 9.643807939324331 },
212 { 0.7162632354907799, 8.978811120287553 },
213 { -2.7548699149775877, 4.1354812280794215 },
214 { 8.304528180745018, 11.602319388898287 },
215 { -2.7633253389165926, 2.786173883989795 },
216 { 1.3322228389460813, 5.447481218602913 },
217 { -1.8120096092851508, 1.605624499560037 },
218 { 3.6546253437206504, 8.195304526564376 },
219 { -2.312349539658588, 1.868941220444169 },
220 { -1.882322136356522, 2.033795570464242 },
221 { 4.562770714939441, 7.414967958885031 },
222 { 4.731882017875329, 8.890676665580747 },
223 { 3.492186010427425, 8.9005225241848 },
224 { -1.619700190174894, 3.314060142479045 },
225 { 3.5466090064003315, 7.75182101001913 },
226 { 5.455682472787392, 8.143119287755635 },
227 { -2.3859602945473197, 1.8826732217294837 },
228 { 3.9095306088680015, 9.258129209626317 },
229 { 7.443020189508173, 7.837840713329312 },
230 { 2.136004873917428, 6.917636475958297 },
231 { -1.7203379410395119, 2.3212878757611524 },
232 { 4.618991257611526, 12.095065976419436 },
233 { -0.4837044029854387, 0.8255970441255125 },
234 { -4.438938966557163, 4.948666297280241 },
235 { -0.4539625134045906, 4.700922454655341 },
236 { 2.1285488271265356, 8.457941480487563 },
237 { 3.4873561871454393, 11.99809827845933 },
238 { 4.723049431412658, 7.813095742563365 },
239 { 1.1245583037967455, 5.20587873556688 },
240 { 1.3411933634409197, 6.069796875785409 },
241 { 4.585119332463686, 7.967669543767418 },
242 { 1.3076522817963823, -0.647431033653445 },
243 { -1.4449446442803178, 1.9400424267464862 },
244 { -2.069794456383682, 3.5824162107496544 },
245 { -0.15959481421417276, 1.5466782303315405 },
246 { -2.0823081278810136, 3.0914366458581437 },
247 { 3.521944615248141, 10.276112932926408 },
248 { 1.0164326704884257, 4.342329556442856 },
249 { 5.3718868590295275, 8.374761158360922 },
250 { 0.3673656866959396, 8.75168581694866 },
251 { -2.250268955954753, 1.4610850300996527 },
252 { -2.312739727403522, 1.5921126297576362 },
253 { 3.138993360831055, 6.7338392374947365 },
254 { 2.6978650950790115, 7.941857288979095 },
255 { 4.387985088655384, 8.253499976968 },
256 { -1.8928961721456705, 0.23631082388724223 },
257 { 4.43509029544109, 8.565290285488782 },
258 { 4.904728034106502, 5.79936660133754 },
259 { -1.7640371853739507, 2.7343727594167433 },
260 { 2.4553674733053463, 7.875871017408807 },
261 { -2.6478965122565006, 4.465127753193949 },
262 { 3.493873671142299, 10.443093773532448 },
263 { 1.1321916197409103, 7.127108479263268 },
264 { -1.7335075535240392, 2.550629648463023 },
265 { -0.9772679734368084, 4.377196298969238 },
266 { 3.6388366973980357, 6.947299283206256 },
267 { 0.27043799318823325, 6.587978599614367 },
268 { 5.356782352010253, 7.388957912116327 },
269 { -0.09187745751354681, 0.23612399246659743 },
270 { 2.903203580353435, 3.8076727621794415 },
271 { 5.297014824937293, 8.650985262326508 },
272 { 4.934508602170976, 9.164571423190052 },
273 { -1.0004911869654256, 4.797064194444461 },
274 { 6.782491700298046, 11.852373338280497 },
275 { 2.8983678524536014, 8.303837362117521 },
276 { 4.805003269830865, 6.790462904325329 },
277 { -0.8815799740744226, 1.3015810062131394 },
278 { 5.115138859802104, 6.376895810201089 },
279 { 4.301239328205988, 8.60546337560793 },
280 { 3.276423626317666, 9.889429652591947 },
281 { -4.001924973153122, 4.3353864592328515 },
282 { 3.9571892554119517, 4.500569057308562 },
283 { 4.783067027436208, 7.451125480601317 },
284 { 4.79065438272821, 9.614122776979698 },
285 { 2.677655270279617, 6.8875223698210135 },
286 { -1.3714746289327362, 2.3992153193382437 },
287 { 3.240136859745249, 7.748339397522042 },
288 { 5.107885374416291, 8.508324480583724 },
289 { -1.5830830226666048, 0.9139127045208315 },
290 { -1.1596156791652918, -0.04502759384531929 },
291 { -0.4670021307952068, 3.6193633227841624 },
292 { -0.7026065228267798, 0.4811423031997131 },
293 { -2.719979836732917, 2.5165041618080104 },
294 { 1.0336754331123372, -0.34966029029320644 },
295 { 4.743217291882213, 5.750060115251131 }
296 };
297 }
298 }
299
300
301
302
303 class MultivariateNormalMixtureModelDistribution
304 extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> {
305
306 public MultivariateNormalMixtureModelDistribution(List<Pair<Double, MultivariateNormalDistribution>> components) {
307 super(components);
308 }
309 }