View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.optim.nonlinear.vector.leastsquares;
23  
24  import java.io.BufferedReader;
25  import java.io.IOException;
26  import java.util.ArrayList;
27  
28  import org.hipparchus.analysis.MultivariateMatrixFunction;
29  import org.hipparchus.analysis.MultivariateVectorFunction;
30  
31  /**
32   * This class gives access to the statistical reference datasets provided by the
33   * NIST (available
34   * <a href="http://www.itl.nist.gov/div898/strd/general/dataarchive.html">here</a>).
35   * Instances of this class can be created by invocation of the
36   * {@link StatisticalReferenceDatasetFactory}.
37   */
38  public abstract class StatisticalReferenceDataset {
39      /** The name of this dataset. */
40      private final String name;
41      /** The total number of observations (data points). */
42      private final int numObservations;
43      /** The total number of parameters. */
44      private final int numParameters;
45      /** The total number of starting points for the optimizations. */
46      private final int numStartingPoints;
47      /** The values of the predictor. */
48      private final double[] x;
49      /** The values of the response. */
50      private final double[] y;
51      /**
52       * The starting values. {@code startingValues[j][i]} is the value of the
53       * {@code i}-th parameter in the {@code j}-th set of starting values.
54       */
55      private final double[][] startingValues;
56      /** The certified values of the parameters. */
57      private final double[] a;
58      /** The certified values of the standard deviation of the parameters. */
59      private final double[] sigA;
60      /** The certified value of the residual sum of squares. */
61      private double residualSumOfSquares;
62      /** The least-squares problem. */
63      private final LeastSquaresProblem problem;
64  
65      /**
66       * Creates a new instance of this class from the specified data file. The
67       * file must follow the StRD format.
68       *
69       * @param in the data file
70       * @throws IOException if an I/O error occurs
71       */
72      public StatisticalReferenceDataset(final BufferedReader in)
73          throws IOException {
74  
75          final ArrayList<String> lines = new ArrayList<String>();
76          for (String line = in.readLine(); line != null; line = in.readLine()) {
77              lines.add(line);
78          }
79          int[] index = findLineNumbers("Data", lines);
80          if (index == null) {
81              throw new AssertionError("could not find line indices for data");
82          }
83          this.numObservations = index[1] - index[0] + 1;
84          this.x = new double[this.numObservations];
85          this.y = new double[this.numObservations];
86          for (int i = 0; i < this.numObservations; i++) {
87              final String line = lines.get(index[0] + i - 1);
88              final String[] tokens = line.trim().split(" ++");
89              // Data columns are in reverse order!!!
90              this.y[i] = Double.parseDouble(tokens[0]);
91              this.x[i] = Double.parseDouble(tokens[1]);
92          }
93  
94          index = findLineNumbers("Starting Values", lines);
95          if (index == null) {
96              throw new AssertionError(
97                                       "could not find line indices for starting values");
98          }
99          this.numParameters = index[1] - index[0] + 1;
100 
101         double[][] start = null;
102         this.a = new double[numParameters];
103         this.sigA = new double[numParameters];
104         for (int i = 0; i < numParameters; i++) {
105             final String line = lines.get(index[0] + i - 1);
106             final String[] tokens = line.trim().split(" ++");
107             if (start == null) {
108                 start = new double[tokens.length - 4][numParameters];
109             }
110             for (int j = 2; j < tokens.length - 2; j++) {
111                 start[j - 2][i] = Double.parseDouble(tokens[j]);
112             }
113             this.a[i] = Double.parseDouble(tokens[tokens.length - 2]);
114             this.sigA[i] = Double.parseDouble(tokens[tokens.length - 1]);
115         }
116         if (start == null) {
117             throw new IOException("could not find starting values");
118         }
119         this.numStartingPoints = start.length;
120         this.startingValues = start;
121 
122         double dummyDouble = Double.NaN;
123         String dummyString = null;
124         for (String line : lines) {
125             if (line.contains("Dataset Name:")) {
126                 dummyString = line
127                     .substring(line.indexOf("Dataset Name:") + 13,
128                                line.indexOf("(")).trim();
129             }
130             if (line.contains("Residual Sum of Squares")) {
131                 final String[] tokens = line.split(" ++");
132                 dummyDouble = Double.parseDouble(tokens[4].trim());
133             }
134         }
135         if (Double.isNaN(dummyDouble)) {
136             throw new IOException(
137                                   "could not find certified value of residual sum of squares");
138         }
139         this.residualSumOfSquares = dummyDouble;
140 
141         if (dummyString == null) {
142             throw new IOException("could not find dataset name");
143         }
144         this.name = dummyString;
145 
146         this.problem = new LeastSquaresProblem();
147     }
148 
149     class LeastSquaresProblem {
150         public MultivariateVectorFunction getModelFunction() {
151             return new MultivariateVectorFunction() {
152                 public double[] value(final double[] a) {
153                     final int n = getNumObservations();
154                     final double[] yhat = new double[n];
155                     for (int i = 0; i < n; i++) {
156                         yhat[i] = getModelValue(getX(i), a);
157                     }
158                     return yhat;
159                 }
160             };
161         }
162 
163         public MultivariateMatrixFunction getModelFunctionJacobian() {
164             return new MultivariateMatrixFunction() {
165                 public double[][] value(final double[] a)
166                     throws IllegalArgumentException {
167                     final int n = getNumObservations();
168                     final double[][] j = new double[n][];
169                     for (int i = 0; i < n; i++) {
170                         j[i] = getModelDerivatives(getX(i), a);
171                     }
172                     return j;
173                 }
174             };
175         }
176     }
177 
178     /**
179      * Returns the name of this dataset.
180      *
181      * @return the name of the dataset
182      */
183     public String getName() {
184         return name;
185     }
186 
187     /**
188      * Returns the total number of observations (data points).
189      *
190      * @return the number of observations
191      */
192     public int getNumObservations() {
193         return numObservations;
194     }
195 
196     /**
197      * Returns a copy of the data arrays. The data is laid out as follows <li>
198      * {@code data[0][i] = x[i]},</li> <li>{@code data[1][i] = y[i]},</li>
199      *
200      * @return the array of data points.
201      */
202     public double[][] getData() {
203         return new double[][] { x.clone(), y.clone() };
204     }
205 
206     /**
207      * Returns the x-value of the {@code i}-th data point.
208      *
209      * @param i the index of the data point
210      * @return the x-value
211      */
212     public double getX(final int i) {
213         return x[i];
214     }
215 
216     /**
217      * Returns the y-value of the {@code i}-th data point.
218      *
219      * @param i the index of the data point
220      * @return the y-value
221      */
222     public double getY(final int i) {
223         return y[i];
224     }
225 
226     /**
227      * Returns the total number of parameters.
228      *
229      * @return the number of parameters
230      */
231     public int getNumParameters() {
232         return numParameters;
233     }
234 
235     /**
236      * Returns the certified values of the paramters.
237      *
238      * @return the values of the parameters
239      */
240     public double[] getParameters() {
241         return a.clone();
242     }
243 
244     /**
245      * Returns the certified value of the {@code i}-th parameter.
246      *
247      * @param i the index of the parameter
248      * @return the value of the parameter
249      */
250     public double getParameter(final int i) {
251         return a[i];
252     }
253 
254     /**
255      * Returns the certified values of the standard deviations of the parameters.
256      *
257      * @return the standard deviations of the parameters
258      */
259     public double[] getParametersStandardDeviations() {
260         return sigA.clone();
261     }
262 
263     /**
264      * Returns the certified value of the standard deviation of the {@code i}-th
265      * parameter.
266      *
267      * @param i the index of the parameter
268      * @return the standard deviation of the parameter
269      */
270     public double getParameterStandardDeviation(final int i) {
271         return sigA[i];
272     }
273 
274     /**
275      * Returns the certified value of the residual sum of squares.
276      *
277      * @return the residual sum of squares
278      */
279     public double getResidualSumOfSquares() {
280         return residualSumOfSquares;
281     }
282 
283     /**
284      * Returns the total number of starting points (initial guesses for the
285      * optimization process).
286      *
287      * @return the number of starting points
288      */
289     public int getNumStartingPoints() {
290         return numStartingPoints;
291     }
292 
293     /**
294      * Returns the {@code i}-th set of initial values of the parameters.
295      *
296      * @param i the index of the starting point
297      * @return the starting point
298      */
299     public double[] getStartingPoint(final int i) {
300         return startingValues[i].clone();
301     }
302 
303     /**
304      * Returns the least-squares problem corresponding to fitting the model to
305      * the specified data.
306      *
307      * @return the least-squares problem
308      */
309     public LeastSquaresProblem getLeastSquaresProblem() {
310         return problem;
311     }
312 
313     /**
314      * Returns the value of the model for the specified values of the predictor
315      * variable and the parameters.
316      *
317      * @param x the predictor variable
318      * @param a the parameters
319      * @return the value of the model
320      */
321     public abstract double getModelValue(final double x, final double[] a);
322 
323     /**
324      * Returns the values of the partial derivatives of the model with respect
325      * to the parameters.
326      *
327      * @param x the predictor variable
328      * @param a the parameters
329      * @return the partial derivatives
330      */
331     public abstract double[] getModelDerivatives(final double x,
332                                                  final double[] a);
333 
334     /**
335      * <p>
336      * Parses the specified text lines, and extracts the indices of the first
337      * and last lines of the data defined by the specified {@code key}. This key
338      * must be one of
339      * </p>
340      * <ul>
341      * <li>{@code "Starting Values"},</li>
342      * <li>{@code "Certified Values"},</li>
343      * <li>{@code "Data"}.</li>
344      * </ul>
345      * <p>
346      * In the NIST data files, the line indices are separated by the keywords
347      * {@code "lines"} and {@code "to"}.
348      * </p>
349      *
350      * @param lines the line of text to be parsed
351      * @return an array of two {@code int}s. First value is the index of the
352      *         first line, second value is the index of the last line.
353      *         {@code null} if the line could not be parsed.
354      */
355     private static int[] findLineNumbers(final String key,
356                                          final Iterable<String> lines) {
357         for (String text : lines) {
358             boolean flag = text.contains(key) && text.contains("lines") &&
359                            text.contains("to") && text.contains(")");
360             if (flag) {
361                 final int[] numbers = new int[2];
362                 final String from = text.substring(text.indexOf("lines") + 5,
363                                                    text.indexOf("to"));
364                 numbers[0] = Integer.parseInt(from.trim());
365                 final String to = text.substring(text.indexOf("to") + 2,
366                                                  text.indexOf(")"));
367                 numbers[1] = Integer.parseInt(to.trim());
368                 return numbers;
369             }
370         }
371         return null;
372     }
373 }