View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.distribution.multivariate;
23  
24  import java.util.ArrayList;
25  import java.util.List;
26  
27  import org.hipparchus.distribution.MultivariateRealDistribution;
28  import org.hipparchus.exception.LocalizedCoreFormats;
29  import org.hipparchus.exception.MathIllegalArgumentException;
30  import org.hipparchus.exception.MathRuntimeException;
31  import org.hipparchus.random.RandomGenerator;
32  import org.hipparchus.random.Well19937c;
33  import org.hipparchus.util.Pair;
34  
35  /**
36   * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
37   * mixture model</a> distributions.
38   *
39   * @param <T> Type of the mixture components.
40   */
41  public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
42      extends AbstractMultivariateRealDistribution {
43      /** Normalized weight of each mixture component. */
44      private final double[] weight;
45      /** Mixture components. */
46      private final List<T> distribution;
47  
48      /**
49       * Creates a mixture model from a list of distributions and their
50       * associated weights.
51       * <p>
52       * <b>Note:</b> this constructor will implicitly create an instance of
53       * {@link Well19937c} as random generator to be used for sampling only (see
54       * {@link #sample()} and {@link #sample(int)}). In case no sampling is
55       * needed for the created distribution, it is advised to pass {@code null}
56       * as random generator via the appropriate constructors to avoid the
57       * additional initialisation overhead.
58       *
59       * @param components List of (weight, distribution) pairs from which to sample.
60       */
61      public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
62          this(new Well19937c(), components);
63      }
64  
65      /**
66       * Creates a mixture model from a list of distributions and their
67       * associated weights.
68       *
69       * @param rng Random number generator.
70       * @param components Distributions from which to sample.
71       * @throws MathIllegalArgumentException if any of the weights is negative.
72       * @throws MathIllegalArgumentException if not all components have the same
73       * number of variables.
74       */
75      public MixtureMultivariateRealDistribution(RandomGenerator rng,
76                                                 List<Pair<Double, T>> components) {
77          super(rng, components.get(0).getSecond().getDimension());
78  
79          final int numComp = components.size();
80          final int dim = getDimension();
81          double weightSum = 0;
82          for (int i = 0; i < numComp; i++) {
83              final Pair<Double, T> comp = components.get(i);
84              if (comp.getSecond().getDimension() != dim) {
85                  throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
86                                                         comp.getSecond().getDimension(), dim);
87              }
88              if (comp.getFirst() < 0) {
89                  throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, comp.getFirst(), 0);
90              }
91              weightSum += comp.getFirst();
92          }
93  
94          // Check for overflow.
95          if (Double.isInfinite(weightSum)) {
96              throw new MathRuntimeException(LocalizedCoreFormats.OVERFLOW);
97          }
98  
99          // Store each distribution and its normalized weight.
100         distribution = new ArrayList<>();
101         weight = new double[numComp];
102         for (int i = 0; i < numComp; i++) {
103             final Pair<Double, T> comp = components.get(i);
104             weight[i] = comp.getFirst() / weightSum;
105             distribution.add(comp.getSecond());
106         }
107     }
108 
109     /** {@inheritDoc} */
110     @Override
111     public double density(final double[] values) {
112         double p = 0;
113         for (int i = 0; i < weight.length; i++) {
114             p += weight[i] * distribution.get(i).density(values);
115         }
116         return p;
117     }
118 
119     /** {@inheritDoc} */
120     @Override
121     public double[] sample() {
122         // Sampled values.
123         double[] vals = null;
124 
125         // Determine which component to sample from.
126         final double randomValue = random.nextDouble();
127         double sum = 0;
128 
129         for (int i = 0; i < weight.length; i++) {
130             sum += weight[i];
131             if (randomValue <= sum) {
132                 // pick model i
133                 vals = distribution.get(i).sample();
134                 break;
135             }
136         }
137 
138         if (vals == null) {
139             // This should never happen, but it ensures we won't return a null in
140             // case the loop above has some floating point inequality problem on
141             // the final iteration.
142             vals = distribution.get(weight.length - 1).sample();
143         }
144 
145         return vals;
146     }
147 
148     /** {@inheritDoc} */
149     @Override
150     public void reseedRandomGenerator(long seed) {
151         // Seed needs to be propagated to underlying components
152         // in order to maintain consistency between runs.
153         super.reseedRandomGenerator(seed);
154 
155         for (int i = 0; i < distribution.size(); i++) {
156             // Make each component's seed different in order to avoid
157             // using the same sequence of random numbers.
158             distribution.get(i).reseedRandomGenerator(i + 1 + seed);
159         }
160     }
161 
162     /**
163      * Gets the distributions that make up the mixture model.
164      *
165      * @return the component distributions and associated weights.
166      */
167     public List<Pair<Double, T>> getComponents() {
168         final List<Pair<Double, T>> list = new ArrayList<>(weight.length);
169 
170         for (int i = 0; i < weight.length; i++) {
171             list.add(new Pair<>(weight[i], distribution.get(i)));
172         }
173 
174         return list;
175     }
176 }