View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.stat.regression;
23  
24  import org.hipparchus.exception.MathIllegalArgumentException;
25  import org.hipparchus.exception.NullArgumentException;
26  import org.hipparchus.linear.RealMatrix;
27  import org.hipparchus.linear.RealVector;
28  import org.junit.jupiter.api.BeforeEach;
29  import org.junit.jupiter.api.Test;
30  
31  import static org.junit.jupiter.api.Assertions.assertEquals;
32  import static org.junit.jupiter.api.Assertions.assertThrows;
33  import static org.junit.jupiter.api.Assertions.assertTrue;
34  
35  
36  public abstract class MultipleLinearRegressionAbstractTest {
37  
38      protected AbstractMultipleLinearRegression regression;
39  
40      @BeforeEach
41      public void setUp(){
42          regression = createRegression();
43      }
44  
45      protected abstract AbstractMultipleLinearRegression createRegression();
46  
47      protected abstract int getNumberOfRegressors();
48  
49      protected abstract int getSampleSize();
50  
51      @Test
52      public void canEstimateRegressionParameters(){
53          double[] beta = regression.estimateRegressionParameters();
54          assertEquals(getNumberOfRegressors(), beta.length);
55      }
56  
57      @Test
58      public void canEstimateResiduals(){
59          double[] e = regression.estimateResiduals();
60          assertEquals(getSampleSize(), e.length);
61      }
62  
63      @Test
64      public void canEstimateRegressionParametersVariance(){
65          double[][] variance = regression.estimateRegressionParametersVariance();
66          assertEquals(getNumberOfRegressors(), variance.length);
67      }
68  
69      @Test
70      public void canEstimateRegressandVariance(){
71          if (getSampleSize() > getNumberOfRegressors()) {
72              double variance = regression.estimateRegressandVariance();
73              assertTrue(variance > 0.0);
74          }
75      }
76  
77      /**
78       * Verifies that newSampleData methods consistently insert unitary columns
79       * in design matrix.  Confirms the fix for MATH-411.
80       */
81      @Test
82      public void testNewSample() {
83          double[] design = new double[] {
84            1, 19, 22, 33,
85            2, 20, 30, 40,
86            3, 25, 35, 45,
87            4, 27, 37, 47
88          };
89          double[] y = new double[] {1, 2, 3, 4};
90          double[][] x = new double[][] {
91            {19, 22, 33},
92            {20, 30, 40},
93            {25, 35, 45},
94            {27, 37, 47}
95          };
96          AbstractMultipleLinearRegression regression = createRegression();
97          regression.newSampleData(design, 4, 3);
98          RealMatrix flatX = regression.getX().copy();
99          RealVector flatY = regression.getY().copy();
100         regression.newXSampleData(x);
101         regression.newYSampleData(y);
102         assertEquals(flatX, regression.getX());
103         assertEquals(flatY, regression.getY());
104 
105         // No intercept
106         regression.setNoIntercept(true);
107         regression.newSampleData(design, 4, 3);
108         flatX = regression.getX().copy();
109         flatY = regression.getY().copy();
110         regression.newXSampleData(x);
111         regression.newYSampleData(y);
112         assertEquals(flatX, regression.getX());
113         assertEquals(flatY, regression.getY());
114     }
115 
116     @Test
117     public void testNewSampleNullData() {
118         assertThrows(NullArgumentException.class, () -> {
119             double[] data = null;
120             createRegression().newSampleData(data, 2, 3);
121         });
122     }
123 
124     @Test
125     public void testNewSampleInvalidData() {
126         assertThrows(MathIllegalArgumentException.class, () -> {
127             double[] data = new double[]{1, 2, 3, 4};
128             createRegression().newSampleData(data, 2, 3);
129         });
130     }
131 
132     @Test
133     public void testNewSampleInsufficientData() {
134         assertThrows(MathIllegalArgumentException.class, () -> {
135             double[] data = new double[]{1, 2, 3, 4};
136             createRegression().newSampleData(data, 1, 3);
137         });
138     }
139 
140     @Test
141     public void testXSampleDataNull() {
142         assertThrows(NullArgumentException.class, () -> {
143             createRegression().newXSampleData(null);
144         });
145     }
146 
147     @Test
148     public void testYSampleDataNull() {
149         assertThrows(NullArgumentException.class, () -> {
150             createRegression().newYSampleData(null);
151         });
152     }
153 
154 }