View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.stat.regression;
23  
24  import org.hipparchus.exception.LocalizedCoreFormats;
25  import org.hipparchus.exception.MathIllegalArgumentException;
26  import org.hipparchus.exception.NullArgumentException;
27  import org.hipparchus.linear.Array2DRowRealMatrix;
28  import org.hipparchus.linear.ArrayRealVector;
29  import org.hipparchus.linear.RealMatrix;
30  import org.hipparchus.linear.RealVector;
31  import org.hipparchus.stat.LocalizedStatFormats;
32  import org.hipparchus.stat.descriptive.moment.Variance;
33  import org.hipparchus.util.FastMath;
34  import org.hipparchus.util.MathUtils;
35  
36  /**
37   * Abstract base class for implementations of MultipleLinearRegression.
38   */
39  public abstract class AbstractMultipleLinearRegression implements
40          MultipleLinearRegression {
41  
42      /** X sample data. */
43      private RealMatrix xMatrix;
44  
45      /** Y sample data. */
46      private RealVector yVector;
47  
48      /** Whether or not the regression model includes an intercept.  True means no intercept. */
49      private boolean noIntercept;
50  
51      /** Empty constructor.
52       * <p>
53       * This constructor is not strictly necessary, but it prevents spurious
54       * javadoc warnings with JDK 18 and later.
55       * </p>
56       * @since 3.0
57       */
58      public AbstractMultipleLinearRegression() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
59          // nothing to do
60      }
61  
62      /** Get the X sample data.
63       * @return the X sample data.
64       */
65      protected RealMatrix getX() {
66          return xMatrix;
67      }
68  
69      /** Get the Y sample data.
70       * @return the Y sample data.
71       */
72      protected RealVector getY() {
73          return yVector;
74      }
75  
76      /** Chekc if the model has no intercept term.
77       * @return true if the model has no intercept term; false otherwise
78       */
79      public boolean isNoIntercept() {
80          return noIntercept;
81      }
82  
83      /** Set intercept flag.
84       * @param noIntercept true means the model is to be estimated without an intercept term
85       */
86      public void setNoIntercept(boolean noIntercept) {
87          this.noIntercept = noIntercept;
88      }
89  
90      /**
91       * <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
92       * </p>
93       * <p>Assumes that rows are concatenated with y values first in each row.  For example, an input
94       * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with
95       * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two
96       * independent variables, as below:
97       * </p>
98       * <pre>
99       *   y   x[0]  x[1]
100      *   --------------
101      *   1     2     3
102      *   4     5     6
103      *   7     8     9
104      * </pre>
105      * <p>Note that there is no need to add an initial unitary column (column of 1's) when
106      * specifying a model including an intercept term.  If {@link #isNoIntercept()} is <code>true</code>,
107      * the X matrix will be created without an initial column of "1"s; otherwise this column will
108      * be added.
109      * </p>
110      * <p>Throws IllegalArgumentException if any of the following preconditions fail:</p>
111      * <ul><li><code>data</code> cannot be null</li>
112      * <li><code>data.length = nobs * (nvars + 1)</code></li>
113      * <li><code>nobs &gt; nvars</code></li></ul>
114      *
115      * @param data input data array
116      * @param nobs number of observations (rows)
117      * @param nvars number of independent variables (columns, not counting y)
118      * @throws NullArgumentException if the data array is null
119      * @throws MathIllegalArgumentException if the length of the data array is not equal
120      * to <code>nobs * (nvars + 1)</code>
121      * @throws MathIllegalArgumentException if <code>nobs</code> is less than
122      * <code>nvars + 1</code>
123      */
124     public void newSampleData(double[] data, int nobs, int nvars) {
125         MathUtils.checkNotNull(data, LocalizedCoreFormats.INPUT_ARRAY);
126         MathUtils.checkDimension(data.length, nobs * (nvars + 1));
127         if (nobs <= nvars) {
128             throw new MathIllegalArgumentException(LocalizedCoreFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE,
129                                                    nobs, nvars + 1);
130         }
131         double[] y = new double[nobs];
132         final int cols = noIntercept ? nvars: nvars + 1;
133         double[][] x = new double[nobs][cols];
134         int pointer = 0;
135         for (int i = 0; i < nobs; i++) {
136             y[i] = data[pointer++];
137             if (!noIntercept) {
138                 x[i][0] = 1.0d;
139             }
140             for (int j = noIntercept ? 0 : 1; j < cols; j++) {
141                 x[i][j] = data[pointer++];
142             }
143         }
144         this.xMatrix = new Array2DRowRealMatrix(x);
145         this.yVector = new ArrayRealVector(y);
146     }
147 
148     /**
149      * Loads new y sample data, overriding any previous data.
150      *
151      * @param y the array representing the y sample
152      * @throws NullArgumentException if y is null
153      * @throws MathIllegalArgumentException if y is empty
154      */
155     protected void newYSampleData(double[] y) {
156         if (y == null) {
157             throw new NullArgumentException();
158         }
159         if (y.length == 0) {
160             throw new MathIllegalArgumentException(LocalizedCoreFormats.NO_DATA);
161         }
162         this.yVector = new ArrayRealVector(y);
163     }
164 
165     /**
166      * <p>Loads new x sample data, overriding any previous data.
167      * </p>
168      * <p>
169      * The input <code>x</code> array should have one row for each sample
170      * observation, with columns corresponding to independent variables.
171      * For example, if
172      * </p>
173      * <pre>
174      * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre>
175      * <p>
176      * then <code>setXSampleData(x) </code> results in a model with two independent
177      * variables and 3 observations:
178      * </p>
179      * <pre>
180      *   x[0]  x[1]
181      *   ----------
182      *     1    2
183      *     3    4
184      *     5    6
185      * </pre>
186      * <p>Note that there is no need to add an initial unitary column (column of 1's) when
187      * specifying a model including an intercept term.
188      * </p>
189      * @param x the rectangular array representing the x sample
190      * @throws NullArgumentException if x is null
191      * @throws MathIllegalArgumentException if x is empty
192      * @throws MathIllegalArgumentException if x is not rectangular
193      */
194     protected void newXSampleData(double[][] x) {
195         if (x == null) {
196             throw new NullArgumentException();
197         }
198         if (x.length == 0) {
199             throw new MathIllegalArgumentException(LocalizedCoreFormats.NO_DATA);
200         }
201         if (noIntercept) {
202             this.xMatrix = new Array2DRowRealMatrix(x, true);
203         } else { // Augment design matrix with initial unitary column
204             final int nVars = x[0].length;
205             final double[][] xAug = new double[x.length][nVars + 1];
206             for (int i = 0; i < x.length; i++) {
207                 MathUtils.checkDimension(x[i].length, nVars);
208                 xAug[i][0] = 1.0d;
209                 System.arraycopy(x[i], 0, xAug[i], 1, nVars);
210             }
211             this.xMatrix = new Array2DRowRealMatrix(xAug, false);
212         }
213     }
214 
215     /**
216      * Validates sample data.
217      * <p>Checks that</p>
218      * <ul><li>Neither x nor y is null or empty;</li>
219      * <li>The length (i.e. number of rows) of x equals the length of y</li>
220      * <li>x has at least one more row than it has columns (i.e. there is
221      * sufficient data to estimate regression coefficients for each of the
222      * columns in x plus an intercept.</li>
223      * </ul>
224      *
225      * @param x the [n,k] array representing the x data
226      * @param y the [n,1] array representing the y data
227      * @throws NullArgumentException if {@code x} or {@code y} is null
228      * @throws MathIllegalArgumentException if {@code x} and {@code y} do not
229      * have the same length
230      * @throws MathIllegalArgumentException if {@code x} or {@code y} are zero-length
231      * @throws MathIllegalArgumentException if the number of rows of {@code x}
232      * is not larger than the number of columns + 1 if the model has an intercept;
233      * or the number of columns if there is no intercept term
234      */
235     protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException {
236         if ((x == null) || (y == null)) {
237             throw new NullArgumentException();
238         }
239         MathUtils.checkDimension(x.length, y.length);
240         if (x.length == 0) {  // Must be no y data either
241             throw new MathIllegalArgumentException(LocalizedCoreFormats.NO_DATA);
242         }
243         if (x[0].length + (noIntercept ? 0 : 1) > x.length) {
244             throw new MathIllegalArgumentException(
245                     LocalizedStatFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
246                     x.length, x[0].length);
247         }
248     }
249 
250     /**
251      * Validates that the x data and covariance matrix have the same
252      * number of rows and that the covariance matrix is square.
253      *
254      * @param x the [n,k] array representing the x sample
255      * @param covariance the [n,n] array representing the covariance matrix
256      * @throws MathIllegalArgumentException if the number of rows in x is not equal
257      * to the number of rows in covariance
258      * @throws MathIllegalArgumentException if the covariance matrix is not square
259      */
260     protected void validateCovarianceData(double[][] x, double[][] covariance) {
261         MathUtils.checkDimension(x.length, covariance.length);
262         if (covariance.length > 0 && covariance.length != covariance[0].length) {
263             throw new MathIllegalArgumentException(LocalizedCoreFormats.NON_SQUARE_MATRIX,
264                                                    covariance.length, covariance[0].length);
265         }
266     }
267 
268     /**
269      * {@inheritDoc}
270      */
271     @Override
272     public double[] estimateRegressionParameters() {
273         RealVector b = calculateBeta();
274         return b.toArray();
275     }
276 
277     /**
278      * {@inheritDoc}
279      */
280     @Override
281     public double[] estimateResiduals() {
282         RealVector b = calculateBeta();
283         RealVector e = yVector.subtract(xMatrix.operate(b));
284         return e.toArray();
285     }
286 
287     /**
288      * {@inheritDoc}
289      */
290     @Override
291     public double[][] estimateRegressionParametersVariance() {
292         return calculateBetaVariance().getData();
293     }
294 
295     /**
296      * {@inheritDoc}
297      */
298     @Override
299     public double[] estimateRegressionParametersStandardErrors() {
300         double[][] betaVariance = estimateRegressionParametersVariance();
301         double sigma = calculateErrorVariance();
302         int length = betaVariance[0].length;
303         double[] result = new double[length];
304         for (int i = 0; i < length; i++) {
305             result[i] = FastMath.sqrt(sigma * betaVariance[i][i]);
306         }
307         return result;
308     }
309 
310     /**
311      * {@inheritDoc}
312      */
313     @Override
314     public double estimateRegressandVariance() {
315         return calculateYVariance();
316     }
317 
318     /**
319      * Estimates the variance of the error.
320      *
321      * @return estimate of the error variance
322      */
323     public double estimateErrorVariance() {
324         return calculateErrorVariance();
325 
326     }
327 
328     /**
329      * Estimates the standard error of the regression.
330      *
331      * @return regression standard error
332      */
333     public double estimateRegressionStandardError() {
334         return FastMath.sqrt(estimateErrorVariance());
335     }
336 
337     /**
338      * Calculates the beta of multiple linear regression in matrix notation.
339      *
340      * @return beta
341      */
342     protected abstract RealVector calculateBeta();
343 
344     /**
345      * Calculates the beta variance of multiple linear regression in matrix
346      * notation.
347      *
348      * @return beta variance
349      */
350     protected abstract RealMatrix calculateBetaVariance();
351 
352 
353     /**
354      * Calculates the variance of the y values.
355      *
356      * @return Y variance
357      */
358     protected double calculateYVariance() {
359         return new Variance().evaluate(yVector.toArray());
360     }
361 
362     /**
363      * <p>Calculates the variance of the error term.</p>
364      * Uses the formula <pre>
365      * var(u) = u &middot; u / (n - k)
366      * </pre>
367      * where n and k are the row and column dimensions of the design
368      * matrix X.
369      *
370      * @return error variance estimate
371      */
372     protected double calculateErrorVariance() {
373         RealVector residuals = calculateResiduals();
374         return residuals.dotProduct(residuals) /
375                (xMatrix.getRowDimension() - xMatrix.getColumnDimension());
376     }
377 
378     /**
379      * Calculates the residuals of multiple linear regression in matrix
380      * notation.
381      *
382      * <pre>
383      * u = y - X * b
384      * </pre>
385      *
386      * @return The residuals [n,1] matrix
387      */
388     protected RealVector calculateResiduals() {
389         RealVector b = calculateBeta();
390         return yVector.subtract(xMatrix.operate(b));
391     }
392 
393 }