View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.analysis.function;
24  
25  import org.hipparchus.analysis.FunctionUtils;
26  import org.hipparchus.analysis.UnivariateFunction;
27  import org.hipparchus.analysis.differentiation.DSFactory;
28  import org.hipparchus.analysis.differentiation.DerivativeStructure;
29  import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
30  import org.hipparchus.exception.MathIllegalArgumentException;
31  import org.hipparchus.exception.NullArgumentException;
32  import org.hipparchus.random.RandomGenerator;
33  import org.hipparchus.random.Well1024a;
34  import org.hipparchus.util.FastMath;
35  import org.junit.Assert;
36  import org.junit.Test;
37  
38  /**
39   * Test for class {@link Logit}.
40   */
41  public class LogitTest {
42      private final double EPS = Math.ulp(1d);
43  
44      @Test(expected=MathIllegalArgumentException.class)
45      public void testPreconditions1() {
46          final double lo = -1;
47          final double hi = 2;
48          final UnivariateFunction f = new Logit(lo, hi);
49  
50          f.value(lo - 1);
51      }
52  
53      @Test(expected=MathIllegalArgumentException.class)
54      public void testPreconditions2() {
55          final double lo = -1;
56          final double hi = 2;
57          final UnivariateFunction f = new Logit(lo, hi);
58  
59          f.value(hi + 1);
60      }
61  
62      @Test
63      public void testSomeValues() {
64          final double lo = 1;
65          final double hi = 2;
66          final UnivariateFunction f = new Logit(lo, hi);
67  
68          Assert.assertEquals(Double.NEGATIVE_INFINITY, f.value(1), EPS);
69          Assert.assertEquals(Double.POSITIVE_INFINITY, f.value(2), EPS);
70          Assert.assertEquals(0, f.value(1.5), EPS);
71      }
72  
73      @Test
74      public void testDerivative() {
75          final double lo = 1;
76          final double hi = 2;
77          final Logit f = new Logit(lo, hi);
78          final DerivativeStructure f15 = f.value(new DSFactory(1, 1).variable(0, 1.5));
79  
80          Assert.assertEquals(4, f15.getPartialDerivative(1), EPS);
81      }
82  
83      @Test
84      public void testDerivativeLargeArguments() {
85          final Logit f = new Logit(1, 2);
86  
87          DSFactory factory = new DSFactory(1, 1);
88          for (double arg : new double[] {
89              Double.NEGATIVE_INFINITY, -Double.MAX_VALUE, -1e155, 1e155, Double.MAX_VALUE, Double.POSITIVE_INFINITY
90              }) {
91              try {
92                  f.value(factory.variable(0, arg));
93                  Assert.fail("an exception should have been thrown");
94              } catch (MathIllegalArgumentException ore) {
95                  // expected
96              } catch (Exception e) {
97                  Assert.fail("wrong exception caught: " + e.getMessage());
98              }
99          }
100     }
101 
102     @Test
103     public void testDerivativesHighOrder() {
104         DerivativeStructure l = new Logit(1, 3).value(new DSFactory(1, 5).variable(0, 1.2));
105         Assert.assertEquals(-2.1972245773362193828, l.getPartialDerivative(0), 1.0e-16);
106         Assert.assertEquals(5.5555555555555555555,  l.getPartialDerivative(1), 9.0e-16);
107         Assert.assertEquals(-24.691358024691358025, l.getPartialDerivative(2), 2.0e-14);
108         Assert.assertEquals(250.34293552812071331,  l.getPartialDerivative(3), 2.0e-13);
109         Assert.assertEquals(-3749.4284407864654778, l.getPartialDerivative(4), 4.0e-12);
110         Assert.assertEquals(75001.270131585632282,  l.getPartialDerivative(5), 8.0e-11);
111     }
112 
113     @Test(expected=NullArgumentException.class)
114     public void testParametricUsage1() {
115         final Logit.Parametric g = new Logit.Parametric();
116         g.value(0, null);
117     }
118 
119     @Test(expected=MathIllegalArgumentException.class)
120     public void testParametricUsage2() {
121         final Logit.Parametric g = new Logit.Parametric();
122         g.value(0, new double[] {0});
123     }
124 
125     @Test(expected=NullArgumentException.class)
126     public void testParametricUsage3() {
127         final Logit.Parametric g = new Logit.Parametric();
128         g.gradient(0, null);
129     }
130 
131     @Test(expected=MathIllegalArgumentException.class)
132     public void testParametricUsage4() {
133         final Logit.Parametric g = new Logit.Parametric();
134         g.gradient(0, new double[] {0});
135     }
136 
137     @Test(expected=MathIllegalArgumentException.class)
138     public void testParametricUsage5() {
139         final Logit.Parametric g = new Logit.Parametric();
140         g.value(-1, new double[] {0, 1});
141     }
142 
143     @Test(expected=MathIllegalArgumentException.class)
144     public void testParametricUsage6() {
145         final Logit.Parametric g = new Logit.Parametric();
146         g.value(2, new double[] {0, 1});
147     }
148 
149     @Test
150     public void testParametricValue() {
151         final double lo = 2;
152         final double hi = 3;
153         final Logit f = new Logit(lo, hi);
154 
155         final Logit.Parametric g = new Logit.Parametric();
156         Assert.assertEquals(f.value(2), g.value(2, new double[] {lo, hi}), 0);
157         Assert.assertEquals(f.value(2.34567), g.value(2.34567, new double[] {lo, hi}), 0);
158         Assert.assertEquals(f.value(3), g.value(3, new double[] {lo, hi}), 0);
159     }
160 
161     @Test
162     public void testValueWithInverseFunction() {
163         final double lo = 2;
164         final double hi = 3;
165         final Logit f = new Logit(lo, hi);
166         final Sigmoid g = new Sigmoid(lo, hi);
167         RandomGenerator random = new Well1024a(0x49914cdd9f0b8db5l);
168         final UnivariateDifferentiableFunction id = FunctionUtils.compose((UnivariateDifferentiableFunction) g,
169                                                                 (UnivariateDifferentiableFunction) f);
170 
171         DSFactory factory = new DSFactory(1, 1);
172         for (int i = 0; i < 10; i++) {
173             final double x = lo + random.nextDouble() * (hi - lo);
174             Assert.assertEquals(x, id.value(factory.variable(0, x)).getValue(), EPS);
175         }
176 
177         Assert.assertEquals(lo, id.value(factory.variable(0, lo)).getValue(), EPS);
178         Assert.assertEquals(hi, id.value(factory.variable(0, hi)).getValue(), EPS);
179     }
180 
181     @Test
182     public void testDerivativesWithInverseFunction() {
183         double[] epsilon = new double[] { 1.0e-20, 4.0e-16, 3.0e-15, 2.0e-11, 3.0e-9, 1.0e-6 };
184         final double lo = 2;
185         final double hi = 3;
186         final Logit f = new Logit(lo, hi);
187         final Sigmoid g = new Sigmoid(lo, hi);
188         RandomGenerator random = new Well1024a(0x96885e9c1f81cea5l);
189         final UnivariateDifferentiableFunction id =
190                 FunctionUtils.compose((UnivariateDifferentiableFunction) g, (UnivariateDifferentiableFunction) f);
191         for (int maxOrder = 0; maxOrder < 6; ++maxOrder) {
192             DSFactory factory = new DSFactory(1, maxOrder);
193             double max = 0;
194             for (int i = 0; i < 10; i++) {
195                 final double x = lo + random.nextDouble() * (hi - lo);
196                 final DerivativeStructure dsX = factory.variable(0, x);
197                 max = FastMath.max(max, FastMath.abs(dsX.getPartialDerivative(maxOrder) -
198                                                      id.value(dsX).getPartialDerivative(maxOrder)));
199                 Assert.assertEquals(dsX.getPartialDerivative(maxOrder),
200                                     id.value(dsX).getPartialDerivative(maxOrder),
201                                     epsilon[maxOrder]);
202             }
203 
204             // each function evaluates correctly near boundaries,
205             // but combination leads to NaN as some intermediate point is infinite
206             final DerivativeStructure dsLo = factory.variable(0, lo);
207             if (maxOrder == 0) {
208                 Assert.assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder)));
209                 Assert.assertEquals(lo, id.value(dsLo).getPartialDerivative(maxOrder), epsilon[maxOrder]);
210             } else if (maxOrder == 1) {
211                 Assert.assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder)));
212                 Assert.assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder)));
213             } else {
214                 Assert.assertTrue(Double.isNaN(f.value(dsLo).getPartialDerivative(maxOrder)));
215                 Assert.assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder)));
216             }
217 
218             final DerivativeStructure dsHi = factory.variable(0, hi);
219             if (maxOrder == 0) {
220                 Assert.assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder)));
221                 Assert.assertEquals(hi, id.value(dsHi).getPartialDerivative(maxOrder), epsilon[maxOrder]);
222             } else if (maxOrder == 1) {
223                 Assert.assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder)));
224                 Assert.assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder)));
225             } else {
226                 Assert.assertTrue(Double.isNaN(f.value(dsHi).getPartialDerivative(maxOrder)));
227                 Assert.assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder)));
228             }
229 
230         }
231     }
232 }