View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus;
24  
25  import java.io.ByteArrayInputStream;
26  import java.io.ByteArrayOutputStream;
27  import java.io.IOException;
28  import java.io.ObjectInputStream;
29  import java.io.ObjectOutputStream;
30  import java.text.DecimalFormat;
31  import java.util.ArrayList;
32  import java.util.Arrays;
33  import java.util.HashMap;
34  import java.util.List;
35  import java.util.Map;
36  
37  import org.hipparchus.complex.Complex;
38  import org.hipparchus.complex.ComplexFormat;
39  import org.hipparchus.complex.FieldComplex;
40  import org.hipparchus.distribution.RealDistribution;
41  import org.hipparchus.distribution.continuous.ChiSquaredDistribution;
42  import org.hipparchus.linear.BlockRealMatrix;
43  import org.hipparchus.linear.FieldMatrix;
44  import org.hipparchus.linear.RealMatrix;
45  import org.hipparchus.linear.RealVector;
46  import org.hipparchus.util.Binary64;
47  import org.hipparchus.util.FastMath;
48  import org.hipparchus.util.Precision;
49  import org.junit.Assert;
50  
51  /**
52   */
53  public class UnitTestUtils {
54      /**
55       * Collection of static methods used in math unit tests.
56       */
57      private UnitTestUtils() {
58          super();
59      }
60  
61      /**
62       * Verifies that expected and actual are within delta, or are both NaN or
63       * infinities of the same sign.
64       */
65      public static void assertEquals(double expected, double actual, double delta) {
66          Assert.assertEquals(null, expected, actual, delta);
67      }
68  
69      /**
70       * Verifies that expected and actual are within delta, or are both NaN or
71       * infinities of the same sign.
72       */
73      public static void assertEquals(String msg, double expected, double actual, double delta) {
74          // check for NaN
75          if(Double.isNaN(expected)){
76              Assert.assertTrue("" + actual + " is not NaN.",
77                  Double.isNaN(actual));
78          } else {
79              Assert.assertEquals(msg, expected, actual, delta);
80          }
81      }
82  
83      /**
84       * Verifies that the two arguments are exactly the same, either
85       * both NaN or infinities of same sign, or identical floating point values.
86       */
87      public static void assertSame(double expected, double actual) {
88       Assert.assertEquals(expected, actual, 0);
89      }
90  
91      /**
92       * Verifies that real and imaginary parts of the two complex arguments
93       * are exactly the same.  Also ensures that NaN / infinite components match.
94       */
95      public static void assertSame(Complex expected, Complex actual) {
96          assertSame(expected.getRealPart(), actual.getRealPart());
97          assertSame(expected.getImaginaryPart(), actual.getImaginaryPart());
98      }
99  
100     /**
101      * Verifies that real and imaginary parts of the two complex arguments
102      * differ by at most delta.  Also ensures that NaN / infinite components match.
103      */
104     public static void assertEquals(Complex expected, Complex actual, double delta) {
105         Assert.assertEquals(expected.getRealPart(), actual.getRealPart(), delta);
106         Assert.assertEquals(expected.getImaginaryPart(), actual.getImaginaryPart(), delta);
107     }
108 
109     /**
110      * Verifies that real and imaginary parts of the two complex arguments
111      * differ by at most delta.  Also ensures that NaN / infinite components match.
112      */
113     public static void assertEquals(FieldComplex<Binary64> expected, FieldComplex<Binary64> actual, double delta) {
114         Assert.assertEquals(expected.getRealPart().getReal(), actual.getRealPart().getReal(), delta);
115         Assert.assertEquals(expected.getImaginaryPart().getReal(), actual.getImaginaryPart().getReal(), delta);
116     }
117 
118     /**
119      * Verifies that real and imaginary parts of the two complex arguments
120      * are exactly the same.  Also ensures that NaN / infinite components match.
121      */
122     public static void assertSame(FieldComplex<Binary64> expected, FieldComplex<Binary64> actual) {
123         assertSame(expected.getRealPart().getReal(), actual.getRealPart().getReal());
124         assertSame(expected.getImaginaryPart().getReal(), actual.getImaginaryPart().getReal());
125     }
126 
127     /**
128      * Verifies that two double arrays have equal entries, up to tolerance
129      */
130     public static void assertEquals(double expected[], double observed[], double tolerance) {
131         assertEquals("Array comparison failure", expected, observed, tolerance);
132     }
133 
134     /**
135      * Serializes an object to a bytes array and then recovers the object from the bytes array.
136      * Returns the deserialized object.
137      *
138      * @param o  object to serialize and recover
139      * @return  the recovered, deserialized object
140      */
141     public static Object serializeAndRecover(Object o) {
142         try {
143             // serialize the Object
144             ByteArrayOutputStream bos = new ByteArrayOutputStream();
145             ObjectOutputStream so = new ObjectOutputStream(bos);
146             so.writeObject(o);
147 
148             // deserialize the Object
149             ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
150             ObjectInputStream si = new ObjectInputStream(bis);
151             return si.readObject();
152         } catch (IOException ioe) {
153             return null;
154         } catch (ClassNotFoundException cnfe) {
155             return null;
156         }
157     }
158 
159     /**
160      * Verifies that serialization preserves equals and hashCode.
161      * Serializes the object, then recovers it and checks equals and hash code.
162      *
163      * @param object  the object to serialize and recover
164      */
165     public static void checkSerializedEquality(Object object) {
166         Object object2 = serializeAndRecover(object);
167         Assert.assertEquals("Equals check", object, object2);
168         Assert.assertEquals("HashCode check", object.hashCode(), object2.hashCode());
169     }
170 
171     /**
172      * Verifies that the relative error in actual vs. expected is less than or
173      * equal to relativeError.  If expected is infinite or NaN, actual must be
174      * the same (NaN or infinity of the same sign).
175      *
176      * @param expected expected value
177      * @param actual  observed value
178      * @param relativeError  maximum allowable relative error
179      */
180     public static void assertRelativelyEquals(double expected, double actual,
181             double relativeError) {
182         assertRelativelyEquals(null, expected, actual, relativeError);
183     }
184 
185     /**
186      * Verifies that the relative error in actual vs. expected is less than or
187      * equal to relativeError.  If expected is infinite or NaN, actual must be
188      * the same (NaN or infinity of the same sign).
189      *
190      * @param msg  message to return with failure
191      * @param expected expected value
192      * @param actual  observed value
193      * @param relativeError  maximum allowable relative error
194      */
195     public static void assertRelativelyEquals(String msg, double expected,
196             double actual, double relativeError) {
197         if (Double.isNaN(expected)) {
198             Assert.assertTrue(msg, Double.isNaN(actual));
199         } else if (Double.isNaN(actual)) {
200             Assert.assertTrue(msg, Double.isNaN(expected));
201         } else if (Double.isInfinite(actual) || Double.isInfinite(expected)) {
202             Assert.assertEquals(expected, actual, relativeError);
203         } else if (expected == 0.0) {
204             Assert.assertEquals(msg, actual, expected, relativeError);
205         } else {
206             double absError = FastMath.abs(expected) * relativeError;
207             Assert.assertEquals(msg, expected, actual, absError);
208         }
209     }
210 
211     /**
212      * Fails iff values does not contain a number within epsilon of z.
213      *
214      * @param msg  message to return with failure
215      * @param values complex array to search
216      * @param z  value sought
217      * @param epsilon  tolerance
218      */
219     public static void assertContains(String msg, Complex[] values,
220                                       Complex z, double epsilon) {
221         for (Complex value : values) {
222             if (Precision.equals(value.getReal(), z.getReal(), epsilon) &&
223                 Precision.equals(value.getImaginary(), z.getImaginary(), epsilon)) {
224                 return;
225             }
226         }
227         Assert.fail(msg + " Unable to find " + (new ComplexFormat()).format(z));
228     }
229 
230     /**
231      * Fails iff values does not contain a number within epsilon of z.
232      *
233      * @param values complex array to search
234      * @param z  value sought
235      * @param epsilon  tolerance
236      */
237     public static void assertContains(Complex[] values,
238             Complex z, double epsilon) {
239         assertContains(null, values, z, epsilon);
240     }
241 
242     /**
243      * Fails iff values does not contain a number within epsilon of x.
244      *
245      * @param msg  message to return with failure
246      * @param values double array to search
247      * @param x value sought
248      * @param epsilon  tolerance
249      */
250     public static void assertContains(String msg, double[] values,
251             double x, double epsilon) {
252         for (double value : values) {
253             if (Precision.equals(value, x, epsilon)) {
254                 return;
255             }
256         }
257         Assert.fail(msg + " Unable to find " + x);
258     }
259 
260     /**
261      * Fails iff values does not contain a number within epsilon of x.
262      *
263      * @param values double array to search
264      * @param x value sought
265      * @param epsilon  tolerance
266      */
267     public static void assertContains(double[] values, double x,
268             double epsilon) {
269        assertContains(null, values, x, epsilon);
270     }
271 
272     /**
273      * Asserts that all entries of the specified vectors are equal to within a
274      * positive {@code delta}.
275      *
276      * @param message the identifying message for the assertion error (can be
277      * {@code null})
278      * @param expected expected value
279      * @param actual actual value
280      * @param delta the maximum difference between the entries of the expected
281      * and actual vectors for which both entries are still considered equal
282      */
283     public static void assertEquals(final String message,
284         final double[] expected, final RealVector actual, final double delta) {
285         final String msgAndSep = message.equals("") ? "" : message + ", ";
286         Assert.assertEquals(msgAndSep + "dimension", expected.length,
287             actual.getDimension());
288         for (int i = 0; i < expected.length; i++) {
289             Assert.assertEquals(msgAndSep + "entry #" + i, expected[i],
290                 actual.getEntry(i), delta);
291         }
292     }
293 
294     /**
295      * Asserts that all entries of the specified vectors are equal to within a
296      * positive {@code delta}.
297      *
298      * @param message the identifying message for the assertion error (can be
299      * {@code null})
300      * @param expected expected value
301      * @param actual actual value
302      * @param delta the maximum difference between the entries of the expected
303      * and actual vectors for which both entries are still considered equal
304      */
305     public static void assertEquals(final String message,
306         final RealVector expected, final RealVector actual, final double delta) {
307         final String msgAndSep = message.equals("") ? "" : message + ", ";
308         Assert.assertEquals(msgAndSep + "dimension", expected.getDimension(),
309             actual.getDimension());
310         final int dim = expected.getDimension();
311         for (int i = 0; i < dim; i++) {
312             Assert.assertEquals(msgAndSep + "entry #" + i,
313                 expected.getEntry(i), actual.getEntry(i), delta);
314         }
315     }
316 
317     /** verifies that two matrices are close (1-norm) */
318     public static void assertEquals(String msg, RealMatrix expected, RealMatrix observed, double tolerance) {
319 
320         Assert.assertNotNull(msg + "\nObserved should not be null",observed);
321 
322         if (expected.getColumnDimension() != observed.getColumnDimension() ||
323                 expected.getRowDimension() != observed.getRowDimension()) {
324             StringBuilder messageBuffer = new StringBuilder(msg);
325             messageBuffer.append("\nObserved has incorrect dimensions.");
326             messageBuffer.append("\nobserved is " + observed.getRowDimension() +
327                     " x " + observed.getColumnDimension());
328             messageBuffer.append("\nexpected " + expected.getRowDimension() +
329                     " x " + expected.getColumnDimension());
330             Assert.fail(messageBuffer.toString());
331         }
332 
333         RealMatrix delta = expected.subtract(observed);
334         if (delta.getNorm1() >= tolerance) {
335             StringBuilder messageBuffer = new StringBuilder(msg);
336             messageBuffer.append("\nExpected: " + expected);
337             messageBuffer.append("\nObserved: " + observed);
338             messageBuffer.append("\nexpected - observed: " + delta);
339             Assert.fail(messageBuffer.toString());
340         }
341     }
342 
343     /** verifies that two matrices are equal */
344     public static void assertEquals(FieldMatrix<? extends FieldElement<?>> expected,
345                                     FieldMatrix<? extends FieldElement<?>> observed) {
346 
347         Assert.assertNotNull("Observed should not be null",observed);
348 
349         if (expected.getColumnDimension() != observed.getColumnDimension() ||
350                 expected.getRowDimension() != observed.getRowDimension()) {
351             StringBuilder messageBuffer = new StringBuilder();
352             messageBuffer.append("Observed has incorrect dimensions.");
353             messageBuffer.append("\nobserved is " + observed.getRowDimension() +
354                     " x " + observed.getColumnDimension());
355             messageBuffer.append("\nexpected " + expected.getRowDimension() +
356                     " x " + expected.getColumnDimension());
357             Assert.fail(messageBuffer.toString());
358         }
359 
360         for (int i = 0; i < expected.getRowDimension(); ++i) {
361             for (int j = 0; j < expected.getColumnDimension(); ++j) {
362                 FieldElement<?> eij = expected.getEntry(i, j);
363                 FieldElement<?> oij = observed.getEntry(i, j);
364                 Assert.assertEquals(eij, oij);
365             }
366         }
367     }
368 
369     /** verifies that two arrays are close (sup norm) */
370     public static void assertEquals(String msg, double[] expected, double[] observed, double tolerance) {
371         StringBuilder out = new StringBuilder(msg);
372         if (expected.length != observed.length) {
373             out.append("\n Arrays not same length. \n");
374             out.append("expected has length ");
375             out.append(expected.length);
376             out.append(" observed length = ");
377             out.append(observed.length);
378             Assert.fail(out.toString());
379         }
380         boolean failure = false;
381         for (int i=0; i < expected.length; i++) {
382             if (!Precision.equalsIncludingNaN(expected[i], observed[i], tolerance)) {
383                 failure = true;
384                 out.append("\n Elements at index ");
385                 out.append(i);
386                 out.append(" differ. ");
387                 out.append(" expected = ");
388                 out.append(expected[i]);
389                 out.append(" observed = ");
390                 out.append(observed[i]);
391             }
392         }
393         if (failure) {
394             Assert.fail(out.toString());
395         }
396     }
397 
398     /** verifies that two int arrays are equal */
399     public static void assertEquals(int[] expected, int[] observed) {
400         StringBuilder out = new StringBuilder();
401         if (expected.length != observed.length) {
402             out.append("\n Arrays not same length. \n");
403             out.append("expected has length ");
404             out.append(expected.length);
405             out.append(" observed length = ");
406             out.append(observed.length);
407             Assert.fail(out.toString());
408         }
409         boolean failure = false;
410         for (int i=0; i < expected.length; i++) {
411             if (expected[i] != observed[i]) {
412                 failure = true;
413                 out.append("\n Elements at index ");
414                 out.append(i);
415                 out.append(" differ. ");
416                 out.append(" expected = ");
417                 out.append(expected[i]);
418                 out.append(" observed = ");
419                 out.append(observed[i]);
420             }
421         }
422         if (failure) {
423             Assert.fail(out.toString());
424         }
425     }
426 
427     /**
428      * verifies that for i = 0,..., observed.length, observed[i] is within epsilon of one of the values in expected[i]
429      * or observed[i] is NaN and expected[i] contains a NaN.
430      */
431     public static void assertContains(double[][] expected, double[] observed, double epsilon) {
432         StringBuilder out = new StringBuilder();
433         if (expected.length != observed.length) {
434             out.append("\n Arrays not same length. \n");
435             out.append("expected has length ");
436             out.append(expected.length);
437             out.append(" observed length = ");
438             out.append(observed.length);
439             Assert.fail(out.toString());
440         }
441         boolean failure = false;
442         for (int i = 0; i < expected.length; i++) {
443             boolean found = false;
444             for (int j = 0; j < expected[i].length; j++) {
445                 if (Precision.equalsIncludingNaN(expected[i][j], observed[i], epsilon)) {
446                     found = true;
447                     break;
448                 }
449             }
450             if (!found) {
451                 out.append("\n Observed element at index ");
452                 out.append(i);
453                 out.append(" is not among the expected values. ");
454                 out.append(" expected = " + Arrays.toString(expected[i]));
455                 out.append(" observed = ");
456                 out.append(observed[i]);
457                 failure = true;
458             }
459         }
460         if (failure) {
461             Assert.fail(out.toString());
462         }
463     }
464 
465 
466 
467     /** verifies that two int arrays are equal */
468     public static void assertEquals(long[] expected, long[] observed) {
469         StringBuilder out = new StringBuilder();
470         if (expected.length != observed.length) {
471             out.append("\n Arrays not same length. \n");
472             out.append("expected has length ");
473             out.append(expected.length);
474             out.append(" observed length = ");
475             out.append(observed.length);
476             Assert.fail(out.toString());
477         }
478         boolean failure = false;
479         for (int i=0; i < expected.length; i++) {
480             if (expected[i] != observed[i]) {
481                 failure = true;
482                 out.append("\n Elements at index ");
483                 out.append(i);
484                 out.append(" differ. ");
485                 out.append(" expected = ");
486                 out.append(expected[i]);
487                 out.append(" observed = ");
488                 out.append(observed[i]);
489             }
490         }
491         if (failure) {
492             Assert.fail(out.toString());
493         }
494     }
495 
496     /** verifies that two arrays are equal */
497     public static <T extends FieldElement<T>> void assertEquals(T[] m, T[] n) {
498         if (m.length != n.length) {
499             Assert.fail("vectors not same length");
500         }
501         for (int i = 0; i < m.length; i++) {
502             Assert.assertEquals(m[i],n[i]);
503         }
504     }
505 
506     /**
507      * Computes the sum of squared deviations of <values> from <target>
508      * @param values array of deviates
509      * @param target value to compute deviations from
510      *
511      * @return sum of squared deviations
512      */
513     public static double sumSquareDev(double[] values, double target) {
514         double sumsq = 0d;
515         for (int i = 0; i < values.length; i++) {
516             final double dev = values[i] - target;
517             sumsq += (dev * dev);
518         }
519         return sumsq;
520     }
521 
522     /**
523      * Asserts the null hypothesis for a ChiSquare test.  Fails and dumps arguments and test
524      * statistics if the null hypothesis can be rejected with confidence 100 * (1 - alpha)%
525      *
526      * @param valueLabels labels for the values of the discrete distribution under test
527      * @param expected expected counts
528      * @param observed observed counts
529      * @param alpha significance level of the test
530      */
531     public static void assertChiSquareAccept(String[] valueLabels, double[] expected, long[] observed, double alpha) {
532 
533         // Fail if we can reject null hypothesis that distributions are the same
534         if (chiSquareTest(expected, observed) <= alpha) {
535             StringBuilder msgBuffer = new StringBuilder();
536             DecimalFormat df = new DecimalFormat("#.##");
537             msgBuffer.append("Chisquare test failed");
538             msgBuffer.append(" p-value = ");
539             msgBuffer.append(chiSquareTest(expected, observed));
540             msgBuffer.append(" chisquare statistic = ");
541             msgBuffer.append(chiSquare(expected, observed));
542             msgBuffer.append(". \n");
543             msgBuffer.append("value\texpected\tobserved\n");
544             for (int i = 0; i < expected.length; i++) {
545                 msgBuffer.append(valueLabels[i]);
546                 msgBuffer.append("\t");
547                 msgBuffer.append(df.format(expected[i]));
548                 msgBuffer.append("\t\t");
549                 msgBuffer.append(observed[i]);
550                 msgBuffer.append("\n");
551             }
552             msgBuffer.append("This test can fail randomly due to sampling error with probability ");
553             msgBuffer.append(alpha);
554             msgBuffer.append(".");
555             Assert.fail(msgBuffer.toString());
556         }
557     }
558 
559     /**
560      * Asserts the null hypothesis for a ChiSquare test.  Fails and dumps arguments and test
561      * statistics if the null hypothesis can be rejected with confidence 100 * (1 - alpha)%
562      *
563      * @param values integer values whose observed and expected counts are being compared
564      * @param expected expected counts
565      * @param observed observed counts
566      * @param alpha significance level of the test
567      */
568     public static void assertChiSquareAccept(int[] values, double[] expected, long[] observed, double alpha) {
569         String[] labels = new String[values.length];
570         for (int i = 0; i < values.length; i++) {
571             labels[i] = Integer.toString(values[i]);
572         }
573         assertChiSquareAccept(labels, expected, observed, alpha);
574     }
575 
576     /**
577      * Asserts the null hypothesis for a ChiSquare test.  Fails and dumps arguments and test
578      * statistics if the null hypothesis can be rejected with confidence 100 * (1 - alpha)%
579      *
580      * @param expected expected counts
581      * @param observed observed counts
582      * @param alpha significance level of the test
583      */
584     public static void assertChiSquareAccept(double[] expected, long[] observed, double alpha) {
585         String[] labels = new String[expected.length];
586         for (int i = 0; i < labels.length; i++) {
587             labels[i] = Integer.toString(i + 1);
588         }
589         assertChiSquareAccept(labels, expected, observed, alpha);
590     }
591 
592     /**
593      * Asserts the null hypothesis that the sample follows the given distribution, using a G-test
594      *
595      * @param expectedDistribution distribution values are supposed to follow
596      * @param values sample data
597      * @param alpha significance level of the test
598      */
599     public static void assertGTest(final RealDistribution expectedDistribution, final double[] values, double alpha) {
600         final int numBins = values.length / 30;
601         final double[] breaks = new double[numBins];
602         for (int b = 0; b < breaks.length; b++) {
603             breaks[b] = expectedDistribution.inverseCumulativeProbability((double) b / numBins);
604         }
605 
606         final long[] observed = new long[numBins];
607         for (final double value : values) {
608             int b = 0;
609             do {
610                 b++;
611             } while (b < numBins && value >= breaks[b]);
612 
613             observed[b - 1]++;
614         }
615 
616         final double[] expected = new double[numBins];
617         Arrays.fill(expected, (double) values.length / numBins);
618 
619         assertGTest(expected, observed, alpha);
620     }
621 
622     /**
623      * Asserts the null hypothesis that the observed counts follow the given distribution implied by expected,
624      * using a G-test
625      *
626      * @param expected expected counts
627      * @param observed observed counts
628      * @param alpha significance level of the test
629      */
630     public static void assertGTest(final double[] expected, long[] observed, double alpha) {
631         if (gTest(expected, observed) <  alpha) {
632             StringBuilder msgBuffer = new StringBuilder();
633             DecimalFormat df = new DecimalFormat("#.##");
634             msgBuffer.append("G test failed");
635             msgBuffer.append(" p-value = ");
636             msgBuffer.append(gTest(expected, observed));
637             msgBuffer.append(". \n");
638             msgBuffer.append("value\texpected\tobserved\n");
639             for (int i = 0; i < expected.length; i++) {
640                 msgBuffer.append(df.format(expected[i]));
641                 msgBuffer.append("\t\t");
642                 msgBuffer.append(observed[i]);
643                 msgBuffer.append("\n");
644             }
645             msgBuffer.append("This test can fail randomly due to sampling error with probability ");
646             msgBuffer.append(alpha);
647             msgBuffer.append(".");
648             Assert.fail(msgBuffer.toString());
649         }
650     }
651 
652     /**
653      * Computes the 25th, 50th and 75th percentiles of the given distribution and returns
654      * these values in an array.
655      */
656     public static double[] getDistributionQuartiles(RealDistribution distribution) {
657         double[] quantiles = new double[3];
658         quantiles[0] = distribution.inverseCumulativeProbability(0.25d);
659         quantiles[1] = distribution.inverseCumulativeProbability(0.5d);
660         quantiles[2] = distribution.inverseCumulativeProbability(0.75d);
661         return quantiles;
662     }
663 
664     /**
665      * Updates observed counts of values in quartiles.
666      * counts[0] ↔ 1st quartile ... counts[3] ↔ top quartile
667      */
668     public static void updateCounts(double value, long[] counts, double[] quartiles) {
669         if (value < quartiles[0]) {
670             counts[0]++;
671         } else if (value > quartiles[2]) {
672             counts[3]++;
673         } else if (value > quartiles[1]) {
674             counts[2]++;
675         } else {
676             counts[1]++;
677         }
678     }
679 
680     /**
681      * Eliminates points with zero mass from densityPoints and densityValues parallel
682      * arrays.  Returns the number of positive mass points and collapses the arrays so
683      * that the first <returned value> elements of the input arrays represent the positive
684      * mass points.
685      */
686     public static int eliminateZeroMassPoints(int[] densityPoints, double[] densityValues) {
687         int positiveMassCount = 0;
688         for (int i = 0; i < densityValues.length; i++) {
689             if (densityValues[i] > 0) {
690                 positiveMassCount++;
691             }
692         }
693         if (positiveMassCount < densityValues.length) {
694             int[] newPoints = new int[positiveMassCount];
695             double[] newValues = new double[positiveMassCount];
696             int j = 0;
697             for (int i = 0; i < densityValues.length; i++) {
698                 if (densityValues[i] > 0) {
699                     newPoints[j] = densityPoints[i];
700                     newValues[j] = densityValues[i];
701                     j++;
702                 }
703             }
704             System.arraycopy(newPoints,0,densityPoints,0,positiveMassCount);
705             System.arraycopy(newValues,0,densityValues,0,positiveMassCount);
706         }
707         return positiveMassCount;
708     }
709 
710     /*************************************************************************************
711      * Stripped-down implementations of some basic statistics borrowed from hipparchus-stat.
712      * NOTE: These implementations are NOT intended for reuse.  They are neither robust,
713      * nor efficient; nor do they handle NaN, infinity or other corner cases in
714      * a predictable way. They DO NOT CHECK PARAMETERS - the assumption is that incorrect
715      * or meaningless results from bad parameters will trigger test failures in unit
716      * tests using these methods.
717      ************************************************************************************/
718 
719     /**
720      * Returns p-value associated with null hypothesis that observed counts follow
721      * expected distribution.  Will normalize inputs if necessary.
722      *
723      * @param expected expected counts
724      * @param observed observed counts
725      * @return p-value of Chi-square test
726      */
727     public static double chiSquareTest(final double[] expected, final long[] observed) {
728             final org.hipparchus.distribution.continuous.ChiSquaredDistribution distribution =
729                 new ChiSquaredDistribution(expected.length - 1.0);
730             return 1.0 - distribution.cumulativeProbability(chiSquare(expected, observed));
731     }
732 
733     /**
734      * Returns chi-square test statistic for expected and observed arrays. Rescales arrays
735      * if necessary.
736      *
737      * @param expected expected counts
738      * @param observed observed counts
739      * @return chi-square statistic
740      */
741     public static double chiSquare(final double[] expected, final long[] observed) {
742             double sumExpected = 0d;
743             double sumObserved = 0d;
744             for (int i = 0; i < observed.length; i++) {
745                 sumExpected += expected[i];
746                 sumObserved += observed[i];
747             }
748             double ratio = 1.0d;
749             boolean rescale = false;
750             if (FastMath.abs(sumExpected - sumObserved) > 10E-6) {
751                 ratio = sumObserved / sumExpected;
752                 rescale = true;
753             }
754             double sumSq = 0.0d;
755             for (int i = 0; i < observed.length; i++) {
756                 if (rescale) {
757                     final double dev = observed[i] - ratio * expected[i];
758                     sumSq += dev * dev / (ratio * expected[i]);
759                 } else {
760                     final double dev = observed[i] - expected[i];
761                     sumSq += dev * dev / expected[i];
762                 }
763             }
764             return sumSq;
765 
766         }
767 
768     /**
769      * Computes G-test statistic for expected, observed counts.
770      * @param expected expected counts
771      * @param observed observed counts
772      * @return G statistic
773      */
774     private static double g(final double[] expected, final long[] observed) {
775         double sumExpected = 0d;
776         double sumObserved = 0d;
777         for (int i = 0; i < observed.length; i++) {
778             sumExpected += expected[i];
779             sumObserved += observed[i];
780         }
781         double ratio = 1d;
782         boolean rescale = false;
783         if (FastMath.abs(sumExpected - sumObserved) > 10E-6) {
784             ratio = sumObserved / sumExpected;
785             rescale = true;
786         }
787         double sum = 0d;
788         for (int i = 0; i < observed.length; i++) {
789             final double dev = rescale ?
790                     FastMath.log(observed[i] / (ratio * expected[i])) :
791                         FastMath.log(observed[i] / expected[i]);
792             sum += (observed[i]) * dev;
793         }
794         return 2d * sum;
795     }
796 
797     /**
798      * Computes p-value for G-test.
799      *
800      * @param expected expected counts
801      * @param observed observed counts
802      * @return p-value
803      */
804     private static double gTest(final double[] expected, final long[] observed) {
805         final ChiSquaredDistribution distribution =
806                 new ChiSquaredDistribution(expected.length - 1.0);
807         return 1.0 - distribution.cumulativeProbability(g(expected, observed));
808     }
809 
810     /**
811      * Computes the mean of the values in the array.
812      *
813      * @param values input values
814      * @return arithmetic mean
815      */
816     public static double mean(final double[] values) {
817         double sum = 0;
818         for (double val : values) {
819             sum += val;
820         }
821         return sum / values.length;
822     }
823 
824     /**
825      * Computes the (bias-adjusted) variance of the values in the input array.
826      *
827      * @param values input values
828      * @return bias-adjusted variance
829      */
830     public static double variance(final double[] values) {
831         final int length = values.length;
832         final double mean = mean(values);
833         double var = Double.NaN;
834         if (length == 1) {
835             var = 0.0;
836         } else if (length > 1) {
837             double accum = 0.0;
838             double dev = 0.0;
839             double accum2 = 0.0;
840             for (int i = 0; i < length; i++) {
841                 dev = values[i] - mean;
842                 accum += dev * dev;
843                 accum2 += dev;
844             }
845             final double len = length;
846             var = (accum - (accum2 * accum2 / len)) / (len - 1.0);
847         }
848         return var;
849     }
850 
851     /**
852      * Computes the standard deviation of the values in the input array.
853      *
854      * @param values input values
855      * @return standard deviation
856      */
857     public static double standardDeviation(final double[] values) {
858         return FastMath.sqrt(variance(values));
859     }
860 
861     /**
862      * Computes the median of the values in the input array.
863      *
864      * @param values input values
865      * @return estimated median
866      */
867     public static double median(final double[] values) {
868         final int len = values.length;
869         final double[] sortedValues = Arrays.copyOf(values, len);
870         Arrays.sort(sortedValues);
871         if (len % 2 == 0) {
872             return ((double)sortedValues[len/2] + (double)sortedValues[len/2 - 1])/2;
873         } else {
874             return (double) sortedValues[len/2];
875         }
876     }
877 
878     /**
879      * Computes the covariance of the two input arrays.
880      *
881      * @param xArray first covariate
882      * @param yArray second covariate
883      * @return covariance
884      */
885     public static double covariance(final double[] xArray, final double[] yArray) {
886         double result = 0d;
887         final int length = xArray.length;
888         final double xMean = mean(xArray);
889         final double yMean = mean(yArray);
890         for (int i = 0; i < length; i++) {
891             final double xDev = xArray[i] - xMean;
892             final double yDev = yArray[i] - yMean;
893             result += (xDev * yDev - result) / (i + 1);
894         }
895         return result * ((double) length / (double)(length - 1));
896     }
897 
898     /**
899      * Computes a covariance matrix from a matrix whose columns represent covariates.
900      *
901      * @param matrix input matrix
902      * @return covariance matrix
903      */
904     public static RealMatrix covarianceMatrix(RealMatrix matrix) {
905         int dimension = matrix.getColumnDimension();
906         final RealMatrix outMatrix = new BlockRealMatrix(dimension, dimension);
907         for (int i = 0; i < dimension; i++) {
908             for (int j = 0; j < i; j++) {
909                 final double cov = covariance(matrix.getColumn(i), matrix.getColumn(j));
910                 outMatrix.setEntry(i, j, cov);
911                 outMatrix.setEntry(j, i, cov);
912             }
913             outMatrix.setEntry(i, i, variance(matrix.getColumn(i)));
914         }
915         return outMatrix;
916     }
917 
918     public static double min(final double[] values) {
919         double min = values[0];
920         for (int i = 1; i < values.length; i++) {
921             if (values[i] < min) {
922                 min = values[i];
923             }
924         }
925         return min;
926     }
927 
928     /**
929      * Computes the maximum of the values in the input array.
930      *
931      * @param values input array
932      * @return the maximum value
933      */
934     public static double max(final double[] values) {
935         double max = values[0];
936         for (int i = 1; i < values.length; i++) {
937             if (values[i] > max) {
938                 max = values[i];
939             }
940         }
941         return max;
942     }
943 
944     /**
945      * Unpacks a list of Doubles into a double[].
946      *
947      * @param values list of Double
948      * @return double array
949      */
950     private static double[] unpack(List<Double> values) {
951         int n = values.size();
952         if (values == null || n == 0) {
953             return new double[] {};
954         }
955         double[] out = new double[n];
956         for (int i = 0; i < n; i++) {
957             out[i] = values.get(i);
958         }
959         return out;
960     }
961 
962     /**
963      * Keeps track of the number of occurrences of distinct T instances
964      * added via {@link #addValue(Object)}.
965      *
966      * @param <T> type of objects being tracked
967      */
968     public static class Frequency<T> {
969         private Map<T, Integer> counts = new HashMap<>();
970         public void addValue(T value) {
971            Integer old = counts.put(value, 0);
972            if (old != null) {
973                counts.put(value, old++);
974            }
975         }
976         public int getCount(T value) {
977            Integer ret = counts.get(value);
978            return ret == null ? 0 : ret;
979         }
980     }
981 
982     /**
983      * Stripped down implementation of StreamingStatistics from o.h.stat.descriptive.
984      * Actually holds all values in memory, so not suitable for very large streams of data.
985      */
986     public static class SimpleStatistics {
987         private final List<Double> values = new ArrayList<>();
988         public void addValue(double value) {
989             values.add(value);
990         }
991         public double getMean() {
992             return mean(unpack(values));
993         }
994         public double getStandardDeviation() {
995             return standardDeviation(unpack(values));
996         }
997         public double getMin() {
998             return min(unpack(values));
999         }
1000         public double getMax() {
1001             return max(unpack(values));
1002         }
1003         public double getMedian() {
1004             return median(unpack(values));
1005         }
1006         public double getVariance() {
1007             return variance(unpack(values));
1008         }
1009         public long getN() {
1010             return values.size();
1011         }
1012     }
1013 
1014     /**
1015      * Stripped-down version of the bivariate regression class with the same name
1016      * in o.h.stat.regression.
1017      * Always estimates the model with an intercept term.
1018      */
1019     public static class SimpleRegression {
1020         private double sumX = 0d;
1021         private double sumXX = 0d;
1022         private double sumY = 0d;
1023         private double sumXY = 0d;
1024         private long n = 0;
1025         private double xbar = 0;
1026         private double ybar = 0;
1027 
1028         public void addData(double x, double y) {
1029             if (n == 0) {
1030                 xbar = x;
1031                 ybar = y;
1032             } else {
1033                 final double fact1 = 1.0 + n;
1034                 final double fact2 = n / (1.0 + n);
1035                 final double dx = x - xbar;
1036                 final double dy = y - ybar;
1037                 sumXX += dx * dx * fact2;
1038                 sumXY += dx * dy * fact2;
1039                 xbar += dx / fact1;
1040                 ybar += dy / fact1;
1041             }
1042             sumX += x;
1043             sumY += y;
1044             n++;
1045         }
1046 
1047         public double getSlope() {
1048             if (n < 2) {
1049                 return Double.NaN; //not enough data
1050             }
1051             if (FastMath.abs(sumXX) < 10 * Double.MIN_VALUE) {
1052                 return Double.NaN; //not enough variation in x
1053             }
1054             return sumXY / sumXX;
1055         }
1056 
1057         public double getIntercept() {
1058             return (sumY - getSlope() * sumX) / n;
1059         }
1060     }
1061 }