View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.distribution.multivariate;
24  
25  import org.hipparchus.UnitTestUtils;
26  import org.hipparchus.distribution.continuous.NormalDistribution;
27  import org.hipparchus.exception.LocalizedCoreFormats;
28  import org.hipparchus.exception.MathIllegalArgumentException;
29  import org.hipparchus.linear.Array2DRowRealMatrix;
30  import org.hipparchus.linear.RealMatrix;
31  import org.hipparchus.random.Well19937c;
32  import org.hipparchus.util.Precision;
33  import org.junit.jupiter.api.Test;
34  
35  import java.util.Random;
36  
37  import static org.junit.jupiter.api.Assertions.assertEquals;
38  import static org.junit.jupiter.api.Assertions.fail;
39  
40  /**
41   * Test cases for {@link MultivariateNormalDistribution}.
42   */
43  class MultivariateNormalDistributionTest {
44      /**
45       * Test the ability of the distribution to report its mean value parameter.
46       */
47      @Test
48      void testGetMean() {
49          final double[] mu = { -1.5, 2 };
50          final double[][] sigma = { { 2, -1.1 },
51                                     { -1.1, 2 } };
52          final MultivariateNormalDistribution d = new MultivariateNormalDistribution(mu, sigma);
53  
54          final double[] m = d.getMeans();
55          for (int i = 0; i < m.length; i++) {
56              assertEquals(mu[i], m[i], 0);
57          }
58      }
59  
60      /**
61       * Test the ability of the distribution to report its covariance matrix parameter.
62       */
63      @Test
64      void testGetCovarianceMatrix() {
65          final double[] mu = { -1.5, 2 };
66          final double[][] sigma = { { 2, -1.1 },
67                                     { -1.1, 2 } };
68          final MultivariateNormalDistribution d = new MultivariateNormalDistribution(mu, sigma);
69  
70          final RealMatrix s = d.getCovariances();
71          final int dim = d.getDimension();
72          for (int i = 0; i < dim; i++) {
73              for (int j = 0; j < dim; j++) {
74                  assertEquals(sigma[i][j], s.getEntry(i, j), 0);
75              }
76          }
77      }
78  
79      /**
80       * Test the accuracy of sampling from the distribution.
81       */
82      @Test
83      void testSampling() {
84          final double[] mu = { -1.5, 2 };
85          final double[][] sigma = { { 2, -1.1 },
86                                     { -1.1, 2 } };
87          final MultivariateNormalDistribution d = new MultivariateNormalDistribution(mu, sigma);
88          d.reseedRandomGenerator(50);
89  
90          final int n = 500000;
91  
92          final double[][] samples = d.sample(n);
93          final int dim = d.getDimension();
94          final double[] sampleMeans = new double[dim];
95  
96          for (int i = 0; i < samples.length; i++) {
97              for (int j = 0; j < dim; j++) {
98                  sampleMeans[j] += samples[i][j];
99              }
100         }
101 
102         final double sampledValueTolerance = 1e-2;
103         for (int j = 0; j < dim; j++) {
104             sampleMeans[j] /= samples.length;
105             assertEquals(mu[j], sampleMeans[j], sampledValueTolerance);
106         }
107 
108         //final double[][] sampleSigma = new Covariance(samples).getCovarianceMatrix().getData();
109         final RealMatrix sampleSigma = UnitTestUtils.covarianceMatrix(new Array2DRowRealMatrix(samples));
110         for (int i = 0; i < dim; i++) {
111             for (int j = 0; j < dim; j++) {
112                 assertEquals(sigma[i][j], sampleSigma.getEntry(i, j), sampledValueTolerance);
113             }
114         }
115     }
116 
117     /**
118      * Test the accuracy of the distribution when calculating densities.
119      */
120     @Test
121     void testDensities() {
122         final double[] mu = { -1.5, 2 };
123         final double[][] sigma = { { 2, -1.1 },
124                                    { -1.1, 2 } };
125         final MultivariateNormalDistribution d = new MultivariateNormalDistribution(mu, sigma);
126 
127         final double[][] testValues = { { -1.5, 2 },
128                                         { 4, 4 },
129                                         { 1.5, -2 },
130                                         { 0, 0 } };
131         final double[] densities = new double[testValues.length];
132         for (int i = 0; i < densities.length; i++) {
133             densities[i] = d.density(testValues[i]);
134         }
135 
136         // From dmvnorm function in R 2.15 CRAN package Mixtools v0.4.5
137         final double[] correctDensities = { 0.09528357207691344,
138                                             5.80932710124009e-09,
139                                             0.001387448895173267,
140                                             0.03309922090210541 };
141 
142         for (int i = 0; i < testValues.length; i++) {
143             assertEquals(correctDensities[i], densities[i], 1e-16);
144         }
145     }
146 
147     /**
148      * Test the accuracy of the distribution when calculating densities.
149      */
150     @Test
151     void testUnivariateDistribution() {
152         final double[] mu = { -1.5 };
153         final double[][] sigma = { { 1 } };
154 
155         final MultivariateNormalDistribution multi = new MultivariateNormalDistribution(mu, sigma);
156 
157         final NormalDistribution uni = new NormalDistribution(mu[0], sigma[0][0]);
158         final Random rng = new Random();
159         final int numCases = 100;
160         final double tol = Math.ulp(1d);
161         for (int i = 0; i < numCases; i++) {
162             final double v = rng.nextDouble() * 10 - 5;
163             assertEquals(uni.density(v), multi.density(new double[] { v }), tol);
164         }
165     }
166 
167     /**
168      * Test getting/setting custom singularMatrixTolerance
169      */
170     @Test
171     void testGetSingularMatrixTolerance() {
172         final double[] mu = { -1.5 };
173         final double[][] sigma = { { 1 } };
174 
175         final double tolerance1 = 1e-2;
176         final MultivariateNormalDistribution mvd1 = new MultivariateNormalDistribution(mu, sigma, tolerance1);
177         assertEquals(tolerance1, mvd1.getSingularMatrixCheckTolerance(), Precision.EPSILON);
178 
179         final double tolerance2 = 1e-3;
180         final MultivariateNormalDistribution mvd2 = new MultivariateNormalDistribution(mu, sigma, tolerance2);
181         assertEquals(tolerance2, mvd2.getSingularMatrixCheckTolerance(), Precision.EPSILON);
182     }
183 
184     @Test
185     void testNotPositiveDefinite() {
186         try {
187             new MultivariateNormalDistribution(new Well19937c(0x543l), new double[2],
188                                                new double[][] { { -1.0, 0.0 }, { 0.0, -2.0 } });
189             fail("an exception should have been thrown");
190         } catch (MathIllegalArgumentException miae) {
191             assertEquals(LocalizedCoreFormats.NOT_POSITIVE_DEFINITE_MATRIX, miae.getSpecifier());
192         }
193     }
194 
195     @Test
196     void testStd() {
197         MultivariateNormalDistribution d = new MultivariateNormalDistribution(new Well19937c(0x543l), new double[2],
198                                                                               new double[][] { { 4.0, 0.0 }, { 0.0, 9.0 } });
199         double[] s = d.getStandardDeviations();
200         assertEquals(2, s.length);
201         assertEquals(2.0, s[0], 1.0e-15);
202         assertEquals(3.0, s[1], 1.0e-15);
203     }
204 
205     @Test
206     void testWrongDensity() {
207         try {
208             MultivariateNormalDistribution d = new MultivariateNormalDistribution(new Well19937c(0x543l), new double[2],
209                                                                                   new double[][] { { 4.0, 0.0 }, { 0.0, 4.0 } });
210             d.density(new double[3]);
211             fail("an exception should have been thrown");
212         } catch (MathIllegalArgumentException miae) {
213             assertEquals(LocalizedCoreFormats.DIMENSIONS_MISMATCH, miae.getSpecifier());
214         }
215     }
216 
217     @Test
218     void testWrongArguments() {
219         checkWrongArguments(new double[3], new double[6][6]);
220         checkWrongArguments(new double[3], new double[3][6]);
221     }
222 
223     private void checkWrongArguments(double[] means, double[][] covariances) {
224         try {
225             new MultivariateNormalDistribution(new Well19937c(0x543l), means, covariances);
226             fail("an exception should have been thrown");
227         } catch (MathIllegalArgumentException miae) {
228             assertEquals(LocalizedCoreFormats.DIMENSIONS_MISMATCH, miae.getSpecifier());
229         }
230     }
231 }