1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.hipparchus.stat.fitting;
18
19 import org.hipparchus.distribution.multivariate.MixtureMultivariateNormalDistribution;
20 import org.hipparchus.distribution.multivariate.MultivariateNormalDistribution;
21 import org.hipparchus.exception.LocalizedCoreFormats;
22 import org.hipparchus.exception.MathIllegalArgumentException;
23 import org.hipparchus.exception.MathIllegalStateException;
24 import org.hipparchus.linear.Array2DRowRealMatrix;
25 import org.hipparchus.linear.RealMatrix;
26 import org.hipparchus.util.Pair;
27 import org.junit.jupiter.api.Test;
28
29 import java.lang.reflect.Constructor;
30 import java.lang.reflect.InvocationTargetException;
31 import java.util.ArrayList;
32 import java.util.Arrays;
33 import java.util.List;
34
35 import static org.junit.jupiter.api.Assertions.assertEquals;
36 import static org.junit.jupiter.api.Assertions.assertNotEquals;
37 import static org.junit.jupiter.api.Assertions.assertThrows;
38 import static org.junit.jupiter.api.Assertions.assertTrue;
39 import static org.junit.jupiter.api.Assertions.fail;
40
41
42
43
44
45 class MultivariateNormalMixtureExpectationMaximizationTest {
46
47 @Test
48 void testNonEmptyData() {
49 assertThrows(MathIllegalArgumentException.class, () -> {
50
51 new MultivariateNormalMixtureExpectationMaximization(new double[][]{});
52 });
53 }
54
55 @Test
56 void testNonJaggedData() {
57 assertThrows(MathIllegalArgumentException.class, () -> {
58
59 double[][] data = new double[][]{
60 {1, 2, 3},
61 {4, 5, 6, 7},
62 };
63 new MultivariateNormalMixtureExpectationMaximization(data);
64 });
65 }
66
67 @Test
68 void testMultipleColumnsRequired() {
69 assertThrows(MathIllegalArgumentException.class, () -> {
70
71 double[][] data = new double[][]{
72 {1}, {2}
73 };
74 new MultivariateNormalMixtureExpectationMaximization(data);
75 });
76 }
77
78 @Test
79 void testMaxIterationsPositive() {
80 assertThrows(MathIllegalArgumentException.class, () -> {
81
82 double[][] data = getTestSamples();
83 MultivariateNormalMixtureExpectationMaximization fitter =
84 new MultivariateNormalMixtureExpectationMaximization(data);
85
86 MixtureMultivariateNormalDistribution
87 initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
88
89 fitter.fit(initialMix, 0, 1E-5);
90 });
91 }
92
93 @Test
94 void testThresholdPositive() {
95 assertThrows(MathIllegalArgumentException.class, () -> {
96
97 double[][] data = getTestSamples();
98 MultivariateNormalMixtureExpectationMaximization fitter =
99 new MultivariateNormalMixtureExpectationMaximization(
100 data);
101
102 MixtureMultivariateNormalDistribution
103 initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
104
105 fitter.fit(initialMix, 1000, 0);
106 });
107 }
108
109 @Test
110 void testConvergenceException() {
111 assertThrows(MathIllegalStateException.class, () -> {
112
113 double[][] data = getTestSamples();
114 MultivariateNormalMixtureExpectationMaximization fitter
115 = new MultivariateNormalMixtureExpectationMaximization(data);
116
117 MixtureMultivariateNormalDistribution
118 initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
119
120
121 fitter.fit(initialMix, 5, 1E-5);
122 });
123 }
124
125 @Test
126 void testIncompatibleIntialMixture() {
127 assertThrows(MathIllegalArgumentException.class, () -> {
128
129 double[][] data = new double[][]{
130 {1, 2, 3}, {4, 5, 6}, {7, 8, 9}
131 };
132 double[] weights = new double[]{0.5, 0.5};
133
134
135
136 MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[2];
137
138 mvns[0] = new MultivariateNormalDistribution(new double[]{
139 -0.0021722935000328823, 3.5432892936887908},
140 new double[][]{
141 {4.537422569229048, 3.5266152281729304},
142 {3.5266152281729304, 6.175448814169779}});
143 mvns[1] = new MultivariateNormalDistribution(new double[]{
144 5.090902706507635, 8.68540656355283}, new double[][]{
145 {2.886778573963039, 1.5257474543463154},
146 {1.5257474543463154, 3.3794567673616918}});
147
148
149 List<Pair<Double, MultivariateNormalDistribution>> components =
150 new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
151 components.add(new Pair<Double, MultivariateNormalDistribution>(
152 weights[0], mvns[0]));
153 components.add(new Pair<Double, MultivariateNormalDistribution>(
154 weights[1], mvns[1]));
155
156 MixtureMultivariateNormalDistribution badInitialMix
157 = new MixtureMultivariateNormalDistribution(components);
158
159 MultivariateNormalMixtureExpectationMaximization fitter
160 = new MultivariateNormalMixtureExpectationMaximization(data);
161
162 fitter.fit(badInitialMix);
163 });
164 }
165
166 @Test
167 void testInitialMixture() {
168
169 final double[] correctWeights = new double[] { 0.5, 0.5 };
170
171 final double[][] correctMeans = new double[][] {
172 {-0.0021722935000328823, 3.5432892936887908},
173 {5.090902706507635, 8.68540656355283},
174 };
175
176 final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
177
178 correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
179 { 4.537422569229048, 3.5266152281729304 },
180 { 3.5266152281729304, 6.175448814169779 } });
181
182 correctCovMats[1] = new Array2DRowRealMatrix( new double[][] {
183 { 2.886778573963039, 1.5257474543463154 },
184 { 1.5257474543463154, 3.3794567673616918 } });
185
186 final MultivariateNormalDistribution[] correctMVNs = new
187 MultivariateNormalDistribution[2];
188
189 correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0],
190 correctCovMats[0].getData());
191
192 correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1],
193 correctCovMats[1].getData());
194
195 final MixtureMultivariateNormalDistribution initialMix
196 = MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(), 2);
197
198 int i = 0;
199 for (Pair<Double, MultivariateNormalDistribution> component : initialMix
200 .getComponents()) {
201 assertEquals(correctWeights[i], component.getFirst(),
202 Math.ulp(1d));
203
204 final double[] means = component.getValue().getMeans();
205 assertTrue(Arrays.equals(correctMeans[i], means));
206
207 final RealMatrix covMat = component.getValue().getCovariances();
208 assertEquals(correctCovMats[i], covMat);
209 i++;
210 }
211 }
212
213 @Test
214 void testWrongData() {
215 checkWrongData(new double[1][1], 2, LocalizedCoreFormats.NUMBER_TOO_SMALL);
216 checkWrongData(new double[3][3], 1, LocalizedCoreFormats.NUMBER_TOO_SMALL);
217 checkWrongData(new double[3][3], 4, LocalizedCoreFormats.NUMBER_TOO_LARGE);
218 }
219
220 private void checkWrongData(final double[][] data, final int numComponents,
221 final LocalizedCoreFormats expected) {
222 try {
223 MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents);
224 fail("an exception should have been thrown");
225 } catch (MathIllegalArgumentException miae) {
226 assertEquals(expected, miae.getSpecifier());
227 }
228 }
229
230 @Test
231 void testUnusedInheritedMethods() {
232
233 try {
234 Class<?> dataRowClass = MultivariateNormalMixtureExpectationMaximization.class.getDeclaredClasses()[0];
235 Constructor<?> dataRowConstructor = dataRowClass.getDeclaredConstructor(double[].class);
236 dataRowConstructor.setAccessible(true);
237 Object dr1 = dataRowConstructor.newInstance(new double[] { 1, 2, 3 });
238 assertEquals(66614367, dr1.hashCode());
239 assertEquals(dr1, dr1);
240 assertNotEquals("", dr1);
241 assertEquals(dr1, dataRowConstructor.newInstance(new double[]{1, 2, 3}));
242 assertNotEquals(dr1, dataRowConstructor.newInstance(new double[]{3, 2, 1}));
243 } catch (InvocationTargetException | NoSuchMethodException | SecurityException |
244 InstantiationException | IllegalAccessException | IllegalArgumentException e) {
245 fail(e.getLocalizedMessage());
246 }
247 }
248
249 @Test
250 void testFit() {
251
252
253 final double[][] data = getTestSamples();
254 final double correctLogLikelihood = -4.292431006791994;
255 final double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 };
256
257 final double[][] correctMeans = new double[][]{
258 {-1.4213112715121132, 1.6924690505757753},
259 {4.213612224374709, 7.975621325853645}
260 };
261
262 final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
263 correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
264 { 1.739356907285747, -0.5867644251487614 },
265 { -0.5867644251487614, 1.0232932029324642 } }
266 );
267 correctCovMats[1] = new Array2DRowRealMatrix(new double[][] {
268 { 4.245384898007161, 2.5797798966382155 },
269 { 2.5797798966382155, 3.9200272522448367 } });
270
271 final MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
272 correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], correctCovMats[0].getData());
273 correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], correctCovMats[1].getData());
274
275 MultivariateNormalMixtureExpectationMaximization fitter
276 = new MultivariateNormalMixtureExpectationMaximization(data);
277
278 MixtureMultivariateNormalDistribution initialMix
279 = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
280 fitter.fit(initialMix);
281 MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
282 List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents();
283
284 assertEquals(correctLogLikelihood,
285 fitter.getLogLikelihood(),
286 Math.ulp(1d));
287
288 int i = 0;
289 for (Pair<Double, MultivariateNormalDistribution> component : components) {
290 final double weight = component.getFirst();
291 final MultivariateNormalDistribution mvn = component.getSecond();
292 final double[] mean = mvn.getMeans();
293 final RealMatrix covMat = mvn.getCovariances();
294 assertEquals(correctWeights[i], weight, Math.ulp(1d));
295 assertTrue(Arrays.equals(correctMeans[i], mean));
296 assertEquals(correctCovMats[i], covMat);
297 i++;
298 }
299 }
300
301 private double[][] getTestSamples() {
302
303
304 return new double[][] { { 7.358553610469948, 11.31260831446758 },
305 { 7.175770420124739, 8.988812210204454 },
306 { 4.324151905768422, 6.837727899051482 },
307 { 2.157832219173036, 6.317444585521968 },
308 { -1.890157421896651, 1.74271202875498 },
309 { 0.8922409354455803, 1.999119343923781 },
310 { 3.396949764787055, 6.813170372579068 },
311 { -2.057498232686068, -0.002522983830852255 },
312 { 6.359932157365045, 8.343600029975851 },
313 { 3.353102234276168, 7.087541882898689 },
314 { -1.763877221595639, 0.9688890460330644 },
315 { 6.151457185125111, 9.075011757431174 },
316 { 4.281597398048899, 5.953270070976117 },
317 { 3.549576703974894, 8.616038155992861 },
318 { 6.004706732349854, 8.959423391087469 },
319 { 2.802915014676262, 6.285676742173564 },
320 { -0.6029879029880616, 1.083332958357485 },
321 { 3.631827105398369, 6.743428504049444 },
322 { 6.161125014007315, 9.60920569689001 },
323 { -1.049582894255342, 0.2020017892080281 },
324 { 3.910573022688315, 8.19609909534937 },
325 { 8.180454017634863, 7.861055769719962 },
326 { 1.488945440439716, 8.02699903761247 },
327 { 4.813750847823778, 12.34416881332515 },
328 { 0.0443208501259158, 5.901148093240691 },
329 { 4.416417235068346, 4.465243084006094 },
330 { 4.0002433603072, 6.721937850166174 },
331 { 3.190113818788205, 10.51648348411058 },
332 { 4.493600914967883, 7.938224231022314 },
333 { -3.675669533266189, 4.472845076673303 },
334 { 6.648645511703989, 12.03544085965724 },
335 { -1.330031331404445, 1.33931042964811 },
336 { -3.812111460708707, 2.50534195568356 },
337 { 5.669339356648331, 6.214488981177026 },
338 { 1.006596727153816, 1.51165463112716 },
339 { 5.039466365033024, 7.476532610478689 },
340 { 4.349091929968925, 7.446356406259756 },
341 { -1.220289665119069, 3.403926955951437 },
342 { 5.553003979122395, 6.886518211202239 },
343 { 2.274487732222856, 7.009541508533196 },
344 { 4.147567059965864, 7.34025244349202 },
345 { 4.083882618965819, 6.362852861075623 },
346 { 2.203122344647599, 7.260295257904624 },
347 { -2.147497550770442, 1.262293431529498 },
348 { 2.473700950426512, 6.558900135505638 },
349 { 8.267081298847554, 12.10214104577748 },
350 { 6.91977329776865, 9.91998488301285 },
351 { 0.1680479852730894, 6.28286034168897 },
352 { -1.268578659195158, 2.326711221485755 },
353 { 1.829966451374701, 6.254187605304518 },
354 { 5.648849025754848, 9.330002040750291 },
355 { -2.302874793257666, 3.585545172776065 },
356 { -2.629218791709046, 2.156215538500288 },
357 { 4.036618140700114, 10.2962785719958 },
358 { 0.4616386422783874, 0.6782756325806778 },
359 { -0.3447896073408363, 0.4999834691645118 },
360 { -0.475281453118318, 1.931470384180492 },
361 { 2.382509690609731, 6.071782429815853 },
362 { -3.203934441889096, 2.572079552602468 },
363 { 8.465636032165087, 13.96462998683518 },
364 { 2.36755660870416, 5.7844595007273 },
365 { 0.5935496528993371, 1.374615871358943 },
366 { -2.467481505748694, 2.097224634713005 },
367 { 4.27867444328542, 10.24772361238549 },
368 { -2.013791907543137, 2.013799426047639 },
369 { 6.424588084404173, 9.185334939684516 },
370 { -0.8448238876802175, 0.5447382022282812 },
371 { 1.342955703473923, 8.645456317633556 },
372 { 3.108712208751979, 8.512156853800064 },
373 { 4.343205178315472, 8.056869549234374 },
374 { -2.971767642212396, 3.201180146824761 },
375 { 2.583820931523672, 5.459873414473854 },
376 { 4.209139115268925, 8.171098193546225 },
377 { 0.4064909057902746, 1.454390775518743 },
378 { 3.068642411145223, 6.959485153620035 },
379 { 6.085968972900461, 7.391429799500965 },
380 { -1.342265795764202, 1.454550012997143 },
381 { 6.249773274516883, 6.290269880772023 },
382 { 4.986225847822566, 7.75266344868907 },
383 { 7.642443254378944, 10.19914817500263 },
384 { 6.438181159163673, 8.464396764810347 },
385 { 2.520859761025108, 7.68222425260111 },
386 { 2.883699944257541, 6.777960331348503 },
387 { 2.788004550956599, 6.634735386652733 },
388 { 3.331661231995638, 5.794191300046592 },
389 { 3.526172276645504, 6.710802266815884 },
390 { 3.188298528138741, 10.34495528210205 },
391 { 0.7345539486114623, 5.807604004180681 },
392 { 1.165044595880125, 7.830121829295257 },
393 { 7.146962523500671, 11.62995162065415 },
394 { 7.813872137162087, 10.62827008714735 },
395 { 3.118099164870063, 8.286003148186371 },
396 { -1.708739286262571, 1.561026755374264 },
397 { 1.786163047580084, 4.172394388214604 },
398 { 3.718506403232386, 7.807752990130349 },
399 { 6.167414046828899, 10.01104941031293 },
400 { -1.063477247689196, 1.61176085846339 },
401 { -3.396739609433642, 0.7127911050002151 },
402 { 2.438885945896797, 7.353011138689225 },
403 { -0.2073204144780931, 0.850771146627012 }, };
404 }
405 }