View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements. See the NOTICE file distributed with this
4    * work for additional information regarding copyright ownership. The ASF
5    * licenses this file to You under the Apache License, Version 2.0 (the
6    * "License"); you may not use this file except in compliance with the License.
7    * You may obtain a copy of the License at
8    *
9    * https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14   * License for the specific language governing permissions and limitations under
15   * the License.
16   */
17  package org.hipparchus.stat.fitting;
18  
19  import java.lang.reflect.Constructor;
20  import java.lang.reflect.InvocationTargetException;
21  import java.util.ArrayList;
22  import java.util.Arrays;
23  import java.util.List;
24  
25  import org.hipparchus.distribution.multivariate.MixtureMultivariateNormalDistribution;
26  import org.hipparchus.distribution.multivariate.MultivariateNormalDistribution;
27  import org.hipparchus.exception.LocalizedCoreFormats;
28  import org.hipparchus.exception.MathIllegalArgumentException;
29  import org.hipparchus.exception.MathIllegalStateException;
30  import org.hipparchus.linear.Array2DRowRealMatrix;
31  import org.hipparchus.linear.RealMatrix;
32  import org.hipparchus.util.Pair;
33  import org.junit.Assert;
34  import org.junit.Test;
35  
36  /**
37   * Test that demonstrates the use of
38   * {@link MultivariateNormalMixtureExpectationMaximization}.
39   */
40  public class MultivariateNormalMixtureExpectationMaximizationTest {
41  
42      @Test(expected = MathIllegalArgumentException.class)
43      public void testNonEmptyData() {
44          // Should not accept empty data
45          new MultivariateNormalMixtureExpectationMaximization(new double[][] {});
46      }
47  
48      @Test(expected = MathIllegalArgumentException.class)
49      public void testNonJaggedData() {
50          // Reject data with nonconstant numbers of columns
51          double[][] data = new double[][] {
52                  { 1, 2, 3 },
53                  { 4, 5, 6, 7 },
54          };
55          new MultivariateNormalMixtureExpectationMaximization(data);
56      }
57  
58      @Test(expected = MathIllegalArgumentException.class)
59      public void testMultipleColumnsRequired() {
60          // Data should have at least 2 columns
61          double[][] data = new double[][] {
62                  { 1 }, { 2 }
63          };
64          new MultivariateNormalMixtureExpectationMaximization(data);
65      }
66  
67      @Test(expected = MathIllegalArgumentException.class)
68      public void testMaxIterationsPositive() {
69          // Maximum iterations for fit must be positive integer
70          double[][] data = getTestSamples();
71          MultivariateNormalMixtureExpectationMaximization fitter =
72                  new MultivariateNormalMixtureExpectationMaximization(data);
73  
74          MixtureMultivariateNormalDistribution
75              initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
76  
77          fitter.fit(initialMix, 0, 1E-5);
78      }
79  
80      @Test(expected = MathIllegalArgumentException.class)
81      public void testThresholdPositive() {
82          // Maximum iterations for fit must be positive
83          double[][] data = getTestSamples();
84          MultivariateNormalMixtureExpectationMaximization fitter =
85                  new MultivariateNormalMixtureExpectationMaximization(
86                      data);
87  
88          MixtureMultivariateNormalDistribution
89              initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
90  
91          fitter.fit(initialMix, 1000, 0);
92      }
93  
94      @Test(expected = MathIllegalStateException.class)
95      public void testConvergenceException() {
96          // MathIllegalStateException thrown if fit terminates before threshold met
97          double[][] data = getTestSamples();
98          MultivariateNormalMixtureExpectationMaximization fitter
99              = new MultivariateNormalMixtureExpectationMaximization(data);
100 
101         MixtureMultivariateNormalDistribution
102             initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
103 
104         // 5 iterations not enough to meet convergence threshold
105         fitter.fit(initialMix, 5, 1E-5);
106     }
107 
108     @Test(expected = MathIllegalArgumentException.class)
109     public void testIncompatibleIntialMixture() {
110         // Data has 3 columns
111         double[][] data = new double[][] {
112                 { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 }
113         };
114         double[] weights = new double[] { 0.5, 0.5 };
115 
116         // These distributions are compatible with 2-column data, not 3-column
117         // data
118         MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[2];
119 
120         mvns[0] = new MultivariateNormalDistribution(new double[] {
121                         -0.0021722935000328823, 3.5432892936887908 },
122                         new double[][] {
123                                 { 4.537422569229048, 3.5266152281729304 },
124                                 { 3.5266152281729304, 6.175448814169779 } });
125         mvns[1] = new MultivariateNormalDistribution(new double[] {
126                         5.090902706507635, 8.68540656355283 }, new double[][] {
127                         { 2.886778573963039, 1.5257474543463154 },
128                         { 1.5257474543463154, 3.3794567673616918 } });
129 
130         // Create components and mixture
131         List<Pair<Double, MultivariateNormalDistribution>> components =
132                 new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
133         components.add(new Pair<Double, MultivariateNormalDistribution>(
134                 weights[0], mvns[0]));
135         components.add(new Pair<Double, MultivariateNormalDistribution>(
136                 weights[1], mvns[1]));
137 
138         MixtureMultivariateNormalDistribution badInitialMix
139             = new MixtureMultivariateNormalDistribution(components);
140 
141         MultivariateNormalMixtureExpectationMaximization fitter
142             = new MultivariateNormalMixtureExpectationMaximization(data);
143 
144         fitter.fit(badInitialMix);
145     }
146 
147     @Test
148     public void testInitialMixture() {
149         // Testing initial mixture estimated from data
150         final double[] correctWeights = new double[] { 0.5, 0.5 };
151 
152         final double[][] correctMeans = new double[][] {
153             {-0.0021722935000328823, 3.5432892936887908},
154             {5.090902706507635, 8.68540656355283},
155         };
156 
157         final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
158 
159         correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
160                 { 4.537422569229048, 3.5266152281729304 },
161                 { 3.5266152281729304, 6.175448814169779 } });
162 
163         correctCovMats[1] = new Array2DRowRealMatrix( new double[][] {
164                 { 2.886778573963039, 1.5257474543463154 },
165                 { 1.5257474543463154, 3.3794567673616918 } });
166 
167         final MultivariateNormalDistribution[] correctMVNs = new
168                 MultivariateNormalDistribution[2];
169 
170         correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0],
171                 correctCovMats[0].getData());
172 
173         correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1],
174                 correctCovMats[1].getData());
175 
176         final MixtureMultivariateNormalDistribution initialMix
177             = MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(), 2);
178 
179         int i = 0;
180         for (Pair<Double, MultivariateNormalDistribution> component : initialMix
181                 .getComponents()) {
182             Assert.assertEquals(correctWeights[i], component.getFirst(),
183                     Math.ulp(1d));
184 
185             final double[] means = component.getValue().getMeans();
186             Assert.assertTrue(Arrays.equals(correctMeans[i], means));
187 
188             final RealMatrix covMat = component.getValue().getCovariances();
189             Assert.assertEquals(correctCovMats[i], covMat);
190             i++;
191         }
192     }
193 
194     @Test
195     public void testWrongData() {
196         checkWrongData(new double[1][1], 2, LocalizedCoreFormats.NUMBER_TOO_SMALL);
197         checkWrongData(new double[3][3], 1, LocalizedCoreFormats.NUMBER_TOO_SMALL);
198         checkWrongData(new double[3][3], 4, LocalizedCoreFormats.NUMBER_TOO_LARGE);
199     }
200 
201     private void checkWrongData(final double[][] data, final int numComponents,
202                                 final LocalizedCoreFormats expected) {
203         try {
204             MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents);
205             Assert.fail("an exception should have been thrown");
206         } catch (MathIllegalArgumentException miae) {
207             Assert.assertEquals(expected, miae.getSpecifier());
208         }
209     }
210 
211     @Test
212     public void testUnusedInheritedMethods() {
213         // this test is just meant for coverage issues
214         try {
215             Class<?> dataRowClass = MultivariateNormalMixtureExpectationMaximization.class.getDeclaredClasses()[0];
216             Constructor<?> dataRowConstructor = dataRowClass.getDeclaredConstructor(double[].class);
217             dataRowConstructor.setAccessible(true);
218             Object dr1 = dataRowConstructor.newInstance(new double[] { 1, 2, 3 });
219             Assert.assertEquals(66614367, dr1.hashCode());
220             Assert.assertTrue(dr1.equals(dr1));
221             Assert.assertFalse(dr1.equals(""));
222             Assert.assertTrue(dr1.equals(dataRowConstructor.newInstance(new double[] { 1, 2, 3 })));
223             Assert.assertFalse(dr1.equals(dataRowConstructor.newInstance(new double[] { 3, 2, 1 })));
224         } catch (InvocationTargetException | NoSuchMethodException | SecurityException |
225                  InstantiationException | IllegalAccessException | IllegalArgumentException e) {
226             Assert.fail(e.getLocalizedMessage());
227         }
228     }
229 
230     @Test
231     public void testFit() {
232         // Test that the loglikelihood, weights, and models are determined and
233         // fitted correctly
234         final double[][] data = getTestSamples();
235         final double correctLogLikelihood = -4.292431006791994;
236         final double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 };
237 
238         final double[][] correctMeans = new double[][]{
239             {-1.4213112715121132, 1.6924690505757753},
240             {4.213612224374709, 7.975621325853645}
241         };
242 
243         final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
244         correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
245             { 1.739356907285747, -0.5867644251487614 },
246             { -0.5867644251487614, 1.0232932029324642 } }
247                 );
248         correctCovMats[1] = new Array2DRowRealMatrix(new double[][] {
249             { 4.245384898007161, 2.5797798966382155 },
250             { 2.5797798966382155, 3.9200272522448367 } });
251 
252         final MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
253         correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], correctCovMats[0].getData());
254         correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], correctCovMats[1].getData());
255 
256         MultivariateNormalMixtureExpectationMaximization fitter
257             = new MultivariateNormalMixtureExpectationMaximization(data);
258 
259         MixtureMultivariateNormalDistribution initialMix
260             = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
261         fitter.fit(initialMix);
262         MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
263         List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents();
264 
265         Assert.assertEquals(correctLogLikelihood,
266                             fitter.getLogLikelihood(),
267                             Math.ulp(1d));
268 
269         int i = 0;
270         for (Pair<Double, MultivariateNormalDistribution> component : components) {
271             final double weight = component.getFirst();
272             final MultivariateNormalDistribution mvn = component.getSecond();
273             final double[] mean = mvn.getMeans();
274             final RealMatrix covMat = mvn.getCovariances();
275             Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d));
276             Assert.assertTrue(Arrays.equals(correctMeans[i], mean));
277             Assert.assertEquals(correctCovMats[i], covMat);
278             i++;
279         }
280     }
281 
282     private double[][] getTestSamples() {
283         // generated using R Mixtools rmvnorm with mean vectors [-1.5, 2] and
284         // [4, 8.2]
285         return new double[][] { { 7.358553610469948, 11.31260831446758 },
286                 { 7.175770420124739, 8.988812210204454 },
287                 { 4.324151905768422, 6.837727899051482 },
288                 { 2.157832219173036, 6.317444585521968 },
289                 { -1.890157421896651, 1.74271202875498 },
290                 { 0.8922409354455803, 1.999119343923781 },
291                 { 3.396949764787055, 6.813170372579068 },
292                 { -2.057498232686068, -0.002522983830852255 },
293                 { 6.359932157365045, 8.343600029975851 },
294                 { 3.353102234276168, 7.087541882898689 },
295                 { -1.763877221595639, 0.9688890460330644 },
296                 { 6.151457185125111, 9.075011757431174 },
297                 { 4.281597398048899, 5.953270070976117 },
298                 { 3.549576703974894, 8.616038155992861 },
299                 { 6.004706732349854, 8.959423391087469 },
300                 { 2.802915014676262, 6.285676742173564 },
301                 { -0.6029879029880616, 1.083332958357485 },
302                 { 3.631827105398369, 6.743428504049444 },
303                 { 6.161125014007315, 9.60920569689001 },
304                 { -1.049582894255342, 0.2020017892080281 },
305                 { 3.910573022688315, 8.19609909534937 },
306                 { 8.180454017634863, 7.861055769719962 },
307                 { 1.488945440439716, 8.02699903761247 },
308                 { 4.813750847823778, 12.34416881332515 },
309                 { 0.0443208501259158, 5.901148093240691 },
310                 { 4.416417235068346, 4.465243084006094 },
311                 { 4.0002433603072, 6.721937850166174 },
312                 { 3.190113818788205, 10.51648348411058 },
313                 { 4.493600914967883, 7.938224231022314 },
314                 { -3.675669533266189, 4.472845076673303 },
315                 { 6.648645511703989, 12.03544085965724 },
316                 { -1.330031331404445, 1.33931042964811 },
317                 { -3.812111460708707, 2.50534195568356 },
318                 { 5.669339356648331, 6.214488981177026 },
319                 { 1.006596727153816, 1.51165463112716 },
320                 { 5.039466365033024, 7.476532610478689 },
321                 { 4.349091929968925, 7.446356406259756 },
322                 { -1.220289665119069, 3.403926955951437 },
323                 { 5.553003979122395, 6.886518211202239 },
324                 { 2.274487732222856, 7.009541508533196 },
325                 { 4.147567059965864, 7.34025244349202 },
326                 { 4.083882618965819, 6.362852861075623 },
327                 { 2.203122344647599, 7.260295257904624 },
328                 { -2.147497550770442, 1.262293431529498 },
329                 { 2.473700950426512, 6.558900135505638 },
330                 { 8.267081298847554, 12.10214104577748 },
331                 { 6.91977329776865, 9.91998488301285 },
332                 { 0.1680479852730894, 6.28286034168897 },
333                 { -1.268578659195158, 2.326711221485755 },
334                 { 1.829966451374701, 6.254187605304518 },
335                 { 5.648849025754848, 9.330002040750291 },
336                 { -2.302874793257666, 3.585545172776065 },
337                 { -2.629218791709046, 2.156215538500288 },
338                 { 4.036618140700114, 10.2962785719958 },
339                 { 0.4616386422783874, 0.6782756325806778 },
340                 { -0.3447896073408363, 0.4999834691645118 },
341                 { -0.475281453118318, 1.931470384180492 },
342                 { 2.382509690609731, 6.071782429815853 },
343                 { -3.203934441889096, 2.572079552602468 },
344                 { 8.465636032165087, 13.96462998683518 },
345                 { 2.36755660870416, 5.7844595007273 },
346                 { 0.5935496528993371, 1.374615871358943 },
347                 { -2.467481505748694, 2.097224634713005 },
348                 { 4.27867444328542, 10.24772361238549 },
349                 { -2.013791907543137, 2.013799426047639 },
350                 { 6.424588084404173, 9.185334939684516 },
351                 { -0.8448238876802175, 0.5447382022282812 },
352                 { 1.342955703473923, 8.645456317633556 },
353                 { 3.108712208751979, 8.512156853800064 },
354                 { 4.343205178315472, 8.056869549234374 },
355                 { -2.971767642212396, 3.201180146824761 },
356                 { 2.583820931523672, 5.459873414473854 },
357                 { 4.209139115268925, 8.171098193546225 },
358                 { 0.4064909057902746, 1.454390775518743 },
359                 { 3.068642411145223, 6.959485153620035 },
360                 { 6.085968972900461, 7.391429799500965 },
361                 { -1.342265795764202, 1.454550012997143 },
362                 { 6.249773274516883, 6.290269880772023 },
363                 { 4.986225847822566, 7.75266344868907 },
364                 { 7.642443254378944, 10.19914817500263 },
365                 { 6.438181159163673, 8.464396764810347 },
366                 { 2.520859761025108, 7.68222425260111 },
367                 { 2.883699944257541, 6.777960331348503 },
368                 { 2.788004550956599, 6.634735386652733 },
369                 { 3.331661231995638, 5.794191300046592 },
370                 { 3.526172276645504, 6.710802266815884 },
371                 { 3.188298528138741, 10.34495528210205 },
372                 { 0.7345539486114623, 5.807604004180681 },
373                 { 1.165044595880125, 7.830121829295257 },
374                 { 7.146962523500671, 11.62995162065415 },
375                 { 7.813872137162087, 10.62827008714735 },
376                 { 3.118099164870063, 8.286003148186371 },
377                 { -1.708739286262571, 1.561026755374264 },
378                 { 1.786163047580084, 4.172394388214604 },
379                 { 3.718506403232386, 7.807752990130349 },
380                 { 6.167414046828899, 10.01104941031293 },
381                 { -1.063477247689196, 1.61176085846339 },
382                 { -3.396739609433642, 0.7127911050002151 },
383                 { 2.438885945896797, 7.353011138689225 },
384                 { -0.2073204144780931, 0.850771146627012 }, };
385     }
386 }