1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.distribution.multivariate;
24
25 import org.hipparchus.UnitTestUtils;
26 import org.hipparchus.distribution.continuous.NormalDistribution;
27 import org.hipparchus.exception.LocalizedCoreFormats;
28 import org.hipparchus.exception.MathIllegalArgumentException;
29 import org.hipparchus.linear.Array2DRowRealMatrix;
30 import org.hipparchus.linear.RealMatrix;
31 import org.hipparchus.random.Well19937c;
32 import org.hipparchus.util.Precision;
33 import org.junit.jupiter.api.Test;
34
35 import java.util.Random;
36
37 import static org.junit.jupiter.api.Assertions.assertEquals;
38 import static org.junit.jupiter.api.Assertions.fail;
39
40
41
42
43 class MultivariateNormalDistributionTest {
44
45
46
47 @Test
48 void testGetMean() {
49 final double[] mu = { -1.5, 2 };
50 final double[][] sigma = { { 2, -1.1 },
51 { -1.1, 2 } };
52 final MultivariateNormalDistribution d = new MultivariateNormalDistribution(mu, sigma);
53
54 final double[] m = d.getMeans();
55 for (int i = 0; i < m.length; i++) {
56 assertEquals(mu[i], m[i], 0);
57 }
58 }
59
60
61
62
63 @Test
64 void testGetCovarianceMatrix() {
65 final double[] mu = { -1.5, 2 };
66 final double[][] sigma = { { 2, -1.1 },
67 { -1.1, 2 } };
68 final MultivariateNormalDistribution d = new MultivariateNormalDistribution(mu, sigma);
69
70 final RealMatrix s = d.getCovariances();
71 final int dim = d.getDimension();
72 for (int i = 0; i < dim; i++) {
73 for (int j = 0; j < dim; j++) {
74 assertEquals(sigma[i][j], s.getEntry(i, j), 0);
75 }
76 }
77 }
78
79
80
81
82 @Test
83 void testSampling() {
84 final double[] mu = { -1.5, 2 };
85 final double[][] sigma = { { 2, -1.1 },
86 { -1.1, 2 } };
87 final MultivariateNormalDistribution d = new MultivariateNormalDistribution(mu, sigma);
88 d.reseedRandomGenerator(50);
89
90 final int n = 500000;
91
92 final double[][] samples = d.sample(n);
93 final int dim = d.getDimension();
94 final double[] sampleMeans = new double[dim];
95
96 for (int i = 0; i < samples.length; i++) {
97 for (int j = 0; j < dim; j++) {
98 sampleMeans[j] += samples[i][j];
99 }
100 }
101
102 final double sampledValueTolerance = 1e-2;
103 for (int j = 0; j < dim; j++) {
104 sampleMeans[j] /= samples.length;
105 assertEquals(mu[j], sampleMeans[j], sampledValueTolerance);
106 }
107
108
109 final RealMatrix sampleSigma = UnitTestUtils.covarianceMatrix(new Array2DRowRealMatrix(samples));
110 for (int i = 0; i < dim; i++) {
111 for (int j = 0; j < dim; j++) {
112 assertEquals(sigma[i][j], sampleSigma.getEntry(i, j), sampledValueTolerance);
113 }
114 }
115 }
116
117
118
119
120 @Test
121 void testDensities() {
122 final double[] mu = { -1.5, 2 };
123 final double[][] sigma = { { 2, -1.1 },
124 { -1.1, 2 } };
125 final MultivariateNormalDistribution d = new MultivariateNormalDistribution(mu, sigma);
126
127 final double[][] testValues = { { -1.5, 2 },
128 { 4, 4 },
129 { 1.5, -2 },
130 { 0, 0 } };
131 final double[] densities = new double[testValues.length];
132 for (int i = 0; i < densities.length; i++) {
133 densities[i] = d.density(testValues[i]);
134 }
135
136
137 final double[] correctDensities = { 0.09528357207691344,
138 5.80932710124009e-09,
139 0.001387448895173267,
140 0.03309922090210541 };
141
142 for (int i = 0; i < testValues.length; i++) {
143 assertEquals(correctDensities[i], densities[i], 1e-16);
144 }
145 }
146
147
148
149
150 @Test
151 void testUnivariateDistribution() {
152 final double[] mu = { -1.5 };
153 final double[][] sigma = { { 1 } };
154
155 final MultivariateNormalDistribution multi = new MultivariateNormalDistribution(mu, sigma);
156
157 final NormalDistribution uni = new NormalDistribution(mu[0], sigma[0][0]);
158 final Random rng = new Random();
159 final int numCases = 100;
160 final double tol = Math.ulp(1d);
161 for (int i = 0; i < numCases; i++) {
162 final double v = rng.nextDouble() * 10 - 5;
163 assertEquals(uni.density(v), multi.density(new double[] { v }), tol);
164 }
165 }
166
167
168
169
170 @Test
171 void testGetSingularMatrixTolerance() {
172 final double[] mu = { -1.5 };
173 final double[][] sigma = { { 1 } };
174
175 final double tolerance1 = 1e-2;
176 final MultivariateNormalDistribution mvd1 = new MultivariateNormalDistribution(mu, sigma, tolerance1);
177 assertEquals(tolerance1, mvd1.getSingularMatrixCheckTolerance(), Precision.EPSILON);
178
179 final double tolerance2 = 1e-3;
180 final MultivariateNormalDistribution mvd2 = new MultivariateNormalDistribution(mu, sigma, tolerance2);
181 assertEquals(tolerance2, mvd2.getSingularMatrixCheckTolerance(), Precision.EPSILON);
182 }
183
184 @Test
185 void testNotPositiveDefinite() {
186 try {
187 new MultivariateNormalDistribution(new Well19937c(0x543l), new double[2],
188 new double[][] { { -1.0, 0.0 }, { 0.0, -2.0 } });
189 fail("an exception should have been thrown");
190 } catch (MathIllegalArgumentException miae) {
191 assertEquals(LocalizedCoreFormats.NOT_POSITIVE_DEFINITE_MATRIX, miae.getSpecifier());
192 }
193 }
194
195 @Test
196 void testStd() {
197 MultivariateNormalDistribution d = new MultivariateNormalDistribution(new Well19937c(0x543l), new double[2],
198 new double[][] { { 4.0, 0.0 }, { 0.0, 9.0 } });
199 double[] s = d.getStandardDeviations();
200 assertEquals(2, s.length);
201 assertEquals(2.0, s[0], 1.0e-15);
202 assertEquals(3.0, s[1], 1.0e-15);
203 }
204
205 @Test
206 void testWrongDensity() {
207 try {
208 MultivariateNormalDistribution d = new MultivariateNormalDistribution(new Well19937c(0x543l), new double[2],
209 new double[][] { { 4.0, 0.0 }, { 0.0, 4.0 } });
210 d.density(new double[3]);
211 fail("an exception should have been thrown");
212 } catch (MathIllegalArgumentException miae) {
213 assertEquals(LocalizedCoreFormats.DIMENSIONS_MISMATCH, miae.getSpecifier());
214 }
215 }
216
217 @Test
218 void testWrongArguments() {
219 checkWrongArguments(new double[3], new double[6][6]);
220 checkWrongArguments(new double[3], new double[3][6]);
221 }
222
223 private void checkWrongArguments(double[] means, double[][] covariances) {
224 try {
225 new MultivariateNormalDistribution(new Well19937c(0x543l), means, covariances);
226 fail("an exception should have been thrown");
227 } catch (MathIllegalArgumentException miae) {
228 assertEquals(LocalizedCoreFormats.DIMENSIONS_MISMATCH, miae.getSpecifier());
229 }
230 }
231 }