View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.linear;
24  
25  import org.hipparchus.exception.LocalizedCoreFormats;
26  import org.hipparchus.exception.MathIllegalArgumentException;
27  import org.hipparchus.util.FastMath;
28  import org.junit.jupiter.api.Test;
29  
30  import java.util.Arrays;
31  
32  import static org.junit.jupiter.api.Assertions.assertEquals;
33  import static org.junit.jupiter.api.Assertions.assertTrue;
34  import static org.junit.jupiter.api.Assertions.fail;
35  
36  class TriDiagonalTransformerTest {
37  
38      private double[][] testSquare5 = {
39              { 1, 2, 3, 1, 1 },
40              { 2, 1, 1, 3, 1 },
41              { 3, 1, 1, 1, 2 },
42              { 1, 3, 1, 2, 1 },
43              { 1, 1, 2, 1, 3 }
44      };
45  
46      private double[][] testSquare3 = {
47              { 1, 3, 4 },
48              { 3, 2, 2 },
49              { 4, 2, 0 }
50      };
51  
52      @Test
53      void testNonSquare() {
54          try {
55              new TriDiagonalTransformer(MatrixUtils.createRealMatrix(new double[3][2]));
56              fail("an exception should have been thrown");
57          } catch (MathIllegalArgumentException ime) {
58              assertEquals(LocalizedCoreFormats.NON_SQUARE_MATRIX, ime.getSpecifier());
59          }
60      }
61  
62      @Test
63      void testAEqualQTQt() {
64          checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare5));
65          checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare3));
66      }
67  
68      private void checkAEqualQTQt(RealMatrix matrix) {
69          TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
70          RealMatrix q  = transformer.getQ();
71          RealMatrix qT = transformer.getQT();
72          RealMatrix t  = transformer.getT();
73          double norm = q.multiply(t).multiply(qT).subtract(matrix).getNorm1();
74          assertEquals(0, norm, 4.0e-15);
75      }
76  
77      @Test
78      void testNoAccessBelowDiagonal() {
79          checkNoAccessBelowDiagonal(testSquare5);
80          checkNoAccessBelowDiagonal(testSquare3);
81      }
82  
83      private void checkNoAccessBelowDiagonal(double[][] data) {
84          double[][] modifiedData = new double[data.length][];
85          for (int i = 0; i < data.length; ++i) {
86              modifiedData[i] = data[i].clone();
87              Arrays.fill(modifiedData[i], 0, i, Double.NaN);
88          }
89          RealMatrix matrix = MatrixUtils.createRealMatrix(modifiedData);
90          TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
91          RealMatrix q  = transformer.getQ();
92          RealMatrix qT = transformer.getQT();
93          RealMatrix t  = transformer.getT();
94          double norm = q.multiply(t).multiply(qT).subtract(MatrixUtils.createRealMatrix(data)).getNorm1();
95          assertEquals(0, norm, 4.0e-15);
96      }
97  
98      @Test
99      void testQOrthogonal() {
100         checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQ());
101         checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQ());
102     }
103 
104     @Test
105     void testQTOrthogonal() {
106         checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQT());
107         checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQT());
108     }
109 
110     private void checkOrthogonal(RealMatrix m) {
111         RealMatrix mTm = m.transposeMultiply(m);
112         RealMatrix id  = MatrixUtils.createRealIdentityMatrix(mTm.getRowDimension());
113         assertEquals(0, mTm.subtract(id).getNorm1(), 1.0e-15);
114     }
115 
116     @Test
117     void testTTriDiagonal() {
118         checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getT());
119         checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getT());
120     }
121 
122     private void checkTriDiagonal(RealMatrix m) {
123         final int rows = m.getRowDimension();
124         final int cols = m.getColumnDimension();
125         for (int i = 0; i < rows; ++i) {
126             for (int j = 0; j < cols; ++j) {
127                 if ((i < j - 1) || (i > j + 1)) {
128                     assertEquals(0, m.getEntry(i, j), 1.0e-16);
129                 }
130             }
131         }
132     }
133 
134     @Test
135     void testMatricesValues5() {
136         checkMatricesValues(testSquare5,
137                             new double[][] {
138                                 { 1.0,  0.0,                 0.0,                  0.0,                   0.0 },
139                                 { 0.0, -0.5163977794943222,  0.016748280772542083, 0.839800693771262,     0.16669620021405473 },
140                                 { 0.0, -0.7745966692414833, -0.4354553000860955,  -0.44989322880603355,  -0.08930153582895772 },
141                                 { 0.0, -0.2581988897471611,  0.6364346693566014,  -0.30263204032131164,   0.6608313651342882 },
142                                 { 0.0, -0.2581988897471611,  0.6364346693566009,  -0.027289660803112598, -0.7263191580755246 }
143                             },
144                             new double[] { 1, 4.4, 1.433099579242636, -0.89537362758743, 2.062274048344794 },
145                             new double[] { -FastMath.sqrt(15), -3.0832882879592476, 0.6082710842351517, 1.1786086405912128 });
146     }
147 
148     @Test
149     void testMatricesValues3() {
150         checkMatricesValues(testSquare3,
151                             new double[][] {
152                                 {  1.0,  0.0,  0.0 },
153                                 {  0.0, -0.6,  0.8 },
154                                 {  0.0, -0.8, -0.6 },
155                             },
156                             new double[] { 1, 2.64, -0.64 },
157                             new double[] { -5, -1.52 });
158     }
159 
160     private void checkMatricesValues(double[][] matrix, double[][] qRef,
161                                      double[] mainDiagnonal,
162                                      double[] secondaryDiagonal) {
163         TriDiagonalTransformer transformer =
164             new TriDiagonalTransformer(MatrixUtils.createRealMatrix(matrix));
165 
166         // check values against known references
167         RealMatrix q = transformer.getQ();
168         assertEquals(0, q.subtract(MatrixUtils.createRealMatrix(qRef)).getNorm1(), 1.0e-14);
169 
170         RealMatrix t = transformer.getT();
171         double[][] tData = new double[mainDiagnonal.length][mainDiagnonal.length];
172         for (int i = 0; i < mainDiagnonal.length; ++i) {
173             tData[i][i] = mainDiagnonal[i];
174             if (i > 0) {
175                 tData[i][i - 1] = secondaryDiagonal[i - 1];
176             }
177             if (i < secondaryDiagonal.length) {
178                 tData[i][i + 1] = secondaryDiagonal[i];
179             }
180         }
181         assertEquals(0, t.subtract(MatrixUtils.createRealMatrix(tData)).getNorm1(), 1.0e-14);
182 
183         // check the same cached instance is returned the second time
184         assertTrue(q == transformer.getQ());
185         assertTrue(t == transformer.getT());
186     }
187 }