1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.stat.regression;
23
24 import org.hipparchus.exception.MathIllegalArgumentException;
25 import org.hipparchus.exception.NullArgumentException;
26 import org.hipparchus.linear.RealMatrix;
27 import org.hipparchus.linear.RealVector;
28 import org.junit.jupiter.api.BeforeEach;
29 import org.junit.jupiter.api.Test;
30
31 import static org.junit.jupiter.api.Assertions.assertEquals;
32 import static org.junit.jupiter.api.Assertions.assertThrows;
33 import static org.junit.jupiter.api.Assertions.assertTrue;
34
35
36 public abstract class MultipleLinearRegressionAbstractTest {
37
38 protected AbstractMultipleLinearRegression regression;
39
40 @BeforeEach
41 public void setUp(){
42 regression = createRegression();
43 }
44
45 protected abstract AbstractMultipleLinearRegression createRegression();
46
47 protected abstract int getNumberOfRegressors();
48
49 protected abstract int getSampleSize();
50
51 @Test
52 public void canEstimateRegressionParameters(){
53 double[] beta = regression.estimateRegressionParameters();
54 assertEquals(getNumberOfRegressors(), beta.length);
55 }
56
57 @Test
58 public void canEstimateResiduals(){
59 double[] e = regression.estimateResiduals();
60 assertEquals(getSampleSize(), e.length);
61 }
62
63 @Test
64 public void canEstimateRegressionParametersVariance(){
65 double[][] variance = regression.estimateRegressionParametersVariance();
66 assertEquals(getNumberOfRegressors(), variance.length);
67 }
68
69 @Test
70 public void canEstimateRegressandVariance(){
71 if (getSampleSize() > getNumberOfRegressors()) {
72 double variance = regression.estimateRegressandVariance();
73 assertTrue(variance > 0.0);
74 }
75 }
76
77
78
79
80
81 @Test
82 public void testNewSample() {
83 double[] design = new double[] {
84 1, 19, 22, 33,
85 2, 20, 30, 40,
86 3, 25, 35, 45,
87 4, 27, 37, 47
88 };
89 double[] y = new double[] {1, 2, 3, 4};
90 double[][] x = new double[][] {
91 {19, 22, 33},
92 {20, 30, 40},
93 {25, 35, 45},
94 {27, 37, 47}
95 };
96 AbstractMultipleLinearRegression regression = createRegression();
97 regression.newSampleData(design, 4, 3);
98 RealMatrix flatX = regression.getX().copy();
99 RealVector flatY = regression.getY().copy();
100 regression.newXSampleData(x);
101 regression.newYSampleData(y);
102 assertEquals(flatX, regression.getX());
103 assertEquals(flatY, regression.getY());
104
105
106 regression.setNoIntercept(true);
107 regression.newSampleData(design, 4, 3);
108 flatX = regression.getX().copy();
109 flatY = regression.getY().copy();
110 regression.newXSampleData(x);
111 regression.newYSampleData(y);
112 assertEquals(flatX, regression.getX());
113 assertEquals(flatY, regression.getY());
114 }
115
116 @Test
117 public void testNewSampleNullData() {
118 assertThrows(NullArgumentException.class, () -> {
119 double[] data = null;
120 createRegression().newSampleData(data, 2, 3);
121 });
122 }
123
124 @Test
125 public void testNewSampleInvalidData() {
126 assertThrows(MathIllegalArgumentException.class, () -> {
127 double[] data = new double[]{1, 2, 3, 4};
128 createRegression().newSampleData(data, 2, 3);
129 });
130 }
131
132 @Test
133 public void testNewSampleInsufficientData() {
134 assertThrows(MathIllegalArgumentException.class, () -> {
135 double[] data = new double[]{1, 2, 3, 4};
136 createRegression().newSampleData(data, 1, 3);
137 });
138 }
139
140 @Test
141 public void testXSampleDataNull() {
142 assertThrows(NullArgumentException.class, () -> {
143 createRegression().newXSampleData(null);
144 });
145 }
146
147 @Test
148 public void testYSampleDataNull() {
149 assertThrows(NullArgumentException.class, () -> {
150 createRegression().newYSampleData(null);
151 });
152 }
153
154 }