1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.linear;
24
25 import org.hipparchus.exception.LocalizedCoreFormats;
26 import org.hipparchus.exception.MathIllegalArgumentException;
27 import org.hipparchus.util.FastMath;
28 import org.junit.jupiter.api.Test;
29
30 import java.util.Arrays;
31
32 import static org.junit.jupiter.api.Assertions.assertEquals;
33 import static org.junit.jupiter.api.Assertions.assertTrue;
34 import static org.junit.jupiter.api.Assertions.fail;
35
36 class TriDiagonalTransformerTest {
37
38 private double[][] testSquare5 = {
39 { 1, 2, 3, 1, 1 },
40 { 2, 1, 1, 3, 1 },
41 { 3, 1, 1, 1, 2 },
42 { 1, 3, 1, 2, 1 },
43 { 1, 1, 2, 1, 3 }
44 };
45
46 private double[][] testSquare3 = {
47 { 1, 3, 4 },
48 { 3, 2, 2 },
49 { 4, 2, 0 }
50 };
51
52 @Test
53 void testNonSquare() {
54 try {
55 new TriDiagonalTransformer(MatrixUtils.createRealMatrix(new double[3][2]));
56 fail("an exception should have been thrown");
57 } catch (MathIllegalArgumentException ime) {
58 assertEquals(LocalizedCoreFormats.NON_SQUARE_MATRIX, ime.getSpecifier());
59 }
60 }
61
62 @Test
63 void testAEqualQTQt() {
64 checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare5));
65 checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare3));
66 }
67
68 private void checkAEqualQTQt(RealMatrix matrix) {
69 TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
70 RealMatrix q = transformer.getQ();
71 RealMatrix qT = transformer.getQT();
72 RealMatrix t = transformer.getT();
73 double norm = q.multiply(t).multiply(qT).subtract(matrix).getNorm1();
74 assertEquals(0, norm, 4.0e-15);
75 }
76
77 @Test
78 void testNoAccessBelowDiagonal() {
79 checkNoAccessBelowDiagonal(testSquare5);
80 checkNoAccessBelowDiagonal(testSquare3);
81 }
82
83 private void checkNoAccessBelowDiagonal(double[][] data) {
84 double[][] modifiedData = new double[data.length][];
85 for (int i = 0; i < data.length; ++i) {
86 modifiedData[i] = data[i].clone();
87 Arrays.fill(modifiedData[i], 0, i, Double.NaN);
88 }
89 RealMatrix matrix = MatrixUtils.createRealMatrix(modifiedData);
90 TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
91 RealMatrix q = transformer.getQ();
92 RealMatrix qT = transformer.getQT();
93 RealMatrix t = transformer.getT();
94 double norm = q.multiply(t).multiply(qT).subtract(MatrixUtils.createRealMatrix(data)).getNorm1();
95 assertEquals(0, norm, 4.0e-15);
96 }
97
98 @Test
99 void testQOrthogonal() {
100 checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQ());
101 checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQ());
102 }
103
104 @Test
105 void testQTOrthogonal() {
106 checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQT());
107 checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQT());
108 }
109
110 private void checkOrthogonal(RealMatrix m) {
111 RealMatrix mTm = m.transposeMultiply(m);
112 RealMatrix id = MatrixUtils.createRealIdentityMatrix(mTm.getRowDimension());
113 assertEquals(0, mTm.subtract(id).getNorm1(), 1.0e-15);
114 }
115
116 @Test
117 void testTTriDiagonal() {
118 checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getT());
119 checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getT());
120 }
121
122 private void checkTriDiagonal(RealMatrix m) {
123 final int rows = m.getRowDimension();
124 final int cols = m.getColumnDimension();
125 for (int i = 0; i < rows; ++i) {
126 for (int j = 0; j < cols; ++j) {
127 if ((i < j - 1) || (i > j + 1)) {
128 assertEquals(0, m.getEntry(i, j), 1.0e-16);
129 }
130 }
131 }
132 }
133
134 @Test
135 void testMatricesValues5() {
136 checkMatricesValues(testSquare5,
137 new double[][] {
138 { 1.0, 0.0, 0.0, 0.0, 0.0 },
139 { 0.0, -0.5163977794943222, 0.016748280772542083, 0.839800693771262, 0.16669620021405473 },
140 { 0.0, -0.7745966692414833, -0.4354553000860955, -0.44989322880603355, -0.08930153582895772 },
141 { 0.0, -0.2581988897471611, 0.6364346693566014, -0.30263204032131164, 0.6608313651342882 },
142 { 0.0, -0.2581988897471611, 0.6364346693566009, -0.027289660803112598, -0.7263191580755246 }
143 },
144 new double[] { 1, 4.4, 1.433099579242636, -0.89537362758743, 2.062274048344794 },
145 new double[] { -FastMath.sqrt(15), -3.0832882879592476, 0.6082710842351517, 1.1786086405912128 });
146 }
147
148 @Test
149 void testMatricesValues3() {
150 checkMatricesValues(testSquare3,
151 new double[][] {
152 { 1.0, 0.0, 0.0 },
153 { 0.0, -0.6, 0.8 },
154 { 0.0, -0.8, -0.6 },
155 },
156 new double[] { 1, 2.64, -0.64 },
157 new double[] { -5, -1.52 });
158 }
159
160 private void checkMatricesValues(double[][] matrix, double[][] qRef,
161 double[] mainDiagnonal,
162 double[] secondaryDiagonal) {
163 TriDiagonalTransformer transformer =
164 new TriDiagonalTransformer(MatrixUtils.createRealMatrix(matrix));
165
166
167 RealMatrix q = transformer.getQ();
168 assertEquals(0, q.subtract(MatrixUtils.createRealMatrix(qRef)).getNorm1(), 1.0e-14);
169
170 RealMatrix t = transformer.getT();
171 double[][] tData = new double[mainDiagnonal.length][mainDiagnonal.length];
172 for (int i = 0; i < mainDiagnonal.length; ++i) {
173 tData[i][i] = mainDiagnonal[i];
174 if (i > 0) {
175 tData[i][i - 1] = secondaryDiagonal[i - 1];
176 }
177 if (i < secondaryDiagonal.length) {
178 tData[i][i + 1] = secondaryDiagonal[i];
179 }
180 }
181 assertEquals(0, t.subtract(MatrixUtils.createRealMatrix(tData)).getNorm1(), 1.0e-14);
182
183
184 assertTrue(q == transformer.getQ());
185 assertTrue(t == transformer.getT());
186 }
187 }