View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.analysis.function;
24  
25  import org.hipparchus.analysis.ParametricUnivariateFunction;
26  import org.hipparchus.analysis.differentiation.Derivative;
27  import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
28  import org.hipparchus.exception.MathIllegalArgumentException;
29  import org.hipparchus.exception.NullArgumentException;
30  import org.hipparchus.util.FastMath;
31  import org.hipparchus.util.MathUtils;
32  
33  /**
34   * <a href="http://en.wikipedia.org/wiki/Logit">
35   *  Logit</a> function.
36   * It is the inverse of the {@link Sigmoid sigmoid} function.
37   *
38   */
39  public class Logit implements UnivariateDifferentiableFunction {
40      /** Lower bound. */
41      private final double lo;
42      /** Higher bound. */
43      private final double hi;
44  
45      /**
46       * Usual logit function, where the lower bound is 0 and the higher
47       * bound is 1.
48       */
49      public Logit() {
50          this(0, 1);
51      }
52  
53      /**
54       * Logit function.
55       *
56       * @param lo Lower bound of the function domain.
57       * @param hi Higher bound of the function domain.
58       */
59      public Logit(double lo,
60                   double hi) {
61          this.lo = lo;
62          this.hi = hi;
63      }
64  
65      /** {@inheritDoc} */
66      @Override
67      public double value(double x)
68          throws MathIllegalArgumentException {
69          return value(x, lo, hi);
70      }
71  
72      /**
73       * Parametric function where the input array contains the parameters of
74       * the logit function, ordered as follows:
75       * <ul>
76       *  <li>Lower bound</li>
77       *  <li>Higher bound</li>
78       * </ul>
79       */
80      public static class Parametric implements ParametricUnivariateFunction {
81  
82          /** Empty constructor.
83           * <p>
84           * This constructor is not strictly necessary, but it prevents spurious
85           * javadoc warnings with JDK 18 and later.
86           * </p>
87           * @since 3.0
88           */
89          public Parametric() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
90              // nothing to do
91          }
92  
93          /**
94           * Computes the value of the logit at {@code x}.
95           *
96           * @param x Value for which the function must be computed.
97           * @param param Values of lower bound and higher bounds.
98           * @return the value of the function.
99           * @throws NullArgumentException if {@code param} is {@code null}.
100          * @throws MathIllegalArgumentException if the size of {@code param} is
101          * not 2.
102          */
103         @Override
104         public double value(double x, double ... param)
105             throws MathIllegalArgumentException, NullArgumentException {
106             validateParameters(param);
107             return Logit.value(x, param[0], param[1]);
108         }
109 
110         /**
111          * Computes the value of the gradient at {@code x}.
112          * The components of the gradient vector are the partial
113          * derivatives of the function with respect to each of the
114          * <em>parameters</em> (lower bound and higher bound).
115          *
116          * @param x Value at which the gradient must be computed.
117          * @param param Values for lower and higher bounds.
118          * @return the gradient vector at {@code x}.
119          * @throws NullArgumentException if {@code param} is {@code null}.
120          * @throws MathIllegalArgumentException if the size of {@code param} is
121          * not 2.
122          */
123         @Override
124         public double[] gradient(double x, double ... param)
125             throws MathIllegalArgumentException, NullArgumentException {
126             validateParameters(param);
127 
128             final double lo = param[0];
129             final double hi = param[1];
130 
131             return new double[] { 1 / (lo - x), 1 / (hi - x) };
132         }
133 
134         /**
135          * Validates parameters to ensure they are appropriate for the evaluation of
136          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
137          * methods.
138          *
139          * @param param Values for lower and higher bounds.
140          * @throws NullArgumentException if {@code param} is {@code null}.
141          * @throws MathIllegalArgumentException if the size of {@code param} is
142          * not 2.
143          */
144         private void validateParameters(double[] param)
145             throws MathIllegalArgumentException, NullArgumentException {
146             MathUtils.checkNotNull(param);
147             MathUtils.checkDimension(param.length, 2);
148         }
149     }
150 
151     /**
152      * @param x Value at which to compute the logit.
153      * @param lo Lower bound.
154      * @param hi Higher bound.
155      * @return the value of the logit function at {@code x}.
156      * @throws MathIllegalArgumentException if {@code x < lo} or {@code x > hi}.
157      */
158     private static double value(double x,
159                                 double lo,
160                                 double hi)
161         throws MathIllegalArgumentException {
162         MathUtils.checkRangeInclusive(x, lo, hi);
163         return FastMath.log((x - lo) / (hi - x));
164     }
165 
166     /** {@inheritDoc}
167      * @exception MathIllegalArgumentException if parameter is outside of function domain
168      */
169     @Override
170     public <T extends Derivative<T>> T value(T t)
171         throws MathIllegalArgumentException {
172         final double x = t.getValue();
173         MathUtils.checkRangeInclusive(x, lo, hi);
174         double[] f = new double[t.getOrder() + 1];
175 
176         // function value
177         f[0] = FastMath.log((x - lo) / (hi - x));
178 
179         if (Double.isInfinite(f[0])) {
180 
181             if (f.length > 1) {
182                 f[1] = Double.POSITIVE_INFINITY;
183             }
184             // fill the array with infinities
185             // (for x close to lo the signs will flip between -inf and +inf,
186             //  for x close to hi the signs will always be +inf)
187             // this is probably overkill, since the call to compose at the end
188             // of the method will transform most infinities into NaN ...
189             for (int i = 2; i < f.length; ++i) {
190                 f[i] = f[i - 2];
191             }
192 
193         } else {
194 
195             // function derivatives
196             final double invL = 1.0 / (x - lo);
197             double xL = invL;
198             final double invH = 1.0 / (hi - x);
199             double xH = invH;
200             for (int i = 1; i < f.length; ++i) {
201                 f[i] = xL + xH;
202                 xL  *= -i * invL;
203                 xH  *=  i * invH;
204             }
205         }
206 
207         return t.compose(f);
208     }
209 }