View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.analysis.solvers;
23  
24  import org.hipparchus.analysis.QuinticFunction;
25  import org.hipparchus.analysis.UnivariateFunction;
26  import org.hipparchus.analysis.XMinus5Function;
27  import org.hipparchus.analysis.function.Sin;
28  import org.hipparchus.exception.MathIllegalArgumentException;
29  import org.hipparchus.util.FastMath;
30  import org.junit.jupiter.api.Test;
31  
32  import java.lang.reflect.Field;
33  
34  import static org.junit.jupiter.api.Assertions.assertEquals;
35  import static org.junit.jupiter.api.Assertions.assertTrue;
36  import static org.junit.jupiter.api.Assertions.fail;
37  
38  /**
39   * Base class for root-finding algorithms tests derived from
40   * {@link BaseSecantSolver}.
41   *
42   */
43  public abstract class BaseSecantSolverAbstractTest {
44      /** Returns the solver to use to perform the tests.
45       * @return the solver to use to perform the tests
46       */
47      protected abstract UnivariateSolver getSolver();
48  
49      /** Returns the expected number of evaluations for the
50       * {@link #testQuinticZero} unit test. A value of {@code -1} indicates that
51       * the test should be skipped for that solver.
52       * @return the expected number of evaluations for the
53       * {@link #testQuinticZero} unit test
54       */
55      protected abstract int[] getQuinticEvalCounts();
56  
57      @Test
58      public void testSinZero() {
59          // The sinus function is behaved well around the root at pi. The second
60          // order derivative is zero, which means linear approximating methods
61          // still converge quadratically.
62          UnivariateFunction f = new Sin();
63          double result;
64          UnivariateSolver solver = getSolver();
65  
66          result = solver.solve(100, f, 3, 4);
67          //System.out.println(
68          //    "Root: " + result + " Evaluations: " + solver.getEvaluations());
69          assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy());
70          assertTrue(solver.getEvaluations() <= 6);
71          result = solver.solve(100, f, 1, 4);
72          //System.out.println(
73          //    "Root: " + result + " Evaluations: " + solver.getEvaluations());
74          assertEquals(FastMath.PI, result, solver.getAbsoluteAccuracy());
75          assertTrue(solver.getEvaluations() <= 7);
76      }
77  
78      @Test
79      public void testQuinticZero() {
80          // The quintic function has zeros at 0, +-0.5 and +-1.
81          // Around the root of 0 the function is well behaved, with a second
82          // derivative of zero a 0.
83          // The other roots are less well to find, in particular the root at 1,
84          // because the function grows fast for x>1.
85          // The function has extrema (first derivative is zero) at 0.27195613
86          // and 0.82221643, intervals containing these values are harder for
87          // the solvers.
88          UnivariateFunction f = new QuinticFunction();
89          double result;
90          UnivariateSolver solver = getSolver();
91          double atol = solver.getAbsoluteAccuracy();
92          int[] counts = getQuinticEvalCounts();
93  
94          // Tests data: initial bounds, and expected solution, per test case.
95          double[][] testsData = {{-0.2,  0.2,  0.0},
96                                  {-0.1,  0.3,  0.0},
97                                  {-0.3,  0.45, 0.0},
98                                  { 0.3,  0.7,  0.5},
99                                  { 0.2,  0.6,  0.5},
100                                 { 0.05, 0.95, 0.5},
101                                 { 0.85, 1.25, 1.0},
102                                 { 0.8,  1.2,  1.0},
103                                 { 0.85, 1.75, 1.0},
104                                 { 0.55, 1.45, 1.0},
105                                 { 0.85, 5.0,  1.0},
106                                };
107         int maxIter = 500;
108 
109         for(int i = 0; i < testsData.length; i++) {
110             // Skip test, if needed.
111             if (counts[i] == -1) continue;
112 
113             // Compute solution.
114             double[] testData = testsData[i];
115             result = solver.solve(maxIter, f, testData[0], testData[1]);
116             //System.out.println(
117             //    "Root: " + result + " Evaluations: " + solver.getEvaluations());
118 
119             // Check solution.
120             assertEquals(result, testData[2], atol);
121             assertTrue(solver.getEvaluations() <= counts[i] + 1,
122                     "" + solver.getEvaluations() + " <= " + (counts[i] + 1));
123         }
124     }
125 
126     @Test
127     public void testRootEndpoints() {
128         UnivariateFunction f = new XMinus5Function();
129         UnivariateSolver solver = getSolver();
130 
131         // End-point is root. This should be a special case in the solver, and
132         // the initial end-point should be returned exactly.
133         double result = solver.solve(100, f, 5.0, 6.0);
134         assertEquals(5.0, result, 0.0);
135 
136         result = solver.solve(100, f, 4.0, 5.0);
137         assertEquals(5.0, result, 0.0);
138 
139         result = solver.solve(100, f, 5.0, 6.0, 5.5);
140         assertEquals(5.0, result, 0.0);
141 
142         result = solver.solve(100, f, 4.0, 5.0, 4.5);
143         assertEquals(5.0, result, 0.0);
144     }
145 
146     @Test
147     public void testCloseEndpoints() {
148         UnivariateFunction f = new XMinus5Function();
149         UnivariateSolver solver = getSolver();
150 
151         double result = solver.solve(100, f, 5.0, FastMath.nextUp(5.0));
152         assertEquals(5.0, result, 0.0);
153 
154         result = solver.solve(100, f, FastMath.nextDown(5.0), 5.0);
155         assertEquals(5.0, result, 0.0);
156     }
157 
158     @Test
159     public void testBadEndpoints() {
160         UnivariateFunction f = new Sin();
161         UnivariateSolver solver = getSolver();
162         try {  // bad interval
163             solver.solve(100, f, 1, -1);
164             fail("Expecting MathIllegalArgumentException - bad interval");
165         } catch (MathIllegalArgumentException ex) {
166             // expected
167         }
168         try {  // no bracket
169             solver.solve(100, f, 1, 1.5);
170             fail("Expecting MathIllegalArgumentException - non-bracketing");
171         } catch (MathIllegalArgumentException ex) {
172             // expected
173         }
174         try {  // no bracket
175             solver.solve(100, f, 1, 1.5, 1.2);
176             fail("Expecting MathIllegalArgumentException - non-bracketing");
177         } catch (MathIllegalArgumentException ex) {
178             // expected
179         }
180     }
181 
182     @Test
183     public void testSolutionLeftSide() {
184         UnivariateFunction f = new Sin();
185         UnivariateSolver solver = getSolver();
186         double left = -1.5;
187         double right = 0.05;
188         for(int i = 0; i < 10; i++) {
189             // Test whether the allowed solutions are taken into account.
190             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.LEFT_SIDE);
191             if (!Double.isNaN(solution)) {
192                 assertTrue(solution <= 0.0);
193             }
194 
195             // Prepare for next test.
196             left -= 0.1;
197             right += 0.3;
198         }
199     }
200 
201     @Test
202     public void testSolutionRightSide() {
203         UnivariateFunction f = new Sin();
204         UnivariateSolver solver = getSolver();
205         double left = -1.5;
206         double right = 0.05;
207         for(int i = 0; i < 10; i++) {
208             // Test whether the allowed solutions are taken into account.
209             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.RIGHT_SIDE);
210             if (!Double.isNaN(solution)) {
211                 assertTrue(solution >= 0.0);
212             }
213 
214             // Prepare for next test.
215             left -= 0.1;
216             right += 0.3;
217         }
218     }
219     @Test
220     public void testSolutionBelowSide() {
221         UnivariateFunction f = new Sin();
222         UnivariateSolver solver = getSolver();
223         double left = -1.5;
224         double right = 0.05;
225         for(int i = 0; i < 10; i++) {
226             // Test whether the allowed solutions are taken into account.
227             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.BELOW_SIDE);
228             if (!Double.isNaN(solution)) {
229                 assertTrue(f.value(solution) <= 0.0);
230             }
231 
232             // Prepare for next test.
233             left -= 0.1;
234             right += 0.3;
235         }
236     }
237 
238     @Test
239     public void testSolutionAboveSide() {
240         UnivariateFunction f = new Sin();
241         UnivariateSolver solver = getSolver();
242         double left = -1.5;
243         double right = 0.05;
244         for(int i = 0; i < 10; i++) {
245             // Test whether the allowed solutions are taken into account.
246             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.ABOVE_SIDE);
247             if (!Double.isNaN(solution)) {
248                 assertTrue(f.value(solution) >= 0.0);
249             }
250 
251             // Prepare for next test.
252             left -= 0.1;
253             right += 0.3;
254         }
255     }
256 
257     private double getSolution(UnivariateSolver solver, int maxEval, UnivariateFunction f,
258                                double left, double right, AllowedSolution allowedSolution) {
259         try {
260             @SuppressWarnings("unchecked")
261             BracketedUnivariateSolver<UnivariateFunction> bracketing =
262             (BracketedUnivariateSolver<UnivariateFunction>) solver;
263             return bracketing.solve(100, f, left, right, allowedSolution);
264         } catch (ClassCastException cce) {
265             double baseRoot = solver.solve(maxEval, f, left, right);
266             if ((baseRoot <= left) || (baseRoot >= right)) {
267                 // the solution slipped out of interval
268                 return Double.NaN;
269             }
270             PegasusSolver bracketing =
271                     new PegasusSolver(solver.getRelativeAccuracy(), solver.getAbsoluteAccuracy(),
272                                       solver.getFunctionValueAccuracy());
273             return UnivariateSolverUtils.forceSide(maxEval - solver.getEvaluations(),
274                                                        f, bracketing, baseRoot, left, right,
275                                                        allowedSolution);
276         }
277     }
278 
279     protected void checktype(UnivariateSolver solver, BaseSecantSolver.Method expected) {
280         try {
281             Field methodField = BaseSecantSolver.class.getDeclaredField("method");
282             methodField.setAccessible(true);
283             BaseSecantSolver.Method method = (BaseSecantSolver.Method) methodField.get(solver);
284             assertEquals(expected, method);
285         } catch (IllegalAccessException | NoSuchFieldException | SecurityException e) {
286             fail(e.getLocalizedMessage());
287         }
288     }
289 
290 }