View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.linear;
23  
24  import org.hipparchus.UnitTestUtils;
25  import org.hipparchus.exception.LocalizedCoreFormats;
26  import org.hipparchus.exception.MathIllegalArgumentException;
27  import org.hipparchus.exception.NullArgumentException;
28  import org.hipparchus.random.RandomGenerator;
29  import org.hipparchus.random.Well1024a;
30  import org.hipparchus.util.Precision;
31  import org.junit.jupiter.api.Test;
32  
33  import static org.junit.jupiter.api.Assertions.assertEquals;
34  import static org.junit.jupiter.api.Assertions.assertFalse;
35  import static org.junit.jupiter.api.Assertions.assertThrows;
36  import static org.junit.jupiter.api.Assertions.assertTrue;
37  import static org.junit.jupiter.api.Assertions.fail;
38  
39  /**
40   * Test cases for the {@link DiagonalMatrix} class.
41   */
42  class DiagonalMatrixTest {
43      @Test
44      void testConstructor1() {
45          final int dim = 3;
46          final DiagonalMatrix m = new DiagonalMatrix(dim);
47          assertEquals(dim, m.getRowDimension());
48          assertEquals(dim, m.getColumnDimension());
49      }
50  
51      @Test
52      void testConstructor2() {
53          final double[] d = { -1.2, 3.4, 5 };
54          final DiagonalMatrix m = new DiagonalMatrix(d);
55          for (int i = 0; i < m.getRowDimension(); i++) {
56              for (int j = 0; j < m.getRowDimension(); j++) {
57                  if (i == j) {
58                      assertEquals(d[i], m.getEntry(i, j), 0d);
59                  } else {
60                      assertEquals(0d, m.getEntry(i, j), 0d);
61                  }
62              }
63          }
64  
65          // Check that the underlying was copied.
66          d[0] = 0;
67          assertFalse(d[0] == m.getEntry(0, 0));
68      }
69  
70      @Test
71      void testConstructor3() {
72          final double[] d = { -1.2, 3.4, 5 };
73          final DiagonalMatrix m = new DiagonalMatrix(d, false);
74          for (int i = 0; i < m.getRowDimension(); i++) {
75              for (int j = 0; j < m.getRowDimension(); j++) {
76                  if (i == j) {
77                      assertEquals(d[i], m.getEntry(i, j), 0d);
78                  } else {
79                      assertEquals(0d, m.getEntry(i, j), 0d);
80                  }
81              }
82          }
83  
84          // Check that the underlying is referenced.
85          d[0] = 0;
86          assertEquals(d[0], m.getEntry(0, 0));
87  
88      }
89  
90      @Test
91      void testCreateError() {
92          assertThrows(MathIllegalArgumentException.class, () -> {
93              final double[] d = {-1.2, 3.4, 5};
94              final DiagonalMatrix m = new DiagonalMatrix(d, false);
95              m.createMatrix(5, 3);
96          });
97      }
98  
99      @Test
100     void testCreate() {
101         final double[] d = { -1.2, 3.4, 5 };
102         final DiagonalMatrix m = new DiagonalMatrix(d, false);
103         final RealMatrix p = m.createMatrix(5, 5);
104         assertTrue(p instanceof DiagonalMatrix);
105         assertEquals(5, p.getRowDimension());
106         assertEquals(5, p.getColumnDimension());
107     }
108 
109     @Test
110     void testCopy() {
111         final double[] d = { -1.2, 3.4, 5 };
112         final DiagonalMatrix m = new DiagonalMatrix(d, false);
113         final DiagonalMatrix p = (DiagonalMatrix) m.copy();
114         for (int i = 0; i < m.getRowDimension(); ++i) {
115             assertEquals(m.getEntry(i, i), p.getEntry(i, i), 1.0e-20);
116         }
117     }
118 
119     @Test
120     void testGetData() {
121         final double[] data = { -1.2, 3.4, 5 };
122         final int dim = 3;
123         final DiagonalMatrix m = new DiagonalMatrix(dim);
124         for (int i = 0; i < dim; i++) {
125             m.setEntry(i, i, data[i]);
126         }
127 
128         final double[][] out = m.getData();
129         assertEquals(dim, out.length);
130         for (int i = 0; i < m.getRowDimension(); i++) {
131             assertEquals(dim, out[i].length);
132             for (int j = 0; j < m.getRowDimension(); j++) {
133                 if (i == j) {
134                     assertEquals(data[i], out[i][j], 0d);
135                 } else {
136                     assertEquals(0d, out[i][j], 0d);
137                 }
138             }
139         }
140     }
141 
142     @Test
143     void testAdd() {
144         final double[] data1 = { -1.2, 3.4, 5 };
145         final DiagonalMatrix m1 = new DiagonalMatrix(data1);
146 
147         final double[] data2 = { 10.1, 2.3, 45 };
148         final DiagonalMatrix m2 = new DiagonalMatrix(data2);
149 
150         final DiagonalMatrix result = m1.add(m2);
151         assertEquals(m1.getRowDimension(), result.getRowDimension());
152         for (int i = 0; i < result.getRowDimension(); i++) {
153             for (int j = 0; j < result.getRowDimension(); j++) {
154                 if (i == j) {
155                     assertEquals(data1[i] + data2[i], result.getEntry(i, j), 0d);
156                 } else {
157                     assertEquals(0d, result.getEntry(i, j), 0d);
158                 }
159             }
160         }
161     }
162 
163     @Test
164     void testSubtract() {
165         final double[] data1 = { -1.2, 3.4, 5 };
166         final DiagonalMatrix m1 = new DiagonalMatrix(data1);
167 
168         final double[] data2 = { 10.1, 2.3, 45 };
169         final DiagonalMatrix m2 = new DiagonalMatrix(data2);
170 
171         final DiagonalMatrix result = m1.subtract(m2);
172         assertEquals(m1.getRowDimension(), result.getRowDimension());
173         for (int i = 0; i < result.getRowDimension(); i++) {
174             for (int j = 0; j < result.getRowDimension(); j++) {
175                 if (i == j) {
176                     assertEquals(data1[i] - data2[i], result.getEntry(i, j), 0d);
177                 } else {
178                     assertEquals(0d, result.getEntry(i, j), 0d);
179                 }
180             }
181         }
182     }
183 
184     @Test
185     void testAddToEntry() {
186         final double[] data = { -1.2, 3.4, 5 };
187         final DiagonalMatrix m = new DiagonalMatrix(data);
188 
189         for (int i = 0; i < m.getRowDimension(); i++) {
190             m.addToEntry(i, i, i);
191             assertEquals(data[i] + i, m.getEntry(i, i), 0d);
192         }
193     }
194 
195     @Test
196     void testMultiplyEntry() {
197         final double[] data = { -1.2, 3.4, 5 };
198         final DiagonalMatrix m = new DiagonalMatrix(data);
199 
200         for (int i = 0; i < m.getRowDimension(); i++) {
201             m.multiplyEntry(i, i, i);
202             assertEquals(data[i] * i, m.getEntry(i, i), 0d);
203         }
204     }
205 
206     @Test
207     void testMultiply1() {
208         final double[] data1 = { -1.2, 3.4, 5 };
209         final DiagonalMatrix m1 = new DiagonalMatrix(data1);
210         final double[] data2 = { 10.1, 2.3, 45 };
211         final DiagonalMatrix m2 = new DiagonalMatrix(data2);
212 
213         final DiagonalMatrix result = (DiagonalMatrix) m1.multiply((RealMatrix) m2);
214         assertEquals(m1.getRowDimension(), result.getRowDimension());
215         for (int i = 0; i < result.getRowDimension(); i++) {
216             for (int j = 0; j < result.getRowDimension(); j++) {
217                 if (i == j) {
218                     assertEquals(data1[i] * data2[i], result.getEntry(i, j), 0d);
219                 } else {
220                     assertEquals(0d, result.getEntry(i, j), 0d);
221                 }
222             }
223         }
224     }
225 
226     @Test
227     void testMultiply2() {
228         final double[] data1 = { -1.2, 3.4, 5 };
229         final DiagonalMatrix diag1 = new DiagonalMatrix(data1);
230 
231         final double[][] data2 = { { -1.2, 3.4 },
232                                    { -5.6, 7.8 },
233                                    {  9.1, 2.3 } };
234         final RealMatrix dense2 = new Array2DRowRealMatrix(data2);
235         final RealMatrix dense1 = new Array2DRowRealMatrix(diag1.getData());
236 
237         final RealMatrix diagResult = diag1.multiply(dense2);
238         final RealMatrix denseResult = dense1.multiply(dense2);
239 
240         for (int i = 0; i < dense1.getRowDimension(); i++) {
241             for (int j = 0; j < dense2.getColumnDimension(); j++) {
242                 assertEquals(denseResult.getEntry(i, j),
243                                     diagResult.getEntry(i, j), 0d);
244             }
245         }
246     }
247 
248     @Test
249     void testMultiplyTransposedDiagonalMatrix() {
250         RandomGenerator randomGenerator = new Well1024a(0x4b20cb5a0440c929l);
251         for (int rows = 1; rows <= 64; rows += 7) {
252             final DiagonalMatrix a = new DiagonalMatrix(rows);
253             for (int i = 0; i < rows; ++i) {
254                 a.setEntry(i, i, randomGenerator.nextDouble());
255             }
256             final DiagonalMatrix b = new DiagonalMatrix(rows);
257             for (int i = 0; i < rows; ++i) {
258                 b.setEntry(i, i, randomGenerator.nextDouble());
259             }
260             assertEquals(0.0,
261                                 a.multiplyTransposed(b).subtract(a.multiply(b.transpose())).getNorm1(),
262                                 1.0e-15);
263         }
264     }
265 
266     @Test
267     void testMultiplyTransposedArray2DRowRealMatrix() {
268         RandomGenerator randomGenerator = new Well1024a(0x0fa7b97d4826cd43l);
269         final RealMatrixChangingVisitor randomSetter = new DefaultRealMatrixChangingVisitor() {
270             public double visit(final int row, final int column, final double value) {
271                 return randomGenerator.nextDouble();
272             }
273         };
274         for (int rows = 1; rows <= 64; rows += 7) {
275             final DiagonalMatrix a = new DiagonalMatrix(rows);
276             for (int i = 0; i < rows; ++i) {
277                 a.setEntry(i, i, randomGenerator.nextDouble());
278             }
279             for (int interm = 1; interm <= 64; interm += 7) {
280                 final Array2DRowRealMatrix b = new Array2DRowRealMatrix(interm, rows);
281                 b.walkInOptimizedOrder(randomSetter);
282                 assertEquals(0.0,
283                                     a.multiplyTransposed(b).subtract(a.multiply(b.transpose())).getNorm1(),
284                                     1.0e-15);
285             }
286         }
287     }
288 
289     @Test
290     void testMultiplyTransposedWrongDimensions() {
291         try {
292             new DiagonalMatrix(3).multiplyTransposed(new DiagonalMatrix(2));
293             fail("an exception should have been thrown");
294         } catch (MathIllegalArgumentException miae) {
295             assertEquals(LocalizedCoreFormats.DIMENSIONS_MISMATCH, miae.getSpecifier());
296             assertEquals(3, ((Integer) miae.getParts()[0]).intValue());
297             assertEquals(2, ((Integer) miae.getParts()[1]).intValue());
298         }
299     }
300 
301     @Test
302     void testTransposeMultiplyDiagonalMatrix() {
303         RandomGenerator randomGenerator = new Well1024a(0x4b20cb5a0440c929l);
304         for (int rows = 1; rows <= 64; rows += 7) {
305             final DiagonalMatrix a = new DiagonalMatrix(rows);
306             for (int i = 0; i < rows; ++i) {
307                 a.setEntry(i, i, randomGenerator.nextDouble());
308             }
309             final DiagonalMatrix b = new DiagonalMatrix(rows);
310             for (int i = 0; i < rows; ++i) {
311                 b.setEntry(i, i, randomGenerator.nextDouble());
312             }
313             assertEquals(0.0,
314                                 a.transposeMultiply(b).subtract(a.transpose().multiply(b)).getNorm1(),
315                                 1.0e-15);
316         }
317     }
318 
319     @Test
320     void testTransposeMultiplyArray2DRowRealMatrix() {
321         RandomGenerator randomGenerator = new Well1024a(0x0fa7b97d4826cd43l);
322         final RealMatrixChangingVisitor randomSetter = new DefaultRealMatrixChangingVisitor() {
323             public double visit(final int row, final int column, final double value) {
324                 return randomGenerator.nextDouble();
325             }
326         };
327         for (int rows = 1; rows <= 64; rows += 7) {
328             final DiagonalMatrix a = new DiagonalMatrix(rows);
329             for (int i = 0; i < rows; ++i) {
330                 a.setEntry(i, i, randomGenerator.nextDouble());
331             }
332             for (int interm = 1; interm <= 64; interm += 7) {
333                 final Array2DRowRealMatrix b = new Array2DRowRealMatrix(rows, interm);
334                 b.walkInOptimizedOrder(randomSetter);
335                 assertEquals(0.0,
336                                     a.transposeMultiply(b).subtract(a.transpose().multiply(b)).getNorm1(),
337                                     1.0e-15);
338             }
339         }
340     }
341 
342     @Test
343     void testTransposeMultiplyWrongDimensions() {
344         try {
345             new DiagonalMatrix(3).transposeMultiply(new DiagonalMatrix(2));
346             fail("an exception should have been thrown");
347         } catch (MathIllegalArgumentException miae) {
348             assertEquals(LocalizedCoreFormats.DIMENSIONS_MISMATCH, miae.getSpecifier());
349             assertEquals(3, ((Integer) miae.getParts()[0]).intValue());
350             assertEquals(2, ((Integer) miae.getParts()[1]).intValue());
351         }
352     }
353 
354     @Test
355     void testOperate() {
356         final double[] data = { -1.2, 3.4, 5 };
357         final DiagonalMatrix diag = new DiagonalMatrix(data);
358         final RealMatrix dense = new Array2DRowRealMatrix(diag.getData());
359 
360         final double[] v = { 6.7, 890.1, 23.4 };
361         final double[] diagResult = diag.operate(v);
362         final double[] denseResult = dense.operate(v);
363 
364         UnitTestUtils.customAssertEquals(diagResult, denseResult, 0d);
365     }
366 
367     @Test
368     void testPreMultiply() {
369         final double[] data = { -1.2, 3.4, 5 };
370         final DiagonalMatrix diag = new DiagonalMatrix(data);
371         final RealMatrix dense = new Array2DRowRealMatrix(diag.getData());
372 
373         final double[] v = { 6.7, 890.1, 23.4 };
374         final double[] diagResult = diag.preMultiply(v);
375         final double[] denseResult = dense.preMultiply(v);
376 
377         UnitTestUtils.customAssertEquals(diagResult, denseResult, 0d);
378     }
379 
380     @Test
381     void testPreMultiplyVector() {
382         final double[] data = { -1.2, 3.4, 5 };
383         final DiagonalMatrix diag = new DiagonalMatrix(data);
384         final RealMatrix dense = new Array2DRowRealMatrix(diag.getData());
385 
386         final double[] v = { 6.7, 890.1, 23.4 };
387         final RealVector vector = MatrixUtils.createRealVector(v);
388         final RealVector diagResult = diag.preMultiply(vector);
389         final RealVector denseResult = dense.preMultiply(vector);
390 
391         UnitTestUtils.customAssertEquals("preMultiply(Vector) returns wrong result", diagResult, denseResult, 0d);
392     }
393 
394     @Test
395     void testSetNonDiagonalEntry() {
396         assertThrows(MathIllegalArgumentException.class, () -> {
397             final DiagonalMatrix diag = new DiagonalMatrix(3);
398             diag.setEntry(1, 2, 3.4);
399         });
400     }
401 
402     @Test
403     void testSetNonDiagonalZero() {
404         final DiagonalMatrix diag = new DiagonalMatrix(3);
405         diag.setEntry(1, 2, 0.0);
406         assertEquals(0.0, diag.getEntry(1, 2), Precision.SAFE_MIN);
407     }
408 
409     @Test
410     void testAddNonDiagonalEntry() {
411         assertThrows(MathIllegalArgumentException.class, () -> {
412             final DiagonalMatrix diag = new DiagonalMatrix(3);
413             diag.addToEntry(1, 2, 3.4);
414         });
415     }
416 
417     @Test
418     void testAddNonDiagonalZero() {
419         final DiagonalMatrix diag = new DiagonalMatrix(3);
420         diag.addToEntry(1, 2, 0.0);
421         assertEquals(0.0, diag.getEntry(1, 2), Precision.SAFE_MIN);
422     }
423 
424     @Test
425     void testMultiplyNonDiagonalEntry() {
426         final DiagonalMatrix diag = new DiagonalMatrix(3);
427         diag.multiplyEntry(1, 2, 3.4);
428         assertEquals(0.0, diag.getEntry(1, 2), Precision.SAFE_MIN);
429     }
430 
431     @Test
432     void testMultiplyNonDiagonalZero() {
433         final DiagonalMatrix diag = new DiagonalMatrix(3);
434         diag.multiplyEntry(1, 2, 0.0);
435         assertEquals(0.0, diag.getEntry(1, 2), Precision.SAFE_MIN);
436     }
437 
438     @Test
439     void testSetEntryOutOfRange() {
440         assertThrows(MathIllegalArgumentException.class, () -> {
441             final DiagonalMatrix diag = new DiagonalMatrix(3);
442             diag.setEntry(3, 3, 3.4);
443         });
444     }
445 
446     @Test
447     void testNull() {
448         assertThrows(NullArgumentException.class, () -> {
449             new DiagonalMatrix(null, false);
450         });
451     }
452 
453     @Test
454     void testSetSubMatrixError() {
455         assertThrows(MathIllegalArgumentException.class, () -> {
456             final double[] data = {-1.2, 3.4, 5};
457             final DiagonalMatrix diag = new DiagonalMatrix(data);
458             diag.setSubMatrix(new double[][]{{1.0, 1.0}, {1.0, 1.0}}, 1, 1);
459         });
460     }
461 
462     @Test
463     void testSetSubMatrix() {
464         final double[] data = { -1.2, 3.4, 5 };
465         final DiagonalMatrix diag = new DiagonalMatrix(data);
466         diag.setSubMatrix(new double[][] { {0.0, 5.0, 0.0}, {0.0, 0.0, 6.0}}, 1, 0);
467         assertEquals(-1.2, diag.getEntry(0, 0), 1.0e-20);
468         assertEquals( 5.0, diag.getEntry(1, 1), 1.0e-20);
469         assertEquals( 6.0, diag.getEntry(2, 2), 1.0e-20);
470     }
471 
472     @Test
473     void testInverseError() {
474         assertThrows(MathIllegalArgumentException.class, () -> {
475             final double[] data = {1, 2, 0};
476             final DiagonalMatrix diag = new DiagonalMatrix(data);
477             diag.inverse();
478         });
479     }
480 
481     @Test
482     void testInverseError2() {
483         assertThrows(MathIllegalArgumentException.class, () -> {
484             final double[] data = {1, 2, 1e-6};
485             final DiagonalMatrix diag = new DiagonalMatrix(data);
486             diag.inverse(1e-5);
487         });
488     }
489 
490     @Test
491     void testInverse() {
492         final double[] data = { 1, 2, 3 };
493         final DiagonalMatrix m = new DiagonalMatrix(data);
494         final DiagonalMatrix inverse = m.inverse();
495 
496         final DiagonalMatrix result = m.multiply(inverse);
497         UnitTestUtils.customAssertEquals("DiagonalMatrix.inverse() returns wrong result",
498                                          MatrixUtils.createRealIdentityMatrix(data.length), result, Math.ulp(1d));
499     }
500 
501 }