View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.samples;
23  
24  import java.awt.Color;
25  import java.awt.Dimension;
26  import java.awt.Graphics;
27  import java.awt.Graphics2D;
28  import java.awt.GridBagConstraints;
29  import java.awt.GridBagLayout;
30  import java.awt.Insets;
31  import java.awt.RenderingHints;
32  import java.awt.Shape;
33  import java.awt.geom.Ellipse2D;
34  import java.util.ArrayList;
35  import java.util.Arrays;
36  import java.util.Collections;
37  import java.util.List;
38  
39  import javax.swing.JComponent;
40  import javax.swing.JLabel;
41  
42  import org.hipparchus.clustering.CentroidCluster;
43  import org.hipparchus.clustering.Cluster;
44  import org.hipparchus.clustering.Clusterable;
45  import org.hipparchus.clustering.Clusterer;
46  import org.hipparchus.clustering.DBSCANClusterer;
47  import org.hipparchus.clustering.DoublePoint;
48  import org.hipparchus.clustering.FuzzyKMeansClusterer;
49  import org.hipparchus.clustering.KMeansPlusPlusClusterer;
50  import org.hipparchus.geometry.euclidean.twod.Vector2D;
51  import org.hipparchus.random.RandomAdaptor;
52  import org.hipparchus.random.RandomDataGenerator;
53  import org.hipparchus.random.RandomGenerator;
54  import org.hipparchus.random.SobolSequenceGenerator;
55  import org.hipparchus.random.Well19937c;
56  import org.hipparchus.samples.ExampleUtils.ExampleFrame;
57  import org.hipparchus.util.FastMath;
58  import org.hipparchus.util.Pair;
59  import org.hipparchus.util.SinCos;
60  
61  /**
62   * Plots clustering results for various algorithms and datasets.
63   * Based on
64   * <a href="http://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html">scikit learn</a>.
65   */
66  //CHECKSTYLE: stop HideUtilityClassConstructor
67  public class ClusterAlgorithmComparison {
68  
69      /** Empty constructor.
70       * <p>
71       * This constructor is not strictly necessary, but it prevents spurious
72       * javadoc warnings with JDK 18 and later.
73       * </p>
74       * @since 3.0
75       */
76      public ClusterAlgorithmComparison() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
77          // nothing to do
78      }
79  
80      /** Make circles patterns.
81       * @param samples number of points
82       * @param shuffle if true, shuffle points
83       * @param noise noise to add to points position
84       * @param factor reduction factor from outer to inner circle
85       * @param random generator to use
86       * @return circle patterns
87       */
88      public static List<Vector2D> makeCircles(int samples, boolean shuffle, double noise, double factor, final RandomGenerator random) {
89          if (factor < 0 || factor > 1) {
90              throw new IllegalArgumentException();
91          }
92  
93          List<Vector2D> points = new ArrayList<Vector2D>();
94          double range = 2.0 * FastMath.PI;
95          double step = range / (samples / 2.0 + 1);
96          for (double angle = 0; angle < range; angle += step) {
97              Vector2D outerCircle = buildVector(angle);
98              Vector2D innerCircle = outerCircle.scalarMultiply(factor);
99  
100             points.add(outerCircle.add(generateNoiseVector(random, noise)));
101             points.add(innerCircle.add(generateNoiseVector(random, noise)));
102         }
103 
104         if (shuffle) {
105             Collections.shuffle(points, new RandomAdaptor(random));
106         }
107 
108         return points;
109     }
110 
111     /** Make Moons patterns.
112      * @param samples number of points
113      * @param shuffle if true, shuffle points
114      * @param noise noise to add to points position
115      * @param random generator to use
116      * @return Moons patterns
117      */
118     public static List<Vector2D> makeMoons(int samples, boolean shuffle, double noise, RandomGenerator random) {
119 
120         int nSamplesOut = samples / 2;
121         int nSamplesIn = samples - nSamplesOut;
122 
123         List<Vector2D> points = new ArrayList<Vector2D>();
124         double range = FastMath.PI;
125         double step = range / (nSamplesOut / 2.0);
126         for (double angle = 0; angle < range; angle += step) {
127             Vector2D outerCircle = buildVector(angle);
128             points.add(outerCircle.add(generateNoiseVector(random, noise)));
129         }
130 
131         step = range / (nSamplesIn / 2.0);
132         for (double angle = 0; angle < range; angle += step) {
133             final SinCos sc = FastMath.sinCos(angle);
134             Vector2D innerCircle = new Vector2D(1 - sc.cos(), 1 - sc.sin() - 0.5);
135             points.add(innerCircle.add(generateNoiseVector(random, noise)));
136         }
137 
138         if (shuffle) {
139             Collections.shuffle(points, new RandomAdaptor(random));
140         }
141 
142         return points;
143     }
144 
145     /** Make blobs patterns.
146      * @param samples number of points
147      * @param centers number of centers
148      * @param clusterStd standard deviation of cluster
149      * @param min range min value
150      * @param max range max value
151      * @param shuffle if true, shuffle points
152      * @param random generator to use
153      * @return blobs patterns
154      */
155     public static List<Vector2D> makeBlobs(int samples, int centers, double clusterStd,
156                                            double min, double max, boolean shuffle, RandomGenerator random) {
157 
158         final RandomDataGenerator randomDataGenerator = RandomDataGenerator.of(random);
159         //NormalDistribution dist = new NormalDistribution(random, 0.0, clusterStd);
160 
161         double range = max - min;
162         Vector2D[] centerPoints = new Vector2D[centers];
163         for (int i = 0; i < centers; i++) {
164             double x = random.nextDouble() * range + min;
165             double y = random.nextDouble() * range + min;
166             centerPoints[i] = new Vector2D(x, y);
167         }
168 
169         int[] nSamplesPerCenter = new int[centers];
170         int count = samples / centers;
171         Arrays.fill(nSamplesPerCenter, count);
172 
173         for (int i = 0; i < samples % centers; i++) {
174             nSamplesPerCenter[i]++;
175         }
176 
177         List<Vector2D> points = new ArrayList<Vector2D>();
178         for (int i = 0; i < centers; i++) {
179             for (int j = 0; j < nSamplesPerCenter[i]; j++) {
180                 Vector2D point = new Vector2D(randomDataGenerator.nextNormal(0, clusterStd),
181                                               randomDataGenerator.nextNormal(0, clusterStd));
182                 points.add(point.add(centerPoints[i]));
183             }
184         }
185 
186         if (shuffle) {
187             Collections.shuffle(points, new RandomAdaptor(random));
188         }
189 
190         return points;
191     }
192 
193     /** Make Sobol patterns.
194      * @param samples number of points
195      * @return Moons patterns
196      */
197     public static List<Vector2D> makeSobol(int samples) {
198         SobolSequenceGenerator generator = new SobolSequenceGenerator(2);
199         generator.skipTo(999999);
200         List<Vector2D> points = new ArrayList<Vector2D>();
201         for (double i = 0; i < samples; i++) {
202             double[] vector = generator.nextVector();
203             vector[0] = vector[0] * 2 - 1;
204             vector[1] = vector[1] * 2 - 1;
205             Vector2D point = new Vector2D(vector);
206             points.add(point);
207         }
208 
209         return points;
210     }
211 
212     /** Generate a random vector.
213      * @param randomGenerator random generator to use
214      * @param noise noise level
215      * @return random vector
216      */
217     public static Vector2D generateNoiseVector(RandomGenerator randomGenerator, double noise) {
218         final RandomDataGenerator randomDataGenerator = RandomDataGenerator.of(randomGenerator);
219         return new Vector2D(randomDataGenerator.nextNormal(0, noise), randomDataGenerator.nextNormal(0, noise));
220     }
221 
222     /** Normolize points in a rectangular area
223      * @param input input points
224      * @param minX range min value in X
225      * @param maxX range max value in X
226      * @param minY range min value in Y
227      * @param maxY range max value in Y
228      * @return normalized points
229      */
230     public static List<DoublePoint> normalize(final List<Vector2D> input, double minX, double maxX, double minY, double maxY) {
231         double rangeX = maxX - minX;
232         double rangeY = maxY - minY;
233         List<DoublePoint> points = new ArrayList<DoublePoint>();
234         for (Vector2D p : input) {
235             double[] arr = p.toArray();
236             arr[0] = (arr[0] - minX) / rangeX * 2 - 1;
237             arr[1] = (arr[1] - minY) / rangeY * 2 - 1;
238             points.add(new DoublePoint(arr));
239         }
240         return points;
241     }
242 
243     /**
244      * Build the 2D vector corresponding to the given angle.
245      * @param alpha angle
246      * @return the corresponding 2D vector
247      */
248     private static Vector2D buildVector(final double alpha) {
249         final SinCos sc = FastMath.sinCos(alpha);
250         return new Vector2D(sc.cos(), sc.sin());
251     }
252 
253     /** Display frame. */
254     @SuppressWarnings("serial")
255     public static class Display extends ExampleFrame {
256 
257         /** Simple consructor. */
258         public Display() {
259             setTitle("Hipparchus: Cluster algorithm comparison");
260             setSize(800, 800);
261 
262             setLayout(new GridBagLayout());
263 
264             int nSamples = 1500;
265 
266             RandomGenerator rng = new Well19937c(0);
267             List<List<DoublePoint>> datasets = new ArrayList<List<DoublePoint>>();
268 
269             datasets.add(normalize(makeCircles(nSamples, true, 0.04, 0.5, rng), -1, 1, -1, 1));
270             datasets.add(normalize(makeMoons(nSamples, true, 0.04, rng), -1, 2, -1, 1));
271             datasets.add(normalize(makeBlobs(nSamples, 3, 1.0, -10, 10, true, rng), -12, 12, -12, 12));
272             datasets.add(normalize(makeSobol(nSamples), -1, 1, -1, 1));
273 
274             List<Pair<String, Clusterer<DoublePoint>>> algorithms = new ArrayList<Pair<String, Clusterer<DoublePoint>>>();
275 
276             algorithms.add(new Pair<String, Clusterer<DoublePoint>>("KMeans\n(k=2)", new KMeansPlusPlusClusterer<DoublePoint>(2)));
277             algorithms.add(new Pair<String, Clusterer<DoublePoint>>("KMeans\n(k=3)", new KMeansPlusPlusClusterer<DoublePoint>(3)));
278             algorithms.add(new Pair<String, Clusterer<DoublePoint>>("FuzzyKMeans\n(k=3, fuzzy=2)", new FuzzyKMeansClusterer<DoublePoint>(3, 2)));
279             algorithms.add(new Pair<String, Clusterer<DoublePoint>>("FuzzyKMeans\n(k=3, fuzzy=10)", new FuzzyKMeansClusterer<DoublePoint>(3, 10)));
280             algorithms.add(new Pair<String, Clusterer<DoublePoint>>("DBSCAN\n(eps=.1, min=3)", new DBSCANClusterer<DoublePoint>(0.1, 3)));
281 
282             GridBagConstraints c = new GridBagConstraints();
283             c.fill = GridBagConstraints.VERTICAL;
284             c.gridx = 0;
285             c.gridy = 0;
286             c.insets = new Insets(2, 2, 2, 2);
287 
288             for (Pair<String, Clusterer<DoublePoint>> pair : algorithms) {
289                 JLabel text = new JLabel("<html><body>" + pair.getFirst().replace("\n", "<br>"));
290                 add(text, c);
291                 c.gridx++;
292             }
293             c.gridy++;
294 
295             for (List<DoublePoint> dataset : datasets) {
296                 c.gridx = 0;
297                 for (Pair<String, Clusterer<DoublePoint>> pair : algorithms) {
298                     long start = System.currentTimeMillis();
299                     List<? extends Cluster<DoublePoint>> clusters = pair.getSecond().cluster(dataset);
300                     long end = System.currentTimeMillis();
301                     add(new ClusterPlot(clusters, end - start), c);
302                     c.gridx++;
303                 }
304                 c.gridy++;
305             }
306         }
307 
308     }
309 
310     /** Plot component. */
311     @SuppressWarnings("serial")
312     public static class ClusterPlot extends JComponent {
313 
314         /** Padding. */
315         private static final double PAD = 10;
316 
317         /** Clusters. */
318         private List<? extends Cluster<DoublePoint>> clusters;
319 
320         /** Duration of the computation. */
321         private long duration;
322 
323         /** Simple constructor.
324          * @param clusters clusters to plot
325          * @param duration duration of the computation
326          */
327         public ClusterPlot(final List<? extends Cluster<DoublePoint>> clusters, long duration) {
328             this.clusters = clusters;
329             this.duration = duration;
330         }
331 
332         @Override
333         protected void paintComponent(Graphics g) {
334             super.paintComponent(g);
335             Graphics2D g2 = (Graphics2D)g;
336             g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING,
337                                 RenderingHints.VALUE_ANTIALIAS_ON);
338 
339             int w = getWidth();
340             int h = getHeight();
341 
342             g2.clearRect(0, 0, w, h);
343 
344             g2.setPaint(Color.black);
345             g2.drawRect(0, 0, w - 1, h - 1);
346 
347             int index = 0;
348             Color[] colors = new Color[] { Color.red, Color.blue, Color.green.darker() };
349             for (Cluster<DoublePoint> cluster : clusters) {
350                 g2.setPaint(colors[index++]);
351                 for (DoublePoint point : cluster.getPoints()) {
352                     Clusterable p = transform(point, w, h);
353                     double[] arr = p.getPoint();
354                     g2.fill(new Ellipse2D.Double(arr[0] - 1, arr[1] - 1, 3, 3));
355                 }
356 
357                 if (cluster instanceof CentroidCluster) {
358                     Clusterable p = transform(((CentroidCluster<?>) cluster).getCenter(), w, h);
359                     double[] arr = p.getPoint();
360                     Shape s = new Ellipse2D.Double(arr[0] - 4, arr[1] - 4, 8, 8);
361                     g2.fill(s);
362                     g2.setPaint(Color.black);
363                     g2.draw(s);
364                 }
365             }
366 
367             g2.setPaint(Color.black);
368             g2.drawString(String.format("%.2f s", duration / 1e3), w - 40, h - 5);
369         }
370 
371         @Override
372         public Dimension getPreferredSize() {
373             return new Dimension(150, 150);
374         }
375 
376         private Clusterable transform(Clusterable point, int width, int height) {
377             double[] arr = point.getPoint();
378             return new DoublePoint(new double[] { PAD + (arr[0] + 1) / 2.0 * (width - 2 * PAD),
379                                                   height - PAD - (arr[1] + 1) / 2.0 * (height - 2 * PAD) });
380         }
381     }
382 
383     /** Example entry point.
384      * @param args arguments (not used)
385      */
386     public static void main(String[] args) {
387         ExampleUtils.showExampleFrame(new Display());
388     }
389 
390 }