1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.analysis.function;
24
25 import org.hipparchus.analysis.FunctionUtils;
26 import org.hipparchus.analysis.UnivariateFunction;
27 import org.hipparchus.analysis.differentiation.DSFactory;
28 import org.hipparchus.analysis.differentiation.DerivativeStructure;
29 import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
30 import org.hipparchus.exception.MathIllegalArgumentException;
31 import org.hipparchus.exception.NullArgumentException;
32 import org.hipparchus.random.RandomGenerator;
33 import org.hipparchus.random.Well1024a;
34 import org.hipparchus.util.FastMath;
35 import org.junit.jupiter.api.Test;
36
37 import static org.junit.jupiter.api.Assertions.assertEquals;
38 import static org.junit.jupiter.api.Assertions.assertThrows;
39 import static org.junit.jupiter.api.Assertions.assertTrue;
40 import static org.junit.jupiter.api.Assertions.fail;
41
42
43
44
45 class LogitTest {
46 private final double EPS = Math.ulp(1d);
47
48 @Test
49 void testPreconditions1() {
50 assertThrows(MathIllegalArgumentException.class, () -> {
51 final double lo = -1;
52 final double hi = 2;
53 final UnivariateFunction f = new Logit(lo, hi);
54
55 f.value(lo - 1);
56 });
57 }
58
59 @Test
60 void testPreconditions2() {
61 assertThrows(MathIllegalArgumentException.class, () -> {
62 final double lo = -1;
63 final double hi = 2;
64 final UnivariateFunction f = new Logit(lo, hi);
65
66 f.value(hi + 1);
67 });
68 }
69
70 @Test
71 void testSomeValues() {
72 final double lo = 1;
73 final double hi = 2;
74 final UnivariateFunction f = new Logit(lo, hi);
75
76 assertEquals(Double.NEGATIVE_INFINITY, f.value(1), EPS);
77 assertEquals(Double.POSITIVE_INFINITY, f.value(2), EPS);
78 assertEquals(0, f.value(1.5), EPS);
79 }
80
81 @Test
82 void testDerivative() {
83 final double lo = 1;
84 final double hi = 2;
85 final Logit f = new Logit(lo, hi);
86 final DerivativeStructure f15 = f.value(new DSFactory(1, 1).variable(0, 1.5));
87
88 assertEquals(4, f15.getPartialDerivative(1), EPS);
89 }
90
91 @Test
92 void testDerivativeLargeArguments() {
93 final Logit f = new Logit(1, 2);
94
95 DSFactory factory = new DSFactory(1, 1);
96 for (double arg : new double[] {
97 Double.NEGATIVE_INFINITY, -Double.MAX_VALUE, -1e155, 1e155, Double.MAX_VALUE, Double.POSITIVE_INFINITY
98 }) {
99 try {
100 f.value(factory.variable(0, arg));
101 fail("an exception should have been thrown");
102 } catch (MathIllegalArgumentException ore) {
103
104 } catch (Exception e) {
105 fail("wrong exception caught: " + e.getMessage());
106 }
107 }
108 }
109
110 @Test
111 void testDerivativesHighOrder() {
112 DerivativeStructure l = new Logit(1, 3).value(new DSFactory(1, 5).variable(0, 1.2));
113 assertEquals(-2.1972245773362193828, l.getPartialDerivative(0), 1.0e-16);
114 assertEquals(5.5555555555555555555, l.getPartialDerivative(1), 9.0e-16);
115 assertEquals(-24.691358024691358025, l.getPartialDerivative(2), 2.0e-14);
116 assertEquals(250.34293552812071331, l.getPartialDerivative(3), 2.0e-13);
117 assertEquals(-3749.4284407864654778, l.getPartialDerivative(4), 4.0e-12);
118 assertEquals(75001.270131585632282, l.getPartialDerivative(5), 8.0e-11);
119 }
120
121 @Test
122 void testParametricUsage1() {
123 assertThrows(NullArgumentException.class, () -> {
124 final Logit.Parametric g = new Logit.Parametric();
125 g.value(0, null);
126 });
127 }
128
129 @Test
130 void testParametricUsage2() {
131 assertThrows(MathIllegalArgumentException.class, () -> {
132 final Logit.Parametric g = new Logit.Parametric();
133 g.value(0, new double[]{0});
134 });
135 }
136
137 @Test
138 void testParametricUsage3() {
139 assertThrows(NullArgumentException.class, () -> {
140 final Logit.Parametric g = new Logit.Parametric();
141 g.gradient(0, null);
142 });
143 }
144
145 @Test
146 void testParametricUsage4() {
147 assertThrows(MathIllegalArgumentException.class, () -> {
148 final Logit.Parametric g = new Logit.Parametric();
149 g.gradient(0, new double[]{0});
150 });
151 }
152
153 @Test
154 void testParametricUsage5() {
155 assertThrows(MathIllegalArgumentException.class, () -> {
156 final Logit.Parametric g = new Logit.Parametric();
157 g.value(-1, new double[]{0, 1});
158 });
159 }
160
161 @Test
162 void testParametricUsage6() {
163 assertThrows(MathIllegalArgumentException.class, () -> {
164 final Logit.Parametric g = new Logit.Parametric();
165 g.value(2, new double[]{0, 1});
166 });
167 }
168
169 @Test
170 void testParametricValue() {
171 final double lo = 2;
172 final double hi = 3;
173 final Logit f = new Logit(lo, hi);
174
175 final Logit.Parametric g = new Logit.Parametric();
176 assertEquals(f.value(2), g.value(2, new double[] {lo, hi}), 0);
177 assertEquals(f.value(2.34567), g.value(2.34567, new double[] {lo, hi}), 0);
178 assertEquals(f.value(3), g.value(3, new double[] {lo, hi}), 0);
179 }
180
181 @Test
182 void testValueWithInverseFunction() {
183 final double lo = 2;
184 final double hi = 3;
185 final Logit f = new Logit(lo, hi);
186 final Sigmoid g = new Sigmoid(lo, hi);
187 RandomGenerator random = new Well1024a(0x49914cdd9f0b8db5l);
188 final UnivariateDifferentiableFunction id = FunctionUtils.compose((UnivariateDifferentiableFunction) g,
189 (UnivariateDifferentiableFunction) f);
190
191 DSFactory factory = new DSFactory(1, 1);
192 for (int i = 0; i < 10; i++) {
193 final double x = lo + random.nextDouble() * (hi - lo);
194 assertEquals(x, id.value(factory.variable(0, x)).getValue(), EPS);
195 }
196
197 assertEquals(lo, id.value(factory.variable(0, lo)).getValue(), EPS);
198 assertEquals(hi, id.value(factory.variable(0, hi)).getValue(), EPS);
199 }
200
201 @Test
202 void testDerivativesWithInverseFunction() {
203 double[] epsilon = new double[] { 1.0e-20, 4.0e-16, 3.0e-15, 2.0e-11, 3.0e-9, 1.0e-6 };
204 final double lo = 2;
205 final double hi = 3;
206 final Logit f = new Logit(lo, hi);
207 final Sigmoid g = new Sigmoid(lo, hi);
208 RandomGenerator random = new Well1024a(0x96885e9c1f81cea5l);
209 final UnivariateDifferentiableFunction id =
210 FunctionUtils.compose((UnivariateDifferentiableFunction) g, (UnivariateDifferentiableFunction) f);
211 for (int maxOrder = 0; maxOrder < 6; ++maxOrder) {
212 DSFactory factory = new DSFactory(1, maxOrder);
213 double max = 0;
214 for (int i = 0; i < 10; i++) {
215 final double x = lo + random.nextDouble() * (hi - lo);
216 final DerivativeStructure dsX = factory.variable(0, x);
217 max = FastMath.max(max, FastMath.abs(dsX.getPartialDerivative(maxOrder) -
218 id.value(dsX).getPartialDerivative(maxOrder)));
219 assertEquals(dsX.getPartialDerivative(maxOrder),
220 id.value(dsX).getPartialDerivative(maxOrder),
221 epsilon[maxOrder]);
222 }
223
224
225
226 final DerivativeStructure dsLo = factory.variable(0, lo);
227 if (maxOrder == 0) {
228 assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder)));
229 assertEquals(lo, id.value(dsLo).getPartialDerivative(maxOrder), epsilon[maxOrder]);
230 } else if (maxOrder == 1) {
231 assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder)));
232 assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder)));
233 } else {
234 assertTrue(Double.isNaN(f.value(dsLo).getPartialDerivative(maxOrder)));
235 assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder)));
236 }
237
238 final DerivativeStructure dsHi = factory.variable(0, hi);
239 if (maxOrder == 0) {
240 assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder)));
241 assertEquals(hi, id.value(dsHi).getPartialDerivative(maxOrder), epsilon[maxOrder]);
242 } else if (maxOrder == 1) {
243 assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder)));
244 assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder)));
245 } else {
246 assertTrue(Double.isNaN(f.value(dsHi).getPartialDerivative(maxOrder)));
247 assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder)));
248 }
249
250 }
251 }
252 }