View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.analysis.function;
24  
25  import org.hipparchus.analysis.UnivariateFunction;
26  import org.hipparchus.analysis.differentiation.DSFactory;
27  import org.hipparchus.analysis.differentiation.DerivativeStructure;
28  import org.hipparchus.exception.MathIllegalArgumentException;
29  import org.hipparchus.exception.NullArgumentException;
30  import org.hipparchus.util.FastMath;
31  import org.junit.jupiter.api.Test;
32  
33  import static org.junit.jupiter.api.Assertions.assertEquals;
34  import static org.junit.jupiter.api.Assertions.assertThrows;
35  
36  /**
37   * Test for class {@link Logistic}.
38   */
39  class LogisticTest {
40      private final double EPS = Math.ulp(1d);
41  
42      @Test
43      void testPreconditions1() {
44          assertThrows(MathIllegalArgumentException.class, () -> {
45              new Logistic(1, 0, 1, 1, 0, -1);
46          });
47      }
48  
49      @Test
50      void testPreconditions2() {
51          assertThrows(MathIllegalArgumentException.class, () -> {
52              new Logistic(1, 0, 1, 1, 0, 0);
53          });
54      }
55  
56      @Test
57      void testCompareSigmoid() {
58          final UnivariateFunction sig = new Sigmoid();
59          final UnivariateFunction sigL = new Logistic(1, 0, 1, 1, 0, 1);
60  
61          final double min = -2;
62          final double max = 2;
63          final int n = 100;
64          final double delta = (max - min) / n;
65          for (int i = 0; i < n; i++) {
66              final double x = min + i * delta;
67              assertEquals(sig.value(x), sigL.value(x), EPS, "x=" + x);
68          }
69      }
70  
71      @Test
72      void testSomeValues() {
73          final double k = 4;
74          final double m = 5;
75          final double b = 2;
76          final double q = 3;
77          final double a = -1;
78          final double n = 2;
79  
80          final UnivariateFunction f = new Logistic(k, m, b, q, a, n);
81  
82          double x;
83          x = m;
84          assertEquals(a + (k - a) / FastMath.sqrt(1 + q), f.value(x), EPS, "x=" + x);
85  
86          x = Double.NEGATIVE_INFINITY;
87          assertEquals(a, f.value(x), EPS, "x=" + x);
88  
89          x = Double.POSITIVE_INFINITY;
90          assertEquals(k, f.value(x), EPS, "x=" + x);
91      }
92  
93      @Test
94      void testCompareDerivativeSigmoid() {
95          final double k = 3;
96          final double a = 2;
97  
98          final Logistic f = new Logistic(k, 0, 1, 1, a, 1);
99          final Sigmoid g = new Sigmoid(a, k);
100 
101         final double min = -10;
102         final double max = 10;
103         final double n = 20;
104         final double delta = (max - min) / n;
105         final DSFactory factory = new DSFactory(1, 5);
106         for (int i = 0; i < n; i++) {
107             final DerivativeStructure x = factory.variable(0, min + i * delta);
108             for (int order = 0; order <= x.getOrder(); ++order) {
109                 assertEquals(g.value(x).getPartialDerivative(order),
110                                     f.value(x).getPartialDerivative(order),
111                                     3.0e-15,
112                                     "x=" + x.getValue());
113             }
114         }
115     }
116 
117     @Test
118     void testParametricUsage1() {
119         assertThrows(NullArgumentException.class, () -> {
120             final Logistic.Parametric g = new Logistic.Parametric();
121             g.value(0, null);
122         });
123     }
124 
125     @Test
126     void testParametricUsage2() {
127         assertThrows(MathIllegalArgumentException.class, () -> {
128             final Logistic.Parametric g = new Logistic.Parametric();
129             g.value(0, new double[]{0});
130         });
131     }
132 
133     @Test
134     void testParametricUsage3() {
135         assertThrows(NullArgumentException.class, () -> {
136             final Logistic.Parametric g = new Logistic.Parametric();
137             g.gradient(0, null);
138         });
139     }
140 
141     @Test
142     void testParametricUsage4() {
143         assertThrows(MathIllegalArgumentException.class, () -> {
144             final Logistic.Parametric g = new Logistic.Parametric();
145             g.gradient(0, new double[]{0});
146         });
147     }
148 
149     @Test
150     void testParametricUsage5() {
151         assertThrows(MathIllegalArgumentException.class, () -> {
152             final Logistic.Parametric g = new Logistic.Parametric();
153             g.value(0, new double[]{1, 0, 1, 1, 0, 0});
154         });
155     }
156 
157     @Test
158     void testParametricUsage6() {
159         assertThrows(MathIllegalArgumentException.class, () -> {
160             final Logistic.Parametric g = new Logistic.Parametric();
161             g.gradient(0, new double[]{1, 0, 1, 1, 0, 0});
162         });
163     }
164 
165     @Test
166     void testGradientComponent0Component4() {
167         final double k = 3;
168         final double a = 2;
169 
170         final Logistic.Parametric f = new Logistic.Parametric();
171         // Compare using the "Sigmoid" function.
172         final Sigmoid.Parametric g = new Sigmoid.Parametric();
173 
174         final double x = 0.12345;
175         final double[] gf = f.gradient(x, new double[] {k, 0, 1, 1, a, 1});
176         final double[] gg = g.gradient(x, new double[] {a, k});
177 
178         assertEquals(gg[0], gf[4], EPS);
179         assertEquals(gg[1], gf[0], EPS);
180     }
181 
182     @Test
183     void testGradientComponent5() {
184         final double m = 1.2;
185         final double k = 3.4;
186         final double a = 2.3;
187         final double q = 0.567;
188         final double b = -FastMath.log(q);
189         final double n = 3.4;
190 
191         final Logistic.Parametric f = new Logistic.Parametric();
192 
193         final double x = m - 1;
194         final double qExp1 = 2;
195 
196         final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
197 
198         assertEquals((k - a) * FastMath.log(qExp1) / (n * n * FastMath.pow(qExp1, 1 / n)),
199                             gf[5], EPS);
200     }
201 
202     @Test
203     void testGradientComponent1Component2Component3() {
204         final double m = 1.2;
205         final double k = 3.4;
206         final double a = 2.3;
207         final double b = 0.567;
208         final double q = 1 / FastMath.exp(b * m);
209         final double n = 3.4;
210 
211         final Logistic.Parametric f = new Logistic.Parametric();
212 
213         final double x = 0;
214         final double qExp1 = 2;
215 
216         final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
217 
218         final double factor = (a - k) / (n * FastMath.pow(qExp1, 1 / n + 1));
219         assertEquals(factor * b, gf[1], EPS);
220         assertEquals(factor * m, gf[2], EPS);
221         assertEquals(factor / q, gf[3], EPS);
222     }
223 }