1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.clustering;
24
25 import org.hipparchus.clustering.distance.EuclideanDistance;
26 import org.hipparchus.exception.MathIllegalArgumentException;
27 import org.hipparchus.random.JDKRandomGenerator;
28 import org.hipparchus.random.RandomGenerator;
29 import org.junit.jupiter.api.BeforeEach;
30 import org.junit.jupiter.api.Test;
31
32 import java.util.ArrayList;
33 import java.util.Arrays;
34 import java.util.Collection;
35 import java.util.List;
36
37 import static org.junit.jupiter.api.Assertions.assertEquals;
38 import static org.junit.jupiter.api.Assertions.assertThrows;
39 import static org.junit.jupiter.api.Assertions.assertTrue;
40
41 class KMeansPlusPlusClustererTest {
42
43 private RandomGenerator random;
44
45 @BeforeEach
46 void setUp() {
47 random = new JDKRandomGenerator();
48 random.setSeed(1746432956321l);
49 }
50
51
52
53
54
55
56 @Test
57 void testPerformClusterAnalysisDegenerate() {
58 KMeansPlusPlusClusterer<DoublePoint> transformer =
59 new KMeansPlusPlusClusterer<DoublePoint>(1, 1);
60
61 DoublePoint[] points = new DoublePoint[] {
62 new DoublePoint(new int[] { 1959, 325100 }),
63 new DoublePoint(new int[] { 1960, 373200 }), };
64 List<? extends Cluster<DoublePoint>> clusters = transformer.cluster(Arrays.asList(points));
65 assertEquals(1, clusters.size());
66 assertEquals(2, (clusters.get(0).getPoints().size()));
67 DoublePoint pt1 = new DoublePoint(new int[] { 1959, 325100 });
68 DoublePoint pt2 = new DoublePoint(new int[] { 1960, 373200 });
69 assertTrue(clusters.get(0).getPoints().contains(pt1));
70 assertTrue(clusters.get(0).getPoints().contains(pt2));
71
72 }
73
74 @Test
75 void testCertainSpace() {
76 KMeansPlusPlusClusterer.EmptyClusterStrategy[] strategies = {
77 KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE,
78 KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_POINTS_NUMBER,
79 KMeansPlusPlusClusterer.EmptyClusterStrategy.FARTHEST_POINT
80 };
81 for (KMeansPlusPlusClusterer.EmptyClusterStrategy strategy : strategies) {
82 int numberOfVariables = 27;
83
84 int position1 = 1;
85 int position2 = position1 + numberOfVariables;
86 int position3 = position2 + numberOfVariables;
87 int position4 = position3 + numberOfVariables;
88
89 int multiplier = 1000000;
90
91 DoublePoint[] breakingPoints = new DoublePoint[numberOfVariables];
92
93 for (int i = 0; i < numberOfVariables; i++) {
94 int[] points = { position1, position2, position3, position4 };
95
96 for (int j = 0; j < points.length; j++) {
97 points[j] *= multiplier;
98 }
99 DoublePoint DoublePoint = new DoublePoint(points);
100 breakingPoints[i] = DoublePoint;
101 position1 += numberOfVariables;
102 position2 += numberOfVariables;
103 position3 += numberOfVariables;
104 position4 += numberOfVariables;
105 }
106
107 for (int n = 2; n < 27; ++n) {
108 KMeansPlusPlusClusterer<DoublePoint> transformer =
109 new KMeansPlusPlusClusterer<DoublePoint>(n, 100, new EuclideanDistance(), random, strategy);
110
111 List<? extends Cluster<DoublePoint>> clusters =
112 transformer.cluster(Arrays.asList(breakingPoints));
113
114 assertEquals(n, clusters.size());
115 int sum = 0;
116 for (Cluster<DoublePoint> cluster : clusters) {
117 sum += cluster.getPoints().size();
118 }
119 assertEquals(numberOfVariables, sum);
120 }
121 }
122
123 }
124
125
126
127
128
129 private class CloseDistance extends EuclideanDistance {
130 private static final long serialVersionUID = 1L;
131
132 @Override
133 public double compute(double[] a, double[] b) {
134 return super.compute(a, b) * 0.001;
135 }
136 }
137
138
139
140
141 @Test
142 void testSmallDistances() {
143
144
145 int[] repeatedArray = { 0 };
146 int[] uniqueArray = { 1 };
147 DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
148 DoublePoint uniquePoint = new DoublePoint(uniqueArray);
149
150 Collection<DoublePoint> points = new ArrayList<DoublePoint>();
151 final int NUM_REPEATED_POINTS = 10 * 1000;
152 for (int i = 0; i < NUM_REPEATED_POINTS; ++i) {
153 points.add(repeatedPoint);
154 }
155 points.add(uniquePoint);
156
157
158
159 final long RANDOM_SEED = 0;
160 final int NUM_CLUSTERS = 2;
161 final int NUM_ITERATIONS = 0;
162 random.setSeed(RANDOM_SEED);
163
164 KMeansPlusPlusClusterer<DoublePoint> clusterer =
165 new KMeansPlusPlusClusterer<DoublePoint>(NUM_CLUSTERS, NUM_ITERATIONS,
166 new CloseDistance(), random);
167 List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
168
169
170 boolean uniquePointIsCenter = false;
171 for (CentroidCluster<DoublePoint> cluster : clusters) {
172 if (cluster.getCenter().equals(uniquePoint)) {
173 uniquePointIsCenter = true;
174 }
175 }
176 assertTrue(uniquePointIsCenter);
177 }
178
179
180
181
182 @Test
183 void testPerformClusterAnalysisToManyClusters() {
184 assertThrows(MathIllegalArgumentException.class, () -> {
185 KMeansPlusPlusClusterer<DoublePoint> transformer =
186 new KMeansPlusPlusClusterer<DoublePoint>(3, 1, new EuclideanDistance(), random);
187
188 DoublePoint[] points = new DoublePoint[]{
189 new DoublePoint(new int[]{
190 1959, 325100
191 }), new DoublePoint(new int[]{
192 1960, 373200
193 })
194 };
195
196 transformer.cluster(Arrays.asList(points));
197
198 });
199
200 }
201
202 }