View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.optim.nonlinear.vector.leastsquares;
24  
25  import java.util.ArrayList;
26  
27  import org.hipparchus.UnitTestUtils;
28  import org.hipparchus.UnitTestUtils.SimpleRegression;
29  import org.hipparchus.analysis.MultivariateMatrixFunction;
30  import org.hipparchus.analysis.MultivariateVectorFunction;
31  import org.hipparchus.analysis.UnivariateFunction;
32  
33  /**
34   * Class that models a straight line defined as {@code y = a x + b}.
35   * The parameters of problem are:
36   * <ul>
37   *  <li>{@code a}</li>
38   *  <li>{@code b}</li>
39   * </ul>
40   * The model functions are:
41   * <ul>
42   *  <li>for each pair (a, b), the y-coordinate of the line.</li>
43   * </ul>
44   */
45  class StraightLineProblem {
46      /** Cloud of points assumed to be fitted by a straight line. */
47      private final ArrayList<double[]> points;
48      /** Error (on the y-coordinate of the points). */
49      private final double sigma;
50  
51      /**
52       * @param error Assumed error for the y-coordinate.
53       */
54      public StraightLineProblem(double error) {
55          points = new ArrayList<double[]>();
56          sigma = error;
57      }
58  
59      public void addPoint(double px, double py) {
60          points.add(new double[] { px, py });
61      }
62  
63      /**
64       * @return the array of x-coordinates.
65       */
66      public double[] x() {
67          final double[] v = new double[points.size()];
68          for (int i = 0; i < points.size(); i++) {
69              final double[] p = points.get(i);
70              v[i] = p[0]; // x-coordinate.
71          }
72  
73          return v;
74      }
75  
76      /**
77       * @return the array of y-coordinates.
78       */
79      public double[] y() {
80          final double[] v = new double[points.size()];
81          for (int i = 0; i < points.size(); i++) {
82              final double[] p = points.get(i);
83              v[i] = p[1]; // y-coordinate.
84          }
85  
86          return v;
87      }
88  
89      public double[] target() {
90          return y();
91      }
92  
93      public double[] weight() {
94          final double weight = 1 / (sigma * sigma);
95          final double[] w = new double[points.size()];
96          for (int i = 0; i < points.size(); i++) {
97              w[i] = weight;
98          }
99  
100         return w;
101     }
102 
103     public MultivariateVectorFunction getModelFunction() {
104         return new MultivariateVectorFunction() {
105             public double[] value(double[] params) {
106                 final Model line = new Model(params[0], params[1]);
107 
108                 final double[] model = new double[points.size()];
109                 for (int i = 0; i < points.size(); i++) {
110                     final double[] p = points.get(i);
111                     model[i] = line.value(p[0]);
112                 }
113 
114                 return model;
115             }
116         };
117     }
118 
119     public MultivariateMatrixFunction getModelFunctionJacobian() {
120         return new MultivariateMatrixFunction() {
121             public double[][] value(double[] point) {
122                 return jacobian(point);
123             }
124         };
125     }
126 
127     /**
128      * Directly solve the linear problem, using the {@link SimpleRegression}
129      * class.
130      */
131     public double[] solve() {
132         final UnitTestUtils.SimpleRegression regress = new UnitTestUtils.SimpleRegression();
133         for (double[] d : points) {
134             regress.addData(d[0], d[1]);
135         }
136 
137         final double[] result = { regress.getSlope(), regress.getIntercept() };
138         return result;
139     }
140 
141     private double[][] jacobian(double[] params) {
142         final double[][] jacobian = new double[points.size()][2];
143 
144         for (int i = 0; i < points.size(); i++) {
145             final double[] p = points.get(i);
146             // Partial derivative wrt "a".
147             jacobian[i][0] = p[0];
148             // Partial derivative wrt "b".
149             jacobian[i][1] = 1;
150         }
151 
152         return jacobian;
153     }
154 
155     /**
156      * Linear function.
157      */
158     public static class Model implements UnivariateFunction {
159         final double a;
160         final double b;
161 
162         public Model(double a,
163                      double b) {
164             this.a = a;
165             this.b = b;
166         }
167 
168         public double value(double x) {
169             return a * x + b;
170         }
171     }
172 }