/*
 * Decompiled with CFR 0.152.
 */
package no.uib.cipr.matrix.distributed.test;

import java.util.Arrays;
import junit.framework.TestCase;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.UpperSymmDenseMatrix;
import no.uib.cipr.matrix.Vector;
import no.uib.cipr.matrix.distributed.CollectiveCommunications;
import no.uib.cipr.matrix.distributed.Communicator;
import no.uib.cipr.matrix.distributed.DistColMatrix;
import no.uib.cipr.matrix.distributed.DistRowMatrix;
import no.uib.cipr.matrix.distributed.DistVector;
import no.uib.cipr.matrix.sparse.BiCGstab;
import no.uib.cipr.matrix.sparse.CG;
import no.uib.cipr.matrix.sparse.DefaultIterationMonitor;
import no.uib.cipr.matrix.sparse.GMRES;
import no.uib.cipr.matrix.sparse.IterativeSolver;
import no.uib.cipr.matrix.sparse.IterativeSolverNotConvergedException;
import no.uib.cipr.matrix.test.Utilities;

public class DistIterativeSolverTest
extends TestCase {
    CollectiveCommunications coll;
    DenseMatrix A_unsymm;
    UpperSymmDenseMatrix A_symm;
    DenseVector x;
    DenseVector b_unsymm;
    DenseVector b_symm;
    int[] localLength;
    double[] output;

    protected void setUp() throws Exception {
        int size = Utilities.getInt((int)1, (int)8);
        this.coll = new CollectiveCommunications(size);
        int n = Utilities.getInt((int)size, (int)250);
        this.A_unsymm = new DenseMatrix(n, n);
        this.A_symm = new UpperSymmDenseMatrix(n);
        Utilities.populate((Matrix)this.A_unsymm);
        Utilities.upperPopulate((Matrix)this.A_unsymm);
        double shift = 10.0;
        do {
            Utilities.addDiagonal((Matrix)this.A_unsymm, (double)shift);
        } while (Utilities.singular((Matrix)this.A_unsymm));
        do {
            Utilities.addDiagonal((Matrix)this.A_symm, (double)shift);
        } while (!Utilities.spd((Matrix)this.A_symm));
        this.x = new DenseVector(n);
        this.b_unsymm = this.x.copy();
        this.b_symm = this.x.copy();
        Utilities.populate((Vector)this.x);
        this.A_unsymm.mult(this.x, this.b_unsymm);
        this.A_symm.mult(this.x, this.b_symm);
        this.output = new double[n];
        this.localLength = new int[size];
        Arrays.fill(this.localLength, n / size);
        int sum = n;
        for (int l : this.localLength) {
            sum -= l;
        }
        int n2 = size - 1;
        this.localLength[n2] = this.localLength[n2] + sum;
    }

    public void testRowGMRES_1() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new GMRESRowDistIterativeSolver(i, Vector.Norm.One));
        }
        this.compare(t);
    }

    public void testRowGMRES_2() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new GMRESRowDistIterativeSolver(i, Vector.Norm.Two));
        }
        this.compare(t);
    }

    public void testRowGMRES_inf() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new GMRESRowDistIterativeSolver(i, Vector.Norm.Infinity));
        }
        this.compare(t);
    }

    public void testRowBiCGstab_1() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new BiCGstabRowDistIterativeSolver(i, Vector.Norm.One));
        }
        this.compare(t);
    }

    public void testRowBiCGstab_2() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new BiCGstabRowDistIterativeSolver(i, Vector.Norm.Two));
        }
        this.compare(t);
    }

    public void testRowBiCGstab_inf() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new BiCGstabRowDistIterativeSolver(i, Vector.Norm.Infinity));
        }
        this.compare(t);
    }

    public void testColumnGMRES_1() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new GMRESColumnDistIterativeSolver(i, Vector.Norm.One));
        }
        this.compare(t);
    }

    public void testColumnGMRES_2() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new GMRESColumnDistIterativeSolver(i, Vector.Norm.Two));
        }
        this.compare(t);
    }

    public void testColumnGMRES_inf() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new GMRESColumnDistIterativeSolver(i, Vector.Norm.Infinity));
        }
        this.compare(t);
    }

    public void testColumnBiCGstab_1() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new BiCGstabColumnDistIterativeSolver(i, Vector.Norm.One));
        }
        this.compare(t);
    }

    public void testColumnBiCGstab_2() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new BiCGstabColumnDistIterativeSolver(i, Vector.Norm.Two));
        }
        this.compare(t);
    }

    public void testColumnBiCGstab_inf() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new BiCGstabColumnDistIterativeSolver(i, Vector.Norm.Infinity));
        }
        this.compare(t);
    }

    public void testRowCG_1() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new CGRowDistIterativeSolver(i, Vector.Norm.One));
        }
        this.compare(t);
    }

    public void testRowCG_2() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new CGRowDistIterativeSolver(i, Vector.Norm.Two));
        }
        this.compare(t);
    }

    public void testRowCG_inf() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new CGRowDistIterativeSolver(i, Vector.Norm.Infinity));
        }
        this.compare(t);
    }

    public void testColumnCG_1() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new CGColumnDistIterativeSolver(i, Vector.Norm.One));
        }
        this.compare(t);
    }

    public void testColumnCG_2() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new CGColumnDistIterativeSolver(i, Vector.Norm.Two));
        }
        this.compare(t);
    }

    public void testColumnCG_inf() throws InterruptedException {
        Thread[] t = new Thread[this.coll.size()];
        for (int i = 0; i < t.length; ++i) {
            t[i] = new Thread(new CGColumnDistIterativeSolver(i, Vector.Norm.Infinity));
        }
        this.compare(t);
    }

    private void compare(Thread[] t) throws InterruptedException {
        for (Thread ti : t) {
            ti.start();
        }
        for (Thread ti : t) {
            ti.join();
        }
        for (int i = 0; i < this.x.size(); ++i) {
            DistIterativeSolverTest.assertEquals((double)this.x.get(i), (double)this.output[i], (double)1.0E-10);
        }
    }

    private class CGColumnDistIterativeSolver
    extends SymmColumnDistIterativeSolver {
        public CGColumnDistIterativeSolver(int rank, Vector.Norm norm) {
            super(rank, norm);
        }

        protected IterativeSolver createSolver(Vector x) {
            return new CG(x);
        }
    }

    private class CGRowDistIterativeSolver
    extends SymmRowDistIterativeSolver {
        public CGRowDistIterativeSolver(int rank, Vector.Norm norm) {
            super(rank, norm);
        }

        protected IterativeSolver createSolver(Vector x) {
            return new CG(x);
        }
    }

    private class BiCGstabColumnDistIterativeSolver
    extends UnSymmColumnDistIterativeSolver {
        public BiCGstabColumnDistIterativeSolver(int rank, Vector.Norm norm) {
            super(rank, norm);
        }

        protected IterativeSolver createSolver(Vector x) {
            return new BiCGstab(x);
        }
    }

    private class GMRESColumnDistIterativeSolver
    extends UnSymmColumnDistIterativeSolver {
        public GMRESColumnDistIterativeSolver(int rank, Vector.Norm norm) {
            super(rank, norm);
        }

        protected IterativeSolver createSolver(Vector x) {
            return new GMRES(x);
        }
    }

    private class BiCGstabRowDistIterativeSolver
    extends UnSymmRowDistIterativeSolver {
        public BiCGstabRowDistIterativeSolver(int rank, Vector.Norm norm) {
            super(rank, norm);
        }

        protected IterativeSolver createSolver(Vector x) {
            return new BiCGstab(x);
        }
    }

    private class GMRESRowDistIterativeSolver
    extends UnSymmRowDistIterativeSolver {
        public GMRESRowDistIterativeSolver(int rank, Vector.Norm norm) {
            super(rank, norm);
        }

        protected IterativeSolver createSolver(Vector x) {
            return new GMRES(x);
        }
    }

    private abstract class UnSymmColumnDistIterativeSolver
    extends ColumnDistIterativeSolver {
        public UnSymmColumnDistIterativeSolver(int rank, Vector.Norm norm) {
            super(rank, norm);
        }

        protected void populateMatrix(Matrix A) {
            int[] m = this.getColumnOwnerships(A);
            for (int i = 0; i < A.numRows(); ++i) {
                for (int j = m[this.rank]; j < m[this.rank + 1]; ++j) {
                    A.set(i, j, DistIterativeSolverTest.this.A_unsymm.get(i, j));
                }
            }
        }

        protected double getVectorEntry(int i) {
            return DistIterativeSolverTest.this.b_unsymm.get(i);
        }
    }

    private abstract class SymmColumnDistIterativeSolver
    extends ColumnDistIterativeSolver {
        public SymmColumnDistIterativeSolver(int rank, Vector.Norm norm) {
            super(rank, norm);
        }

        protected void populateMatrix(Matrix A) {
            int[] m = this.getColumnOwnerships(A);
            for (int i = 0; i < A.numRows(); ++i) {
                for (int j = m[this.rank]; j < m[this.rank + 1]; ++j) {
                    A.set(i, j, DistIterativeSolverTest.this.A_symm.get(i, j));
                }
            }
        }

        protected double getVectorEntry(int i) {
            return DistIterativeSolverTest.this.b_symm.get(i);
        }
    }

    private abstract class UnSymmRowDistIterativeSolver
    extends RowDistIterativeSolver {
        public UnSymmRowDistIterativeSolver(int rank, Vector.Norm norm) {
            super(rank, norm);
        }

        protected void populateMatrix(Matrix A) {
            int[] n = this.getRowOwnerships(A);
            for (int i = n[this.rank]; i < n[this.rank + 1]; ++i) {
                for (int j = 0; j < A.numColumns(); ++j) {
                    A.set(i, j, DistIterativeSolverTest.this.A_unsymm.get(i, j));
                }
            }
        }

        protected double getVectorEntry(int i) {
            return DistIterativeSolverTest.this.b_unsymm.get(i);
        }
    }

    private abstract class SymmRowDistIterativeSolver
    extends RowDistIterativeSolver {
        public SymmRowDistIterativeSolver(int rank, Vector.Norm norm) {
            super(rank, norm);
        }

        protected void populateMatrix(Matrix A) {
            int[] n = this.getRowOwnerships(A);
            for (int i = n[this.rank]; i < n[this.rank + 1]; ++i) {
                for (int j = 0; j < A.numColumns(); ++j) {
                    A.set(i, j, DistIterativeSolverTest.this.A_symm.get(i, j));
                }
            }
        }

        protected double getVectorEntry(int i) {
            return DistIterativeSolverTest.this.b_symm.get(i);
        }
    }

    private abstract class ColumnDistIterativeSolver
    extends DistIterativeSolver {
        public ColumnDistIterativeSolver(int rank, Vector.Norm norm) {
            super(rank, norm);
        }

        protected Matrix createMatrix(Communicator comm) {
            int n = DistIterativeSolverTest.this.x.size();
            DenseMatrix Al = new DenseMatrix(DistIterativeSolverTest.this.localLength[this.rank], DistIterativeSolverTest.this.localLength[this.rank]);
            DenseMatrix Bl = new DenseMatrix(n, DistIterativeSolverTest.this.localLength[this.rank]);
            return new DistColMatrix(n, n, comm, Al, Bl);
        }

        protected int[] getColumnOwnerships(Matrix A) {
            return ((DistColMatrix)A).getColumnOwnerships();
        }

        protected int[] getRowOwnerships(Matrix A) {
            return ((DistColMatrix)A).getRowOwnerships();
        }
    }

    private abstract class RowDistIterativeSolver
    extends DistIterativeSolver {
        public RowDistIterativeSolver(int rank, Vector.Norm norm) {
            super(rank, norm);
        }

        protected Matrix createMatrix(Communicator comm) {
            int n = DistIterativeSolverTest.this.x.size();
            DenseMatrix Al = new DenseMatrix(DistIterativeSolverTest.this.localLength[this.rank], DistIterativeSolverTest.this.localLength[this.rank]);
            DenseMatrix Bl = new DenseMatrix(DistIterativeSolverTest.this.localLength[this.rank], n);
            return new DistRowMatrix(n, n, comm, Al, Bl);
        }

        protected int[] getColumnOwnerships(Matrix A) {
            return ((DistRowMatrix)A).getColumnOwnerships();
        }

        protected int[] getRowOwnerships(Matrix A) {
            return ((DistRowMatrix)A).getRowOwnerships();
        }
    }

    private abstract class DistIterativeSolver
    implements Runnable {
        protected int rank;
        protected Vector.Norm norm;

        public DistIterativeSolver(int rank, Vector.Norm norm) {
            this.rank = rank;
            this.norm = norm;
        }

        protected abstract Matrix createMatrix(Communicator var1);

        protected abstract int[] getRowOwnerships(Matrix var1);

        protected abstract int[] getColumnOwnerships(Matrix var1);

        protected abstract void populateMatrix(Matrix var1);

        protected abstract double getVectorEntry(int var1);

        protected abstract IterativeSolver createSolver(Vector var1);

        public void run() {
            Communicator comm = DistIterativeSolverTest.this.coll.createCommunicator(this.rank);
            Matrix A = this.createMatrix(comm);
            this.populateMatrix(A);
            DenseVector bl = new DenseVector(DistIterativeSolverTest.this.localLength[this.rank]);
            DistVector b_dist = new DistVector(DistIterativeSolverTest.this.x.size(), comm, bl);
            DistVector x_dist = b_dist.copy();
            int[] n = this.getRowOwnerships(A);
            for (int i = n[this.rank]; i < n[this.rank + 1]; ++i) {
                b_dist.set(i, this.getVectorEntry(i));
            }
            IterativeSolver solver = this.createSolver(b_dist);
            DefaultIterationMonitor monitor = new DefaultIterationMonitor(1000, 1.0E-50, 1.0E-12, 100000.0);
            monitor.setNormType(this.norm);
            solver.setIterationMonitor(monitor);
            try {
                solver.solve(A, b_dist, x_dist);
            }
            catch (IterativeSolverNotConvergedException e) {
                // empty catch block
            }
            for (int i = n[this.rank]; i < n[this.rank + 1]; ++i) {
                DistIterativeSolverTest.this.output[i] = x_dist.get(i);
            }
        }
    }
}

