1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.linear;
23
24 import org.hipparchus.UnitTestUtils;
25 import org.hipparchus.exception.LocalizedCoreFormats;
26 import org.hipparchus.exception.MathIllegalArgumentException;
27 import org.hipparchus.exception.NullArgumentException;
28 import org.hipparchus.random.RandomGenerator;
29 import org.hipparchus.random.Well1024a;
30 import org.hipparchus.util.Precision;
31 import org.junit.jupiter.api.Test;
32
33 import static org.junit.jupiter.api.Assertions.assertEquals;
34 import static org.junit.jupiter.api.Assertions.assertFalse;
35 import static org.junit.jupiter.api.Assertions.assertThrows;
36 import static org.junit.jupiter.api.Assertions.assertTrue;
37 import static org.junit.jupiter.api.Assertions.fail;
38
39
40
41
42 class DiagonalMatrixTest {
43 @Test
44 void testConstructor1() {
45 final int dim = 3;
46 final DiagonalMatrix m = new DiagonalMatrix(dim);
47 assertEquals(dim, m.getRowDimension());
48 assertEquals(dim, m.getColumnDimension());
49 }
50
51 @Test
52 void testConstructor2() {
53 final double[] d = { -1.2, 3.4, 5 };
54 final DiagonalMatrix m = new DiagonalMatrix(d);
55 for (int i = 0; i < m.getRowDimension(); i++) {
56 for (int j = 0; j < m.getRowDimension(); j++) {
57 if (i == j) {
58 assertEquals(d[i], m.getEntry(i, j), 0d);
59 } else {
60 assertEquals(0d, m.getEntry(i, j), 0d);
61 }
62 }
63 }
64
65
66 d[0] = 0;
67 assertFalse(d[0] == m.getEntry(0, 0));
68 }
69
70 @Test
71 void testConstructor3() {
72 final double[] d = { -1.2, 3.4, 5 };
73 final DiagonalMatrix m = new DiagonalMatrix(d, false);
74 for (int i = 0; i < m.getRowDimension(); i++) {
75 for (int j = 0; j < m.getRowDimension(); j++) {
76 if (i == j) {
77 assertEquals(d[i], m.getEntry(i, j), 0d);
78 } else {
79 assertEquals(0d, m.getEntry(i, j), 0d);
80 }
81 }
82 }
83
84
85 d[0] = 0;
86 assertEquals(d[0], m.getEntry(0, 0));
87
88 }
89
90 @Test
91 void testCreateError() {
92 assertThrows(MathIllegalArgumentException.class, () -> {
93 final double[] d = {-1.2, 3.4, 5};
94 final DiagonalMatrix m = new DiagonalMatrix(d, false);
95 m.createMatrix(5, 3);
96 });
97 }
98
99 @Test
100 void testCreate() {
101 final double[] d = { -1.2, 3.4, 5 };
102 final DiagonalMatrix m = new DiagonalMatrix(d, false);
103 final RealMatrix p = m.createMatrix(5, 5);
104 assertTrue(p instanceof DiagonalMatrix);
105 assertEquals(5, p.getRowDimension());
106 assertEquals(5, p.getColumnDimension());
107 }
108
109 @Test
110 void testCopy() {
111 final double[] d = { -1.2, 3.4, 5 };
112 final DiagonalMatrix m = new DiagonalMatrix(d, false);
113 final DiagonalMatrix p = (DiagonalMatrix) m.copy();
114 for (int i = 0; i < m.getRowDimension(); ++i) {
115 assertEquals(m.getEntry(i, i), p.getEntry(i, i), 1.0e-20);
116 }
117 }
118
119 @Test
120 void testGetData() {
121 final double[] data = { -1.2, 3.4, 5 };
122 final int dim = 3;
123 final DiagonalMatrix m = new DiagonalMatrix(dim);
124 for (int i = 0; i < dim; i++) {
125 m.setEntry(i, i, data[i]);
126 }
127
128 final double[][] out = m.getData();
129 assertEquals(dim, out.length);
130 for (int i = 0; i < m.getRowDimension(); i++) {
131 assertEquals(dim, out[i].length);
132 for (int j = 0; j < m.getRowDimension(); j++) {
133 if (i == j) {
134 assertEquals(data[i], out[i][j], 0d);
135 } else {
136 assertEquals(0d, out[i][j], 0d);
137 }
138 }
139 }
140 }
141
142 @Test
143 void testAdd() {
144 final double[] data1 = { -1.2, 3.4, 5 };
145 final DiagonalMatrix m1 = new DiagonalMatrix(data1);
146
147 final double[] data2 = { 10.1, 2.3, 45 };
148 final DiagonalMatrix m2 = new DiagonalMatrix(data2);
149
150 final DiagonalMatrix result = m1.add(m2);
151 assertEquals(m1.getRowDimension(), result.getRowDimension());
152 for (int i = 0; i < result.getRowDimension(); i++) {
153 for (int j = 0; j < result.getRowDimension(); j++) {
154 if (i == j) {
155 assertEquals(data1[i] + data2[i], result.getEntry(i, j), 0d);
156 } else {
157 assertEquals(0d, result.getEntry(i, j), 0d);
158 }
159 }
160 }
161 }
162
163 @Test
164 void testSubtract() {
165 final double[] data1 = { -1.2, 3.4, 5 };
166 final DiagonalMatrix m1 = new DiagonalMatrix(data1);
167
168 final double[] data2 = { 10.1, 2.3, 45 };
169 final DiagonalMatrix m2 = new DiagonalMatrix(data2);
170
171 final DiagonalMatrix result = m1.subtract(m2);
172 assertEquals(m1.getRowDimension(), result.getRowDimension());
173 for (int i = 0; i < result.getRowDimension(); i++) {
174 for (int j = 0; j < result.getRowDimension(); j++) {
175 if (i == j) {
176 assertEquals(data1[i] - data2[i], result.getEntry(i, j), 0d);
177 } else {
178 assertEquals(0d, result.getEntry(i, j), 0d);
179 }
180 }
181 }
182 }
183
184 @Test
185 void testAddToEntry() {
186 final double[] data = { -1.2, 3.4, 5 };
187 final DiagonalMatrix m = new DiagonalMatrix(data);
188
189 for (int i = 0; i < m.getRowDimension(); i++) {
190 m.addToEntry(i, i, i);
191 assertEquals(data[i] + i, m.getEntry(i, i), 0d);
192 }
193 }
194
195 @Test
196 void testMultiplyEntry() {
197 final double[] data = { -1.2, 3.4, 5 };
198 final DiagonalMatrix m = new DiagonalMatrix(data);
199
200 for (int i = 0; i < m.getRowDimension(); i++) {
201 m.multiplyEntry(i, i, i);
202 assertEquals(data[i] * i, m.getEntry(i, i), 0d);
203 }
204 }
205
206 @Test
207 void testMultiply1() {
208 final double[] data1 = { -1.2, 3.4, 5 };
209 final DiagonalMatrix m1 = new DiagonalMatrix(data1);
210 final double[] data2 = { 10.1, 2.3, 45 };
211 final DiagonalMatrix m2 = new DiagonalMatrix(data2);
212
213 final DiagonalMatrix result = (DiagonalMatrix) m1.multiply((RealMatrix) m2);
214 assertEquals(m1.getRowDimension(), result.getRowDimension());
215 for (int i = 0; i < result.getRowDimension(); i++) {
216 for (int j = 0; j < result.getRowDimension(); j++) {
217 if (i == j) {
218 assertEquals(data1[i] * data2[i], result.getEntry(i, j), 0d);
219 } else {
220 assertEquals(0d, result.getEntry(i, j), 0d);
221 }
222 }
223 }
224 }
225
226 @Test
227 void testMultiply2() {
228 final double[] data1 = { -1.2, 3.4, 5 };
229 final DiagonalMatrix diag1 = new DiagonalMatrix(data1);
230
231 final double[][] data2 = { { -1.2, 3.4 },
232 { -5.6, 7.8 },
233 { 9.1, 2.3 } };
234 final RealMatrix dense2 = new Array2DRowRealMatrix(data2);
235 final RealMatrix dense1 = new Array2DRowRealMatrix(diag1.getData());
236
237 final RealMatrix diagResult = diag1.multiply(dense2);
238 final RealMatrix denseResult = dense1.multiply(dense2);
239
240 for (int i = 0; i < dense1.getRowDimension(); i++) {
241 for (int j = 0; j < dense2.getColumnDimension(); j++) {
242 assertEquals(denseResult.getEntry(i, j),
243 diagResult.getEntry(i, j), 0d);
244 }
245 }
246 }
247
248 @Test
249 void testMultiplyTransposedDiagonalMatrix() {
250 RandomGenerator randomGenerator = new Well1024a(0x4b20cb5a0440c929l);
251 for (int rows = 1; rows <= 64; rows += 7) {
252 final DiagonalMatrix a = new DiagonalMatrix(rows);
253 for (int i = 0; i < rows; ++i) {
254 a.setEntry(i, i, randomGenerator.nextDouble());
255 }
256 final DiagonalMatrix b = new DiagonalMatrix(rows);
257 for (int i = 0; i < rows; ++i) {
258 b.setEntry(i, i, randomGenerator.nextDouble());
259 }
260 assertEquals(0.0,
261 a.multiplyTransposed(b).subtract(a.multiply(b.transpose())).getNorm1(),
262 1.0e-15);
263 }
264 }
265
266 @Test
267 void testMultiplyTransposedArray2DRowRealMatrix() {
268 RandomGenerator randomGenerator = new Well1024a(0x0fa7b97d4826cd43l);
269 final RealMatrixChangingVisitor randomSetter = new DefaultRealMatrixChangingVisitor() {
270 public double visit(final int row, final int column, final double value) {
271 return randomGenerator.nextDouble();
272 }
273 };
274 for (int rows = 1; rows <= 64; rows += 7) {
275 final DiagonalMatrix a = new DiagonalMatrix(rows);
276 for (int i = 0; i < rows; ++i) {
277 a.setEntry(i, i, randomGenerator.nextDouble());
278 }
279 for (int interm = 1; interm <= 64; interm += 7) {
280 final Array2DRowRealMatrix b = new Array2DRowRealMatrix(interm, rows);
281 b.walkInOptimizedOrder(randomSetter);
282 assertEquals(0.0,
283 a.multiplyTransposed(b).subtract(a.multiply(b.transpose())).getNorm1(),
284 1.0e-15);
285 }
286 }
287 }
288
289 @Test
290 void testMultiplyTransposedWrongDimensions() {
291 try {
292 new DiagonalMatrix(3).multiplyTransposed(new DiagonalMatrix(2));
293 fail("an exception should have been thrown");
294 } catch (MathIllegalArgumentException miae) {
295 assertEquals(LocalizedCoreFormats.DIMENSIONS_MISMATCH, miae.getSpecifier());
296 assertEquals(3, ((Integer) miae.getParts()[0]).intValue());
297 assertEquals(2, ((Integer) miae.getParts()[1]).intValue());
298 }
299 }
300
301 @Test
302 void testTransposeMultiplyDiagonalMatrix() {
303 RandomGenerator randomGenerator = new Well1024a(0x4b20cb5a0440c929l);
304 for (int rows = 1; rows <= 64; rows += 7) {
305 final DiagonalMatrix a = new DiagonalMatrix(rows);
306 for (int i = 0; i < rows; ++i) {
307 a.setEntry(i, i, randomGenerator.nextDouble());
308 }
309 final DiagonalMatrix b = new DiagonalMatrix(rows);
310 for (int i = 0; i < rows; ++i) {
311 b.setEntry(i, i, randomGenerator.nextDouble());
312 }
313 assertEquals(0.0,
314 a.transposeMultiply(b).subtract(a.transpose().multiply(b)).getNorm1(),
315 1.0e-15);
316 }
317 }
318
319 @Test
320 void testTransposeMultiplyArray2DRowRealMatrix() {
321 RandomGenerator randomGenerator = new Well1024a(0x0fa7b97d4826cd43l);
322 final RealMatrixChangingVisitor randomSetter = new DefaultRealMatrixChangingVisitor() {
323 public double visit(final int row, final int column, final double value) {
324 return randomGenerator.nextDouble();
325 }
326 };
327 for (int rows = 1; rows <= 64; rows += 7) {
328 final DiagonalMatrix a = new DiagonalMatrix(rows);
329 for (int i = 0; i < rows; ++i) {
330 a.setEntry(i, i, randomGenerator.nextDouble());
331 }
332 for (int interm = 1; interm <= 64; interm += 7) {
333 final Array2DRowRealMatrix b = new Array2DRowRealMatrix(rows, interm);
334 b.walkInOptimizedOrder(randomSetter);
335 assertEquals(0.0,
336 a.transposeMultiply(b).subtract(a.transpose().multiply(b)).getNorm1(),
337 1.0e-15);
338 }
339 }
340 }
341
342 @Test
343 void testTransposeMultiplyWrongDimensions() {
344 try {
345 new DiagonalMatrix(3).transposeMultiply(new DiagonalMatrix(2));
346 fail("an exception should have been thrown");
347 } catch (MathIllegalArgumentException miae) {
348 assertEquals(LocalizedCoreFormats.DIMENSIONS_MISMATCH, miae.getSpecifier());
349 assertEquals(3, ((Integer) miae.getParts()[0]).intValue());
350 assertEquals(2, ((Integer) miae.getParts()[1]).intValue());
351 }
352 }
353
354 @Test
355 void testOperate() {
356 final double[] data = { -1.2, 3.4, 5 };
357 final DiagonalMatrix diag = new DiagonalMatrix(data);
358 final RealMatrix dense = new Array2DRowRealMatrix(diag.getData());
359
360 final double[] v = { 6.7, 890.1, 23.4 };
361 final double[] diagResult = diag.operate(v);
362 final double[] denseResult = dense.operate(v);
363
364 UnitTestUtils.customAssertEquals(diagResult, denseResult, 0d);
365 }
366
367 @Test
368 void testPreMultiply() {
369 final double[] data = { -1.2, 3.4, 5 };
370 final DiagonalMatrix diag = new DiagonalMatrix(data);
371 final RealMatrix dense = new Array2DRowRealMatrix(diag.getData());
372
373 final double[] v = { 6.7, 890.1, 23.4 };
374 final double[] diagResult = diag.preMultiply(v);
375 final double[] denseResult = dense.preMultiply(v);
376
377 UnitTestUtils.customAssertEquals(diagResult, denseResult, 0d);
378 }
379
380 @Test
381 void testPreMultiplyVector() {
382 final double[] data = { -1.2, 3.4, 5 };
383 final DiagonalMatrix diag = new DiagonalMatrix(data);
384 final RealMatrix dense = new Array2DRowRealMatrix(diag.getData());
385
386 final double[] v = { 6.7, 890.1, 23.4 };
387 final RealVector vector = MatrixUtils.createRealVector(v);
388 final RealVector diagResult = diag.preMultiply(vector);
389 final RealVector denseResult = dense.preMultiply(vector);
390
391 UnitTestUtils.customAssertEquals("preMultiply(Vector) returns wrong result", diagResult, denseResult, 0d);
392 }
393
394 @Test
395 void testSetNonDiagonalEntry() {
396 assertThrows(MathIllegalArgumentException.class, () -> {
397 final DiagonalMatrix diag = new DiagonalMatrix(3);
398 diag.setEntry(1, 2, 3.4);
399 });
400 }
401
402 @Test
403 void testSetNonDiagonalZero() {
404 final DiagonalMatrix diag = new DiagonalMatrix(3);
405 diag.setEntry(1, 2, 0.0);
406 assertEquals(0.0, diag.getEntry(1, 2), Precision.SAFE_MIN);
407 }
408
409 @Test
410 void testAddNonDiagonalEntry() {
411 assertThrows(MathIllegalArgumentException.class, () -> {
412 final DiagonalMatrix diag = new DiagonalMatrix(3);
413 diag.addToEntry(1, 2, 3.4);
414 });
415 }
416
417 @Test
418 void testAddNonDiagonalZero() {
419 final DiagonalMatrix diag = new DiagonalMatrix(3);
420 diag.addToEntry(1, 2, 0.0);
421 assertEquals(0.0, diag.getEntry(1, 2), Precision.SAFE_MIN);
422 }
423
424 @Test
425 void testMultiplyNonDiagonalEntry() {
426 final DiagonalMatrix diag = new DiagonalMatrix(3);
427 diag.multiplyEntry(1, 2, 3.4);
428 assertEquals(0.0, diag.getEntry(1, 2), Precision.SAFE_MIN);
429 }
430
431 @Test
432 void testMultiplyNonDiagonalZero() {
433 final DiagonalMatrix diag = new DiagonalMatrix(3);
434 diag.multiplyEntry(1, 2, 0.0);
435 assertEquals(0.0, diag.getEntry(1, 2), Precision.SAFE_MIN);
436 }
437
438 @Test
439 void testSetEntryOutOfRange() {
440 assertThrows(MathIllegalArgumentException.class, () -> {
441 final DiagonalMatrix diag = new DiagonalMatrix(3);
442 diag.setEntry(3, 3, 3.4);
443 });
444 }
445
446 @Test
447 void testNull() {
448 assertThrows(NullArgumentException.class, () -> {
449 new DiagonalMatrix(null, false);
450 });
451 }
452
453 @Test
454 void testSetSubMatrixError() {
455 assertThrows(MathIllegalArgumentException.class, () -> {
456 final double[] data = {-1.2, 3.4, 5};
457 final DiagonalMatrix diag = new DiagonalMatrix(data);
458 diag.setSubMatrix(new double[][]{{1.0, 1.0}, {1.0, 1.0}}, 1, 1);
459 });
460 }
461
462 @Test
463 void testSetSubMatrix() {
464 final double[] data = { -1.2, 3.4, 5 };
465 final DiagonalMatrix diag = new DiagonalMatrix(data);
466 diag.setSubMatrix(new double[][] { {0.0, 5.0, 0.0}, {0.0, 0.0, 6.0}}, 1, 0);
467 assertEquals(-1.2, diag.getEntry(0, 0), 1.0e-20);
468 assertEquals( 5.0, diag.getEntry(1, 1), 1.0e-20);
469 assertEquals( 6.0, diag.getEntry(2, 2), 1.0e-20);
470 }
471
472 @Test
473 void testInverseError() {
474 assertThrows(MathIllegalArgumentException.class, () -> {
475 final double[] data = {1, 2, 0};
476 final DiagonalMatrix diag = new DiagonalMatrix(data);
477 diag.inverse();
478 });
479 }
480
481 @Test
482 void testInverseError2() {
483 assertThrows(MathIllegalArgumentException.class, () -> {
484 final double[] data = {1, 2, 1e-6};
485 final DiagonalMatrix diag = new DiagonalMatrix(data);
486 diag.inverse(1e-5);
487 });
488 }
489
490 @Test
491 void testInverse() {
492 final double[] data = { 1, 2, 3 };
493 final DiagonalMatrix m = new DiagonalMatrix(data);
494 final DiagonalMatrix inverse = m.inverse();
495
496 final DiagonalMatrix result = m.multiply(inverse);
497 UnitTestUtils.customAssertEquals("DiagonalMatrix.inverse() returns wrong result",
498 MatrixUtils.createRealIdentityMatrix(data.length), result, Math.ulp(1d));
499 }
500
501 }