View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.stat.correlation;
23  
24  import java.util.Arrays;
25  
26  import org.hipparchus.exception.MathIllegalArgumentException;
27  import org.hipparchus.linear.BlockRealMatrix;
28  import org.hipparchus.linear.MatrixUtils;
29  import org.hipparchus.linear.RealMatrix;
30  import org.hipparchus.util.FastMath;
31  import org.hipparchus.util.MathArrays;
32  
33  /**
34   * Implementation of Kendall's Tau-b rank correlation.
35   * <p>
36   * A pair of observations (x<sub>1</sub>, y<sub>1</sub>) and
37   * (x<sub>2</sub>, y<sub>2</sub>) are considered <i>concordant</i> if
38   * x<sub>1</sub> &lt; x<sub>2</sub> and y<sub>1</sub> &lt; y<sub>2</sub>
39   * or x<sub>2</sub> &lt; x<sub>1</sub> and y<sub>2</sub> &lt; y<sub>1</sub>.
40   * The pair is <i>discordant</i> if x<sub>1</sub> &lt; x<sub>2</sub> and
41   * y<sub>2</sub> &lt; y<sub>1</sub> or x<sub>2</sub> &lt; x<sub>1</sub> and
42   * y<sub>1</sub> &lt; y<sub>2</sub>.  If either x<sub>1</sub> = x<sub>2</sub>
43   * or y<sub>1</sub> = y<sub>2</sub>, the pair is neither concordant nor
44   * discordant.
45   * <p>
46   * Kendall's Tau-b is defined as:
47   * \[
48   * \tau_b = \frac{n_c - n_d}{\sqrt{(n_0 - n_1) (n_0 - n_2)}}
49   * \]
50   * <p>
51   * where:
52   * <ul>
53   *     <li>n<sub>0</sub> = n * (n - 1) / 2</li>
54   *     <li>n<sub>c</sub> = Number of concordant pairs</li>
55   *     <li>n<sub>d</sub> = Number of discordant pairs</li>
56   *     <li>n<sub>1</sub> = sum of t<sub>i</sub> * (t<sub>i</sub> - 1) / 2 for all i</li>
57   *     <li>n<sub>2</sub> = sum of u<sub>j</sub> * (u<sub>j</sub> - 1) / 2 for all j</li>
58   *     <li>t<sub>i</sub> = Number of tied values in the i<sup>th</sup> group of ties in x</li>
59   *     <li>u<sub>j</sub> = Number of tied values in the j<sup>th</sup> group of ties in y</li>
60   * </ul>
61   * <p>
62   * This implementation uses the O(n log n) algorithm described in
63   * William R. Knight's 1966 paper "A Computer Method for Calculating
64   * Kendall's Tau with Ungrouped Data" in the Journal of the American
65   * Statistical Association.
66   *
67   * @see <a href="http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient">
68   * Kendall tau rank correlation coefficient (Wikipedia)</a>
69   * @see <a href="http://www.jstor.org/stable/2282833">A Computer
70   * Method for Calculating Kendall's Tau with Ungrouped Data</a>
71   */
72  public class KendallsCorrelation {
73  
74      /** correlation matrix */
75      private final RealMatrix correlationMatrix;
76  
77      /**
78       * Create a KendallsCorrelation instance without data.
79       */
80      public KendallsCorrelation() {
81          correlationMatrix = null;
82      }
83  
84      /**
85       * Create a KendallsCorrelation from a rectangular array
86       * whose columns represent values of variables to be correlated.
87       *
88       * @param data rectangular array with columns representing variables
89       * @throws IllegalArgumentException if the input data array is not
90       * rectangular with at least two rows and two columns.
91       */
92      public KendallsCorrelation(double[][] data) {
93          this(MatrixUtils.createRealMatrix(data));
94      }
95  
96      /**
97       * Create a KendallsCorrelation from a RealMatrix whose columns
98       * represent variables to be correlated.
99       *
100      * @param matrix matrix with columns representing variables to correlate
101      */
102     public KendallsCorrelation(RealMatrix matrix) {
103         correlationMatrix = computeCorrelationMatrix(matrix);
104     }
105 
106     /**
107      * Returns the correlation matrix.
108      *
109      * @return correlation matrix
110      */
111     public RealMatrix getCorrelationMatrix() {
112         return correlationMatrix;
113     }
114 
115     /**
116      * Computes the Kendall's Tau rank correlation matrix for the columns of
117      * the input matrix.
118      *
119      * @param matrix matrix with columns representing variables to correlate
120      * @return correlation matrix
121      */
122     public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
123         int nVars = matrix.getColumnDimension();
124         RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars);
125         for (int i = 0; i < nVars; i++) {
126             for (int j = 0; j < i; j++) {
127                 double corr = correlation(matrix.getColumn(i), matrix.getColumn(j));
128                 outMatrix.setEntry(i, j, corr);
129                 outMatrix.setEntry(j, i, corr);
130             }
131             outMatrix.setEntry(i, i, 1d);
132         }
133         return outMatrix;
134     }
135 
136     /**
137      * Computes the Kendall's Tau rank correlation matrix for the columns of
138      * the input rectangular array.  The columns of the array represent values
139      * of variables to be correlated.
140      *
141      * @param matrix matrix with columns representing variables to correlate
142      * @return correlation matrix
143      */
144     public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
145        return computeCorrelationMatrix(new BlockRealMatrix(matrix));
146     }
147 
148     /**
149      * Computes the Kendall's Tau rank correlation coefficient between the two arrays.
150      *
151      * @param xArray first data array
152      * @param yArray second data array
153      * @return Returns Kendall's Tau rank correlation coefficient for the two arrays
154      * @throws MathIllegalArgumentException if the arrays lengths do not match
155      */
156     public double correlation(final double[] xArray, final double[] yArray)
157             throws MathIllegalArgumentException {
158 
159         MathArrays.checkEqualLength(xArray, yArray);
160 
161         final int n = xArray.length;
162         final long numPairs = sum(n - 1);
163 
164         DoublePair[] pairs = new DoublePair[n];
165         for (int i = 0; i < n; i++) {
166             pairs[i] = new DoublePair(xArray[i], yArray[i]);
167         }
168 
169         Arrays.sort(pairs, (p1, p2) -> {
170             int compareKey = Double.compare(p1.getFirst(), p2.getFirst());
171             return compareKey != 0 ? compareKey : Double.compare(p1.getSecond(), p2.getSecond());
172         });
173 
174         long tiedXPairs = 0;
175         long tiedXYPairs = 0;
176         long consecutiveXTies = 1;
177         long consecutiveXYTies = 1;
178         DoublePair prev = pairs[0];
179         for (int i = 1; i < n; i++) {
180             final DoublePair curr = pairs[i];
181             if (Double.compare(curr.getFirst(), prev.getFirst()) == 0) {
182                 consecutiveXTies++;
183                 if (Double.compare(curr.getSecond(), prev.getSecond()) == 0) {
184                     consecutiveXYTies++;
185                 } else {
186                     tiedXYPairs += sum(consecutiveXYTies - 1);
187                     consecutiveXYTies = 1;
188                 }
189             } else {
190                 tiedXPairs += sum(consecutiveXTies - 1);
191                 consecutiveXTies = 1;
192                 tiedXYPairs += sum(consecutiveXYTies - 1);
193                 consecutiveXYTies = 1;
194             }
195             prev = curr;
196         }
197         tiedXPairs += sum(consecutiveXTies - 1);
198         tiedXYPairs += sum(consecutiveXYTies - 1);
199 
200         long swaps = 0;
201         DoublePair[] pairsDestination = new DoublePair[n];
202         for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) {
203             for (int offset = 0; offset < n; offset += 2 * segmentSize) {
204                 int i = offset;
205                 final int iEnd = FastMath.min(i + segmentSize, n);
206                 int j = iEnd;
207                 final int jEnd = FastMath.min(j + segmentSize, n);
208 
209                 int copyLocation = offset;
210                 while (i < iEnd || j < jEnd) {
211                     if (i < iEnd) {
212                         if (j < jEnd) {
213                             if (Double.compare(pairs[i].getSecond(), pairs[j].getSecond()) <= 0) {
214                                 pairsDestination[copyLocation] = pairs[i];
215                                 i++;
216                             } else {
217                                 pairsDestination[copyLocation] = pairs[j];
218                                 j++;
219                                 swaps += iEnd - i;
220                             }
221                         } else {
222                             pairsDestination[copyLocation] = pairs[i];
223                             i++;
224                         }
225                     } else {
226                         pairsDestination[copyLocation] = pairs[j];
227                         j++;
228                     }
229                     copyLocation++;
230                 }
231             }
232             final DoublePair[] pairsTemp = pairs;
233             pairs = pairsDestination;
234             pairsDestination = pairsTemp;
235         }
236 
237         long tiedYPairs = 0;
238         long consecutiveYTies = 1;
239         prev = pairs[0];
240         for (int i = 1; i < n; i++) {
241             final DoublePair curr = pairs[i];
242             if (Double.compare(curr.getSecond(), prev.getSecond()) == 0) {
243                 consecutiveYTies++;
244             } else {
245                 tiedYPairs += sum(consecutiveYTies - 1);
246                 consecutiveYTies = 1;
247             }
248             prev = curr;
249         }
250         tiedYPairs += sum(consecutiveYTies - 1);
251 
252         final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps;
253         final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs);
254         return concordantMinusDiscordant / FastMath.sqrt(nonTiedPairsMultiplied);
255     }
256 
257     /**
258      * Returns the sum of the number from 1 .. n according to Gauss' summation formula:
259      * \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \]
260      *
261      * @param n the summation end
262      * @return the sum of the number from 1 to n
263      */
264     private static long sum(long n) {
265         return n * (n + 1) / 2l;
266     }
267 
268     /**
269      * Helper data structure holding a (double, double) pair.
270      */
271     private static class DoublePair {
272         /** The first value */
273         private final double first;
274         /** The second value */
275         private final double second;
276 
277         /**
278          * @param first first value.
279          * @param second second value.
280          */
281         DoublePair(double first, double second) {
282             this.first = first;
283             this.second = second;
284         }
285 
286         /** @return the first value. */
287         public double getFirst() {
288             return first;
289         }
290 
291         /** @return the second value. */
292         public double getSecond() {
293             return second;
294         }
295 
296     }
297 
298 }