ADMMQPKKT.java

/*
 * Licensed to the Hipparchus project under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.hipparchus.optim.nonlinear.vector.constrained;


import org.hipparchus.linear.ArrayRealVector;
import org.hipparchus.linear.DecompositionSolver;
import org.hipparchus.linear.EigenDecompositionSymmetric;
import org.hipparchus.linear.MatrixUtils;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.linear.RealVector;
import org.hipparchus.util.FastMath;

/** Alternative Direction Method of Multipliers Solver.
 * @since 3.1
 */
public class ADMMQPKKT implements KarushKuhnTuckerSolver<ADMMQPSolution> {

    /** Square matrix of weights for quadratic terms. */
    private RealMatrix H;

    /** Vector of weights for linear terms. */
    private RealVector q;

    /** Constraints coefficients matrix. */
    private RealMatrix A;

    /** Regularization term sigma for Karush–Kuhn–Tucker solver. */
    private double sigma;

    /** TBC. */
    private RealMatrix R;

    /** Inverse of R. */
    private RealMatrix Rinv;

    /** Lower bound. */
    private RealVector lb;

    /** Upper bound. */
    private RealVector ub;

    /** Alpha filter for ADMM iteration. */
    private double alpha;

    /** Constrained problem KKT matrix. */
    private RealMatrix M;

    /** Solver for M. */
    private DecompositionSolver dsX;

    /** Simple constructor.
     * <p>
     * BEWARE, nothing is initialized here, it is {@link #initialize(RealMatrix, RealMatrix,
     * RealVector, int, RealVector, RealVector, double, double, double) initialize} <em>must</em>
     * be called before using the instance.
     * </p>
     */
    ADMMQPKKT() {
        // nothing initialized yet!
    }

    /** {@inheritDoc} */
    @Override
    public ADMMQPSolution solve(RealVector b1, final RealVector b2) {
        RealVector z = dsX.solve(new ArrayRealVector((ArrayRealVector) b1,b2));
        return new ADMMQPSolution(z.getSubVector(0,b1.getDimension()), z.getSubVector(b1.getDimension(), b2.getDimension()));
    }

    /** Update steps
     * @param newSigma new regularization term sigma for Karush–Kuhn–Tucker solver
     * @param me number of equality constraints
     * @param rho new step size
     */
    public void updateSigmaRho(double newSigma, int me, double rho) {
        this.sigma = newSigma;
        this.H = H.add(MatrixUtils.createRealIdentityMatrix(H.getColumnDimension()).scalarMultiply(newSigma));
        createPenaltyMatrix(me, rho);
        M =  MatrixUtils.createRealMatrix(H.getRowDimension() + A.getRowDimension(),
                                          H.getRowDimension() + A.getRowDimension());
        M.setSubMatrix(H.getData(), 0,0);
        M.setSubMatrix(A.getData(), H.getRowDimension(),0);
        M.setSubMatrix(A.transpose().getData(), 0, H.getRowDimension());
        M.setSubMatrix(Rinv.scalarMultiply(-1.0).getData(), H.getRowDimension(),H.getRowDimension());
        dsX = new EigenDecompositionSymmetric(M).getSolver();
    }

    /** Initialize problem
     * @param newH square matrix of weights for quadratic term
     * @param newA constraints coefficients matrix
     * @param newQ TBD
     * @param me number of equality constraints
     * @param newLb lower bound
     * @param newUb upper bound
     * @param rho step size
     * @param newSigma regularization term sigma for Karush–Kuhn–Tucker solver
     * @param newAlpha alpha filter for ADMM iteration
     */
    public void initialize(RealMatrix newH, RealMatrix newA, RealVector newQ,
                           int me, RealVector newLb, RealVector newUb,
                           double rho, double newSigma, double newAlpha) {
        this.lb = newLb;
        this.ub = newUb;
        this.alpha = newAlpha;
        this.sigma = newSigma;
        this.H = newH.add(MatrixUtils.createRealIdentityMatrix(newH.getColumnDimension()).scalarMultiply(newSigma));
        this.A = newA.copy();
        this.q = newQ.copy();
        createPenaltyMatrix(me, rho);

        M =  MatrixUtils.createRealMatrix(newH.getRowDimension() + newA.getRowDimension(),
                                          newH.getRowDimension() + newA.getRowDimension());
        M.setSubMatrix(newH.getData(),0,0);
        M.setSubMatrix(newA.getData(),newH.getRowDimension(),0);
        M.setSubMatrix(newA.transpose().getData(),0,newH.getRowDimension());
        M.setSubMatrix(Rinv.scalarMultiply(-1.0).getData(),newH.getRowDimension(),newH.getRowDimension());
        dsX = new EigenDecompositionSymmetric(M).getSolver();
    }

    private void createPenaltyMatrix(int me, double rho) {
        this.R = MatrixUtils.createRealIdentityMatrix(A.getRowDimension());

        for (int i = 0; i < R.getRowDimension(); i++) {
            if (i < me) {
                R.setEntry(i, i, rho * 1000.0);

            } else {
                R.setEntry(i, i, rho);

            }
        }
        this.Rinv = MatrixUtils.inverse(R);
    }

    /** {@inheritDoc} */
    @Override
    public ADMMQPSolution iterate(RealVector... previousSol) {
        double onealfa = 1.0 - alpha;
        //SAVE OLD VALUE
        RealVector xold = previousSol[0].copy();
        RealVector yold = previousSol[1].copy();
        RealVector zold = previousSol[2].copy();

        //UPDATE RIGHT VECTOR
        RealVector b1 = previousSol[0].mapMultiply(sigma).subtract(q);
        RealVector b2 = previousSol[2].subtract(Rinv.operate(previousSol[1]));

        //SOLVE KKT SYSYEM
        ADMMQPSolution sol = solve(b1, b2);
        RealVector xtilde = sol.getX();
        RealVector vtilde = sol.getV();

        //UPDATE ZTILDE
        RealVector ztilde = zold.add(Rinv.operate(vtilde.subtract(yold)));
        //UPDATE X
        previousSol[0] = xtilde.mapMultiply(alpha).add(xold.mapMultiply(onealfa));

        //UPDATE Z PARTIAL
        RealVector zpartial = ztilde.mapMultiply(alpha).add(zold.mapMultiply(onealfa)).add(Rinv.operate(yold));

        //PROJECT ZPARTIAL AND UPDATE Z
        for (int j = 0; j < previousSol[2].getDimension(); j++) {
            previousSol[2].setEntry(j, FastMath.min(FastMath.max(zpartial.getEntry(j), lb.getEntry(j)), ub.getEntry(j)));
        }

        //UPDATE Y
        RealVector ytilde = ztilde.mapMultiply(alpha).add(zold.mapMultiply(onealfa).subtract(previousSol[2]));
        previousSol[1] = yold.add(R.operate(ytilde));

        return new ADMMQPSolution(previousSol[0], vtilde, previousSol[1], previousSol[2]);
    }

}