1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.linear;
23
24 import org.hipparchus.exception.MathIllegalArgumentException;
25 import org.hipparchus.exception.MathIllegalStateException;
26 import org.hipparchus.exception.MathRuntimeException;
27 import org.hipparchus.util.FastMath;
28 import org.hipparchus.util.IterationEvent;
29 import org.hipparchus.util.IterationListener;
30 import org.junit.jupiter.api.Test;
31
32 import java.util.Arrays;
33
34 import static org.junit.jupiter.api.Assertions.assertEquals;
35 import static org.junit.jupiter.api.Assertions.assertNotSame;
36 import static org.junit.jupiter.api.Assertions.assertSame;
37 import static org.junit.jupiter.api.Assertions.assertThrows;
38 import static org.junit.jupiter.api.Assertions.assertTrue;
39 import static org.junit.jupiter.api.Assertions.fail;
40
41 class ConjugateGradientTest {
42
43 @Test
44 void testNonSquareOperator() {
45 assertThrows(MathIllegalArgumentException.class, () -> {
46 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 3);
47 final IterativeLinearSolver solver;
48 solver = new ConjugateGradient(10, 0., false);
49 final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
50 final ArrayRealVector x = new ArrayRealVector(a.getColumnDimension());
51 solver.solve(a, b, x);
52 });
53 }
54
55 @Test
56 void testDimensionMismatchRightHandSide() {
57 assertThrows(MathIllegalArgumentException.class, () -> {
58 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
59 final IterativeLinearSolver solver;
60 solver = new ConjugateGradient(10, 0., false);
61 final ArrayRealVector b = new ArrayRealVector(2);
62 final ArrayRealVector x = new ArrayRealVector(3);
63 solver.solve(a, b, x);
64 });
65 }
66
67 @Test
68 void testDimensionMismatchSolution() {
69 assertThrows(MathIllegalArgumentException.class, () -> {
70 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
71 final IterativeLinearSolver solver;
72 solver = new ConjugateGradient(10, 0., false);
73 final ArrayRealVector b = new ArrayRealVector(3);
74 final ArrayRealVector x = new ArrayRealVector(2);
75 solver.solve(a, b, x);
76 });
77 }
78
79 @Test
80 void testNonPositiveDefiniteLinearOperator() {
81 assertThrows(MathIllegalArgumentException.class, () -> {
82 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
83 a.setEntry(0, 0, -1.);
84 a.setEntry(0, 1, 2.);
85 a.setEntry(1, 0, 3.);
86 a.setEntry(1, 1, 4.);
87 final IterativeLinearSolver solver;
88 solver = new ConjugateGradient(10, 0., true);
89 final ArrayRealVector b = new ArrayRealVector(2);
90 b.setEntry(0, -1.);
91 b.setEntry(1, -1.);
92 final ArrayRealVector x = new ArrayRealVector(2);
93 solver.solve(a, b, x);
94 });
95 }
96
97 @Test
98 void testUnpreconditionedSolution() {
99 final int n = 5;
100 final int maxIterations = 100;
101 final RealLinearOperator a = new HilbertMatrix(n);
102 final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
103 final IterativeLinearSolver solver;
104 solver = new ConjugateGradient(maxIterations, 1E-10, true);
105 final RealVector b = new ArrayRealVector(n);
106 for (int j = 0; j < n; j++) {
107 b.set(0.);
108 b.setEntry(j, 1.);
109 final RealVector x = solver.solve(a, b);
110 for (int i = 0; i < n; i++) {
111 final double actual = x.getEntry(i);
112 final double expected = ainv.getEntry(i, j);
113 final double delta = 1E-10 * FastMath.abs(expected);
114 final String msg = String.format("entry[%d][%d]", i, j);
115 assertEquals(expected, actual, delta, msg);
116 }
117 }
118 }
119
120 @Test
121 void testUnpreconditionedInPlaceSolutionWithInitialGuess() {
122 final int n = 5;
123 final int maxIterations = 100;
124 final RealLinearOperator a = new HilbertMatrix(n);
125 final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
126 final IterativeLinearSolver solver;
127 solver = new ConjugateGradient(maxIterations, 1E-10, true);
128 final RealVector b = new ArrayRealVector(n);
129 for (int j = 0; j < n; j++) {
130 b.set(0.);
131 b.setEntry(j, 1.);
132 final RealVector x0 = new ArrayRealVector(n);
133 x0.set(1.);
134 final RealVector x = solver.solveInPlace(a, b, x0);
135 assertSame(x0, x, "x should be a reference to x0");
136 for (int i = 0; i < n; i++) {
137 final double actual = x.getEntry(i);
138 final double expected = ainv.getEntry(i, j);
139 final double delta = 1E-10 * FastMath.abs(expected);
140 final String msg = String.format("entry[%d][%d)", i, j);
141 assertEquals(expected, actual, delta, msg);
142 }
143 }
144 }
145
146 @Test
147 void testUnpreconditionedSolutionWithInitialGuess() {
148 final int n = 5;
149 final int maxIterations = 100;
150 final RealLinearOperator a = new HilbertMatrix(n);
151 final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
152 final IterativeLinearSolver solver;
153 solver = new ConjugateGradient(maxIterations, 1E-10, true);
154 final RealVector b = new ArrayRealVector(n);
155 for (int j = 0; j < n; j++) {
156 b.set(0.);
157 b.setEntry(j, 1.);
158 final RealVector x0 = new ArrayRealVector(n);
159 x0.set(1.);
160 final RealVector x = solver.solve(a, b, x0);
161 assertNotSame(x0, x, "x should not be a reference to x0");
162 for (int i = 0; i < n; i++) {
163 final double actual = x.getEntry(i);
164 final double expected = ainv.getEntry(i, j);
165 final double delta = 1E-10 * FastMath.abs(expected);
166 final String msg = String.format("entry[%d][%d]", i, j);
167 assertEquals(expected, actual, delta, msg);
168 assertEquals(1., x0.getEntry(i), Math.ulp(1.), msg);
169 }
170 }
171 }
172
173
174
175
176
177
178
179 @Test
180 void testUnpreconditionedResidual() {
181 final int n = 10;
182 final int maxIterations = n;
183 final RealLinearOperator a = new HilbertMatrix(n);
184 final ConjugateGradient solver;
185 solver = new ConjugateGradient(maxIterations, 1E-15, true);
186 final RealVector r = new ArrayRealVector(n);
187 final RealVector x = new ArrayRealVector(n);
188 final IterationListener listener = new IterationListener() {
189
190 public void terminationPerformed(final IterationEvent e) {
191
192 }
193
194 public void iterationStarted(final IterationEvent e) {
195
196 }
197
198 public void iterationPerformed(final IterationEvent e) {
199 final IterativeLinearSolverEvent evt;
200 evt = (IterativeLinearSolverEvent) e;
201 RealVector v = evt.getResidual();
202 r.setSubVector(0, v);
203 v = evt.getSolution();
204 x.setSubVector(0, v);
205 }
206
207 public void initializationPerformed(final IterationEvent e) {
208
209 }
210 };
211 solver.getIterationManager().addIterationListener(listener);
212 final RealVector b = new ArrayRealVector(n);
213 for (int j = 0; j < n; j++) {
214 b.set(0.);
215 b.setEntry(j, 1.);
216
217 boolean caught = false;
218 try {
219 solver.solve(a, b);
220 } catch (MathIllegalStateException e) {
221 caught = true;
222 final RealVector y = a.operate(x);
223 for (int i = 0; i < n; i++) {
224 final double actual = b.getEntry(i) - y.getEntry(i);
225 final double expected = r.getEntry(i);
226 final double delta = 1E-6 * FastMath.abs(expected);
227 final String msg = String
228 .format("column %d, residual %d", i, j);
229 assertEquals(expected, actual, delta, msg);
230 }
231 }
232 assertTrue(caught,
233 "MathIllegalStateException should have been caught");
234 }
235 }
236
237 @Test
238 void testNonSquarePreconditioner() {
239 assertThrows(MathIllegalArgumentException.class, () -> {
240 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
241 final RealLinearOperator m = new RealLinearOperator() {
242
243 @Override
244 public RealVector operate(final RealVector x) {
245 throw new UnsupportedOperationException();
246 }
247
248 @Override
249 public int getRowDimension() {
250 return 2;
251 }
252
253 @Override
254 public int getColumnDimension() {
255 return 3;
256 }
257 };
258 final PreconditionedIterativeLinearSolver solver;
259 solver = new ConjugateGradient(10, 0d, false);
260 final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
261 solver.solve(a, m, b);
262 });
263 }
264
265 @Test
266 void testMismatchedOperatorDimensions() {
267 assertThrows(MathIllegalArgumentException.class, () -> {
268 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
269 final RealLinearOperator m = new RealLinearOperator() {
270
271 @Override
272 public RealVector operate(final RealVector x) {
273 throw new UnsupportedOperationException();
274 }
275
276 @Override
277 public int getRowDimension() {
278 return 3;
279 }
280
281 @Override
282 public int getColumnDimension() {
283 return 3;
284 }
285 };
286 final PreconditionedIterativeLinearSolver solver;
287 solver = new ConjugateGradient(10, 0d, false);
288 final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
289 solver.solve(a, m, b);
290 });
291 }
292
293 @Test
294 void testNonPositiveDefinitePreconditioner() {
295 assertThrows(MathIllegalArgumentException.class, () -> {
296 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
297 a.setEntry(0, 0, 1d);
298 a.setEntry(0, 1, 2d);
299 a.setEntry(1, 0, 3d);
300 a.setEntry(1, 1, 4d);
301 final RealLinearOperator m = new RealLinearOperator() {
302
303 @Override
304 public RealVector operate(final RealVector x) {
305 final ArrayRealVector y = new ArrayRealVector(2);
306 y.setEntry(0, -x.getEntry(0));
307 y.setEntry(1, x.getEntry(1));
308 return y;
309 }
310
311 @Override
312 public int getRowDimension() {
313 return 2;
314 }
315
316 @Override
317 public int getColumnDimension() {
318 return 2;
319 }
320 };
321 final PreconditionedIterativeLinearSolver solver;
322 solver = new ConjugateGradient(10, 0d, true);
323 final ArrayRealVector b = new ArrayRealVector(2);
324 b.setEntry(0, -1d);
325 b.setEntry(1, -1d);
326 solver.solve(a, m, b);
327 });
328 }
329
330 @Test
331 void testPreconditionedSolution() {
332 final int n = 8;
333 final int maxIterations = 100;
334 final RealLinearOperator a = new HilbertMatrix(n);
335 final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
336 final RealLinearOperator m = JacobiPreconditioner.create(a);
337 final PreconditionedIterativeLinearSolver solver;
338 solver = new ConjugateGradient(maxIterations, 1E-15, true);
339 final RealVector b = new ArrayRealVector(n);
340 for (int j = 0; j < n; j++) {
341 b.set(0.);
342 b.setEntry(j, 1.);
343 final RealVector x = solver.solve(a, m, b);
344 for (int i = 0; i < n; i++) {
345 final double actual = x.getEntry(i);
346 final double expected = ainv.getEntry(i, j);
347 final double delta = 1E-6 * FastMath.abs(expected);
348 final String msg = String.format("coefficient (%d, %d)", i, j);
349 assertEquals(expected, actual, delta, msg);
350 }
351 }
352 }
353
354 @Test
355 void testPreconditionedResidual() {
356 final int n = 10;
357 final int maxIterations = n;
358 final RealLinearOperator a = new HilbertMatrix(n);
359 final RealLinearOperator m = JacobiPreconditioner.create(a);
360 final ConjugateGradient solver;
361 solver = new ConjugateGradient(maxIterations, 1E-15, true);
362 final RealVector r = new ArrayRealVector(n);
363 final RealVector x = new ArrayRealVector(n);
364 final IterationListener listener = new IterationListener() {
365
366 public void terminationPerformed(final IterationEvent e) {
367
368 }
369
370 public void iterationStarted(final IterationEvent e) {
371
372 }
373
374 public void iterationPerformed(final IterationEvent e) {
375 final IterativeLinearSolverEvent evt;
376 evt = (IterativeLinearSolverEvent) e;
377 RealVector v = evt.getResidual();
378 r.setSubVector(0, v);
379 v = evt.getSolution();
380 x.setSubVector(0, v);
381 }
382
383 public void initializationPerformed(final IterationEvent e) {
384
385 }
386 };
387 solver.getIterationManager().addIterationListener(listener);
388 final RealVector b = new ArrayRealVector(n);
389
390 for (int j = 0; j < n; j++) {
391 b.set(0.);
392 b.setEntry(j, 1.);
393
394 boolean caught = false;
395 try {
396 solver.solve(a, m, b);
397 } catch (MathIllegalStateException e) {
398 caught = true;
399 final RealVector y = a.operate(x);
400 for (int i = 0; i < n; i++) {
401 final double actual = b.getEntry(i) - y.getEntry(i);
402 final double expected = r.getEntry(i);
403 final double delta = 1E-6 * FastMath.abs(expected);
404 final String msg = String.format("column %d, residual %d", i, j);
405 assertEquals(expected, actual, delta, msg);
406 }
407 }
408 assertTrue(caught, "MathIllegalStateException should have been caught");
409 }
410 }
411
412 @Test
413 void testPreconditionedSolution2() {
414 final int n = 100;
415 final int maxIterations = 100000;
416 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(n, n);
417 double daux = 1.;
418 for (int i = 0; i < n; i++) {
419 a.setEntry(i, i, daux);
420 daux *= 1.2;
421 for (int j = i + 1; j < n; j++) {
422 if (i == j) {
423 } else {
424 final double value = 1.0;
425 a.setEntry(i, j, value);
426 a.setEntry(j, i, value);
427 }
428 }
429 }
430 final RealLinearOperator m = JacobiPreconditioner.create(a);
431 final PreconditionedIterativeLinearSolver pcg;
432 final IterativeLinearSolver cg;
433 pcg = new ConjugateGradient(maxIterations, 1E-6, true);
434 cg = new ConjugateGradient(maxIterations, 1E-6, true);
435 final RealVector b = new ArrayRealVector(n);
436 final String pattern = "preconditioned gradient (%d iterations) should"
437 + " have been faster than unpreconditioned (%d iterations)";
438 String msg;
439 for (int j = 0; j < 1; j++) {
440 b.set(0.);
441 b.setEntry(j, 1.);
442 final RealVector px = pcg.solve(a, m, b);
443 final RealVector x = cg.solve(a, b);
444 final int npcg = pcg.getIterationManager().getIterations();
445 final int ncg = cg.getIterationManager().getIterations();
446 msg = String.format(pattern, npcg, ncg);
447 assertTrue(npcg < ncg, msg);
448 for (int i = 0; i < n; i++) {
449 msg = String.format("row %d, column %d", i, j);
450 final double expected = x.getEntry(i);
451 final double actual = px.getEntry(i);
452 final double delta = 1E-6 * FastMath.abs(expected);
453 assertEquals(expected, actual, delta, msg);
454 }
455 }
456 }
457
458 @Test
459 void testEventManagement() {
460 final int n = 5;
461 final int maxIterations = 100;
462 final RealLinearOperator a = new HilbertMatrix(n);
463 final IterativeLinearSolver solver;
464
465
466
467
468
469
470 final int[] count = new int[] {0, 0, 0, 0};
471 final IterationListener listener = new IterationListener() {
472 private void doTestVectorsAreUnmodifiable(final IterationEvent e) {
473 final IterativeLinearSolverEvent evt;
474 evt = (IterativeLinearSolverEvent) e;
475 try {
476 evt.getResidual().set(0.0);
477 fail("r is modifiable");
478 } catch (MathRuntimeException exc){
479
480 }
481 try {
482 evt.getRightHandSideVector().set(0.0);
483 fail("b is modifiable");
484 } catch (MathRuntimeException exc){
485
486 }
487 try {
488 evt.getSolution().set(0.0);
489 fail("x is modifiable");
490 } catch (MathRuntimeException exc){
491
492 }
493 }
494
495 public void initializationPerformed(final IterationEvent e) {
496 ++count[0];
497 doTestVectorsAreUnmodifiable(e);
498 }
499
500 public void iterationPerformed(final IterationEvent e) {
501 ++count[2];
502 assertEquals(count[2], e.getIterations() - 1, "iteration performed");
503 doTestVectorsAreUnmodifiable(e);
504 }
505
506 public void iterationStarted(final IterationEvent e) {
507 ++count[1];
508 assertEquals(count[1], e.getIterations() - 1, "iteration started");
509 doTestVectorsAreUnmodifiable(e);
510 }
511
512 public void terminationPerformed(final IterationEvent e) {
513 ++count[3];
514 doTestVectorsAreUnmodifiable(e);
515 }
516 };
517 solver = new ConjugateGradient(maxIterations, 1E-10, true);
518 solver.getIterationManager().addIterationListener(listener);
519 final RealVector b = new ArrayRealVector(n);
520 for (int j = 0; j < n; j++) {
521 Arrays.fill(count, 0);
522 b.set(0.);
523 b.setEntry(j, 1.);
524 solver.solve(a, b);
525 String msg = String.format("column %d (initialization)", j);
526 assertEquals(1, count[0], msg);
527 msg = String.format("column %d (finalization)", j);
528 assertEquals(1, count[3], msg);
529 }
530 }
531
532 @Test
533 void testUnpreconditionedNormOfResidual() {
534 final int n = 5;
535 final int maxIterations = 100;
536 final RealLinearOperator a = new HilbertMatrix(n);
537 final IterativeLinearSolver solver;
538 final IterationListener listener = new IterationListener() {
539
540 private void doTestNormOfResidual(final IterationEvent e) {
541 final IterativeLinearSolverEvent evt;
542 evt = (IterativeLinearSolverEvent) e;
543 final RealVector x = evt.getSolution();
544 final RealVector b = evt.getRightHandSideVector();
545 final RealVector r = b.subtract(a.operate(x));
546 final double rnorm = r.getNorm();
547 assertEquals(rnorm, evt.getNormOfResidual(),
548 FastMath.max(1E-5 * rnorm, 1E-10),
549 "iteration performed (residual)");
550 }
551
552 public void initializationPerformed(final IterationEvent e) {
553 doTestNormOfResidual(e);
554 }
555
556 public void iterationPerformed(final IterationEvent e) {
557 doTestNormOfResidual(e);
558 }
559
560 public void iterationStarted(final IterationEvent e) {
561 doTestNormOfResidual(e);
562 }
563
564 public void terminationPerformed(final IterationEvent e) {
565 doTestNormOfResidual(e);
566 }
567 };
568 solver = new ConjugateGradient(maxIterations, 1E-10, true);
569 solver.getIterationManager().addIterationListener(listener);
570 final RealVector b = new ArrayRealVector(n);
571 for (int j = 0; j < n; j++) {
572 b.set(0.);
573 b.setEntry(j, 1.);
574 solver.solve(a, b);
575 }
576 }
577
578 @Test
579 void testPreconditionedNormOfResidual() {
580 final int n = 5;
581 final int maxIterations = 100;
582 final RealLinearOperator a = new HilbertMatrix(n);
583 final RealLinearOperator m = JacobiPreconditioner.create(a);
584 final PreconditionedIterativeLinearSolver solver;
585 final IterationListener listener = new IterationListener() {
586
587 private void doTestNormOfResidual(final IterationEvent e) {
588 final IterativeLinearSolverEvent evt;
589 evt = (IterativeLinearSolverEvent) e;
590 final RealVector x = evt.getSolution();
591 final RealVector b = evt.getRightHandSideVector();
592 final RealVector r = b.subtract(a.operate(x));
593 final double rnorm = r.getNorm();
594 assertEquals(rnorm, evt.getNormOfResidual(),
595 FastMath.max(1E-5 * rnorm, 1E-10),
596 "iteration performed (residual)");
597 }
598
599 public void initializationPerformed(final IterationEvent e) {
600 doTestNormOfResidual(e);
601 }
602
603 public void iterationPerformed(final IterationEvent e) {
604 doTestNormOfResidual(e);
605 }
606
607 public void iterationStarted(final IterationEvent e) {
608 doTestNormOfResidual(e);
609 }
610
611 public void terminationPerformed(final IterationEvent e) {
612 doTestNormOfResidual(e);
613 }
614 };
615 solver = new ConjugateGradient(maxIterations, 1E-10, true);
616 solver.getIterationManager().addIterationListener(listener);
617 final RealVector b = new ArrayRealVector(n);
618 for (int j = 0; j < n; j++) {
619 b.set(0.);
620 b.setEntry(j, 1.);
621 solver.solve(a, m, b);
622 }
623 }
624 }