1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.linear;
24
25 import java.util.Arrays;
26 import java.util.function.Predicate;
27
28 import org.hipparchus.CalculusFieldElement;
29 import org.hipparchus.FieldElement;
30 import org.hipparchus.exception.LocalizedCoreFormats;
31 import org.hipparchus.exception.MathIllegalArgumentException;
32 import org.hipparchus.util.FastMath;
33 import org.hipparchus.util.MathArrays;
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 public class FieldQRDecomposition<T extends CalculusFieldElement<T>> {
53
54
55
56
57
58
59 private T[][] qrt;
60
61 private T[] rDiag;
62
63 private FieldMatrix<T> cachedQ;
64
65 private FieldMatrix<T> cachedQT;
66
67 private FieldMatrix<T> cachedR;
68
69 private FieldMatrix<T> cachedH;
70
71 private final T threshold;
72
73 private final Predicate<T> zeroChecker;
74
75
76
77
78
79
80
81
82
83 public FieldQRDecomposition(FieldMatrix<T> matrix) {
84 this(matrix, matrix.getField().getZero());
85 }
86
87
88
89
90
91
92
93 public FieldQRDecomposition(FieldMatrix<T> matrix, T threshold) {
94 this(matrix, threshold, FieldElement::isZero);
95 }
96
97
98
99
100
101
102
103
104 public FieldQRDecomposition(FieldMatrix<T> matrix, T threshold, Predicate<T> zeroChecker) {
105 this.threshold = threshold;
106 this.zeroChecker = zeroChecker;
107
108 final int m = matrix.getRowDimension();
109 final int n = matrix.getColumnDimension();
110 qrt = matrix.transpose().getData();
111 rDiag = MathArrays.buildArray(threshold.getField(),FastMath.min(m, n));
112 cachedQ = null;
113 cachedQT = null;
114 cachedR = null;
115 cachedH = null;
116
117 decompose(qrt);
118
119 }
120
121
122
123
124 protected void decompose(T[][] matrix) {
125 for (int minor = 0; minor < FastMath.min(matrix.length, matrix[0].length); minor++) {
126 performHouseholderReflection(minor, matrix);
127 }
128 }
129
130
131
132
133
134 protected void performHouseholderReflection(int minor, T[][] matrix) {
135
136 final T[] qrtMinor = matrix[minor];
137 final T zero = threshold.getField().getZero();
138
139
140
141
142
143
144
145 T xNormSqr = zero;
146 for (int row = minor; row < qrtMinor.length; row++) {
147 final T c = qrtMinor[row];
148 xNormSqr = xNormSqr.add(c.square());
149 }
150 final T a = (qrtMinor[minor].getReal() > 0) ? xNormSqr.sqrt().negate() : xNormSqr.sqrt();
151 rDiag[minor] = a;
152
153 if (!zeroChecker.test(a)) {
154
155
156
157
158
159
160
161
162
163 qrtMinor[minor] = qrtMinor[minor].subtract(a);
164
165
166
167
168
169
170
171
172
173
174
175
176
177 for (int col = minor+1; col < matrix.length; col++) {
178 final T[] qrtCol = matrix[col];
179 T alpha = zero;
180 for (int row = minor; row < qrtCol.length; row++) {
181 alpha = alpha.subtract(qrtCol[row].multiply(qrtMinor[row]));
182 }
183 alpha = alpha.divide(a.multiply(qrtMinor[minor]));
184
185
186 for (int row = minor; row < qrtCol.length; row++) {
187 qrtCol[row] = qrtCol[row].subtract(alpha.multiply(qrtMinor[row]));
188 }
189 }
190 }
191 }
192
193
194
195
196
197
198
199 public FieldMatrix<T> getR() {
200
201 if (cachedR == null) {
202
203
204 final int n = qrt.length;
205 final int m = qrt[0].length;
206 T[][] ra = MathArrays.buildArray(threshold.getField(), m, n);
207
208 for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
209 ra[row][row] = rDiag[row];
210 for (int col = row + 1; col < n; col++) {
211 ra[row][col] = qrt[col][row];
212 }
213 }
214 cachedR = MatrixUtils.createFieldMatrix(ra);
215 }
216
217
218 return cachedR;
219 }
220
221
222
223
224
225
226 public FieldMatrix<T> getQ() {
227 if (cachedQ == null) {
228 cachedQ = getQT().transpose();
229 }
230 return cachedQ;
231 }
232
233
234
235
236
237
238 public FieldMatrix<T> getQT() {
239 if (cachedQT == null) {
240
241
242 final int n = qrt.length;
243 final int m = qrt[0].length;
244 T[][] qta = MathArrays.buildArray(threshold.getField(), m, m);
245
246
247
248
249
250
251 for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
252 qta[minor][minor] = threshold.getField().getOne();
253 }
254
255 for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
256 final T[] qrtMinor = qrt[minor];
257 qta[minor][minor] = threshold.getField().getOne();
258 if (!qrtMinor[minor].isZero()) {
259 for (int col = minor; col < m; col++) {
260 T alpha = threshold.getField().getZero();
261 for (int row = minor; row < m; row++) {
262 alpha = alpha.subtract(qta[col][row].multiply(qrtMinor[row]));
263 }
264 alpha = alpha.divide(rDiag[minor].multiply(qrtMinor[minor]));
265
266 for (int row = minor; row < m; row++) {
267 qta[col][row] = qta[col][row].add(alpha.negate().multiply(qrtMinor[row]));
268 }
269 }
270 }
271 }
272 cachedQT = MatrixUtils.createFieldMatrix(qta);
273 }
274
275
276 return cachedQT;
277 }
278
279
280
281
282
283
284
285
286 public FieldMatrix<T> getH() {
287 if (cachedH == null) {
288
289 final int n = qrt.length;
290 final int m = qrt[0].length;
291 T[][] ha = MathArrays.buildArray(threshold.getField(), m, n);
292 for (int i = 0; i < m; ++i) {
293 for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
294 ha[i][j] = qrt[j][i].divide(rDiag[j].negate());
295 }
296 }
297 cachedH = MatrixUtils.createFieldMatrix(ha);
298 }
299
300
301 return cachedH;
302 }
303
304
305
306
307
308
309
310
311
312
313
314
315
316 public FieldDecompositionSolver<T> getSolver() {
317 return new FieldSolver();
318 }
319
320
321
322
323 private class FieldSolver implements FieldDecompositionSolver<T>{
324
325
326 @Override
327 public boolean isNonSingular() {
328 return !checkSingular(rDiag, threshold, false);
329 }
330
331
332 @Override
333 public FieldVector<T> solve(FieldVector<T> b) {
334 final int n = qrt.length;
335 final int m = qrt[0].length;
336 if (b.getDimension() != m) {
337 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
338 b.getDimension(), m);
339 }
340 checkSingular(rDiag, threshold, true);
341
342 final T[] x =MathArrays.buildArray(threshold.getField(),n);
343 final T[] y = b.toArray();
344
345
346 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
347
348 final T[] qrtMinor = qrt[minor];
349 T dotProduct = threshold.getField().getZero();
350 for (int row = minor; row < m; row++) {
351 dotProduct = dotProduct.add(y[row].multiply(qrtMinor[row]));
352 }
353 dotProduct = dotProduct.divide(rDiag[minor].multiply(qrtMinor[minor]));
354
355 for (int row = minor; row < m; row++) {
356 y[row] = y[row].add(dotProduct.multiply(qrtMinor[row]));
357 }
358 }
359
360
361 for (int row = rDiag.length - 1; row >= 0; --row) {
362 y[row] = y[row].divide(rDiag[row]);
363 final T yRow = y[row];
364 final T[] qrtRow = qrt[row];
365 x[row] = yRow;
366 for (int i = 0; i < row; i++) {
367 y[i] = y[i].subtract(yRow.multiply(qrtRow[i]));
368 }
369 }
370
371 return new ArrayFieldVector<>(x, false);
372 }
373
374
375 @Override
376 public FieldMatrix<T> solve(FieldMatrix<T> b) {
377 final int n = qrt.length;
378 final int m = qrt[0].length;
379 if (b.getRowDimension() != m) {
380 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
381 b.getRowDimension(), m);
382 }
383 checkSingular(rDiag, threshold, true);
384
385 final int columns = b.getColumnDimension();
386 final int blockSize = BlockFieldMatrix.BLOCK_SIZE;
387 final int cBlocks = (columns + blockSize - 1) / blockSize;
388 final T[][] xBlocks = BlockFieldMatrix.createBlocksLayout(threshold.getField(),n, columns);
389 final T[][] y = MathArrays.buildArray(threshold.getField(), b.getRowDimension(), blockSize);
390 final T[] alpha = MathArrays.buildArray(threshold.getField(), blockSize);
391
392 for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
393 final int kStart = kBlock * blockSize;
394 final int kEnd = FastMath.min(kStart + blockSize, columns);
395 final int kWidth = kEnd - kStart;
396
397
398 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
399
400
401 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
402 final T[] qrtMinor = qrt[minor];
403 final T factor = rDiag[minor].multiply(qrtMinor[minor]).reciprocal();
404
405 Arrays.fill(alpha, 0, kWidth, threshold.getField().getZero());
406 for (int row = minor; row < m; ++row) {
407 final T d = qrtMinor[row];
408 final T[] yRow = y[row];
409 for (int k = 0; k < kWidth; ++k) {
410 alpha[k] = alpha[k].add(d.multiply(yRow[k]));
411 }
412 }
413
414 for (int k = 0; k < kWidth; ++k) {
415 alpha[k] = alpha[k].multiply(factor);
416 }
417
418 for (int row = minor; row < m; ++row) {
419 final T d = qrtMinor[row];
420 final T[] yRow = y[row];
421 for (int k = 0; k < kWidth; ++k) {
422 yRow[k] = yRow[k].add(alpha[k].multiply(d));
423 }
424 }
425 }
426
427
428 for (int j = rDiag.length - 1; j >= 0; --j) {
429 final int jBlock = j / blockSize;
430 final int jStart = jBlock * blockSize;
431 final T factor = rDiag[j].reciprocal();
432 final T[] yJ = y[j];
433 final T[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
434 int index = (j - jStart) * kWidth;
435 for (int k = 0; k < kWidth; ++k) {
436 yJ[k] =yJ[k].multiply(factor);
437 xBlock[index++] = yJ[k];
438 }
439
440 final T[] qrtJ = qrt[j];
441 for (int i = 0; i < j; ++i) {
442 final T rIJ = qrtJ[i];
443 final T[] yI = y[i];
444 for (int k = 0; k < kWidth; ++k) {
445 yI[k] = yI[k].subtract(yJ[k].multiply(rIJ));
446 }
447 }
448 }
449 }
450
451 return new BlockFieldMatrix<>(n, columns, xBlocks, false);
452 }
453
454
455
456
457
458 @Override
459 public FieldMatrix<T> getInverse() {
460 return solve(MatrixUtils.createFieldIdentityMatrix(threshold.getField(), qrt[0].length));
461 }
462
463
464
465
466
467
468
469
470
471
472
473
474
475 private boolean checkSingular(T[] diag,
476 T min,
477 boolean raise) {
478 for (final T d : diag) {
479 if (FastMath.abs(d.getReal()) <= min.getReal()) {
480 if (raise) {
481 throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
482 } else {
483 return true;
484 }
485 }
486 }
487 return false;
488 }
489
490
491 @Override
492 public int getRowDimension() {
493 return qrt[0].length;
494 }
495
496
497 @Override
498 public int getColumnDimension() {
499 return qrt.length;
500 }
501
502 }
503 }