View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.analysis.interpolation;
23  
24  import java.io.Serializable;
25  import java.util.Arrays;
26  
27  import org.hipparchus.analysis.polynomials.PolynomialSplineFunction;
28  import org.hipparchus.exception.LocalizedCoreFormats;
29  import org.hipparchus.exception.MathIllegalArgumentException;
30  import org.hipparchus.util.FastMath;
31  import org.hipparchus.util.MathArrays;
32  import org.hipparchus.util.MathUtils;
33  
34  /**
35   * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression">
36   * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of
37   * real univariate functions.
38   * <p>
39   * For reference, see
40   * <a href="http://amstat.tandfonline.com/doi/abs/10.1080/01621459.1979.10481038">
41   * William S. Cleveland - Robust Locally Weighted Regression and Smoothing
42   * Scatterplots</a></p>
43   * <p>
44   * This class implements both the loess method and serves as an interpolation
45   * adapter to it, allowing one to build a spline on the obtained loess fit.</p>
46   *
47   */
48  public class LoessInterpolator
49      implements UnivariateInterpolator, Serializable {
50      /** Default value of the bandwidth parameter. */
51      public static final double DEFAULT_BANDWIDTH = 0.3;
52      /** Default value of the number of robustness iterations. */
53      public static final int DEFAULT_ROBUSTNESS_ITERS = 2;
54      /**
55       * Default value for accuracy.
56       */
57      public static final double DEFAULT_ACCURACY = 1e-12;
58      /** serializable version identifier. */
59      private static final long serialVersionUID = 5204927143605193821L;
60      /**
61       * The bandwidth parameter: when computing the loess fit at
62       * a particular point, this fraction of source points closest
63       * to the current point is taken into account for computing
64       * a least-squares regression.
65       * <p>
66       * A sensible value is usually 0.25 to 0.5.</p>
67       */
68      private final double bandwidth;
69      /**
70       * The number of robustness iterations parameter: this many
71       * robustness iterations are done.
72       * <p>
73       * A sensible value is usually 0 (just the initial fit without any
74       * robustness iterations) to 4.</p>
75       */
76      private final int robustnessIters;
77      /**
78       * If the median residual at a certain robustness iteration
79       * is less than this amount, no more iterations are done.
80       */
81      private final double accuracy;
82  
83      /**
84       * Constructs a new {@link LoessInterpolator}
85       * with a bandwidth of {@link #DEFAULT_BANDWIDTH},
86       * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations
87       * and an accuracy of {#link #DEFAULT_ACCURACY}.
88       * See {@link #LoessInterpolator(double, int, double)} for an explanation of
89       * the parameters.
90       */
91      public LoessInterpolator() {
92          this.bandwidth = DEFAULT_BANDWIDTH;
93          this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS;
94          this.accuracy = DEFAULT_ACCURACY;
95      }
96  
97      /**
98       * Construct a new {@link LoessInterpolator}
99       * with given bandwidth and number of robustness iterations.
100      * <p>
101      * Calling this constructor is equivalent to calling {link {@link
102      * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth,
103      * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)}
104      * </p>
105      *
106      * @param bandwidth  when computing the loess fit at
107      * a particular point, this fraction of source points closest
108      * to the current point is taken into account for computing
109      * a least-squares regression.
110      * A sensible value is usually 0.25 to 0.5, the default value is
111      * {@link #DEFAULT_BANDWIDTH}.
112      * @param robustnessIters This many robustness iterations are done.
113      * A sensible value is usually 0 (just the initial fit without any
114      * robustness iterations) to 4, the default value is
115      * {@link #DEFAULT_ROBUSTNESS_ITERS}.
116 
117      * @see #LoessInterpolator(double, int, double)
118      */
119     public LoessInterpolator(double bandwidth, int robustnessIters) {
120         this(bandwidth, robustnessIters, DEFAULT_ACCURACY);
121     }
122 
123     /**
124      * Construct a new {@link LoessInterpolator}
125      * with given bandwidth, number of robustness iterations and accuracy.
126      *
127      * @param bandwidth  when computing the loess fit at
128      * a particular point, this fraction of source points closest
129      * to the current point is taken into account for computing
130      * a least-squares regression.
131      * A sensible value is usually 0.25 to 0.5, the default value is
132      * {@link #DEFAULT_BANDWIDTH}.
133      * @param robustnessIters This many robustness iterations are done.
134      * A sensible value is usually 0 (just the initial fit without any
135      * robustness iterations) to 4, the default value is
136      * {@link #DEFAULT_ROBUSTNESS_ITERS}.
137      * @param accuracy If the median residual at a certain robustness iteration
138      * is less than this amount, no more iterations are done.
139      * @throws MathIllegalArgumentException if bandwidth does not lie in the interval [0,1].
140      * @throws MathIllegalArgumentException if {@code robustnessIters} is negative.
141      * @see #LoessInterpolator(double, int)
142      */
143     public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy)
144         throws MathIllegalArgumentException {
145         if (bandwidth < 0 ||
146             bandwidth > 1) {
147             throw new MathIllegalArgumentException(LocalizedCoreFormats.BANDWIDTH, bandwidth, 0, 1);
148         }
149         this.bandwidth = bandwidth;
150         if (robustnessIters < 0) {
151             throw new MathIllegalArgumentException(LocalizedCoreFormats.ROBUSTNESS_ITERATIONS, robustnessIters);
152         }
153         this.robustnessIters = robustnessIters;
154         this.accuracy = accuracy;
155     }
156 
157     /**
158      * Compute an interpolating function by performing a loess fit
159      * on the data at the original abscissae and then building a cubic spline
160      * with a
161      * {@link org.hipparchus.analysis.interpolation.SplineInterpolator}
162      * on the resulting fit.
163      *
164      * @param xval the arguments for the interpolation points
165      * @param yval the values for the interpolation points
166      * @return A cubic spline built upon a loess fit to the data at the original abscissae
167      * @throws MathIllegalArgumentException if {@code xval} not sorted in
168      * strictly increasing order.
169      * @throws MathIllegalArgumentException if {@code xval} and {@code yval} have
170      * different sizes.
171      * @throws MathIllegalArgumentException if {@code xval} or {@code yval} has zero size.
172      * @throws MathIllegalArgumentException if any of the arguments and values are
173      * not finite real numbers.
174      * @throws MathIllegalArgumentException if the bandwidth is too small to
175      * accomodate the size of the input data (i.e. the bandwidth must be
176      * larger than 2/n).
177      */
178     @Override
179     public final PolynomialSplineFunction interpolate(final double[] xval,
180                                                       final double[] yval)
181         throws MathIllegalArgumentException {
182         return new SplineInterpolator().interpolate(xval, smooth(xval, yval));
183     }
184 
185     /**
186      * Compute a weighted loess fit on the data at the original abscissae.
187      *
188      * @param xval Arguments for the interpolation points.
189      * @param yval Values for the interpolation points.
190      * @param weights point weights: coefficients by which the robustness weight
191      * of a point is multiplied.
192      * @return the values of the loess fit at corresponding original abscissae.
193      * @throws MathIllegalArgumentException if {@code xval} not sorted in
194      * strictly increasing order.
195      * @throws MathIllegalArgumentException if {@code xval} and {@code yval} have
196      * different sizes.
197      * @throws MathIllegalArgumentException if {@code xval} or {@code yval} has zero size.
198      * @throws MathIllegalArgumentException if any of the arguments and values are
199      not finite real numbers.
200      * @throws MathIllegalArgumentException if the bandwidth is too small to
201      * accomodate the size of the input data (i.e. the bandwidth must be
202      * larger than 2/n).
203      */
204     public final double[] smooth(final double[] xval, final double[] yval,
205                                  final double[] weights)
206         throws MathIllegalArgumentException {
207         if (xval.length != yval.length) {
208             throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
209                                                    xval.length, yval.length);
210         }
211 
212         final int n = xval.length;
213 
214         if (n == 0) {
215             throw new MathIllegalArgumentException(LocalizedCoreFormats.NO_DATA);
216         }
217 
218         checkAllFiniteReal(xval);
219         checkAllFiniteReal(yval);
220         checkAllFiniteReal(weights);
221 
222         MathArrays.checkOrder(xval);
223 
224         if (n == 1) {
225             return new double[]{yval[0]};
226         }
227 
228         if (n == 2) {
229             return new double[]{yval[0], yval[1]};
230         }
231 
232         int bandwidthInPoints = (int) (bandwidth * n);
233 
234         if (bandwidthInPoints < 2) {
235             throw new MathIllegalArgumentException(LocalizedCoreFormats.BANDWIDTH,
236                                                 bandwidthInPoints, 2, true);
237         }
238 
239         final double[] res = new double[n];
240 
241         final double[] residuals = new double[n];
242         final double[] sortedResiduals = new double[n];
243 
244         final double[] robustnessWeights = new double[n];
245 
246         // Do an initial fit and 'robustnessIters' robustness iterations.
247         // This is equivalent to doing 'robustnessIters+1' robustness iterations
248         // starting with all robustness weights set to 1.
249         Arrays.fill(robustnessWeights, 1);
250 
251         for (int iter = 0; iter <= robustnessIters; ++iter) {
252             final int[] bandwidthInterval = {0, bandwidthInPoints - 1};
253             // At each x, compute a local weighted linear regression
254             for (int i = 0; i < n; ++i) {
255                 final double x = xval[i];
256 
257                 // Find out the interval of source points on which
258                 // a regression is to be made.
259                 if (i > 0) {
260                     updateBandwidthInterval(xval, weights, i, bandwidthInterval);
261                 }
262 
263                 final int ileft = bandwidthInterval[0];
264                 final int iright = bandwidthInterval[1];
265 
266                 // Compute the point of the bandwidth interval that is
267                 // farthest from x
268                 final int edge;
269                 if (xval[i] - xval[ileft] > xval[iright] - xval[i]) {
270                     edge = ileft;
271                 } else {
272                     edge = iright;
273                 }
274 
275                 // Compute a least-squares linear fit weighted by
276                 // the product of robustness weights and the tricube
277                 // weight function.
278                 // See http://en.wikipedia.org/wiki/Linear_regression
279                 // (section "Univariate linear case")
280                 // and http://en.wikipedia.org/wiki/Weighted_least_squares
281                 // (section "Weighted least squares")
282                 double sumWeights = 0;
283                 double sumX = 0;
284                 double sumXSquared = 0;
285                 double sumY = 0;
286                 double sumXY = 0;
287                 double denom = FastMath.abs(1.0 / (xval[edge] - x));
288                 for (int k = ileft; k <= iright; ++k) {
289                     final double xk   = xval[k];
290                     final double yk   = yval[k];
291                     final double dist = (k < i) ? x - xk : xk - x;
292                     final double w    = tricube(dist * denom) * robustnessWeights[k] * weights[k];
293                     final double xkw  = xk * w;
294                     sumWeights += w;
295                     sumX += xkw;
296                     sumXSquared += xk * xkw;
297                     sumY += yk * w;
298                     sumXY += yk * xkw;
299                 }
300 
301                 final double meanX = sumX / sumWeights;
302                 final double meanY = sumY / sumWeights;
303                 final double meanXY = sumXY / sumWeights;
304                 final double meanXSquared = sumXSquared / sumWeights;
305 
306                 final double beta;
307                 if (FastMath.sqrt(FastMath.abs(meanXSquared - meanX * meanX)) < accuracy) {
308                     beta = 0;
309                 } else {
310                     beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX);
311                 }
312 
313                 final double alpha = meanY - beta * meanX;
314 
315                 res[i] = beta * x + alpha;
316                 residuals[i] = FastMath.abs(yval[i] - res[i]);
317             }
318 
319             // No need to recompute the robustness weights at the last
320             // iteration, they won't be needed anymore
321             if (iter == robustnessIters) {
322                 break;
323             }
324 
325             // Recompute the robustness weights.
326 
327             // Find the median residual.
328             // An arraycopy and a sort are completely tractable here,
329             // because the preceding loop is a lot more expensive
330             System.arraycopy(residuals, 0, sortedResiduals, 0, n);
331             Arrays.sort(sortedResiduals);
332             final double medianResidual = sortedResiduals[n / 2];
333 
334             if (FastMath.abs(medianResidual) < accuracy) {
335                 break;
336             }
337 
338             for (int i = 0; i < n; ++i) {
339                 final double arg = residuals[i] / (6 * medianResidual);
340                 if (arg >= 1) {
341                     robustnessWeights[i] = 0;
342                 } else {
343                     final double w = 1 - arg * arg;
344                     robustnessWeights[i] = w * w;
345                 }
346             }
347         }
348 
349         return res;
350     }
351 
352     /**
353      * Compute a loess fit on the data at the original abscissae.
354      *
355      * @param xval the arguments for the interpolation points
356      * @param yval the values for the interpolation points
357      * @return values of the loess fit at corresponding original abscissae
358      * @throws MathIllegalArgumentException if {@code xval} not sorted in
359      * strictly increasing order.
360      * @throws MathIllegalArgumentException if {@code xval} and {@code yval} have
361      * different sizes.
362      * @throws MathIllegalArgumentException if {@code xval} or {@code yval} has zero size.
363      * @throws MathIllegalArgumentException if any of the arguments and values are
364      * not finite real numbers.
365      * @throws MathIllegalArgumentException if the bandwidth is too small to
366      * accomodate the size of the input data (i.e. the bandwidth must be
367      * larger than 2/n).
368      */
369     public final double[] smooth(final double[] xval, final double[] yval)
370         throws MathIllegalArgumentException {
371         if (xval.length != yval.length) {
372             throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
373                                                    xval.length, yval.length);
374         }
375 
376         final double[] unitWeights = new double[xval.length];
377         Arrays.fill(unitWeights, 1.0);
378 
379         return smooth(xval, yval, unitWeights);
380     }
381 
382     /**
383      * Given an index interval into xval that embraces a certain number of
384      * points closest to {@code xval[i-1]}, update the interval so that it
385      * embraces the same number of points closest to {@code xval[i]},
386      * ignoring zero weights.
387      *
388      * @param xval Arguments array.
389      * @param weights Weights array.
390      * @param i Index around which the new interval should be computed.
391      * @param bandwidthInterval a two-element array {left, right} such that:
392      * {@code (left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])}
393      * and
394      * {@code (right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])}.
395      * The array will be updated.
396      */
397     private static void updateBandwidthInterval(final double[] xval, final double[] weights,
398                                                 final int i,
399                                                 final int[] bandwidthInterval) {
400         final int left = bandwidthInterval[0];
401         final int right = bandwidthInterval[1];
402 
403         // The right edge should be adjusted if the next point to the right
404         // is closer to xval[i] than the leftmost point of the current interval
405         int nextRight = nextNonzero(weights, right);
406         if (nextRight < xval.length && xval[nextRight] - xval[i] < xval[i] - xval[left]) {
407             int nextLeft = nextNonzero(weights, bandwidthInterval[0]);
408             bandwidthInterval[0] = nextLeft;
409             bandwidthInterval[1] = nextRight;
410         }
411     }
412 
413     /**
414      * Return the smallest index {@code j} such that
415      * {@code j > i && (j == weights.length || weights[j] != 0)}.
416      *
417      * @param weights Weights array.
418      * @param i Index from which to start search.
419      * @return the smallest compliant index.
420      */
421     private static int nextNonzero(final double[] weights, final int i) {
422         int j = i + 1;
423         while(j < weights.length && weights[j] == 0) {
424             ++j;
425         }
426         return j;
427     }
428 
429     /**
430      * Compute the
431      * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a>
432      * weight function
433      *
434      * @param x Argument.
435      * @return <code>(1 - |x|<sup>3</sup>)<sup>3</sup></code> for |x| &lt; 1, 0 otherwise.
436      */
437     private static double tricube(final double x) {
438         final double absX = FastMath.abs(x);
439         if (absX >= 1.0) {
440             return 0.0;
441         }
442         final double tmp = 1 - absX * absX * absX;
443         return tmp * tmp * tmp;
444     }
445 
446     /**
447      * Check that all elements of an array are finite real numbers.
448      *
449      * @param values Values array.
450      * @throws org.hipparchus.exception.MathIllegalArgumentException
451      * if one of the values is not a finite real number.
452      */
453     private static void checkAllFiniteReal(final double[] values) {
454         for (int i = 0; i < values.length; i++) {
455             MathUtils.checkFinite(values[i]);
456         }
457     }
458 }