View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.linear;
24  
25  import java.util.Arrays;
26  import java.util.function.Predicate;
27  
28  import org.hipparchus.CalculusFieldElement;
29  import org.hipparchus.FieldElement;
30  import org.hipparchus.exception.LocalizedCoreFormats;
31  import org.hipparchus.exception.MathIllegalArgumentException;
32  import org.hipparchus.util.FastMath;
33  import org.hipparchus.util.MathArrays;
34  
35  
36  /**
37   * Calculates the QR-decomposition of a field matrix.
38   * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
39   * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
40   * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
41   * <p>This class compute the decomposition using Householder reflectors.</p>
42   * <p>For efficiency purposes, the decomposition in packed form is transposed.
43   * This allows inner loop to iterate inside rows, which is much more cache-efficient
44   * in Java.</p>
45   * <p>This class is based on the class {@link QRDecomposition}.</p>
46   *
47   * @param <T> type of the underlying field elements
48   * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
49   * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
50   *
51   */
52  public class FieldQRDecomposition<T extends CalculusFieldElement<T>> {
53      /**
54       * A packed TRANSPOSED representation of the QR decomposition.
55       * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
56       * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
57       * from which an explicit form of Q can be recomputed if desired.</p>
58       */
59      private T[][] qrt;
60      /** The diagonal elements of R. */
61      private T[] rDiag;
62      /** Cached value of Q. */
63      private FieldMatrix<T> cachedQ;
64      /** Cached value of QT. */
65      private FieldMatrix<T> cachedQT;
66      /** Cached value of R. */
67      private FieldMatrix<T> cachedR;
68      /** Cached value of H. */
69      private FieldMatrix<T> cachedH;
70      /** Singularity threshold. */
71      private final T threshold;
72      /** checker for zero. */
73      private final Predicate<T> zeroChecker;
74  
75      /**
76       * Calculates the QR-decomposition of the given matrix.
77       * The singularity threshold defaults to zero.
78       *
79       * @param matrix The matrix to decompose.
80       *
81       * @see #FieldQRDecomposition(FieldMatrix, CalculusFieldElement)
82       */
83      public FieldQRDecomposition(FieldMatrix<T> matrix) {
84          this(matrix, matrix.getField().getZero());
85      }
86  
87      /**
88       * Calculates the QR-decomposition of the given matrix.
89       *
90       * @param matrix The matrix to decompose.
91       * @param threshold Singularity threshold.
92       */
93      public FieldQRDecomposition(FieldMatrix<T> matrix, T threshold) {
94          this(matrix, threshold, FieldElement::isZero);
95      }
96  
97      /**
98       * Calculates the QR-decomposition of the given matrix.
99       *
100      * @param matrix The matrix to decompose.
101      * @param threshold Singularity threshold.
102      * @param zeroChecker checker for zero
103      */
104     public FieldQRDecomposition(FieldMatrix<T> matrix, T threshold, Predicate<T> zeroChecker) {
105         this.threshold   = threshold;
106         this.zeroChecker = zeroChecker;
107 
108         final int m = matrix.getRowDimension();
109         final int n = matrix.getColumnDimension();
110         qrt = matrix.transpose().getData();
111         rDiag = MathArrays.buildArray(threshold.getField(),FastMath.min(m, n));
112         cachedQ  = null;
113         cachedQT = null;
114         cachedR  = null;
115         cachedH  = null;
116 
117         decompose(qrt);
118 
119     }
120 
121     /** Decompose matrix.
122      * @param matrix transposed matrix
123      */
124     protected void decompose(T[][] matrix) {
125         for (int minor = 0; minor < FastMath.min(matrix.length, matrix[0].length); minor++) {
126             performHouseholderReflection(minor, matrix);
127         }
128     }
129 
130     /** Perform Householder reflection for a minor A(minor, minor) of A.
131      * @param minor minor index
132      * @param matrix transposed matrix
133      */
134     protected void performHouseholderReflection(int minor, T[][] matrix) {
135 
136         final T[] qrtMinor = matrix[minor];
137         final T zero = threshold.getField().getZero();
138         /*
139          * Let x be the first column of the minor, and a^2 = |x|^2.
140          * x will be in the positions qr[minor][minor] through qr[m][minor].
141          * The first column of the transformed minor will be (a,0,0,..)'
142          * The sign of a is chosen to be opposite to the sign of the first
143          * component of x. Let's find a:
144          */
145         T xNormSqr = zero;
146         for (int row = minor; row < qrtMinor.length; row++) {
147             final T c = qrtMinor[row];
148             xNormSqr = xNormSqr.add(c.square());
149         }
150         final T a = (qrtMinor[minor].getReal() > 0) ? xNormSqr.sqrt().negate() : xNormSqr.sqrt();
151         rDiag[minor] = a;
152 
153         if (!zeroChecker.test(a)) {
154 
155             /*
156              * Calculate the normalized reflection vector v and transform
157              * the first column. We know the norm of v beforehand: v = x-ae
158              * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
159              * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
160              * Here <x, e> is now qr[minor][minor].
161              * v = x-ae is stored in the column at qr:
162              */
163             qrtMinor[minor] = qrtMinor[minor].subtract(a); // now |v|^2 = -2a*(qr[minor][minor])
164 
165             /*
166              * Transform the rest of the columns of the minor:
167              * They will be transformed by the matrix H = I-2vv'/|v|^2.
168              * If x is a column vector of the minor, then
169              * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
170              * Therefore the transformation is easily calculated by
171              * subtracting the column vector (2<x,v>/|v|^2)v from x.
172              *
173              * Let 2<x,v>/|v|^2 = alpha. From above we have
174              * |v|^2 = -2a*(qr[minor][minor]), so
175              * alpha = -<x,v>/(a*qr[minor][minor])
176              */
177             for (int col = minor+1; col < matrix.length; col++) {
178                 final T[] qrtCol = matrix[col];
179                 T alpha = zero;
180                 for (int row = minor; row < qrtCol.length; row++) {
181                     alpha = alpha.subtract(qrtCol[row].multiply(qrtMinor[row]));
182                 }
183                 alpha = alpha.divide(a.multiply(qrtMinor[minor]));
184 
185                 // Subtract the column vector alpha*v from x.
186                 for (int row = minor; row < qrtCol.length; row++) {
187                     qrtCol[row] = qrtCol[row].subtract(alpha.multiply(qrtMinor[row]));
188                 }
189             }
190         }
191     }
192 
193 
194     /**
195      * Returns the matrix R of the decomposition.
196      * <p>R is an upper-triangular matrix</p>
197      * @return the R matrix
198      */
199     public FieldMatrix<T> getR() {
200 
201         if (cachedR == null) {
202 
203             // R is supposed to be m x n
204             final int n = qrt.length;
205             final int m = qrt[0].length;
206             T[][] ra = MathArrays.buildArray(threshold.getField(), m, n);
207             // copy the diagonal from rDiag and the upper triangle of qr
208             for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
209                 ra[row][row] = rDiag[row];
210                 for (int col = row + 1; col < n; col++) {
211                     ra[row][col] = qrt[col][row];
212                 }
213             }
214             cachedR = MatrixUtils.createFieldMatrix(ra);
215         }
216 
217         // return the cached matrix
218         return cachedR;
219     }
220 
221     /**
222      * Returns the matrix Q of the decomposition.
223      * <p>Q is an orthogonal matrix</p>
224      * @return the Q matrix
225      */
226     public FieldMatrix<T> getQ() {
227         if (cachedQ == null) {
228             cachedQ = getQT().transpose();
229         }
230         return cachedQ;
231     }
232 
233     /**
234      * Returns the transpose of the matrix Q of the decomposition.
235      * <p>Q is an orthogonal matrix</p>
236      * @return the transpose of the Q matrix, Q<sup>T</sup>
237      */
238     public FieldMatrix<T> getQT() {
239         if (cachedQT == null) {
240 
241             // QT is supposed to be m x m
242             final int n = qrt.length;
243             final int m = qrt[0].length;
244             T[][] qta = MathArrays.buildArray(threshold.getField(), m, m);
245 
246             /*
247              * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
248              * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
249              * succession to the result
250              */
251             for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
252                 qta[minor][minor] = threshold.getField().getOne();
253             }
254 
255             for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
256                 final T[] qrtMinor = qrt[minor];
257                 qta[minor][minor] = threshold.getField().getOne();
258                 if (!qrtMinor[minor].isZero()) {
259                     for (int col = minor; col < m; col++) {
260                         T alpha = threshold.getField().getZero();
261                         for (int row = minor; row < m; row++) {
262                             alpha = alpha.subtract(qta[col][row].multiply(qrtMinor[row]));
263                         }
264                         alpha = alpha.divide(rDiag[minor].multiply(qrtMinor[minor]));
265 
266                         for (int row = minor; row < m; row++) {
267                             qta[col][row] = qta[col][row].add(alpha.negate().multiply(qrtMinor[row]));
268                         }
269                     }
270                 }
271             }
272             cachedQT = MatrixUtils.createFieldMatrix(qta);
273         }
274 
275         // return the cached matrix
276         return cachedQT;
277     }
278 
279     /**
280      * Returns the Householder reflector vectors.
281      * <p>H is a lower trapezoidal matrix whose columns represent
282      * each successive Householder reflector vector. This matrix is used
283      * to compute Q.</p>
284      * @return a matrix containing the Householder reflector vectors
285      */
286     public FieldMatrix<T> getH() {
287         if (cachedH == null) {
288 
289             final int n = qrt.length;
290             final int m = qrt[0].length;
291             T[][] ha = MathArrays.buildArray(threshold.getField(), m, n);
292             for (int i = 0; i < m; ++i) {
293                 for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
294                     ha[i][j] = qrt[j][i].divide(rDiag[j].negate());
295                 }
296             }
297             cachedH = MatrixUtils.createFieldMatrix(ha);
298         }
299 
300         // return the cached matrix
301         return cachedH;
302     }
303 
304     /**
305      * Get a solver for finding the A &times; X = B solution in least square sense.
306      * <p>
307      * Least Square sense means a solver can be computed for an overdetermined system,
308      * (i.e. a system with more equations than unknowns, which corresponds to a tall A
309      * matrix with more rows than columns). In any case, if the matrix is singular
310      * within the tolerance set at {@link #FieldQRDecomposition(FieldMatrix,
311      * CalculusFieldElement) construction}, an error will be triggered when
312      * the {@link DecompositionSolver#solve(RealVector) solve} method will be called.
313      * </p>
314      * @return a solver
315      */
316     public FieldDecompositionSolver<T> getSolver() {
317         return new FieldSolver();
318     }
319 
320     /**
321      * Specialized solver.
322      */
323     private class FieldSolver implements FieldDecompositionSolver<T>{
324 
325         /** {@inheritDoc} */
326         @Override
327         public boolean isNonSingular() {
328             return !checkSingular(rDiag, threshold, false);
329         }
330 
331         /** {@inheritDoc} */
332         @Override
333         public FieldVector<T> solve(FieldVector<T> b) {
334             final int n = qrt.length;
335             final int m = qrt[0].length;
336             if (b.getDimension() != m) {
337                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
338                                                        b.getDimension(), m);
339             }
340             checkSingular(rDiag, threshold, true);
341 
342             final T[] x =MathArrays.buildArray(threshold.getField(),n);
343             final T[] y = b.toArray();
344 
345             // apply Householder transforms to solve Q.y = b
346             for (int minor = 0; minor < FastMath.min(m, n); minor++) {
347 
348                 final T[] qrtMinor = qrt[minor];
349                 T dotProduct = threshold.getField().getZero();
350                 for (int row = minor; row < m; row++) {
351                     dotProduct = dotProduct.add(y[row].multiply(qrtMinor[row]));
352                 }
353                 dotProduct =  dotProduct.divide(rDiag[minor].multiply(qrtMinor[minor]));
354 
355                 for (int row = minor; row < m; row++) {
356                     y[row] = y[row].add(dotProduct.multiply(qrtMinor[row]));
357                 }
358             }
359 
360             // solve triangular system R.x = y
361             for (int row = rDiag.length - 1; row >= 0; --row) {
362                 y[row] = y[row].divide(rDiag[row]);
363                 final T yRow = y[row];
364                 final T[] qrtRow = qrt[row];
365                 x[row] = yRow;
366                 for (int i = 0; i < row; i++) {
367                     y[i] = y[i].subtract(yRow.multiply(qrtRow[i]));
368                 }
369             }
370 
371             return new ArrayFieldVector<>(x, false);
372         }
373 
374         /** {@inheritDoc} */
375         @Override
376         public FieldMatrix<T> solve(FieldMatrix<T> b) {
377             final int n = qrt.length;
378             final int m = qrt[0].length;
379             if (b.getRowDimension() != m) {
380                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
381                                                        b.getRowDimension(), m);
382             }
383             checkSingular(rDiag, threshold, true);
384 
385             final int columns        = b.getColumnDimension();
386             final int blockSize      = BlockFieldMatrix.BLOCK_SIZE;
387             final int cBlocks        = (columns + blockSize - 1) / blockSize;
388             final T[][] xBlocks = BlockFieldMatrix.createBlocksLayout(threshold.getField(),n, columns);
389             final T[][] y       = MathArrays.buildArray(threshold.getField(), b.getRowDimension(), blockSize);
390             final T[]   alpha   = MathArrays.buildArray(threshold.getField(), blockSize);
391 
392             for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
393                 final int kStart = kBlock * blockSize;
394                 final int kEnd   = FastMath.min(kStart + blockSize, columns);
395                 final int kWidth = kEnd - kStart;
396 
397                 // get the right hand side vector
398                 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
399 
400                 // apply Householder transforms to solve Q.y = b
401                 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
402                     final T[] qrtMinor = qrt[minor];
403                     final T factor     = rDiag[minor].multiply(qrtMinor[minor]).reciprocal();
404 
405                     Arrays.fill(alpha, 0, kWidth, threshold.getField().getZero());
406                     for (int row = minor; row < m; ++row) {
407                         final T   d    = qrtMinor[row];
408                         final T[] yRow = y[row];
409                         for (int k = 0; k < kWidth; ++k) {
410                             alpha[k] = alpha[k].add(d.multiply(yRow[k]));
411                         }
412                     }
413 
414                     for (int k = 0; k < kWidth; ++k) {
415                         alpha[k] = alpha[k].multiply(factor);
416                     }
417 
418                     for (int row = minor; row < m; ++row) {
419                         final T   d    = qrtMinor[row];
420                         final T[] yRow = y[row];
421                         for (int k = 0; k < kWidth; ++k) {
422                             yRow[k] = yRow[k].add(alpha[k].multiply(d));
423                         }
424                     }
425                 }
426 
427                 // solve triangular system R.x = y
428                 for (int j = rDiag.length - 1; j >= 0; --j) {
429                     final int      jBlock = j / blockSize;
430                     final int      jStart = jBlock * blockSize;
431                     final T   factor = rDiag[j].reciprocal();
432                     final T[] yJ     = y[j];
433                     final T[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
434                     int index = (j - jStart) * kWidth;
435                     for (int k = 0; k < kWidth; ++k) {
436                         yJ[k]           =yJ[k].multiply(factor);
437                         xBlock[index++] = yJ[k];
438                     }
439 
440                     final T[] qrtJ = qrt[j];
441                     for (int i = 0; i < j; ++i) {
442                         final T rIJ  = qrtJ[i];
443                         final T[] yI = y[i];
444                         for (int k = 0; k < kWidth; ++k) {
445                             yI[k] = yI[k].subtract(yJ[k].multiply(rIJ));
446                         }
447                     }
448                 }
449             }
450 
451             return new BlockFieldMatrix<>(n, columns, xBlocks, false);
452         }
453 
454         /**
455          * {@inheritDoc}
456          * @throws MathIllegalArgumentException if the decomposed matrix is singular.
457          */
458         @Override
459         public FieldMatrix<T> getInverse() {
460             return solve(MatrixUtils.createFieldIdentityMatrix(threshold.getField(), qrt[0].length));
461         }
462 
463         /**
464          * Check singularity.
465          *
466          * @param diag Diagonal elements of the R matrix.
467          * @param min Singularity threshold.
468          * @param raise Whether to raise a {@link MathIllegalArgumentException}
469          * if any element of the diagonal fails the check.
470          * @return {@code true} if any element of the diagonal is smaller
471          * or equal to {@code min}.
472          * @throws MathIllegalArgumentException if the matrix is singular and
473          * {@code raise} is {@code true}.
474          */
475         private boolean checkSingular(T[] diag,
476                                              T min,
477                                              boolean raise) {
478             for (final T d : diag) {
479                 if (FastMath.abs(d.getReal()) <= min.getReal()) {
480                     if (raise) {
481                         throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
482                     } else {
483                         return true;
484                     }
485                 }
486             }
487             return false;
488         }
489 
490         /** {@inheritDoc} */
491         @Override
492         public int getRowDimension() {
493             return qrt[0].length;
494         }
495 
496         /** {@inheritDoc} */
497         @Override
498         public int getColumnDimension() {
499             return qrt.length;
500         }
501 
502     }
503 }