View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.linear;
23  
24  import org.hipparchus.exception.MathIllegalArgumentException;
25  import org.hipparchus.exception.MathIllegalStateException;
26  import org.hipparchus.exception.MathRuntimeException;
27  import org.hipparchus.util.FastMath;
28  import org.hipparchus.util.IterationEvent;
29  import org.hipparchus.util.IterationListener;
30  import org.junit.jupiter.api.Test;
31  
32  import java.util.Arrays;
33  
34  import static org.junit.jupiter.api.Assertions.assertEquals;
35  import static org.junit.jupiter.api.Assertions.assertNotSame;
36  import static org.junit.jupiter.api.Assertions.assertSame;
37  import static org.junit.jupiter.api.Assertions.assertThrows;
38  import static org.junit.jupiter.api.Assertions.assertTrue;
39  import static org.junit.jupiter.api.Assertions.fail;
40  
41  class ConjugateGradientTest {
42  
43      @Test
44      void testNonSquareOperator() {
45          assertThrows(MathIllegalArgumentException.class, () -> {
46              final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 3);
47              final IterativeLinearSolver solver;
48              solver = new ConjugateGradient(10, 0., false);
49              final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
50              final ArrayRealVector x = new ArrayRealVector(a.getColumnDimension());
51              solver.solve(a, b, x);
52          });
53      }
54  
55      @Test
56      void testDimensionMismatchRightHandSide() {
57          assertThrows(MathIllegalArgumentException.class, () -> {
58              final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
59              final IterativeLinearSolver solver;
60              solver = new ConjugateGradient(10, 0., false);
61              final ArrayRealVector b = new ArrayRealVector(2);
62              final ArrayRealVector x = new ArrayRealVector(3);
63              solver.solve(a, b, x);
64          });
65      }
66  
67      @Test
68      void testDimensionMismatchSolution() {
69          assertThrows(MathIllegalArgumentException.class, () -> {
70              final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
71              final IterativeLinearSolver solver;
72              solver = new ConjugateGradient(10, 0., false);
73              final ArrayRealVector b = new ArrayRealVector(3);
74              final ArrayRealVector x = new ArrayRealVector(2);
75              solver.solve(a, b, x);
76          });
77      }
78  
79      @Test
80      void testNonPositiveDefiniteLinearOperator() {
81          assertThrows(MathIllegalArgumentException.class, () -> {
82              final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
83              a.setEntry(0, 0, -1.);
84              a.setEntry(0, 1, 2.);
85              a.setEntry(1, 0, 3.);
86              a.setEntry(1, 1, 4.);
87              final IterativeLinearSolver solver;
88              solver = new ConjugateGradient(10, 0., true);
89              final ArrayRealVector b = new ArrayRealVector(2);
90              b.setEntry(0, -1.);
91              b.setEntry(1, -1.);
92              final ArrayRealVector x = new ArrayRealVector(2);
93              solver.solve(a, b, x);
94          });
95      }
96  
97      @Test
98      void testUnpreconditionedSolution() {
99          final int n = 5;
100         final int maxIterations = 100;
101         final RealLinearOperator a = new HilbertMatrix(n);
102         final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
103         final IterativeLinearSolver solver;
104         solver = new ConjugateGradient(maxIterations, 1E-10, true);
105         final RealVector b = new ArrayRealVector(n);
106         for (int j = 0; j < n; j++) {
107             b.set(0.);
108             b.setEntry(j, 1.);
109             final RealVector x = solver.solve(a, b);
110             for (int i = 0; i < n; i++) {
111                 final double actual = x.getEntry(i);
112                 final double expected = ainv.getEntry(i, j);
113                 final double delta = 1E-10 * FastMath.abs(expected);
114                 final String msg = String.format("entry[%d][%d]", i, j);
115                 assertEquals(expected, actual, delta, msg);
116             }
117         }
118     }
119 
120     @Test
121     void testUnpreconditionedInPlaceSolutionWithInitialGuess() {
122         final int n = 5;
123         final int maxIterations = 100;
124         final RealLinearOperator a = new HilbertMatrix(n);
125         final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
126         final IterativeLinearSolver solver;
127         solver = new ConjugateGradient(maxIterations, 1E-10, true);
128         final RealVector b = new ArrayRealVector(n);
129         for (int j = 0; j < n; j++) {
130             b.set(0.);
131             b.setEntry(j, 1.);
132             final RealVector x0 = new ArrayRealVector(n);
133             x0.set(1.);
134             final RealVector x = solver.solveInPlace(a, b, x0);
135             assertSame(x0, x, "x should be a reference to x0");
136             for (int i = 0; i < n; i++) {
137                 final double actual = x.getEntry(i);
138                 final double expected = ainv.getEntry(i, j);
139                 final double delta = 1E-10 * FastMath.abs(expected);
140                 final String msg = String.format("entry[%d][%d)", i, j);
141                 assertEquals(expected, actual, delta, msg);
142             }
143         }
144     }
145 
146     @Test
147     void testUnpreconditionedSolutionWithInitialGuess() {
148         final int n = 5;
149         final int maxIterations = 100;
150         final RealLinearOperator a = new HilbertMatrix(n);
151         final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
152         final IterativeLinearSolver solver;
153         solver = new ConjugateGradient(maxIterations, 1E-10, true);
154         final RealVector b = new ArrayRealVector(n);
155         for (int j = 0; j < n; j++) {
156             b.set(0.);
157             b.setEntry(j, 1.);
158             final RealVector x0 = new ArrayRealVector(n);
159             x0.set(1.);
160             final RealVector x = solver.solve(a, b, x0);
161             assertNotSame(x0, x, "x should not be a reference to x0");
162             for (int i = 0; i < n; i++) {
163                 final double actual = x.getEntry(i);
164                 final double expected = ainv.getEntry(i, j);
165                 final double delta = 1E-10 * FastMath.abs(expected);
166                 final String msg = String.format("entry[%d][%d]", i, j);
167                 assertEquals(expected, actual, delta, msg);
168                 assertEquals(1., x0.getEntry(i), Math.ulp(1.), msg);
169             }
170         }
171     }
172 
173     /**
174      * Check whether the estimate of the (updated) residual corresponds to the
175      * exact residual. This fails to be true for a large number of iterations,
176      * due to the loss of orthogonality of the successive search directions.
177      * Therefore, in the present test, the number of iterations is limited.
178      */
179     @Test
180     void testUnpreconditionedResidual() {
181         final int n = 10;
182         final int maxIterations = n;
183         final RealLinearOperator a = new HilbertMatrix(n);
184         final ConjugateGradient solver;
185         solver = new ConjugateGradient(maxIterations, 1E-15, true);
186         final RealVector r = new ArrayRealVector(n);
187         final RealVector x = new ArrayRealVector(n);
188         final IterationListener listener = new IterationListener() {
189 
190             public void terminationPerformed(final IterationEvent e) {
191                 // Do nothing
192             }
193 
194             public void iterationStarted(final IterationEvent e) {
195                 // Do nothing
196             }
197 
198             public void iterationPerformed(final IterationEvent e) {
199                 final IterativeLinearSolverEvent evt;
200                 evt = (IterativeLinearSolverEvent) e;
201                 RealVector v = evt.getResidual();
202                 r.setSubVector(0, v);
203                 v = evt.getSolution();
204                 x.setSubVector(0, v);
205             }
206 
207             public void initializationPerformed(final IterationEvent e) {
208                 // Do nothing
209             }
210         };
211         solver.getIterationManager().addIterationListener(listener);
212         final RealVector b = new ArrayRealVector(n);
213         for (int j = 0; j < n; j++) {
214             b.set(0.);
215             b.setEntry(j, 1.);
216 
217             boolean caught = false;
218             try {
219                 solver.solve(a, b);
220             } catch (MathIllegalStateException e) {
221                 caught = true;
222                 final RealVector y = a.operate(x);
223                 for (int i = 0; i < n; i++) {
224                     final double actual = b.getEntry(i) - y.getEntry(i);
225                     final double expected = r.getEntry(i);
226                     final double delta = 1E-6 * FastMath.abs(expected);
227                     final String msg = String
228                         .format("column %d, residual %d", i, j);
229                     assertEquals(expected, actual, delta, msg);
230                 }
231             }
232             assertTrue(caught,
233                             "MathIllegalStateException should have been caught");
234         }
235     }
236 
237     @Test
238     void testNonSquarePreconditioner() {
239         assertThrows(MathIllegalArgumentException.class, () -> {
240             final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
241             final RealLinearOperator m = new RealLinearOperator() {
242 
243                 @Override
244                 public RealVector operate(final RealVector x) {
245                     throw new UnsupportedOperationException();
246                 }
247 
248                 @Override
249                 public int getRowDimension() {
250                     return 2;
251                 }
252 
253                 @Override
254                 public int getColumnDimension() {
255                     return 3;
256                 }
257             };
258             final PreconditionedIterativeLinearSolver solver;
259             solver = new ConjugateGradient(10, 0d, false);
260             final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
261             solver.solve(a, m, b);
262         });
263     }
264 
265     @Test
266     void testMismatchedOperatorDimensions() {
267         assertThrows(MathIllegalArgumentException.class, () -> {
268             final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
269             final RealLinearOperator m = new RealLinearOperator() {
270 
271                 @Override
272                 public RealVector operate(final RealVector x) {
273                     throw new UnsupportedOperationException();
274                 }
275 
276                 @Override
277                 public int getRowDimension() {
278                     return 3;
279                 }
280 
281                 @Override
282                 public int getColumnDimension() {
283                     return 3;
284                 }
285             };
286             final PreconditionedIterativeLinearSolver solver;
287             solver = new ConjugateGradient(10, 0d, false);
288             final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
289             solver.solve(a, m, b);
290         });
291     }
292 
293     @Test
294     void testNonPositiveDefinitePreconditioner() {
295         assertThrows(MathIllegalArgumentException.class, () -> {
296             final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
297             a.setEntry(0, 0, 1d);
298             a.setEntry(0, 1, 2d);
299             a.setEntry(1, 0, 3d);
300             a.setEntry(1, 1, 4d);
301             final RealLinearOperator m = new RealLinearOperator() {
302 
303                 @Override
304                 public RealVector operate(final RealVector x) {
305                     final ArrayRealVector y = new ArrayRealVector(2);
306                     y.setEntry(0, -x.getEntry(0));
307                     y.setEntry(1, x.getEntry(1));
308                     return y;
309                 }
310 
311                 @Override
312                 public int getRowDimension() {
313                     return 2;
314                 }
315 
316                 @Override
317                 public int getColumnDimension() {
318                     return 2;
319                 }
320             };
321             final PreconditionedIterativeLinearSolver solver;
322             solver = new ConjugateGradient(10, 0d, true);
323             final ArrayRealVector b = new ArrayRealVector(2);
324             b.setEntry(0, -1d);
325             b.setEntry(1, -1d);
326             solver.solve(a, m, b);
327         });
328     }
329 
330     @Test
331     void testPreconditionedSolution() {
332         final int n = 8;
333         final int maxIterations = 100;
334         final RealLinearOperator a = new HilbertMatrix(n);
335         final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
336         final RealLinearOperator m = JacobiPreconditioner.create(a);
337         final PreconditionedIterativeLinearSolver solver;
338         solver = new ConjugateGradient(maxIterations, 1E-15, true);
339         final RealVector b = new ArrayRealVector(n);
340         for (int j = 0; j < n; j++) {
341             b.set(0.);
342             b.setEntry(j, 1.);
343             final RealVector x = solver.solve(a, m, b);
344             for (int i = 0; i < n; i++) {
345                 final double actual = x.getEntry(i);
346                 final double expected = ainv.getEntry(i, j);
347                 final double delta = 1E-6 * FastMath.abs(expected);
348                 final String msg = String.format("coefficient (%d, %d)", i, j);
349                 assertEquals(expected, actual, delta, msg);
350             }
351         }
352     }
353 
354     @Test
355     void testPreconditionedResidual() {
356         final int n = 10;
357         final int maxIterations = n;
358         final RealLinearOperator a = new HilbertMatrix(n);
359         final RealLinearOperator m = JacobiPreconditioner.create(a);
360         final ConjugateGradient solver;
361         solver = new ConjugateGradient(maxIterations, 1E-15, true);
362         final RealVector r = new ArrayRealVector(n);
363         final RealVector x = new ArrayRealVector(n);
364         final IterationListener listener = new IterationListener() {
365 
366             public void terminationPerformed(final IterationEvent e) {
367                 // Do nothing
368             }
369 
370             public void iterationStarted(final IterationEvent e) {
371                 // Do nothing
372             }
373 
374             public void iterationPerformed(final IterationEvent e) {
375                 final IterativeLinearSolverEvent evt;
376                 evt = (IterativeLinearSolverEvent) e;
377                 RealVector v = evt.getResidual();
378                 r.setSubVector(0, v);
379                 v = evt.getSolution();
380                 x.setSubVector(0, v);
381             }
382 
383             public void initializationPerformed(final IterationEvent e) {
384                 // Do nothing
385             }
386         };
387         solver.getIterationManager().addIterationListener(listener);
388         final RealVector b = new ArrayRealVector(n);
389 
390         for (int j = 0; j < n; j++) {
391             b.set(0.);
392             b.setEntry(j, 1.);
393 
394             boolean caught = false;
395             try {
396                 solver.solve(a, m, b);
397             } catch (MathIllegalStateException e) {
398                 caught = true;
399                 final RealVector y = a.operate(x);
400                 for (int i = 0; i < n; i++) {
401                     final double actual = b.getEntry(i) - y.getEntry(i);
402                     final double expected = r.getEntry(i);
403                     final double delta = 1E-6 * FastMath.abs(expected);
404                     final String msg = String.format("column %d, residual %d", i, j);
405                     assertEquals(expected, actual, delta, msg);
406                 }
407             }
408             assertTrue(caught, "MathIllegalStateException should have been caught");
409         }
410     }
411 
412     @Test
413     void testPreconditionedSolution2() {
414         final int n = 100;
415         final int maxIterations = 100000;
416         final Array2DRowRealMatrix a = new Array2DRowRealMatrix(n, n);
417         double daux = 1.;
418         for (int i = 0; i < n; i++) {
419             a.setEntry(i, i, daux);
420             daux *= 1.2;
421             for (int j = i + 1; j < n; j++) {
422                 if (i == j) {
423                 } else {
424                     final double value = 1.0;
425                     a.setEntry(i, j, value);
426                     a.setEntry(j, i, value);
427                 }
428             }
429         }
430         final RealLinearOperator m = JacobiPreconditioner.create(a);
431         final PreconditionedIterativeLinearSolver pcg;
432         final IterativeLinearSolver cg;
433         pcg = new ConjugateGradient(maxIterations, 1E-6, true);
434         cg = new ConjugateGradient(maxIterations, 1E-6, true);
435         final RealVector b = new ArrayRealVector(n);
436         final String pattern = "preconditioned gradient (%d iterations) should"
437                                + " have been faster than unpreconditioned (%d iterations)";
438         String msg;
439         for (int j = 0; j < 1; j++) {
440             b.set(0.);
441             b.setEntry(j, 1.);
442             final RealVector px = pcg.solve(a, m, b);
443             final RealVector x = cg.solve(a, b);
444             final int npcg = pcg.getIterationManager().getIterations();
445             final int ncg = cg.getIterationManager().getIterations();
446             msg = String.format(pattern, npcg, ncg);
447             assertTrue(npcg < ncg, msg);
448             for (int i = 0; i < n; i++) {
449                 msg = String.format("row %d, column %d", i, j);
450                 final double expected = x.getEntry(i);
451                 final double actual = px.getEntry(i);
452                 final double delta = 1E-6 * FastMath.abs(expected);
453                 assertEquals(expected, actual, delta, msg);
454             }
455         }
456     }
457 
458     @Test
459     void testEventManagement() {
460         final int n = 5;
461         final int maxIterations = 100;
462         final RealLinearOperator a = new HilbertMatrix(n);
463         final IterativeLinearSolver solver;
464         /*
465          * count[0] = number of calls to initializationPerformed
466          * count[1] = number of calls to iterationStarted
467          * count[2] = number of calls to iterationPerformed
468          * count[3] = number of calls to terminationPerformed
469          */
470         final int[] count = new int[] {0, 0, 0, 0};
471         final IterationListener listener = new IterationListener() {
472             private void doTestVectorsAreUnmodifiable(final IterationEvent e) {
473                 final IterativeLinearSolverEvent evt;
474                 evt = (IterativeLinearSolverEvent) e;
475                 try {
476                     evt.getResidual().set(0.0);
477                     fail("r is modifiable");
478                 } catch (MathRuntimeException exc){
479                     // Expected behavior
480                 }
481                 try {
482                     evt.getRightHandSideVector().set(0.0);
483                     fail("b is modifiable");
484                 } catch (MathRuntimeException exc){
485                     // Expected behavior
486                 }
487                 try {
488                     evt.getSolution().set(0.0);
489                     fail("x is modifiable");
490                 } catch (MathRuntimeException exc){
491                     // Expected behavior
492                 }
493             }
494 
495             public void initializationPerformed(final IterationEvent e) {
496                 ++count[0];
497                 doTestVectorsAreUnmodifiable(e);
498             }
499 
500             public void iterationPerformed(final IterationEvent e) {
501                 ++count[2];
502                 assertEquals(count[2], e.getIterations() - 1, "iteration performed");
503                 doTestVectorsAreUnmodifiable(e);
504             }
505 
506             public void iterationStarted(final IterationEvent e) {
507                 ++count[1];
508                 assertEquals(count[1], e.getIterations() - 1, "iteration started");
509                 doTestVectorsAreUnmodifiable(e);
510             }
511 
512             public void terminationPerformed(final IterationEvent e) {
513                 ++count[3];
514                 doTestVectorsAreUnmodifiable(e);
515             }
516         };
517         solver = new ConjugateGradient(maxIterations, 1E-10, true);
518         solver.getIterationManager().addIterationListener(listener);
519         final RealVector b = new ArrayRealVector(n);
520         for (int j = 0; j < n; j++) {
521             Arrays.fill(count, 0);
522             b.set(0.);
523             b.setEntry(j, 1.);
524             solver.solve(a, b);
525             String msg = String.format("column %d (initialization)", j);
526             assertEquals(1, count[0], msg);
527             msg = String.format("column %d (finalization)", j);
528             assertEquals(1, count[3], msg);
529         }
530     }
531 
532     @Test
533     void testUnpreconditionedNormOfResidual() {
534         final int n = 5;
535         final int maxIterations = 100;
536         final RealLinearOperator a = new HilbertMatrix(n);
537         final IterativeLinearSolver solver;
538         final IterationListener listener = new IterationListener() {
539 
540             private void doTestNormOfResidual(final IterationEvent e) {
541                 final IterativeLinearSolverEvent evt;
542                 evt = (IterativeLinearSolverEvent) e;
543                 final RealVector x = evt.getSolution();
544                 final RealVector b = evt.getRightHandSideVector();
545                 final RealVector r = b.subtract(a.operate(x));
546                 final double rnorm = r.getNorm();
547                 assertEquals(rnorm, evt.getNormOfResidual(),
548                     FastMath.max(1E-5 * rnorm, 1E-10),
549                     "iteration performed (residual)");
550             }
551 
552             public void initializationPerformed(final IterationEvent e) {
553                 doTestNormOfResidual(e);
554             }
555 
556             public void iterationPerformed(final IterationEvent e) {
557                 doTestNormOfResidual(e);
558             }
559 
560             public void iterationStarted(final IterationEvent e) {
561                 doTestNormOfResidual(e);
562             }
563 
564             public void terminationPerformed(final IterationEvent e) {
565                 doTestNormOfResidual(e);
566             }
567         };
568         solver = new ConjugateGradient(maxIterations, 1E-10, true);
569         solver.getIterationManager().addIterationListener(listener);
570         final RealVector b = new ArrayRealVector(n);
571         for (int j = 0; j < n; j++) {
572             b.set(0.);
573             b.setEntry(j, 1.);
574             solver.solve(a, b);
575         }
576     }
577 
578     @Test
579     void testPreconditionedNormOfResidual() {
580         final int n = 5;
581         final int maxIterations = 100;
582         final RealLinearOperator a = new HilbertMatrix(n);
583         final RealLinearOperator m = JacobiPreconditioner.create(a);
584         final PreconditionedIterativeLinearSolver solver;
585         final IterationListener listener = new IterationListener() {
586 
587             private void doTestNormOfResidual(final IterationEvent e) {
588                 final IterativeLinearSolverEvent evt;
589                 evt = (IterativeLinearSolverEvent) e;
590                 final RealVector x = evt.getSolution();
591                 final RealVector b = evt.getRightHandSideVector();
592                 final RealVector r = b.subtract(a.operate(x));
593                 final double rnorm = r.getNorm();
594                 assertEquals(rnorm, evt.getNormOfResidual(),
595                     FastMath.max(1E-5 * rnorm, 1E-10),
596                     "iteration performed (residual)");
597             }
598 
599             public void initializationPerformed(final IterationEvent e) {
600                 doTestNormOfResidual(e);
601             }
602 
603             public void iterationPerformed(final IterationEvent e) {
604                 doTestNormOfResidual(e);
605             }
606 
607             public void iterationStarted(final IterationEvent e) {
608                 doTestNormOfResidual(e);
609             }
610 
611             public void terminationPerformed(final IterationEvent e) {
612                 doTestNormOfResidual(e);
613             }
614         };
615         solver = new ConjugateGradient(maxIterations, 1E-10, true);
616         solver.getIterationManager().addIterationListener(listener);
617         final RealVector b = new ArrayRealVector(n);
618         for (int j = 0; j < n; j++) {
619             b.set(0.);
620             b.setEntry(j, 1.);
621             solver.solve(a, m, b);
622         }
623     }
624 }