View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.linear;
23  
24  import java.util.Arrays;
25  
26  import org.hipparchus.exception.MathIllegalArgumentException;
27  import org.hipparchus.exception.MathIllegalStateException;
28  import org.hipparchus.exception.MathRuntimeException;
29  import org.hipparchus.util.FastMath;
30  import org.hipparchus.util.IterationEvent;
31  import org.hipparchus.util.IterationListener;
32  import org.junit.Assert;
33  import org.junit.Test;
34  
35  public class ConjugateGradientTest {
36  
37      @Test(expected = MathIllegalArgumentException.class)
38      public void testNonSquareOperator() {
39          final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 3);
40          final IterativeLinearSolver solver;
41          solver = new ConjugateGradient(10, 0., false);
42          final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
43          final ArrayRealVector x = new ArrayRealVector(a.getColumnDimension());
44          solver.solve(a, b, x);
45      }
46  
47      @Test(expected = MathIllegalArgumentException.class)
48      public void testDimensionMismatchRightHandSide() {
49          final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
50          final IterativeLinearSolver solver;
51          solver = new ConjugateGradient(10, 0., false);
52          final ArrayRealVector b = new ArrayRealVector(2);
53          final ArrayRealVector x = new ArrayRealVector(3);
54          solver.solve(a, b, x);
55      }
56  
57      @Test(expected = MathIllegalArgumentException.class)
58      public void testDimensionMismatchSolution() {
59          final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
60          final IterativeLinearSolver solver;
61          solver = new ConjugateGradient(10, 0., false);
62          final ArrayRealVector b = new ArrayRealVector(3);
63          final ArrayRealVector x = new ArrayRealVector(2);
64          solver.solve(a, b, x);
65      }
66  
67      @Test(expected = MathIllegalArgumentException.class)
68      public void testNonPositiveDefiniteLinearOperator() {
69          final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
70          a.setEntry(0, 0, -1.);
71          a.setEntry(0, 1, 2.);
72          a.setEntry(1, 0, 3.);
73          a.setEntry(1, 1, 4.);
74          final IterativeLinearSolver solver;
75          solver = new ConjugateGradient(10, 0., true);
76          final ArrayRealVector b = new ArrayRealVector(2);
77          b.setEntry(0, -1.);
78          b.setEntry(1, -1.);
79          final ArrayRealVector x = new ArrayRealVector(2);
80          solver.solve(a, b, x);
81      }
82  
83      @Test
84      public void testUnpreconditionedSolution() {
85          final int n = 5;
86          final int maxIterations = 100;
87          final RealLinearOperator a = new HilbertMatrix(n);
88          final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
89          final IterativeLinearSolver solver;
90          solver = new ConjugateGradient(maxIterations, 1E-10, true);
91          final RealVector b = new ArrayRealVector(n);
92          for (int j = 0; j < n; j++) {
93              b.set(0.);
94              b.setEntry(j, 1.);
95              final RealVector x = solver.solve(a, b);
96              for (int i = 0; i < n; i++) {
97                  final double actual = x.getEntry(i);
98                  final double expected = ainv.getEntry(i, j);
99                  final double delta = 1E-10 * FastMath.abs(expected);
100                 final String msg = String.format("entry[%d][%d]", i, j);
101                 Assert.assertEquals(msg, expected, actual, delta);
102             }
103         }
104     }
105 
106     @Test
107     public void testUnpreconditionedInPlaceSolutionWithInitialGuess() {
108         final int n = 5;
109         final int maxIterations = 100;
110         final RealLinearOperator a = new HilbertMatrix(n);
111         final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
112         final IterativeLinearSolver solver;
113         solver = new ConjugateGradient(maxIterations, 1E-10, true);
114         final RealVector b = new ArrayRealVector(n);
115         for (int j = 0; j < n; j++) {
116             b.set(0.);
117             b.setEntry(j, 1.);
118             final RealVector x0 = new ArrayRealVector(n);
119             x0.set(1.);
120             final RealVector x = solver.solveInPlace(a, b, x0);
121             Assert.assertSame("x should be a reference to x0", x0, x);
122             for (int i = 0; i < n; i++) {
123                 final double actual = x.getEntry(i);
124                 final double expected = ainv.getEntry(i, j);
125                 final double delta = 1E-10 * FastMath.abs(expected);
126                 final String msg = String.format("entry[%d][%d)", i, j);
127                 Assert.assertEquals(msg, expected, actual, delta);
128             }
129         }
130     }
131 
132     @Test
133     public void testUnpreconditionedSolutionWithInitialGuess() {
134         final int n = 5;
135         final int maxIterations = 100;
136         final RealLinearOperator a = new HilbertMatrix(n);
137         final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
138         final IterativeLinearSolver solver;
139         solver = new ConjugateGradient(maxIterations, 1E-10, true);
140         final RealVector b = new ArrayRealVector(n);
141         for (int j = 0; j < n; j++) {
142             b.set(0.);
143             b.setEntry(j, 1.);
144             final RealVector x0 = new ArrayRealVector(n);
145             x0.set(1.);
146             final RealVector x = solver.solve(a, b, x0);
147             Assert.assertNotSame("x should not be a reference to x0", x0, x);
148             for (int i = 0; i < n; i++) {
149                 final double actual = x.getEntry(i);
150                 final double expected = ainv.getEntry(i, j);
151                 final double delta = 1E-10 * FastMath.abs(expected);
152                 final String msg = String.format("entry[%d][%d]", i, j);
153                 Assert.assertEquals(msg, expected, actual, delta);
154                 Assert.assertEquals(msg, x0.getEntry(i), 1., Math.ulp(1.));
155             }
156         }
157     }
158 
159     /**
160      * Check whether the estimate of the (updated) residual corresponds to the
161      * exact residual. This fails to be true for a large number of iterations,
162      * due to the loss of orthogonality of the successive search directions.
163      * Therefore, in the present test, the number of iterations is limited.
164      */
165     @Test
166     public void testUnpreconditionedResidual() {
167         final int n = 10;
168         final int maxIterations = n;
169         final RealLinearOperator a = new HilbertMatrix(n);
170         final ConjugateGradient solver;
171         solver = new ConjugateGradient(maxIterations, 1E-15, true);
172         final RealVector r = new ArrayRealVector(n);
173         final RealVector x = new ArrayRealVector(n);
174         final IterationListener listener = new IterationListener() {
175 
176             public void terminationPerformed(final IterationEvent e) {
177                 // Do nothing
178             }
179 
180             public void iterationStarted(final IterationEvent e) {
181                 // Do nothing
182             }
183 
184             public void iterationPerformed(final IterationEvent e) {
185                 final IterativeLinearSolverEvent evt;
186                 evt = (IterativeLinearSolverEvent) e;
187                 RealVector v = evt.getResidual();
188                 r.setSubVector(0, v);
189                 v = evt.getSolution();
190                 x.setSubVector(0, v);
191             }
192 
193             public void initializationPerformed(final IterationEvent e) {
194                 // Do nothing
195             }
196         };
197         solver.getIterationManager().addIterationListener(listener);
198         final RealVector b = new ArrayRealVector(n);
199         for (int j = 0; j < n; j++) {
200             b.set(0.);
201             b.setEntry(j, 1.);
202 
203             boolean caught = false;
204             try {
205                 solver.solve(a, b);
206             } catch (MathIllegalStateException e) {
207                 caught = true;
208                 final RealVector y = a.operate(x);
209                 for (int i = 0; i < n; i++) {
210                     final double actual = b.getEntry(i) - y.getEntry(i);
211                     final double expected = r.getEntry(i);
212                     final double delta = 1E-6 * FastMath.abs(expected);
213                     final String msg = String
214                         .format("column %d, residual %d", i, j);
215                     Assert.assertEquals(msg, expected, actual, delta);
216                 }
217             }
218             Assert
219                 .assertTrue("MathIllegalStateException should have been caught",
220                             caught);
221         }
222     }
223 
224     @Test(expected = MathIllegalArgumentException.class)
225     public void testNonSquarePreconditioner() {
226         final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
227         final RealLinearOperator m = new RealLinearOperator() {
228 
229             @Override
230             public RealVector operate(final RealVector x) {
231                 throw new UnsupportedOperationException();
232             }
233 
234             @Override
235             public int getRowDimension() {
236                 return 2;
237             }
238 
239             @Override
240             public int getColumnDimension() {
241                 return 3;
242             }
243         };
244         final PreconditionedIterativeLinearSolver solver;
245         solver = new ConjugateGradient(10, 0d, false);
246         final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
247         solver.solve(a, m, b);
248     }
249 
250     @Test(expected = MathIllegalArgumentException.class)
251     public void testMismatchedOperatorDimensions() {
252         final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
253         final RealLinearOperator m = new RealLinearOperator() {
254 
255             @Override
256             public RealVector operate(final RealVector x) {
257                 throw new UnsupportedOperationException();
258             }
259 
260             @Override
261             public int getRowDimension() {
262                 return 3;
263             }
264 
265             @Override
266             public int getColumnDimension() {
267                 return 3;
268             }
269         };
270         final PreconditionedIterativeLinearSolver solver;
271         solver = new ConjugateGradient(10, 0d, false);
272         final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
273         solver.solve(a, m, b);
274     }
275 
276     @Test(expected = MathIllegalArgumentException.class)
277     public void testNonPositiveDefinitePreconditioner() {
278         final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
279         a.setEntry(0, 0, 1d);
280         a.setEntry(0, 1, 2d);
281         a.setEntry(1, 0, 3d);
282         a.setEntry(1, 1, 4d);
283         final RealLinearOperator m = new RealLinearOperator() {
284 
285             @Override
286             public RealVector operate(final RealVector x) {
287                 final ArrayRealVector y = new ArrayRealVector(2);
288                 y.setEntry(0, -x.getEntry(0));
289                 y.setEntry(1, x.getEntry(1));
290                 return y;
291             }
292 
293             @Override
294             public int getRowDimension() {
295                 return 2;
296             }
297 
298             @Override
299             public int getColumnDimension() {
300                 return 2;
301             }
302         };
303         final PreconditionedIterativeLinearSolver solver;
304         solver = new ConjugateGradient(10, 0d, true);
305         final ArrayRealVector b = new ArrayRealVector(2);
306         b.setEntry(0, -1d);
307         b.setEntry(1, -1d);
308         solver.solve(a, m, b);
309     }
310 
311     @Test
312     public void testPreconditionedSolution() {
313         final int n = 8;
314         final int maxIterations = 100;
315         final RealLinearOperator a = new HilbertMatrix(n);
316         final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
317         final RealLinearOperator m = JacobiPreconditioner.create(a);
318         final PreconditionedIterativeLinearSolver solver;
319         solver = new ConjugateGradient(maxIterations, 1E-15, true);
320         final RealVector b = new ArrayRealVector(n);
321         for (int j = 0; j < n; j++) {
322             b.set(0.);
323             b.setEntry(j, 1.);
324             final RealVector x = solver.solve(a, m, b);
325             for (int i = 0; i < n; i++) {
326                 final double actual = x.getEntry(i);
327                 final double expected = ainv.getEntry(i, j);
328                 final double delta = 1E-6 * FastMath.abs(expected);
329                 final String msg = String.format("coefficient (%d, %d)", i, j);
330                 Assert.assertEquals(msg, expected, actual, delta);
331             }
332         }
333     }
334 
335     @Test
336     public void testPreconditionedResidual() {
337         final int n = 10;
338         final int maxIterations = n;
339         final RealLinearOperator a = new HilbertMatrix(n);
340         final RealLinearOperator m = JacobiPreconditioner.create(a);
341         final ConjugateGradient solver;
342         solver = new ConjugateGradient(maxIterations, 1E-15, true);
343         final RealVector r = new ArrayRealVector(n);
344         final RealVector x = new ArrayRealVector(n);
345         final IterationListener listener = new IterationListener() {
346 
347             public void terminationPerformed(final IterationEvent e) {
348                 // Do nothing
349             }
350 
351             public void iterationStarted(final IterationEvent e) {
352                 // Do nothing
353             }
354 
355             public void iterationPerformed(final IterationEvent e) {
356                 final IterativeLinearSolverEvent evt;
357                 evt = (IterativeLinearSolverEvent) e;
358                 RealVector v = evt.getResidual();
359                 r.setSubVector(0, v);
360                 v = evt.getSolution();
361                 x.setSubVector(0, v);
362             }
363 
364             public void initializationPerformed(final IterationEvent e) {
365                 // Do nothing
366             }
367         };
368         solver.getIterationManager().addIterationListener(listener);
369         final RealVector b = new ArrayRealVector(n);
370 
371         for (int j = 0; j < n; j++) {
372             b.set(0.);
373             b.setEntry(j, 1.);
374 
375             boolean caught = false;
376             try {
377                 solver.solve(a, m, b);
378             } catch (MathIllegalStateException e) {
379                 caught = true;
380                 final RealVector y = a.operate(x);
381                 for (int i = 0; i < n; i++) {
382                     final double actual = b.getEntry(i) - y.getEntry(i);
383                     final double expected = r.getEntry(i);
384                     final double delta = 1E-6 * FastMath.abs(expected);
385                     final String msg = String.format("column %d, residual %d", i, j);
386                     Assert.assertEquals(msg, expected, actual, delta);
387                 }
388             }
389             Assert.assertTrue("MathIllegalStateException should have been caught", caught);
390         }
391     }
392 
393     @Test
394     public void testPreconditionedSolution2() {
395         final int n = 100;
396         final int maxIterations = 100000;
397         final Array2DRowRealMatrix a = new Array2DRowRealMatrix(n, n);
398         double daux = 1.;
399         for (int i = 0; i < n; i++) {
400             a.setEntry(i, i, daux);
401             daux *= 1.2;
402             for (int j = i + 1; j < n; j++) {
403                 if (i == j) {
404                 } else {
405                     final double value = 1.0;
406                     a.setEntry(i, j, value);
407                     a.setEntry(j, i, value);
408                 }
409             }
410         }
411         final RealLinearOperator m = JacobiPreconditioner.create(a);
412         final PreconditionedIterativeLinearSolver pcg;
413         final IterativeLinearSolver cg;
414         pcg = new ConjugateGradient(maxIterations, 1E-6, true);
415         cg = new ConjugateGradient(maxIterations, 1E-6, true);
416         final RealVector b = new ArrayRealVector(n);
417         final String pattern = "preconditioned gradient (%d iterations) should"
418                                + " have been faster than unpreconditioned (%d iterations)";
419         String msg;
420         for (int j = 0; j < 1; j++) {
421             b.set(0.);
422             b.setEntry(j, 1.);
423             final RealVector px = pcg.solve(a, m, b);
424             final RealVector x = cg.solve(a, b);
425             final int npcg = pcg.getIterationManager().getIterations();
426             final int ncg = cg.getIterationManager().getIterations();
427             msg = String.format(pattern, npcg, ncg);
428             Assert.assertTrue(msg, npcg < ncg);
429             for (int i = 0; i < n; i++) {
430                 msg = String.format("row %d, column %d", i, j);
431                 final double expected = x.getEntry(i);
432                 final double actual = px.getEntry(i);
433                 final double delta = 1E-6 * FastMath.abs(expected);
434                 Assert.assertEquals(msg, expected, actual, delta);
435             }
436         }
437     }
438 
439     @Test
440     public void testEventManagement() {
441         final int n = 5;
442         final int maxIterations = 100;
443         final RealLinearOperator a = new HilbertMatrix(n);
444         final IterativeLinearSolver solver;
445         /*
446          * count[0] = number of calls to initializationPerformed
447          * count[1] = number of calls to iterationStarted
448          * count[2] = number of calls to iterationPerformed
449          * count[3] = number of calls to terminationPerformed
450          */
451         final int[] count = new int[] {0, 0, 0, 0};
452         final IterationListener listener = new IterationListener() {
453             private void doTestVectorsAreUnmodifiable(final IterationEvent e) {
454                 final IterativeLinearSolverEvent evt;
455                 evt = (IterativeLinearSolverEvent) e;
456                 try {
457                     evt.getResidual().set(0.0);
458                     Assert.fail("r is modifiable");
459                 } catch (MathRuntimeException exc){
460                     // Expected behavior
461                 }
462                 try {
463                     evt.getRightHandSideVector().set(0.0);
464                     Assert.fail("b is modifiable");
465                 } catch (MathRuntimeException exc){
466                     // Expected behavior
467                 }
468                 try {
469                     evt.getSolution().set(0.0);
470                     Assert.fail("x is modifiable");
471                 } catch (MathRuntimeException exc){
472                     // Expected behavior
473                 }
474             }
475 
476             public void initializationPerformed(final IterationEvent e) {
477                 ++count[0];
478                 doTestVectorsAreUnmodifiable(e);
479             }
480 
481             public void iterationPerformed(final IterationEvent e) {
482                 ++count[2];
483                 Assert.assertEquals("iteration performed",
484                     count[2], e.getIterations() - 1);
485                 doTestVectorsAreUnmodifiable(e);
486             }
487 
488             public void iterationStarted(final IterationEvent e) {
489                 ++count[1];
490                 Assert.assertEquals("iteration started",
491                     count[1], e.getIterations() - 1);
492                 doTestVectorsAreUnmodifiable(e);
493             }
494 
495             public void terminationPerformed(final IterationEvent e) {
496                 ++count[3];
497                 doTestVectorsAreUnmodifiable(e);
498             }
499         };
500         solver = new ConjugateGradient(maxIterations, 1E-10, true);
501         solver.getIterationManager().addIterationListener(listener);
502         final RealVector b = new ArrayRealVector(n);
503         for (int j = 0; j < n; j++) {
504             Arrays.fill(count, 0);
505             b.set(0.);
506             b.setEntry(j, 1.);
507             solver.solve(a, b);
508             String msg = String.format("column %d (initialization)", j);
509             Assert.assertEquals(msg, 1, count[0]);
510             msg = String.format("column %d (finalization)", j);
511             Assert.assertEquals(msg, 1, count[3]);
512         }
513     }
514 
515     @Test
516     public void testUnpreconditionedNormOfResidual() {
517         final int n = 5;
518         final int maxIterations = 100;
519         final RealLinearOperator a = new HilbertMatrix(n);
520         final IterativeLinearSolver solver;
521         final IterationListener listener = new IterationListener() {
522 
523             private void doTestNormOfResidual(final IterationEvent e) {
524                 final IterativeLinearSolverEvent evt;
525                 evt = (IterativeLinearSolverEvent) e;
526                 final RealVector x = evt.getSolution();
527                 final RealVector b = evt.getRightHandSideVector();
528                 final RealVector r = b.subtract(a.operate(x));
529                 final double rnorm = r.getNorm();
530                 Assert.assertEquals("iteration performed (residual)",
531                     rnorm, evt.getNormOfResidual(),
532                     FastMath.max(1E-5 * rnorm, 1E-10));
533             }
534 
535             public void initializationPerformed(final IterationEvent e) {
536                 doTestNormOfResidual(e);
537             }
538 
539             public void iterationPerformed(final IterationEvent e) {
540                 doTestNormOfResidual(e);
541             }
542 
543             public void iterationStarted(final IterationEvent e) {
544                 doTestNormOfResidual(e);
545             }
546 
547             public void terminationPerformed(final IterationEvent e) {
548                 doTestNormOfResidual(e);
549             }
550         };
551         solver = new ConjugateGradient(maxIterations, 1E-10, true);
552         solver.getIterationManager().addIterationListener(listener);
553         final RealVector b = new ArrayRealVector(n);
554         for (int j = 0; j < n; j++) {
555             b.set(0.);
556             b.setEntry(j, 1.);
557             solver.solve(a, b);
558         }
559     }
560 
561     @Test
562     public void testPreconditionedNormOfResidual() {
563         final int n = 5;
564         final int maxIterations = 100;
565         final RealLinearOperator a = new HilbertMatrix(n);
566         final RealLinearOperator m = JacobiPreconditioner.create(a);
567         final PreconditionedIterativeLinearSolver solver;
568         final IterationListener listener = new IterationListener() {
569 
570             private void doTestNormOfResidual(final IterationEvent e) {
571                 final IterativeLinearSolverEvent evt;
572                 evt = (IterativeLinearSolverEvent) e;
573                 final RealVector x = evt.getSolution();
574                 final RealVector b = evt.getRightHandSideVector();
575                 final RealVector r = b.subtract(a.operate(x));
576                 final double rnorm = r.getNorm();
577                 Assert.assertEquals("iteration performed (residual)",
578                     rnorm, evt.getNormOfResidual(),
579                     FastMath.max(1E-5 * rnorm, 1E-10));
580             }
581 
582             public void initializationPerformed(final IterationEvent e) {
583                 doTestNormOfResidual(e);
584             }
585 
586             public void iterationPerformed(final IterationEvent e) {
587                 doTestNormOfResidual(e);
588             }
589 
590             public void iterationStarted(final IterationEvent e) {
591                 doTestNormOfResidual(e);
592             }
593 
594             public void terminationPerformed(final IterationEvent e) {
595                 doTestNormOfResidual(e);
596             }
597         };
598         solver = new ConjugateGradient(maxIterations, 1E-10, true);
599         solver.getIterationManager().addIterationListener(listener);
600         final RealVector b = new ArrayRealVector(n);
601         for (int j = 0; j < n; j++) {
602             b.set(0.);
603             b.setEntry(j, 1.);
604             solver.solve(a, m, b);
605         }
606     }
607 }