View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements. See the NOTICE file distributed with this
4    * work for additional information regarding copyright ownership. The ASF
5    * licenses this file to You under the Apache License, Version 2.0 (the
6    * "License"); you may not use this file except in compliance with the License.
7    * You may obtain a copy of the License at
8    *
9    * https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14   * License for the specific language governing permissions and limitations under
15   * the License.
16   */
17  package org.hipparchus.stat.fitting;
18  
19  import org.hipparchus.distribution.multivariate.MixtureMultivariateNormalDistribution;
20  import org.hipparchus.distribution.multivariate.MultivariateNormalDistribution;
21  import org.hipparchus.exception.LocalizedCoreFormats;
22  import org.hipparchus.exception.MathIllegalArgumentException;
23  import org.hipparchus.exception.MathIllegalStateException;
24  import org.hipparchus.linear.Array2DRowRealMatrix;
25  import org.hipparchus.linear.RealMatrix;
26  import org.hipparchus.util.Pair;
27  import org.junit.jupiter.api.Test;
28  
29  import java.lang.reflect.Constructor;
30  import java.lang.reflect.InvocationTargetException;
31  import java.util.ArrayList;
32  import java.util.Arrays;
33  import java.util.List;
34  
35  import static org.junit.jupiter.api.Assertions.assertEquals;
36  import static org.junit.jupiter.api.Assertions.assertNotEquals;
37  import static org.junit.jupiter.api.Assertions.assertThrows;
38  import static org.junit.jupiter.api.Assertions.assertTrue;
39  import static org.junit.jupiter.api.Assertions.fail;
40  
41  /**
42   * Test that demonstrates the use of
43   * {@link MultivariateNormalMixtureExpectationMaximization}.
44   */
45  class MultivariateNormalMixtureExpectationMaximizationTest {
46  
47      @Test
48      void testNonEmptyData() {
49          assertThrows(MathIllegalArgumentException.class, () -> {
50              // Should not accept empty data
51              new MultivariateNormalMixtureExpectationMaximization(new double[][]{});
52          });
53      }
54  
55      @Test
56      void testNonJaggedData() {
57          assertThrows(MathIllegalArgumentException.class, () -> {
58              // Reject data with nonconstant numbers of columns
59              double[][] data = new double[][]{
60                  {1, 2, 3},
61                  {4, 5, 6, 7},
62              };
63              new MultivariateNormalMixtureExpectationMaximization(data);
64          });
65      }
66  
67      @Test
68      void testMultipleColumnsRequired() {
69          assertThrows(MathIllegalArgumentException.class, () -> {
70              // Data should have at least 2 columns
71              double[][] data = new double[][]{
72                  {1}, {2}
73              };
74              new MultivariateNormalMixtureExpectationMaximization(data);
75          });
76      }
77  
78      @Test
79      void testMaxIterationsPositive() {
80          assertThrows(MathIllegalArgumentException.class, () -> {
81              // Maximum iterations for fit must be positive integer
82              double[][] data = getTestSamples();
83              MultivariateNormalMixtureExpectationMaximization fitter =
84                  new MultivariateNormalMixtureExpectationMaximization(data);
85  
86              MixtureMultivariateNormalDistribution
87                  initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
88  
89              fitter.fit(initialMix, 0, 1E-5);
90          });
91      }
92  
93      @Test
94      void testThresholdPositive() {
95          assertThrows(MathIllegalArgumentException.class, () -> {
96              // Maximum iterations for fit must be positive
97              double[][] data = getTestSamples();
98              MultivariateNormalMixtureExpectationMaximization fitter =
99                  new MultivariateNormalMixtureExpectationMaximization(
100                     data);
101 
102             MixtureMultivariateNormalDistribution
103                 initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
104 
105             fitter.fit(initialMix, 1000, 0);
106         });
107     }
108 
109     @Test
110     void testConvergenceException() {
111         assertThrows(MathIllegalStateException.class, () -> {
112             // MathIllegalStateException thrown if fit terminates before threshold met
113             double[][] data = getTestSamples();
114             MultivariateNormalMixtureExpectationMaximization fitter
115                 = new MultivariateNormalMixtureExpectationMaximization(data);
116 
117             MixtureMultivariateNormalDistribution
118                 initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
119 
120             // 5 iterations not enough to meet convergence threshold
121             fitter.fit(initialMix, 5, 1E-5);
122         });
123     }
124 
125     @Test
126     void testIncompatibleIntialMixture() {
127         assertThrows(MathIllegalArgumentException.class, () -> {
128             // Data has 3 columns
129             double[][] data = new double[][]{
130                 {1, 2, 3}, {4, 5, 6}, {7, 8, 9}
131             };
132             double[] weights = new double[]{0.5, 0.5};
133 
134             // These distributions are compatible with 2-column data, not 3-column
135             // data
136             MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[2];
137 
138             mvns[0] = new MultivariateNormalDistribution(new double[]{
139                     -0.0021722935000328823, 3.5432892936887908},
140                 new double[][]{
141                     {4.537422569229048, 3.5266152281729304},
142                     {3.5266152281729304, 6.175448814169779}});
143             mvns[1] = new MultivariateNormalDistribution(new double[]{
144                 5.090902706507635, 8.68540656355283}, new double[][]{
145                 {2.886778573963039, 1.5257474543463154},
146                 {1.5257474543463154, 3.3794567673616918}});
147 
148             // Create components and mixture
149             List<Pair<Double, MultivariateNormalDistribution>> components =
150                 new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
151             components.add(new Pair<Double, MultivariateNormalDistribution>(
152                 weights[0], mvns[0]));
153             components.add(new Pair<Double, MultivariateNormalDistribution>(
154                 weights[1], mvns[1]));
155 
156             MixtureMultivariateNormalDistribution badInitialMix
157                 = new MixtureMultivariateNormalDistribution(components);
158 
159             MultivariateNormalMixtureExpectationMaximization fitter
160                 = new MultivariateNormalMixtureExpectationMaximization(data);
161 
162             fitter.fit(badInitialMix);
163         });
164     }
165 
166     @Test
167     void testInitialMixture() {
168         // Testing initial mixture estimated from data
169         final double[] correctWeights = new double[] { 0.5, 0.5 };
170 
171         final double[][] correctMeans = new double[][] {
172             {-0.0021722935000328823, 3.5432892936887908},
173             {5.090902706507635, 8.68540656355283},
174         };
175 
176         final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
177 
178         correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
179                 { 4.537422569229048, 3.5266152281729304 },
180                 { 3.5266152281729304, 6.175448814169779 } });
181 
182         correctCovMats[1] = new Array2DRowRealMatrix( new double[][] {
183                 { 2.886778573963039, 1.5257474543463154 },
184                 { 1.5257474543463154, 3.3794567673616918 } });
185 
186         final MultivariateNormalDistribution[] correctMVNs = new
187                 MultivariateNormalDistribution[2];
188 
189         correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0],
190                 correctCovMats[0].getData());
191 
192         correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1],
193                 correctCovMats[1].getData());
194 
195         final MixtureMultivariateNormalDistribution initialMix
196             = MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(), 2);
197 
198         int i = 0;
199         for (Pair<Double, MultivariateNormalDistribution> component : initialMix
200                 .getComponents()) {
201             assertEquals(correctWeights[i], component.getFirst(),
202                     Math.ulp(1d));
203 
204             final double[] means = component.getValue().getMeans();
205             assertTrue(Arrays.equals(correctMeans[i], means));
206 
207             final RealMatrix covMat = component.getValue().getCovariances();
208             assertEquals(correctCovMats[i], covMat);
209             i++;
210         }
211     }
212 
213     @Test
214     void testWrongData() {
215         checkWrongData(new double[1][1], 2, LocalizedCoreFormats.NUMBER_TOO_SMALL);
216         checkWrongData(new double[3][3], 1, LocalizedCoreFormats.NUMBER_TOO_SMALL);
217         checkWrongData(new double[3][3], 4, LocalizedCoreFormats.NUMBER_TOO_LARGE);
218     }
219 
220     private void checkWrongData(final double[][] data, final int numComponents,
221                                 final LocalizedCoreFormats expected) {
222         try {
223             MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents);
224             fail("an exception should have been thrown");
225         } catch (MathIllegalArgumentException miae) {
226             assertEquals(expected, miae.getSpecifier());
227         }
228     }
229 
230     @Test
231     void testUnusedInheritedMethods() {
232         // this test is just meant for coverage issues
233         try {
234             Class<?> dataRowClass = MultivariateNormalMixtureExpectationMaximization.class.getDeclaredClasses()[0];
235             Constructor<?> dataRowConstructor = dataRowClass.getDeclaredConstructor(double[].class);
236             dataRowConstructor.setAccessible(true);
237             Object dr1 = dataRowConstructor.newInstance(new double[] { 1, 2, 3 });
238             assertEquals(66614367, dr1.hashCode());
239             assertEquals(dr1, dr1);
240             assertNotEquals("", dr1);
241             assertEquals(dr1, dataRowConstructor.newInstance(new double[]{1, 2, 3}));
242             assertNotEquals(dr1, dataRowConstructor.newInstance(new double[]{3, 2, 1}));
243         } catch (InvocationTargetException | NoSuchMethodException | SecurityException |
244                  InstantiationException | IllegalAccessException | IllegalArgumentException e) {
245             fail(e.getLocalizedMessage());
246         }
247     }
248 
249     @Test
250     void testFit() {
251         // Test that the loglikelihood, weights, and models are determined and
252         // fitted correctly
253         final double[][] data = getTestSamples();
254         final double correctLogLikelihood = -4.292431006791994;
255         final double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 };
256 
257         final double[][] correctMeans = new double[][]{
258             {-1.4213112715121132, 1.6924690505757753},
259             {4.213612224374709, 7.975621325853645}
260         };
261 
262         final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
263         correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
264             { 1.739356907285747, -0.5867644251487614 },
265             { -0.5867644251487614, 1.0232932029324642 } }
266                 );
267         correctCovMats[1] = new Array2DRowRealMatrix(new double[][] {
268             { 4.245384898007161, 2.5797798966382155 },
269             { 2.5797798966382155, 3.9200272522448367 } });
270 
271         final MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
272         correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], correctCovMats[0].getData());
273         correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], correctCovMats[1].getData());
274 
275         MultivariateNormalMixtureExpectationMaximization fitter
276             = new MultivariateNormalMixtureExpectationMaximization(data);
277 
278         MixtureMultivariateNormalDistribution initialMix
279             = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
280         fitter.fit(initialMix);
281         MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
282         List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents();
283 
284         assertEquals(correctLogLikelihood,
285                             fitter.getLogLikelihood(),
286                             Math.ulp(1d));
287 
288         int i = 0;
289         for (Pair<Double, MultivariateNormalDistribution> component : components) {
290             final double weight = component.getFirst();
291             final MultivariateNormalDistribution mvn = component.getSecond();
292             final double[] mean = mvn.getMeans();
293             final RealMatrix covMat = mvn.getCovariances();
294             assertEquals(correctWeights[i], weight, Math.ulp(1d));
295             assertTrue(Arrays.equals(correctMeans[i], mean));
296             assertEquals(correctCovMats[i], covMat);
297             i++;
298         }
299     }
300 
301     private double[][] getTestSamples() {
302         // generated using R Mixtools rmvnorm with mean vectors [-1.5, 2] and
303         // [4, 8.2]
304         return new double[][] { { 7.358553610469948, 11.31260831446758 },
305                 { 7.175770420124739, 8.988812210204454 },
306                 { 4.324151905768422, 6.837727899051482 },
307                 { 2.157832219173036, 6.317444585521968 },
308                 { -1.890157421896651, 1.74271202875498 },
309                 { 0.8922409354455803, 1.999119343923781 },
310                 { 3.396949764787055, 6.813170372579068 },
311                 { -2.057498232686068, -0.002522983830852255 },
312                 { 6.359932157365045, 8.343600029975851 },
313                 { 3.353102234276168, 7.087541882898689 },
314                 { -1.763877221595639, 0.9688890460330644 },
315                 { 6.151457185125111, 9.075011757431174 },
316                 { 4.281597398048899, 5.953270070976117 },
317                 { 3.549576703974894, 8.616038155992861 },
318                 { 6.004706732349854, 8.959423391087469 },
319                 { 2.802915014676262, 6.285676742173564 },
320                 { -0.6029879029880616, 1.083332958357485 },
321                 { 3.631827105398369, 6.743428504049444 },
322                 { 6.161125014007315, 9.60920569689001 },
323                 { -1.049582894255342, 0.2020017892080281 },
324                 { 3.910573022688315, 8.19609909534937 },
325                 { 8.180454017634863, 7.861055769719962 },
326                 { 1.488945440439716, 8.02699903761247 },
327                 { 4.813750847823778, 12.34416881332515 },
328                 { 0.0443208501259158, 5.901148093240691 },
329                 { 4.416417235068346, 4.465243084006094 },
330                 { 4.0002433603072, 6.721937850166174 },
331                 { 3.190113818788205, 10.51648348411058 },
332                 { 4.493600914967883, 7.938224231022314 },
333                 { -3.675669533266189, 4.472845076673303 },
334                 { 6.648645511703989, 12.03544085965724 },
335                 { -1.330031331404445, 1.33931042964811 },
336                 { -3.812111460708707, 2.50534195568356 },
337                 { 5.669339356648331, 6.214488981177026 },
338                 { 1.006596727153816, 1.51165463112716 },
339                 { 5.039466365033024, 7.476532610478689 },
340                 { 4.349091929968925, 7.446356406259756 },
341                 { -1.220289665119069, 3.403926955951437 },
342                 { 5.553003979122395, 6.886518211202239 },
343                 { 2.274487732222856, 7.009541508533196 },
344                 { 4.147567059965864, 7.34025244349202 },
345                 { 4.083882618965819, 6.362852861075623 },
346                 { 2.203122344647599, 7.260295257904624 },
347                 { -2.147497550770442, 1.262293431529498 },
348                 { 2.473700950426512, 6.558900135505638 },
349                 { 8.267081298847554, 12.10214104577748 },
350                 { 6.91977329776865, 9.91998488301285 },
351                 { 0.1680479852730894, 6.28286034168897 },
352                 { -1.268578659195158, 2.326711221485755 },
353                 { 1.829966451374701, 6.254187605304518 },
354                 { 5.648849025754848, 9.330002040750291 },
355                 { -2.302874793257666, 3.585545172776065 },
356                 { -2.629218791709046, 2.156215538500288 },
357                 { 4.036618140700114, 10.2962785719958 },
358                 { 0.4616386422783874, 0.6782756325806778 },
359                 { -0.3447896073408363, 0.4999834691645118 },
360                 { -0.475281453118318, 1.931470384180492 },
361                 { 2.382509690609731, 6.071782429815853 },
362                 { -3.203934441889096, 2.572079552602468 },
363                 { 8.465636032165087, 13.96462998683518 },
364                 { 2.36755660870416, 5.7844595007273 },
365                 { 0.5935496528993371, 1.374615871358943 },
366                 { -2.467481505748694, 2.097224634713005 },
367                 { 4.27867444328542, 10.24772361238549 },
368                 { -2.013791907543137, 2.013799426047639 },
369                 { 6.424588084404173, 9.185334939684516 },
370                 { -0.8448238876802175, 0.5447382022282812 },
371                 { 1.342955703473923, 8.645456317633556 },
372                 { 3.108712208751979, 8.512156853800064 },
373                 { 4.343205178315472, 8.056869549234374 },
374                 { -2.971767642212396, 3.201180146824761 },
375                 { 2.583820931523672, 5.459873414473854 },
376                 { 4.209139115268925, 8.171098193546225 },
377                 { 0.4064909057902746, 1.454390775518743 },
378                 { 3.068642411145223, 6.959485153620035 },
379                 { 6.085968972900461, 7.391429799500965 },
380                 { -1.342265795764202, 1.454550012997143 },
381                 { 6.249773274516883, 6.290269880772023 },
382                 { 4.986225847822566, 7.75266344868907 },
383                 { 7.642443254378944, 10.19914817500263 },
384                 { 6.438181159163673, 8.464396764810347 },
385                 { 2.520859761025108, 7.68222425260111 },
386                 { 2.883699944257541, 6.777960331348503 },
387                 { 2.788004550956599, 6.634735386652733 },
388                 { 3.331661231995638, 5.794191300046592 },
389                 { 3.526172276645504, 6.710802266815884 },
390                 { 3.188298528138741, 10.34495528210205 },
391                 { 0.7345539486114623, 5.807604004180681 },
392                 { 1.165044595880125, 7.830121829295257 },
393                 { 7.146962523500671, 11.62995162065415 },
394                 { 7.813872137162087, 10.62827008714735 },
395                 { 3.118099164870063, 8.286003148186371 },
396                 { -1.708739286262571, 1.561026755374264 },
397                 { 1.786163047580084, 4.172394388214604 },
398                 { 3.718506403232386, 7.807752990130349 },
399                 { 6.167414046828899, 10.01104941031293 },
400                 { -1.063477247689196, 1.61176085846339 },
401                 { -3.396739609433642, 0.7127911050002151 },
402                 { 2.438885945896797, 7.353011138689225 },
403                 { -0.2073204144780931, 0.850771146627012 }, };
404     }
405 }