1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.analysis.solvers;
23
24 import org.hipparchus.analysis.QuinticFunction;
25 import org.hipparchus.analysis.UnivariateFunction;
26 import org.hipparchus.analysis.XMinus5Function;
27 import org.hipparchus.analysis.function.Sin;
28 import org.hipparchus.exception.MathIllegalArgumentException;
29 import org.hipparchus.util.FastMath;
30 import org.junit.jupiter.api.Test;
31
32 import java.lang.reflect.Field;
33
34 import static org.junit.jupiter.api.Assertions.assertEquals;
35 import static org.junit.jupiter.api.Assertions.assertTrue;
36 import static org.junit.jupiter.api.Assertions.fail;
37
38
39
40
41
42
43 public abstract class BaseSecantSolverAbstractTest {
44
45
46
47 protected abstract UnivariateSolver getSolver();
48
49
50
51
52
53
54
55 protected abstract int[] getQuinticEvalCounts();
56
57 @Test
58 public void testSinZero() {
59
60
61
62 UnivariateFunction f = new Sin();
63 double result;
64 UnivariateSolver solver = getSolver();
65
66 result = solver.solve(100, f, 3, 4);
67
68
69 assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy());
70 assertTrue(solver.getEvaluations() <= 6);
71 result = solver.solve(100, f, 1, 4);
72
73
74 assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy());
75 assertTrue(solver.getEvaluations() <= 7);
76 }
77
78 @Test
79 public void testQuinticZero() {
80
81
82
83
84
85
86
87
88 UnivariateFunction f = new QuinticFunction();
89 double result;
90 UnivariateSolver solver = getSolver();
91 double atol = solver.getAbsoluteAccuracy();
92 int[] counts = getQuinticEvalCounts();
93
94
95 double[][] testsData = {{-0.2, 0.2, 0.0},
96 {-0.1, 0.3, 0.0},
97 {-0.3, 0.45, 0.0},
98 { 0.3, 0.7, 0.5},
99 { 0.2, 0.6, 0.5},
100 { 0.05, 0.95, 0.5},
101 { 0.85, 1.25, 1.0},
102 { 0.8, 1.2, 1.0},
103 { 0.85, 1.75, 1.0},
104 { 0.55, 1.45, 1.0},
105 { 0.85, 5.0, 1.0},
106 };
107 int maxIter = 500;
108
109 for(int i = 0; i < testsData.length; i++) {
110
111 if (counts[i] == -1) continue;
112
113
114 double[] testData = testsData[i];
115 result = solver.solve(maxIter, f, testData[0], testData[1]);
116
117
118
119
120 assertEquals(result, testData[2], atol);
121 assertTrue(solver.getEvaluations() <= counts[i] + 1,
122 "" + solver.getEvaluations() + " <= " + (counts[i] + 1));
123 }
124 }
125
126 @Test
127 public void testRootEndpoints() {
128 UnivariateFunction f = new XMinus5Function();
129 UnivariateSolver solver = getSolver();
130
131
132
133 double result = solver.solve(100, f, 5.0, 6.0);
134 assertEquals(5.0, result, 0.0);
135
136 result = solver.solve(100, f, 4.0, 5.0);
137 assertEquals(5.0, result, 0.0);
138
139 result = solver.solve(100, f, 5.0, 6.0, 5.5);
140 assertEquals(5.0, result, 0.0);
141
142 result = solver.solve(100, f, 4.0, 5.0, 4.5);
143 assertEquals(5.0, result, 0.0);
144 }
145
146 @Test
147 public void testCloseEndpoints() {
148 UnivariateFunction f = new XMinus5Function();
149 UnivariateSolver solver = getSolver();
150
151 double result = solver.solve(100, f, 5.0, FastMath.nextUp(5.0));
152 assertEquals(5.0, result, 0.0);
153
154 result = solver.solve(100, f, FastMath.nextDown(5.0), 5.0);
155 assertEquals(5.0, result, 0.0);
156 }
157
158 @Test
159 public void testBadEndpoints() {
160 UnivariateFunction f = new Sin();
161 UnivariateSolver solver = getSolver();
162 try {
163 solver.solve(100, f, 1, -1);
164 fail("Expecting MathIllegalArgumentException - bad interval");
165 } catch (MathIllegalArgumentException ex) {
166
167 }
168 try {
169 solver.solve(100, f, 1, 1.5);
170 fail("Expecting MathIllegalArgumentException - non-bracketing");
171 } catch (MathIllegalArgumentException ex) {
172
173 }
174 try {
175 solver.solve(100, f, 1, 1.5, 1.2);
176 fail("Expecting MathIllegalArgumentException - non-bracketing");
177 } catch (MathIllegalArgumentException ex) {
178
179 }
180 }
181
182 @Test
183 public void testSolutionLeftSide() {
184 UnivariateFunction f = new Sin();
185 UnivariateSolver solver = getSolver();
186 double left = -1.5;
187 double right = 0.05;
188 for(int i = 0; i < 10; i++) {
189
190 double solution = getSolution(solver, 100, f, left, right, AllowedSolution.LEFT_SIDE);
191 if (!Double.isNaN(solution)) {
192 assertTrue(solution <= 0.0);
193 }
194
195
196 left -= 0.1;
197 right += 0.3;
198 }
199 }
200
201 @Test
202 public void testSolutionRightSide() {
203 UnivariateFunction f = new Sin();
204 UnivariateSolver solver = getSolver();
205 double left = -1.5;
206 double right = 0.05;
207 for(int i = 0; i < 10; i++) {
208
209 double solution = getSolution(solver, 100, f, left, right, AllowedSolution.RIGHT_SIDE);
210 if (!Double.isNaN(solution)) {
211 assertTrue(solution >= 0.0);
212 }
213
214
215 left -= 0.1;
216 right += 0.3;
217 }
218 }
219 @Test
220 public void testSolutionBelowSide() {
221 UnivariateFunction f = new Sin();
222 UnivariateSolver solver = getSolver();
223 double left = -1.5;
224 double right = 0.05;
225 for(int i = 0; i < 10; i++) {
226
227 double solution = getSolution(solver, 100, f, left, right, AllowedSolution.BELOW_SIDE);
228 if (!Double.isNaN(solution)) {
229 assertTrue(f.value(solution) <= 0.0);
230 }
231
232
233 left -= 0.1;
234 right += 0.3;
235 }
236 }
237
238 @Test
239 public void testSolutionAboveSide() {
240 UnivariateFunction f = new Sin();
241 UnivariateSolver solver = getSolver();
242 double left = -1.5;
243 double right = 0.05;
244 for(int i = 0; i < 10; i++) {
245
246 double solution = getSolution(solver, 100, f, left, right, AllowedSolution.ABOVE_SIDE);
247 if (!Double.isNaN(solution)) {
248 assertTrue(f.value(solution) >= 0.0);
249 }
250
251
252 left -= 0.1;
253 right += 0.3;
254 }
255 }
256
257 private double getSolution(UnivariateSolver solver, int maxEval, UnivariateFunction f,
258 double left, double right, AllowedSolution allowedSolution) {
259 try {
260 @SuppressWarnings("unchecked")
261 BracketedUnivariateSolver<UnivariateFunction> bracketing =
262 (BracketedUnivariateSolver<UnivariateFunction>) solver;
263 return bracketing.solve(100, f, left, right, allowedSolution);
264 } catch (ClassCastException cce) {
265 double baseRoot = solver.solve(maxEval, f, left, right);
266 if ((baseRoot <= left) || (baseRoot >= right)) {
267
268 return Double.NaN;
269 }
270 PegasusSolver bracketing =
271 new PegasusSolver(solver.getRelativeAccuracy(), solver.getAbsoluteAccuracy(),
272 solver.getFunctionValueAccuracy());
273 return UnivariateSolverUtils.forceSide(maxEval - solver.getEvaluations(),
274 f, bracketing, baseRoot, left, right,
275 allowedSolution);
276 }
277 }
278
279 protected void checktype(UnivariateSolver solver, BaseSecantSolver.Method expected) {
280 try {
281 Field methodField = BaseSecantSolver.class.getDeclaredField("method");
282 methodField.setAccessible(true);
283 BaseSecantSolver.Method method = (BaseSecantSolver.Method) methodField.get(solver);
284 assertEquals(expected, method);
285 } catch (IllegalAccessException | NoSuchFieldException | SecurityException e) {
286 fail(e.getLocalizedMessage());
287 }
288 }
289
290 }