View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.analysis.function;
24  
25  import org.hipparchus.analysis.FunctionUtils;
26  import org.hipparchus.analysis.UnivariateFunction;
27  import org.hipparchus.analysis.differentiation.DSFactory;
28  import org.hipparchus.analysis.differentiation.DerivativeStructure;
29  import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
30  import org.hipparchus.exception.MathIllegalArgumentException;
31  import org.hipparchus.exception.NullArgumentException;
32  import org.hipparchus.random.RandomGenerator;
33  import org.hipparchus.random.Well1024a;
34  import org.hipparchus.util.FastMath;
35  import org.junit.jupiter.api.Test;
36  
37  import static org.junit.jupiter.api.Assertions.assertEquals;
38  import static org.junit.jupiter.api.Assertions.assertThrows;
39  import static org.junit.jupiter.api.Assertions.assertTrue;
40  import static org.junit.jupiter.api.Assertions.fail;
41  
42  /**
43   * Test for class {@link Logit}.
44   */
45  class LogitTest {
46      private final double EPS = Math.ulp(1d);
47  
48      @Test
49      void testPreconditions1() {
50          assertThrows(MathIllegalArgumentException.class, () -> {
51              final double lo = -1;
52              final double hi = 2;
53              final UnivariateFunction f = new Logit(lo, hi);
54  
55              f.value(lo - 1);
56          });
57      }
58  
59      @Test
60      void testPreconditions2() {
61          assertThrows(MathIllegalArgumentException.class, () -> {
62              final double lo = -1;
63              final double hi = 2;
64              final UnivariateFunction f = new Logit(lo, hi);
65  
66              f.value(hi + 1);
67          });
68      }
69  
70      @Test
71      void testSomeValues() {
72          final double lo = 1;
73          final double hi = 2;
74          final UnivariateFunction f = new Logit(lo, hi);
75  
76          assertEquals(Double.NEGATIVE_INFINITY, f.value(1), EPS);
77          assertEquals(Double.POSITIVE_INFINITY, f.value(2), EPS);
78          assertEquals(0, f.value(1.5), EPS);
79      }
80  
81      @Test
82      void testDerivative() {
83          final double lo = 1;
84          final double hi = 2;
85          final Logit f = new Logit(lo, hi);
86          final DerivativeStructure f15 = f.value(new DSFactory(1, 1).variable(0, 1.5));
87  
88          assertEquals(4, f15.getPartialDerivative(1), EPS);
89      }
90  
91      @Test
92      void testDerivativeLargeArguments() {
93          final Logit f = new Logit(1, 2);
94  
95          DSFactory factory = new DSFactory(1, 1);
96          for (double arg : new double[] {
97              Double.NEGATIVE_INFINITY, -Double.MAX_VALUE, -1e155, 1e155, Double.MAX_VALUE, Double.POSITIVE_INFINITY
98              }) {
99              try {
100                 f.value(factory.variable(0, arg));
101                 fail("an exception should have been thrown");
102             } catch (MathIllegalArgumentException ore) {
103                 // expected
104             } catch (Exception e) {
105                 fail("wrong exception caught: " + e.getMessage());
106             }
107         }
108     }
109 
110     @Test
111     void testDerivativesHighOrder() {
112         DerivativeStructure l = new Logit(1, 3).value(new DSFactory(1, 5).variable(0, 1.2));
113         assertEquals(-2.1972245773362193828, l.getPartialDerivative(0), 1.0e-16);
114         assertEquals(5.5555555555555555555,  l.getPartialDerivative(1), 9.0e-16);
115         assertEquals(-24.691358024691358025, l.getPartialDerivative(2), 2.0e-14);
116         assertEquals(250.34293552812071331,  l.getPartialDerivative(3), 2.0e-13);
117         assertEquals(-3749.4284407864654778, l.getPartialDerivative(4), 4.0e-12);
118         assertEquals(75001.270131585632282,  l.getPartialDerivative(5), 8.0e-11);
119     }
120 
121     @Test
122     void testParametricUsage1() {
123         assertThrows(NullArgumentException.class, () -> {
124             final Logit.Parametric g = new Logit.Parametric();
125             g.value(0, null);
126         });
127     }
128 
129     @Test
130     void testParametricUsage2() {
131         assertThrows(MathIllegalArgumentException.class, () -> {
132             final Logit.Parametric g = new Logit.Parametric();
133             g.value(0, new double[]{0});
134         });
135     }
136 
137     @Test
138     void testParametricUsage3() {
139         assertThrows(NullArgumentException.class, () -> {
140             final Logit.Parametric g = new Logit.Parametric();
141             g.gradient(0, null);
142         });
143     }
144 
145     @Test
146     void testParametricUsage4() {
147         assertThrows(MathIllegalArgumentException.class, () -> {
148             final Logit.Parametric g = new Logit.Parametric();
149             g.gradient(0, new double[]{0});
150         });
151     }
152 
153     @Test
154     void testParametricUsage5() {
155         assertThrows(MathIllegalArgumentException.class, () -> {
156             final Logit.Parametric g = new Logit.Parametric();
157             g.value(-1, new double[]{0, 1});
158         });
159     }
160 
161     @Test
162     void testParametricUsage6() {
163         assertThrows(MathIllegalArgumentException.class, () -> {
164             final Logit.Parametric g = new Logit.Parametric();
165             g.value(2, new double[]{0, 1});
166         });
167     }
168 
169     @Test
170     void testParametricValue() {
171         final double lo = 2;
172         final double hi = 3;
173         final Logit f = new Logit(lo, hi);
174 
175         final Logit.Parametric g = new Logit.Parametric();
176         assertEquals(f.value(2), g.value(2, new double[] {lo, hi}), 0);
177         assertEquals(f.value(2.34567), g.value(2.34567, new double[] {lo, hi}), 0);
178         assertEquals(f.value(3), g.value(3, new double[] {lo, hi}), 0);
179     }
180 
181     @Test
182     void testValueWithInverseFunction() {
183         final double lo = 2;
184         final double hi = 3;
185         final Logit f = new Logit(lo, hi);
186         final Sigmoid g = new Sigmoid(lo, hi);
187         RandomGenerator random = new Well1024a(0x49914cdd9f0b8db5l);
188         final UnivariateDifferentiableFunction id = FunctionUtils.compose((UnivariateDifferentiableFunction) g,
189                                                                 (UnivariateDifferentiableFunction) f);
190 
191         DSFactory factory = new DSFactory(1, 1);
192         for (int i = 0; i < 10; i++) {
193             final double x = lo + random.nextDouble() * (hi - lo);
194             assertEquals(x, id.value(factory.variable(0, x)).getValue(), EPS);
195         }
196 
197         assertEquals(lo, id.value(factory.variable(0, lo)).getValue(), EPS);
198         assertEquals(hi, id.value(factory.variable(0, hi)).getValue(), EPS);
199     }
200 
201     @Test
202     void testDerivativesWithInverseFunction() {
203         double[] epsilon = new double[] { 1.0e-20, 4.0e-16, 3.0e-15, 2.0e-11, 3.0e-9, 1.0e-6 };
204         final double lo = 2;
205         final double hi = 3;
206         final Logit f = new Logit(lo, hi);
207         final Sigmoid g = new Sigmoid(lo, hi);
208         RandomGenerator random = new Well1024a(0x96885e9c1f81cea5l);
209         final UnivariateDifferentiableFunction id =
210                 FunctionUtils.compose((UnivariateDifferentiableFunction) g, (UnivariateDifferentiableFunction) f);
211         for (int maxOrder = 0; maxOrder < 6; ++maxOrder) {
212             DSFactory factory = new DSFactory(1, maxOrder);
213             double max = 0;
214             for (int i = 0; i < 10; i++) {
215                 final double x = lo + random.nextDouble() * (hi - lo);
216                 final DerivativeStructure dsX = factory.variable(0, x);
217                 max = FastMath.max(max, FastMath.abs(dsX.getPartialDerivative(maxOrder) -
218                                                      id.value(dsX).getPartialDerivative(maxOrder)));
219                 assertEquals(dsX.getPartialDerivative(maxOrder),
220                                     id.value(dsX).getPartialDerivative(maxOrder),
221                                     epsilon[maxOrder]);
222             }
223 
224             // each function evaluates correctly near boundaries,
225             // but combination leads to NaN as some intermediate point is infinite
226             final DerivativeStructure dsLo = factory.variable(0, lo);
227             if (maxOrder == 0) {
228                 assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder)));
229                 assertEquals(lo, id.value(dsLo).getPartialDerivative(maxOrder), epsilon[maxOrder]);
230             } else if (maxOrder == 1) {
231                 assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder)));
232                 assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder)));
233             } else {
234                 assertTrue(Double.isNaN(f.value(dsLo).getPartialDerivative(maxOrder)));
235                 assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder)));
236             }
237 
238             final DerivativeStructure dsHi = factory.variable(0, hi);
239             if (maxOrder == 0) {
240                 assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder)));
241                 assertEquals(hi, id.value(dsHi).getPartialDerivative(maxOrder), epsilon[maxOrder]);
242             } else if (maxOrder == 1) {
243                 assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder)));
244                 assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder)));
245             } else {
246                 assertTrue(Double.isNaN(f.value(dsHi).getPartialDerivative(maxOrder)));
247                 assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder)));
248             }
249 
250         }
251     }
252 }