1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.analysis.function;
24
25 import org.hipparchus.analysis.UnivariateFunction;
26 import org.hipparchus.analysis.differentiation.DSFactory;
27 import org.hipparchus.analysis.differentiation.DerivativeStructure;
28 import org.hipparchus.exception.MathIllegalArgumentException;
29 import org.hipparchus.exception.NullArgumentException;
30 import org.hipparchus.util.FastMath;
31 import org.junit.jupiter.api.Test;
32
33 import static org.junit.jupiter.api.Assertions.assertEquals;
34 import static org.junit.jupiter.api.Assertions.assertThrows;
35
36
37
38
39 class LogisticTest {
40 private final double EPS = Math.ulp(1d);
41
42 @Test
43 void testPreconditions1() {
44 assertThrows(MathIllegalArgumentException.class, () -> {
45 new Logistic(1, 0, 1, 1, 0, -1);
46 });
47 }
48
49 @Test
50 void testPreconditions2() {
51 assertThrows(MathIllegalArgumentException.class, () -> {
52 new Logistic(1, 0, 1, 1, 0, 0);
53 });
54 }
55
56 @Test
57 void testCompareSigmoid() {
58 final UnivariateFunction sig = new Sigmoid();
59 final UnivariateFunction sigL = new Logistic(1, 0, 1, 1, 0, 1);
60
61 final double min = -2;
62 final double max = 2;
63 final int n = 100;
64 final double delta = (max - min) / n;
65 for (int i = 0; i < n; i++) {
66 final double x = min + i * delta;
67 assertEquals(sig.value(x), sigL.value(x), EPS, "x=" + x);
68 }
69 }
70
71 @Test
72 void testSomeValues() {
73 final double k = 4;
74 final double m = 5;
75 final double b = 2;
76 final double q = 3;
77 final double a = -1;
78 final double n = 2;
79
80 final UnivariateFunction f = new Logistic(k, m, b, q, a, n);
81
82 double x;
83 x = m;
84 assertEquals(a + (k - a) / FastMath.sqrt(1 + q), f.value(x), EPS, "x=" + x);
85
86 x = Double.NEGATIVE_INFINITY;
87 assertEquals(a, f.value(x), EPS, "x=" + x);
88
89 x = Double.POSITIVE_INFINITY;
90 assertEquals(k, f.value(x), EPS, "x=" + x);
91 }
92
93 @Test
94 void testCompareDerivativeSigmoid() {
95 final double k = 3;
96 final double a = 2;
97
98 final Logistic f = new Logistic(k, 0, 1, 1, a, 1);
99 final Sigmoid g = new Sigmoid(a, k);
100
101 final double min = -10;
102 final double max = 10;
103 final double n = 20;
104 final double delta = (max - min) / n;
105 final DSFactory factory = new DSFactory(1, 5);
106 for (int i = 0; i < n; i++) {
107 final DerivativeStructure x = factory.variable(0, min + i * delta);
108 for (int order = 0; order <= x.getOrder(); ++order) {
109 assertEquals(g.value(x).getPartialDerivative(order),
110 f.value(x).getPartialDerivative(order),
111 3.0e-15,
112 "x=" + x.getValue());
113 }
114 }
115 }
116
117 @Test
118 void testParametricUsage1() {
119 assertThrows(NullArgumentException.class, () -> {
120 final Logistic.Parametric g = new Logistic.Parametric();
121 g.value(0, null);
122 });
123 }
124
125 @Test
126 void testParametricUsage2() {
127 assertThrows(MathIllegalArgumentException.class, () -> {
128 final Logistic.Parametric g = new Logistic.Parametric();
129 g.value(0, new double[]{0});
130 });
131 }
132
133 @Test
134 void testParametricUsage3() {
135 assertThrows(NullArgumentException.class, () -> {
136 final Logistic.Parametric g = new Logistic.Parametric();
137 g.gradient(0, null);
138 });
139 }
140
141 @Test
142 void testParametricUsage4() {
143 assertThrows(MathIllegalArgumentException.class, () -> {
144 final Logistic.Parametric g = new Logistic.Parametric();
145 g.gradient(0, new double[]{0});
146 });
147 }
148
149 @Test
150 void testParametricUsage5() {
151 assertThrows(MathIllegalArgumentException.class, () -> {
152 final Logistic.Parametric g = new Logistic.Parametric();
153 g.value(0, new double[]{1, 0, 1, 1, 0, 0});
154 });
155 }
156
157 @Test
158 void testParametricUsage6() {
159 assertThrows(MathIllegalArgumentException.class, () -> {
160 final Logistic.Parametric g = new Logistic.Parametric();
161 g.gradient(0, new double[]{1, 0, 1, 1, 0, 0});
162 });
163 }
164
165 @Test
166 void testGradientComponent0Component4() {
167 final double k = 3;
168 final double a = 2;
169
170 final Logistic.Parametric f = new Logistic.Parametric();
171
172 final Sigmoid.Parametric g = new Sigmoid.Parametric();
173
174 final double x = 0.12345;
175 final double[] gf = f.gradient(x, new double[] {k, 0, 1, 1, a, 1});
176 final double[] gg = g.gradient(x, new double[] {a, k});
177
178 assertEquals(gg[0], gf[4], EPS);
179 assertEquals(gg[1], gf[0], EPS);
180 }
181
182 @Test
183 void testGradientComponent5() {
184 final double m = 1.2;
185 final double k = 3.4;
186 final double a = 2.3;
187 final double q = 0.567;
188 final double b = -FastMath.log(q);
189 final double n = 3.4;
190
191 final Logistic.Parametric f = new Logistic.Parametric();
192
193 final double x = m - 1;
194 final double qExp1 = 2;
195
196 final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
197
198 assertEquals((k - a) * FastMath.log(qExp1) / (n * n * FastMath.pow(qExp1, 1 / n)),
199 gf[5], EPS);
200 }
201
202 @Test
203 void testGradientComponent1Component2Component3() {
204 final double m = 1.2;
205 final double k = 3.4;
206 final double a = 2.3;
207 final double b = 0.567;
208 final double q = 1 / FastMath.exp(b * m);
209 final double n = 3.4;
210
211 final Logistic.Parametric f = new Logistic.Parametric();
212
213 final double x = 0;
214 final double qExp1 = 2;
215
216 final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
217
218 final double factor = (a - k) / (n * FastMath.pow(qExp1, 1 / n + 1));
219 assertEquals(factor * b, gf[1], EPS);
220 assertEquals(factor * m, gf[2], EPS);
221 assertEquals(factor / q, gf[3], EPS);
222 }
223 }