1
2
3
4
5
6
7
8
9
10
11
12
13
14 package org.hipparchus.optim.nonlinear.vector.leastsquares;
15
16 import org.hipparchus.UnitTestUtils;
17 import org.hipparchus.linear.ArrayRealVector;
18 import org.hipparchus.linear.DiagonalMatrix;
19 import org.hipparchus.linear.RealVector;
20 import org.hipparchus.util.FastMath;
21 import org.junit.jupiter.api.Disabled;
22 import org.junit.jupiter.api.Test;
23
24 import java.awt.geom.Point2D;
25 import java.util.ArrayList;
26 import java.util.List;
27
28 import static org.junit.jupiter.api.Assertions.assertEquals;
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47 class EvaluationTestValidation {
48
49 private static final int MONTE_CARLO_RUNS = Integer.parseInt(System.getProperty("mcRuns",
50 "100"));
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67 @Disabled
68 @Test
69 void testParametersErrorMonteCarloObservations() {
70
71 final double yError = 15;
72
73
74 final double slope = 123.456;
75 final double offset = -98.765;
76
77
78 final RandomStraightLinePointGenerator lineGenerator
79 = new RandomStraightLinePointGenerator(slope, offset,
80 yError,
81 -1e3, 1e4,
82 138577L);
83
84
85 final int numObs = 100;
86
87 final int numParams = 2;
88
89
90 final UnitTestUtils.SimpleStatistics[] paramsFoundByDirectSolution = new UnitTestUtils.SimpleStatistics[numParams];
91
92
93 final UnitTestUtils.SimpleStatistics[] sigmaEstimate = new UnitTestUtils.SimpleStatistics[numParams];
94
95
96 for (int i = 0; i < numParams; i++) {
97 paramsFoundByDirectSolution[i] = new UnitTestUtils.SimpleStatistics();
98 sigmaEstimate[i] = new UnitTestUtils.SimpleStatistics();
99 }
100
101 final RealVector init = new ArrayRealVector(new double[]{ slope, offset }, false);
102
103
104 final int mcRepeat = MONTE_CARLO_RUNS;
105 int mcCount = 0;
106 while (mcCount < mcRepeat) {
107
108 final Point2D.Double[] obs = lineGenerator.generate(numObs);
109
110 final StraightLineProblem problem = new StraightLineProblem(yError);
111 for (int i = 0; i < numObs; i++) {
112 final Point2D.Double p = obs[i];
113 problem.addPoint(p.x, p.y);
114 }
115
116
117 final double[] regress = problem.solve();
118
119
120
121 final LeastSquaresProblem lsp = builder(problem).build();
122
123 final RealVector sigma = lsp.evaluate(init).getSigma(1e-14);
124
125
126 for (int i = 0; i < numParams; i++) {
127 paramsFoundByDirectSolution[i].addValue(regress[i]);
128 sigmaEstimate[i].addValue(sigma.getEntry(i));
129 }
130
131
132 ++mcCount;
133 }
134
135
136 final String line = "--------------------------------------------------------------";
137 System.out.println(" True value Mean Std deviation");
138 for (int i = 0; i < numParams; i++) {
139 System.out.println(line);
140 System.out.println("Parameter #" + i);
141
142 System.out.printf(" %+.6e %+.6e %+.6e\n",
143 init.getEntry(i),
144 paramsFoundByDirectSolution[i].getMean(),
145 paramsFoundByDirectSolution[i].getStandardDeviation());
146
147 System.out.printf("sigma: %+.6e (%+.6e)\n",
148 sigmaEstimate[i].getMean(),
149 sigmaEstimate[i].getStandardDeviation());
150 }
151 System.out.println(line);
152
153
154 for (int i = 0; i < numParams; i++) {
155 assertEquals(paramsFoundByDirectSolution[i].getStandardDeviation(),
156 sigmaEstimate[i].getMean(),
157 8e-2);
158 }
159 }
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185 @Disabled
186 @Test
187 void testParametersErrorMonteCarloParameters() {
188
189 final double yError = 15;
190
191
192 final double slope = 123.456;
193 final double offset = -98.765;
194
195
196 final RandomStraightLinePointGenerator lineGenerator
197 = new RandomStraightLinePointGenerator(slope, offset,
198 yError,
199 -1e3, 1e4,
200 13839013L);
201
202
203 final int numObs = 10;
204
205
206
207 final Point2D.Double[] obs = lineGenerator.generate(numObs);
208
209 final StraightLineProblem problem = new StraightLineProblem(yError);
210 for (int i = 0; i < numObs; i++) {
211 final Point2D.Double p = obs[i];
212 problem.addPoint(p.x, p.y);
213 }
214
215
216 final RealVector regress = new ArrayRealVector(problem.solve(), false);
217
218
219 final LeastSquaresProblem lsp = builder(problem).build();
220
221
222
223 final double bestChi2N = getChi2N(lsp, regress);
224 final RealVector sigma = lsp.evaluate(regress).getSigma(1e-14);
225
226
227 final int mcRepeat = MONTE_CARLO_RUNS;
228 final int gridSize = (int) FastMath.sqrt(mcRepeat);
229
230
231
232
233
234 final List<double[]> paramsAndChi2 = new ArrayList<double[]>(gridSize * gridSize);
235
236 final double slopeRange = 10 * sigma.getEntry(0);
237 final double offsetRange = 10 * sigma.getEntry(1);
238 final double minSlope = slope - 0.5 * slopeRange;
239 final double minOffset = offset - 0.5 * offsetRange;
240 final double deltaSlope = slopeRange/ gridSize;
241 final double deltaOffset = offsetRange / gridSize;
242 for (int i = 0; i < gridSize; i++) {
243 final double s = minSlope + i * deltaSlope;
244 for (int j = 0; j < gridSize; j++) {
245 final double o = minOffset + j * deltaOffset;
246 final double chi2N = getChi2N(lsp,
247 new ArrayRealVector(new double[] {s, o}, false));
248
249 paramsAndChi2.add(new double[] {s, o, chi2N});
250 }
251 }
252
253
254
255
256
257
258 final double chi2NPlusOne = bestChi2N + 1;
259 int numLarger = 0;
260
261 final String lineFmt = "%+.10e %+.10e %.8e\n";
262
263
264 System.out.printf(lineFmt, regress.getEntry(0), regress.getEntry(1), bestChi2N);
265 System.out.println();
266
267
268 for (double[] d : paramsAndChi2) {
269 if (d[2] <= chi2NPlusOne) {
270 System.out.printf(lineFmt, d[0], d[1], d[2]);
271 }
272 }
273 System.out.println();
274
275
276 for (double[] d : paramsAndChi2) {
277 if (d[2] > chi2NPlusOne) {
278 ++numLarger;
279 System.out.printf(lineFmt, d[0], d[1], d[2]);
280 }
281 }
282 System.out.println();
283
284 System.out.println("# sigma=" + sigma.toString());
285 System.out.println("# " + numLarger + " sets filtered out");
286 }
287
288 LeastSquaresBuilder builder(StraightLineProblem problem){
289 return new LeastSquaresBuilder()
290 .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
291 .target(problem.target())
292 .weight(new DiagonalMatrix(problem.weight()))
293
294 .start(new double[2]);
295 }
296
297
298
299 private double getChi2N(LeastSquaresProblem lsp,
300 RealVector params) {
301 final double cost = lsp.evaluate(params).getCost();
302 return cost * cost / (lsp.getObservationSize() - params.getDimension());
303 }
304 }
305