View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.linear;
24  
25  import java.util.function.Predicate;
26  
27  import org.hipparchus.Field;
28  import org.hipparchus.FieldElement;
29  import org.hipparchus.exception.LocalizedCoreFormats;
30  import org.hipparchus.exception.MathIllegalArgumentException;
31  import org.hipparchus.util.FastMath;
32  import org.hipparchus.util.MathArrays;
33  
34  /**
35   * Calculates the LUP-decomposition of a square matrix.
36   * <p>The LUP-decomposition of a matrix A consists of three matrices
37   * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
38   * upper triangular and P is a permutation matrix. All matrices are
39   * m&times;m.</p>
40   * <p>This class is based on the class with similar name from the
41   * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
42   * <ul>
43   *   <li>a {@link #getP() getP} method has been added,</li>
44   *   <li>the {@code det} method has been renamed as {@link #getDeterminant()
45   *   getDeterminant},</li>
46   *   <li>the {@code getDoublePivot} method has been removed (but the int based
47   *   {@link #getPivot() getPivot} method has been kept),</li>
48   *   <li>the {@code solve} and {@code isNonSingular} methods have been replaced
49   *   by a {@link #getSolver() getSolver} method and the equivalent methods
50   *   provided by the returned {@link DecompositionSolver}.</li>
51   * </ul>
52   *
53   * @param <T> the type of the field elements
54   * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
55   * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
56   */
57  public class FieldLUDecomposition<T extends FieldElement<T>> {
58  
59      /** Field to which the elements belong. */
60      private final Field<T> field;
61  
62      /** Entries of LU decomposition. */
63      private T[][] lu;
64  
65      /** Pivot permutation associated with LU decomposition. */
66      private int[] pivot;
67  
68      /** Parity of the permutation associated with the LU decomposition. */
69      private boolean even;
70  
71      /** Singularity indicator. */
72      private boolean singular;
73  
74      /** Cached value of L. */
75      private FieldMatrix<T> cachedL;
76  
77      /** Cached value of U. */
78      private FieldMatrix<T> cachedU;
79  
80      /** Cached value of P. */
81      private FieldMatrix<T> cachedP;
82  
83      /**
84       * Calculates the LU-decomposition of the given matrix.
85       * <p>
86       * By default, <code>numericPermutationChoice</code> is set to <code>true</code>.
87       * </p>
88       * @param matrix The matrix to decompose.
89       * @throws MathIllegalArgumentException if matrix is not square
90       * @see #FieldLUDecomposition(FieldMatrix, Predicate)
91       * @see #FieldLUDecomposition(FieldMatrix, Predicate, boolean)
92       */
93      public FieldLUDecomposition(FieldMatrix<T> matrix) {
94          this(matrix, e -> e.isZero());
95      }
96  
97      /**
98       * Calculates the LU-decomposition of the given matrix.
99       * <p>
100      * By default, <code>numericPermutationChoice</code> is set to <code>true</code>.
101      * </p>
102      * @param matrix The matrix to decompose.
103      * @param zeroChecker checker for zero elements
104      * @throws MathIllegalArgumentException if matrix is not square
105      * @see #FieldLUDecomposition(FieldMatrix, Predicate, boolean)
106      */
107     public FieldLUDecomposition(FieldMatrix<T> matrix, final Predicate<T> zeroChecker ) {
108         this(matrix, zeroChecker, true);
109     }
110 
111     /**
112      * Calculates the LU-decomposition of the given matrix.
113      * @param matrix The matrix to decompose.
114      * @param zeroChecker checker for zero elements
115      * @param numericPermutationChoice if <code>true</code> choose permutation index with numeric calculations, otherwise choose with <code>zeroChecker</code>
116      * @throws MathIllegalArgumentException if matrix is not square
117      */
118     public FieldLUDecomposition(FieldMatrix<T> matrix, final Predicate<T> zeroChecker, boolean numericPermutationChoice) {
119         if (!matrix.isSquare()) {
120             throw new MathIllegalArgumentException(LocalizedCoreFormats.NON_SQUARE_MATRIX,
121                                                    matrix.getRowDimension(), matrix.getColumnDimension());
122         }
123 
124         final int m = matrix.getColumnDimension();
125         field = matrix.getField();
126         lu = matrix.getData();
127         pivot = new int[m];
128         cachedL = null;
129         cachedU = null;
130         cachedP = null;
131 
132         // Initialize permutation array and parity
133         for (int row = 0; row < m; row++) {
134             pivot[row] = row;
135         }
136         even     = true;
137         singular = false;
138 
139         // Loop over columns
140         for (int col = 0; col < m; col++) {
141 
142             // upper
143             for (int row = 0; row < col; row++) {
144                 final T[] luRow = lu[row];
145                 T sum = luRow[col];
146                 for (int i = 0; i < row; i++) {
147                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
148                 }
149                 luRow[col] = sum;
150             }
151 
152             int max = col; // permutation row
153             if (numericPermutationChoice) {
154 
155                 // lower
156                 double largest = Double.NEGATIVE_INFINITY;
157 
158                 for (int row = col; row < m; row++) {
159                     final T[] luRow = lu[row];
160                     T sum = luRow[col];
161                     for (int i = 0; i < col; i++) {
162                         sum = sum.subtract(luRow[i].multiply(lu[i][col]));
163                     }
164                     luRow[col] = sum;
165 
166                     // maintain best permutation choice
167                     double absSum = FastMath.abs(sum.getReal());
168                     if (absSum > largest) {
169                         largest = absSum;
170                         max = row;
171                     }
172                 }
173 
174             } else {
175 
176                 // lower
177                 int nonZero = col; // permutation row
178                 for (int row = col; row < m; row++) {
179                     final T[] luRow = lu[row];
180                     T sum = luRow[col];
181                     for (int i = 0; i < col; i++) {
182                         sum = sum.subtract(luRow[i].multiply(lu[i][col]));
183                     }
184                     luRow[col] = sum;
185 
186                     if (zeroChecker.test(lu[nonZero][col])) {
187                         // try to select a better permutation choice
188                         ++nonZero;
189                     }
190                 }
191                 max = FastMath.min(m - 1, nonZero);
192 
193             }
194 
195             // Singularity check
196             if (zeroChecker.test(lu[max][col])) {
197                 singular = true;
198                 return;
199             }
200 
201             // Pivot if necessary
202             if (max != col) {
203                 final T[] luMax = lu[max];
204                 final T[] luCol = lu[col];
205                 for (int i = 0; i < m; i++) {
206                     final T tmp = luMax[i];
207                     luMax[i] = luCol[i];
208                     luCol[i] = tmp;
209                 }
210                 int temp = pivot[max];
211                 pivot[max] = pivot[col];
212                 pivot[col] = temp;
213                 even = !even;
214             }
215 
216             // Divide the lower elements by the "winning" diagonal elt.
217             final T luDiag = lu[col][col];
218             for (int row = col + 1; row < m; row++) {
219                 lu[row][col] = lu[row][col].divide(luDiag);
220             }
221         }
222 
223     }
224 
225     /**
226      * Returns the matrix L of the decomposition.
227      * <p>L is a lower-triangular matrix</p>
228      * @return the L matrix (or null if decomposed matrix is singular)
229      */
230     public FieldMatrix<T> getL() {
231         if ((cachedL == null) && !singular) {
232             final int m = pivot.length;
233             cachedL = new Array2DRowFieldMatrix<>(field, m, m);
234             for (int i = 0; i < m; ++i) {
235                 final T[] luI = lu[i];
236                 for (int j = 0; j < i; ++j) {
237                     cachedL.setEntry(i, j, luI[j]);
238                 }
239                 cachedL.setEntry(i, i, field.getOne());
240             }
241         }
242         return cachedL;
243     }
244 
245     /**
246      * Returns the matrix U of the decomposition.
247      * <p>U is an upper-triangular matrix</p>
248      * @return the U matrix (or null if decomposed matrix is singular)
249      */
250     public FieldMatrix<T> getU() {
251         if ((cachedU == null) && !singular) {
252             final int m = pivot.length;
253             cachedU = new Array2DRowFieldMatrix<>(field, m, m);
254             for (int i = 0; i < m; ++i) {
255                 final T[] luI = lu[i];
256                 for (int j = i; j < m; ++j) {
257                     cachedU.setEntry(i, j, luI[j]);
258                 }
259             }
260         }
261         return cachedU;
262     }
263 
264     /**
265      * Returns the P rows permutation matrix.
266      * <p>P is a sparse matrix with exactly one element set to 1.0 in
267      * each row and each column, all other elements being set to 0.0.</p>
268      * <p>The positions of the 1 elements are given by the {@link #getPivot()
269      * pivot permutation vector}.</p>
270      * @return the P rows permutation matrix (or null if decomposed matrix is singular)
271      * @see #getPivot()
272      */
273     public FieldMatrix<T> getP() {
274         if ((cachedP == null) && !singular) {
275             final int m = pivot.length;
276             cachedP = new Array2DRowFieldMatrix<>(field, m, m);
277             for (int i = 0; i < m; ++i) {
278                 cachedP.setEntry(i, pivot[i], field.getOne());
279             }
280         }
281         return cachedP;
282     }
283 
284     /**
285      * Returns the pivot permutation vector.
286      * @return the pivot permutation vector
287      * @see #getP()
288      */
289     public int[] getPivot() {
290         return pivot.clone();
291     }
292 
293     /**
294      * Return the determinant of the matrix.
295      * @return determinant of the matrix
296      */
297     public T getDeterminant() {
298         if (singular) {
299             return field.getZero();
300         } else {
301             final int m = pivot.length;
302             T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
303             for (int i = 0; i < m; i++) {
304                 determinant = determinant.multiply(lu[i][i]);
305             }
306             return determinant;
307         }
308     }
309 
310     /**
311      * Get a solver for finding the A &times; X = B solution in exact linear sense.
312      * @return a solver
313      */
314     public FieldDecompositionSolver<T> getSolver() {
315         return new Solver();
316     }
317 
318     /** Specialized solver.
319      */
320     private class Solver implements FieldDecompositionSolver<T> {
321 
322         /** {@inheritDoc} */
323         @Override
324         public boolean isNonSingular() {
325             return !singular;
326         }
327 
328         /** {@inheritDoc} */
329         @Override
330         public FieldVector<T> solve(FieldVector<T> b) {
331             if (b instanceof ArrayFieldVector) {
332                 return solve((ArrayFieldVector<T>) b);
333             } else {
334 
335                 final int m = pivot.length;
336                 if (b.getDimension() != m) {
337                     throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
338                                                            b.getDimension(), m);
339                 }
340                 if (singular) {
341                     throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
342                 }
343 
344                 // Apply permutations to b
345                 final T[] bp = MathArrays.buildArray(field, m);
346                 for (int row = 0; row < m; row++) {
347                     bp[row] = b.getEntry(pivot[row]);
348                 }
349 
350                 // Solve LY = b
351                 for (int col = 0; col < m; col++) {
352                     final T bpCol = bp[col];
353                     for (int i = col + 1; i < m; i++) {
354                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
355                     }
356                 }
357 
358                 // Solve UX = Y
359                 for (int col = m - 1; col >= 0; col--) {
360                     bp[col] = bp[col].divide(lu[col][col]);
361                     final T bpCol = bp[col];
362                     for (int i = 0; i < col; i++) {
363                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
364                     }
365                 }
366 
367                 return new ArrayFieldVector<T>(field, bp, false);
368 
369             }
370         }
371 
372         /** Solve the linear equation A &times; X = B.
373          * <p>The A matrix is implicit here. It is </p>
374          * @param b right-hand side of the equation A &times; X = B
375          * @return a vector X such that A &times; X = B
376          * @throws MathIllegalArgumentException if the matrices dimensions do not match.
377          * @throws MathIllegalArgumentException if the decomposed matrix is singular.
378          */
379         public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
380             final int m = pivot.length;
381             final int length = b.getDimension();
382             if (length != m) {
383                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
384                                                        length, m);
385             }
386             if (singular) {
387                 throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
388             }
389 
390             // Apply permutations to b
391             final T[] bp = MathArrays.buildArray(field, m);
392             for (int row = 0; row < m; row++) {
393                 bp[row] = b.getEntry(pivot[row]);
394             }
395 
396             // Solve LY = b
397             for (int col = 0; col < m; col++) {
398                 final T bpCol = bp[col];
399                 for (int i = col + 1; i < m; i++) {
400                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
401                 }
402             }
403 
404             // Solve UX = Y
405             for (int col = m - 1; col >= 0; col--) {
406                 bp[col] = bp[col].divide(lu[col][col]);
407                 final T bpCol = bp[col];
408                 for (int i = 0; i < col; i++) {
409                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
410                 }
411             }
412 
413             return new ArrayFieldVector<T>(bp, false);
414         }
415 
416         /** {@inheritDoc} */
417         @Override
418         public FieldMatrix<T> solve(FieldMatrix<T> b) {
419             final int m = pivot.length;
420             if (b.getRowDimension() != m) {
421                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
422                                                        b.getRowDimension(), m);
423             }
424             if (singular) {
425                 throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
426             }
427 
428             final int nColB = b.getColumnDimension();
429 
430             // Apply permutations to b
431             final T[][] bp = MathArrays.buildArray(field, m, nColB);
432             for (int row = 0; row < m; row++) {
433                 final T[] bpRow = bp[row];
434                 final int pRow = pivot[row];
435                 for (int col = 0; col < nColB; col++) {
436                     bpRow[col] = b.getEntry(pRow, col);
437                 }
438             }
439 
440             // Solve LY = b
441             for (int col = 0; col < m; col++) {
442                 final T[] bpCol = bp[col];
443                 for (int i = col + 1; i < m; i++) {
444                     final T[] bpI = bp[i];
445                     final T luICol = lu[i][col];
446                     for (int j = 0; j < nColB; j++) {
447                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
448                     }
449                 }
450             }
451 
452             // Solve UX = Y
453             for (int col = m - 1; col >= 0; col--) {
454                 final T[] bpCol = bp[col];
455                 final T luDiag = lu[col][col];
456                 for (int j = 0; j < nColB; j++) {
457                     bpCol[j] = bpCol[j].divide(luDiag);
458                 }
459                 for (int i = 0; i < col; i++) {
460                     final T[] bpI = bp[i];
461                     final T luICol = lu[i][col];
462                     for (int j = 0; j < nColB; j++) {
463                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
464                     }
465                 }
466             }
467 
468             return new Array2DRowFieldMatrix<T>(field, bp, false);
469 
470         }
471 
472         /** {@inheritDoc} */
473         @Override
474         public FieldMatrix<T> getInverse() {
475             return solve(MatrixUtils.createFieldIdentityMatrix(field, pivot.length));
476         }
477 
478         /** {@inheritDoc} */
479         @Override
480         public int getRowDimension() {
481             return lu.length;
482         }
483 
484         /** {@inheritDoc} */
485         @Override
486         public int getColumnDimension() {
487             return lu[0].length;
488         }
489 
490     }
491 }