1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.distribution.multivariate;
23
24 import java.util.ArrayList;
25 import java.util.List;
26
27 import org.hipparchus.distribution.MultivariateRealDistribution;
28 import org.hipparchus.exception.LocalizedCoreFormats;
29 import org.hipparchus.exception.MathIllegalArgumentException;
30 import org.hipparchus.exception.MathRuntimeException;
31 import org.hipparchus.random.RandomGenerator;
32 import org.hipparchus.random.Well19937c;
33 import org.hipparchus.util.Pair;
34
35
36
37
38
39
40
41 public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
42 extends AbstractMultivariateRealDistribution {
43
44 private final double[] weight;
45
46 private final List<T> distribution;
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61 public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
62 this(new Well19937c(), components);
63 }
64
65
66
67
68
69
70
71
72
73
74
75 public MixtureMultivariateRealDistribution(RandomGenerator rng,
76 List<Pair<Double, T>> components) {
77 super(rng, components.get(0).getSecond().getDimension());
78
79 final int numComp = components.size();
80 final int dim = getDimension();
81 double weightSum = 0;
82 for (final Pair<Double, T> comp : components) {
83 if (comp.getSecond().getDimension() != dim) {
84 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
85 comp.getSecond().getDimension(), dim);
86 }
87 if (comp.getFirst() < 0) {
88 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, comp.getFirst(), 0);
89 }
90 weightSum += comp.getFirst();
91 }
92
93
94 if (Double.isInfinite(weightSum)) {
95 throw new MathRuntimeException(LocalizedCoreFormats.OVERFLOW);
96 }
97
98
99 distribution = new ArrayList<>();
100 weight = new double[numComp];
101 for (int i = 0; i < numComp; i++) {
102 final Pair<Double, T> comp = components.get(i);
103 weight[i] = comp.getFirst() / weightSum;
104 distribution.add(comp.getSecond());
105 }
106 }
107
108
109 @Override
110 public double density(final double[] values) {
111 double p = 0;
112 for (int i = 0; i < weight.length; i++) {
113 p += weight[i] * distribution.get(i).density(values);
114 }
115 return p;
116 }
117
118
119 @Override
120 public double[] sample() {
121
122 double[] vals = null;
123
124
125 final double randomValue = random.nextDouble();
126 double sum = 0;
127
128 for (int i = 0; i < weight.length; i++) {
129 sum += weight[i];
130 if (randomValue <= sum) {
131
132 vals = distribution.get(i).sample();
133 break;
134 }
135 }
136
137 if (vals == null) {
138
139
140
141 vals = distribution.get(weight.length - 1).sample();
142 }
143
144 return vals;
145 }
146
147
148 @Override
149 public void reseedRandomGenerator(long seed) {
150
151
152 super.reseedRandomGenerator(seed);
153
154 for (int i = 0; i < distribution.size(); i++) {
155
156
157 distribution.get(i).reseedRandomGenerator(i + 1 + seed);
158 }
159 }
160
161
162
163
164
165
166 public List<Pair<Double, T>> getComponents() {
167 final List<Pair<Double, T>> list = new ArrayList<>(weight.length);
168
169 for (int i = 0; i < weight.length; i++) {
170 list.add(new Pair<>(weight[i], distribution.get(i)));
171 }
172
173 return list;
174 }
175 }