View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.distribution.multivariate;
23  
24  import org.hipparchus.exception.LocalizedCoreFormats;
25  import org.hipparchus.exception.MathIllegalArgumentException;
26  import org.hipparchus.linear.Array2DRowRealMatrix;
27  import org.hipparchus.linear.EigenDecompositionSymmetric;
28  import org.hipparchus.linear.RealMatrix;
29  import org.hipparchus.random.RandomGenerator;
30  import org.hipparchus.random.Well19937c;
31  import org.hipparchus.util.FastMath;
32  import org.hipparchus.util.Precision;
33  
34  /**
35   * Implementation of the multivariate normal (Gaussian) distribution.
36   *
37   * @see <a href="http://en.wikipedia.org/wiki/Multivariate_normal_distribution">
38   * Multivariate normal distribution (Wikipedia)</a>
39   * @see <a href="http://mathworld.wolfram.com/MultivariateNormalDistribution.html">
40   * Multivariate normal distribution (MathWorld)</a>
41   */
42  public class MultivariateNormalDistribution
43      extends AbstractMultivariateRealDistribution {
44      /** Default singular matrix tolerance check value **/
45      private static final double DEFAULT_TOLERANCE = Precision.EPSILON;
46  
47      /** Vector of means. */
48      private final double[] means;
49      /** Covariance matrix. */
50      private final RealMatrix covarianceMatrix;
51      /** The matrix inverse of the covariance matrix. */
52      private final RealMatrix covarianceMatrixInverse;
53      /** The determinant of the covariance matrix. */
54      private final double covarianceMatrixDeterminant;
55      /** Matrix used in computation of samples. */
56      private final RealMatrix samplingMatrix;
57      /** Inverse singular check tolerance when testing if invertable **/
58      private final double singularMatrixCheckTolerance;
59  
60      /**
61       * Creates a multivariate normal distribution with the given mean vector and
62       * covariance matrix.<br>
63       * The number of dimensions is equal to the length of the mean vector
64       * and to the number of rows and columns of the covariance matrix.
65       * It is frequently written as "p" in formulae.
66       * <p>
67       * <b>Note:</b> this constructor will implicitly create an instance of
68       * {@link Well19937c} as random generator to be used for sampling only (see
69       * {@link #sample()} and {@link #sample(int)}). In case no sampling is
70       * needed for the created distribution, it is advised to pass {@code null}
71       * as random generator via the appropriate constructors to avoid the
72       * additional initialisation overhead.
73       *
74       * @param means Vector of means.
75       * @param covariances Covariance matrix.
76       * @throws MathIllegalArgumentException if the arrays length are
77       * inconsistent.
78       * @throws MathIllegalArgumentException if the eigenvalue decomposition cannot
79       * be performed on the provided covariance matrix.
80       * @throws MathIllegalArgumentException if any of the eigenvalues is
81       * negative.
82       */
83      public MultivariateNormalDistribution(final double[] means,
84                                            final double[][] covariances)
85          throws MathIllegalArgumentException {
86          this(means, covariances, DEFAULT_TOLERANCE);
87      }
88  
89      /**
90       * Creates a multivariate normal distribution with the given mean vector and
91       * covariance matrix.<br>
92       * The number of dimensions is equal to the length of the mean vector
93       * and to the number of rows and columns of the covariance matrix.
94       * It is frequently written as "p" in formulae.
95       * <p>
96       * <b>Note:</b> this constructor will implicitly create an instance of
97       * {@link Well19937c} as random generator to be used for sampling only (see
98       * {@link #sample()} and {@link #sample(int)}). In case no sampling is
99       * needed for the created distribution, it is advised to pass {@code null}
100      * as random generator via the appropriate constructors to avoid the
101      * additional initialisation overhead.
102      *
103      * @param means Vector of means.
104      * @param covariances Covariance matrix.
105      * @param singularMatrixCheckTolerance Tolerance used during the singular matrix check before inversion
106      * @throws MathIllegalArgumentException if the arrays length are
107      * inconsistent.
108      * @throws MathIllegalArgumentException if the eigenvalue decomposition cannot
109      * be performed on the provided covariance matrix.
110      * @throws MathIllegalArgumentException if any of the eigenvalues is
111      * negative.
112      */
113     public MultivariateNormalDistribution(final double[] means,
114                                           final double[][] covariances,
115                                           final double singularMatrixCheckTolerance)
116         throws MathIllegalArgumentException {
117         this(new Well19937c(), means, covariances, singularMatrixCheckTolerance);
118     }
119 
120 
121     /**
122      * Creates a multivariate normal distribution with the given mean vector and
123      * covariance matrix.
124      * <br>
125      * The number of dimensions is equal to the length of the mean vector
126      * and to the number of rows and columns of the covariance matrix.
127      * It is frequently written as "p" in formulae.
128      *
129      * @param rng Random Number Generator.
130      * @param means Vector of means.
131      * @param covariances Covariance matrix.
132      * @throws MathIllegalArgumentException if the arrays length are
133      * inconsistent.
134      * @throws MathIllegalArgumentException if the eigenvalue decomposition cannot
135      * be performed on the provided covariance matrix.
136      * @throws MathIllegalArgumentException if any of the eigenvalues is
137      * negative.
138      */
139     public MultivariateNormalDistribution(RandomGenerator rng,
140                                           final double[] means,
141                                           final double[][] covariances) {
142         this(rng, means, covariances, DEFAULT_TOLERANCE);
143     }
144 
145     /**
146      * Creates a multivariate normal distribution with the given mean vector and
147      * covariance matrix.
148      * <br>
149      * The number of dimensions is equal to the length of the mean vector
150      * and to the number of rows and columns of the covariance matrix.
151      * It is frequently written as "p" in formulae.
152      *
153      * @param rng Random Number Generator.
154      * @param means Vector of means.
155      * @param covariances Covariance matrix.
156      * @param singularMatrixCheckTolerance Tolerance used during the singular matrix check before inversion
157      * @throws MathIllegalArgumentException if the arrays length are
158      * inconsistent.
159      * @throws MathIllegalArgumentException if the eigenvalue decomposition cannot
160      * be performed on the provided covariance matrix.
161      * @throws MathIllegalArgumentException if any of the eigenvalues is
162      * negative.
163      */
164     public MultivariateNormalDistribution(RandomGenerator rng,
165                                           final double[] means,
166                                           final double[][] covariances,
167                                           final double singularMatrixCheckTolerance)
168             throws MathIllegalArgumentException {
169         super(rng, means.length);
170 
171         final int dim = means.length;
172 
173         if (covariances.length != dim) {
174             throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
175                                                    covariances.length, dim);
176         }
177 
178         for (int i = 0; i < dim; i++) {
179             if (dim != covariances[i].length) {
180                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
181                                                        covariances[i].length, dim);
182             }
183         }
184 
185         this.means = means.clone();
186         this.singularMatrixCheckTolerance = singularMatrixCheckTolerance;
187 
188         covarianceMatrix = new Array2DRowRealMatrix(covariances);
189 
190         // Covariance matrix eigen decomposition.
191         final EigenDecompositionSymmetric covMatDec =
192                         new EigenDecompositionSymmetric(covarianceMatrix, singularMatrixCheckTolerance, true);
193 
194         // Compute and store the inverse.
195         covarianceMatrixInverse = covMatDec.getSolver().getInverse();
196         // Compute and store the determinant.
197         covarianceMatrixDeterminant = covMatDec.getDeterminant();
198 
199         // Eigenvalues of the covariance matrix.
200         final double[] covMatEigenvalues = covMatDec.getEigenvalues();
201 
202         for (int i = 0; i < covMatEigenvalues.length; i++) {
203             if (covMatEigenvalues[i] < 0) {
204                 throw new MathIllegalArgumentException(LocalizedCoreFormats.NOT_POSITIVE_DEFINITE_MATRIX);
205             }
206         }
207 
208         // Matrix where each column is an eigenvector of the covariance matrix.
209         final Array2DRowRealMatrix covMatEigenvectors = new Array2DRowRealMatrix(dim, dim);
210         for (int v = 0; v < dim; v++) {
211             final double[] evec = covMatDec.getEigenvector(v).toArray();
212             covMatEigenvectors.setColumn(v, evec);
213         }
214 
215         final RealMatrix tmpMatrix = covMatEigenvectors.transpose();
216 
217         // Scale each eigenvector by the square root of its eigenvalue.
218         for (int row = 0; row < dim; row++) {
219             final double factor = FastMath.sqrt(covMatEigenvalues[row]);
220             for (int col = 0; col < dim; col++) {
221                 tmpMatrix.multiplyEntry(row, col, factor);
222             }
223         }
224 
225         samplingMatrix = covMatEigenvectors.multiply(tmpMatrix);
226     }
227 
228     /**
229      * Gets the mean vector.
230      *
231      * @return the mean vector.
232      */
233     public double[] getMeans() {
234         return means.clone();
235     }
236 
237     /**
238      * Gets the covariance matrix.
239      *
240      * @return the covariance matrix.
241      */
242     public RealMatrix getCovariances() {
243         return covarianceMatrix.copy();
244     }
245 
246     /**
247      * Gets the current setting for the tolerance check used during singular checks before inversion
248      * @return tolerance
249      */
250     public double getSingularMatrixCheckTolerance() { return singularMatrixCheckTolerance; }
251 
252     /** {@inheritDoc} */
253     @Override
254     public double density(final double[] vals) throws MathIllegalArgumentException {
255         final int dim = getDimension();
256         if (vals.length != dim) {
257             throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
258                                                    vals.length, dim);
259         }
260 
261         return FastMath.pow(2 * FastMath.PI, -0.5 * dim) *
262             FastMath.pow(covarianceMatrixDeterminant, -0.5) *
263             getExponentTerm(vals);
264     }
265 
266     /**
267      * Gets the square root of each element on the diagonal of the covariance
268      * matrix.
269      *
270      * @return the standard deviations.
271      */
272     public double[] getStandardDeviations() {
273         final int dim = getDimension();
274         final double[] std = new double[dim];
275         final double[][] s = covarianceMatrix.getData();
276         for (int i = 0; i < dim; i++) {
277             std[i] = FastMath.sqrt(s[i][i]);
278         }
279         return std;
280     }
281 
282     /** {@inheritDoc} */
283     @Override
284     public double[] sample() {
285         final int dim = getDimension();
286         final double[] normalVals = new double[dim];
287 
288         for (int i = 0; i < dim; i++) {
289             normalVals[i] = random.nextGaussian();
290         }
291 
292         final double[] vals = samplingMatrix.operate(normalVals);
293 
294         for (int i = 0; i < dim; i++) {
295             vals[i] += means[i];
296         }
297 
298         return vals;
299     }
300 
301     /**
302      * Computes the term used in the exponent (see definition of the distribution).
303      *
304      * @param values Values at which to compute density.
305      * @return the multiplication factor of density calculations.
306      */
307     private double getExponentTerm(final double[] values) {
308         final double[] centered = new double[values.length];
309         for (int i = 0; i < centered.length; i++) {
310             centered[i] = values[i] - getMeans()[i];
311         }
312         final double[] preMultiplied = covarianceMatrixInverse.preMultiply(centered);
313         double sum = 0;
314         for (int i = 0; i < preMultiplied.length; i++) {
315             sum += preMultiplied[i] * centered[i];
316         }
317         return FastMath.exp(-0.5 * sum);
318     }
319 }