View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.distribution.multivariate;
23  
24  import java.util.ArrayList;
25  import java.util.List;
26  
27  import org.hipparchus.distribution.MultivariateRealDistribution;
28  import org.hipparchus.exception.LocalizedCoreFormats;
29  import org.hipparchus.exception.MathIllegalArgumentException;
30  import org.hipparchus.exception.MathRuntimeException;
31  import org.hipparchus.random.RandomGenerator;
32  import org.hipparchus.random.Well19937c;
33  import org.hipparchus.util.Pair;
34  
35  /**
36   * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
37   * mixture model</a> distributions.
38   *
39   * @param <T> Type of the mixture components.
40   */
41  public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
42      extends AbstractMultivariateRealDistribution {
43      /** Normalized weight of each mixture component. */
44      private final double[] weight;
45      /** Mixture components. */
46      private final List<T> distribution;
47  
48      /**
49       * Creates a mixture model from a list of distributions and their
50       * associated weights.
51       * <p>
52       * <b>Note:</b> this constructor will implicitly create an instance of
53       * {@link Well19937c} as random generator to be used for sampling only (see
54       * {@link #sample()} and {@link #sample(int)}). In case no sampling is
55       * needed for the created distribution, it is advised to pass {@code null}
56       * as random generator via the appropriate constructors to avoid the
57       * additional initialisation overhead.
58       *
59       * @param components List of (weight, distribution) pairs from which to sample.
60       */
61      public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
62          this(new Well19937c(), components);
63      }
64  
65      /**
66       * Creates a mixture model from a list of distributions and their
67       * associated weights.
68       *
69       * @param rng Random number generator.
70       * @param components Distributions from which to sample.
71       * @throws MathIllegalArgumentException if any of the weights is negative.
72       * @throws MathIllegalArgumentException if not all components have the same
73       * number of variables.
74       */
75      public MixtureMultivariateRealDistribution(RandomGenerator rng,
76                                                 List<Pair<Double, T>> components) {
77          super(rng, components.get(0).getSecond().getDimension());
78  
79          final int numComp = components.size();
80          final int dim = getDimension();
81          double weightSum = 0;
82          for (final Pair<Double, T> comp : components) {
83              if (comp.getSecond().getDimension() != dim) {
84                  throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
85                          comp.getSecond().getDimension(), dim);
86              }
87              if (comp.getFirst() < 0) {
88                  throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, comp.getFirst(), 0);
89              }
90              weightSum += comp.getFirst();
91          }
92  
93          // Check for overflow.
94          if (Double.isInfinite(weightSum)) {
95              throw new MathRuntimeException(LocalizedCoreFormats.OVERFLOW);
96          }
97  
98          // Store each distribution and its normalized weight.
99          distribution = new ArrayList<>();
100         weight = new double[numComp];
101         for (int i = 0; i < numComp; i++) {
102             final Pair<Double, T> comp = components.get(i);
103             weight[i] = comp.getFirst() / weightSum;
104             distribution.add(comp.getSecond());
105         }
106     }
107 
108     /** {@inheritDoc} */
109     @Override
110     public double density(final double[] values) {
111         double p = 0;
112         for (int i = 0; i < weight.length; i++) {
113             p += weight[i] * distribution.get(i).density(values);
114         }
115         return p;
116     }
117 
118     /** {@inheritDoc} */
119     @Override
120     public double[] sample() {
121         // Sampled values.
122         double[] vals = null;
123 
124         // Determine which component to sample from.
125         final double randomValue = random.nextDouble();
126         double sum = 0;
127 
128         for (int i = 0; i < weight.length; i++) {
129             sum += weight[i];
130             if (randomValue <= sum) {
131                 // pick model i
132                 vals = distribution.get(i).sample();
133                 break;
134             }
135         }
136 
137         if (vals == null) {
138             // This should never happen, but it ensures we won't return a null in
139             // case the loop above has some floating point inequality problem on
140             // the final iteration.
141             vals = distribution.get(weight.length - 1).sample();
142         }
143 
144         return vals;
145     }
146 
147     /** {@inheritDoc} */
148     @Override
149     public void reseedRandomGenerator(long seed) {
150         // Seed needs to be propagated to underlying components
151         // in order to maintain consistency between runs.
152         super.reseedRandomGenerator(seed);
153 
154         for (int i = 0; i < distribution.size(); i++) {
155             // Make each component's seed different in order to avoid
156             // using the same sequence of random numbers.
157             distribution.get(i).reseedRandomGenerator(i + 1 + seed);
158         }
159     }
160 
161     /**
162      * Gets the distributions that make up the mixture model.
163      *
164      * @return the component distributions and associated weights.
165      */
166     public List<Pair<Double, T>> getComponents() {
167         final List<Pair<Double, T>> list = new ArrayList<>(weight.length);
168 
169         for (int i = 0; i < weight.length; i++) {
170             list.add(new Pair<>(weight[i], distribution.get(i)));
171         }
172 
173         return list;
174     }
175 }