/*
 * Decompiled with CFR 0.152.
 */
package jdplus.toolkit.base.core.ssf.arima;

import jdplus.toolkit.base.api.data.DoubleSeqCursor;
import jdplus.toolkit.base.core.arima.ArimaModel;
import jdplus.toolkit.base.core.arima.IArimaModel;
import jdplus.toolkit.base.core.arima.StationaryTransformation;
import jdplus.toolkit.base.core.data.DataBlock;
import jdplus.toolkit.base.core.data.DataBlockIterator;
import jdplus.toolkit.base.core.data.DataWindow;
import jdplus.toolkit.base.core.math.linearfilters.BackFilter;
import jdplus.toolkit.base.core.math.matrices.FastMatrix;
import jdplus.toolkit.base.core.math.matrices.SymmetricMatrix;
import jdplus.toolkit.base.core.math.polynomials.Polynomial;
import jdplus.toolkit.base.core.math.polynomials.RationalFunction;
import jdplus.toolkit.base.core.math.polynomials.UnitRoots;
import jdplus.toolkit.base.core.ssf.ISsfDynamics;
import jdplus.toolkit.base.core.ssf.ISsfInitialization;
import jdplus.toolkit.base.core.ssf.ISsfLoading;
import jdplus.toolkit.base.core.ssf.SsfException;
import jdplus.toolkit.base.core.ssf.StateComponent;
import jdplus.toolkit.base.core.ssf.UpdateInformation;
import jdplus.toolkit.base.core.ssf.arima.Rw;
import jdplus.toolkit.base.core.ssf.basic.Loading;
import jdplus.toolkit.base.core.ssf.ckms.CkmsDiffuseInitializer;
import jdplus.toolkit.base.core.ssf.ckms.CkmsFilter;
import jdplus.toolkit.base.core.ssf.ckms.CkmsState;
import jdplus.toolkit.base.core.ssf.univariate.ISsf;
import jdplus.toolkit.base.core.ssf.univariate.ISsfData;
import jdplus.toolkit.base.core.ssf.univariate.OrdinaryFilter;
import jdplus.toolkit.base.core.ssf.univariate.Ssf;
import lombok.Generated;

public final class SsfArima {
    public static int dim(IArimaModel arima) {
        return Math.max(arima.getArOrder(), arima.getMaOrder() + 1);
    }

    public static ISsfLoading defaultLoading() {
        return Loading.fromPosition(0);
    }

    public static StateComponent differencingSsf(int d, double var) {
        if (d <= 0) {
            return null;
        }
        if (d == 1) {
            return Rw.of(var, false);
        }
        ArimaModel model = new ArimaModel(BackFilter.ONE, new BackFilter(UnitRoots.D(1, d)), BackFilter.ONE, var);
        return SsfArima.ofNonStationary(model);
    }

    public static StateComponent stateComponent(IArimaModel arima) {
        if (arima.isStationary()) {
            return SsfArima.ofStationary(arima);
        }
        return SsfArima.ofNonStationary(arima);
    }

    public static Ssf ssf(IArimaModel arima) {
        double var = arima.getInnovationVariance();
        if (var == 0.0) {
            throw new SsfException("Invalid stochastic model");
        }
        return Ssf.of(SsfArima.stateComponent(arima), SsfArima.defaultLoading());
    }

    public static CkmsFilter.IFastFilterInitializer fastInitializer(IArimaModel arima) {
        return (state, upd, ssf, data) -> {
            if (arima.isStationary()) {
                return SsfArima.stInitialize(state, upd, arima, ssf, data);
            }
            return SsfArima.dInitialize(state, upd, arima, ssf, data);
        };
    }

    private static int stInitialize(CkmsState state, UpdateInformation upd, IArimaModel arima, ISsf ssf, ISsfData data) {
        int n = ssf.getStateDim();
        double[] values = arima.getAutoCovarianceFunction().values(n);
        DataBlock M = upd.M();
        DataBlock L = state.l();
        upd.M().copyFrom(values, 0);
        L.copy(M);
        ssf.dynamics().TX(0, L);
        upd.setVariance(values[0]);
        return 0;
    }

    private static int dInitialize(CkmsState state, UpdateInformation upd, IArimaModel arima, ISsf ssf, ISsfData data) {
        return new CkmsDiffuseInitializer(SsfArima.diffuseInitializer(arima)).initializeFilter(state, upd, ssf, data);
    }

    private static OrdinaryFilter.Initializer diffuseInitializer(IArimaModel arima) {
        return (state, ssf, data) -> {
            ArimaInitialization initialization = (ArimaInitialization)ssf.initialization();
            int nr = ssf.getStateDim();
            int nd = initialization.getDiffuseDim();
            FastMatrix A = FastMatrix.make(nr + nd, nd);
            double[] dif = arima.getNonStationaryAr().asPolynomial().toArray();
            for (int j = 0; j < nd; ++j) {
                A.set(j, j, 1.0);
                for (int i = nd; i < nd + nr; ++i) {
                    double c = 0.0;
                    for (int k = 1; k <= nd; ++k) {
                        c -= dif[k] * A.get(i - k, j);
                    }
                    A.set(i, j, c);
                }
            }
            for (int i = 0; i < nr; ++i) {
                double c = 0.0;
                for (int j = 0; j < nd; ++j) {
                    c += A.get(i + nd, j) * data.get(j);
                }
                state.a().set(i, c);
            }
            FastMatrix stV = FastMatrix.square(nr);
            ArimaInitialization.stVar(stV, initialization.stpsi, initialization.stacgf, initialization.data.var);
            FastMatrix K = FastMatrix.square(nr);
            ArimaInitialization.sigma(K, initialization.dif);
            SymmetricMatrix.XSXt(stV, K, state.P());
            return nd;
        };
    }

    private static StateComponent ofStationary(IArimaModel arima) {
        ArmaInitialization initialization = new ArmaInitialization(arima);
        ArimaDynamics dynamics = new ArimaDynamics(initialization.data);
        return new StateComponent(initialization, dynamics);
    }

    private static StateComponent ofNonStationary(IArimaModel arima) {
        ArimaInitialization initialization = new ArimaInitialization(arima);
        ArimaDynamics dynamics = new ArimaDynamics(initialization.data);
        return new StateComponent(initialization, dynamics);
    }

    @Generated
    private SsfArima() {
        throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
    }

    static class ArmaInitialization
    implements ISsfInitialization {
        final ArimaData data;
        private final DataBlock acgf;
        private final FastMatrix P0;

        ArmaInitialization(IArimaModel arima) {
            this.data = new ArimaData(arima);
            this.acgf = DataBlock.of(arima.getAutoCovarianceFunction().values(this.data.dim));
            this.P0 = ArmaInitialization.p0(this.data.var, this.acgf, this.data.psi);
        }

        static FastMatrix v(double var, DataBlock psi) {
            FastMatrix v = SymmetricMatrix.xxt(psi);
            v.mul(var);
            return v;
        }

        private static FastMatrix p0(double var, DataBlock acgf, DataBlock psi) {
            int dim = acgf.length();
            FastMatrix P = FastMatrix.square(dim);
            P.column(0).copy(acgf);
            for (int j = 0; j < dim - 1; ++j) {
                double psij = psi.get(j);
                P.set(j + 1, j + 1, P.get(j, j) - psij * psij * var);
                for (int k = 0; k < j; ++k) {
                    P.set(j + 1, k + 1, P.get(j, k) - psij * psi.get(k) * var);
                }
            }
            SymmetricMatrix.fromLower(P);
            return P;
        }

        @Override
        public boolean isDiffuse() {
            return false;
        }

        @Override
        public int getDiffuseDim() {
            return 0;
        }

        @Override
        public void diffuseConstraints(FastMatrix b) {
        }

        @Override
        public void a0(DataBlock a0) {
        }

        @Override
        public void Pf0(FastMatrix pf0) {
            pf0.copy(this.P0);
        }

        @Override
        public void Pi0(FastMatrix pi0) {
        }

        @Override
        public int getStateDim() {
            return this.data.dim;
        }
    }

    static class ArimaDynamics
    implements ISsfDynamics {
        private final ArimaData data;
        private final DataBlock z;
        private final FastMatrix V;

        public ArimaDynamics(ArimaData data) {
            this.data = data;
            this.z = DataBlock.make(data.dim);
            this.V = ArmaInitialization.v(data.var, data.psi);
        }

        @Override
        public void T(int pos, FastMatrix tr) {
            this.T(tr);
        }

        public void T(FastMatrix tr) {
            int i;
            tr.set(0.0);
            for (i = 1; i < this.data.dim; ++i) {
                tr.set(i - 1, i, 1.0);
            }
            for (i = 1; i < this.data.phi.length; ++i) {
                tr.set(this.data.dim - 1, this.data.dim - i, -this.data.phi[i]);
            }
        }

        @Override
        public void TVT(int pos, FastMatrix vm) {
            if (this.data.phi.length == 1) {
                vm.upLeftShift(1);
                vm.column(this.data.dim - 1).set(0.0);
                vm.row(this.data.dim - 1).set(0.0);
            } else {
                this.z.set(0.0);
                DataBlockIterator cols = vm.reverseColumnsIterator();
                for (int i = 1; i < this.data.phi.length; ++i) {
                    this.z.addAY(-this.data.phi[i], cols.next());
                }
                this.TX(pos, this.z);
                vm.upLeftShift(1);
                vm.column(this.data.dim - 1).copy(this.z);
                vm.row(this.data.dim - 1).copy(this.z);
            }
        }

        @Override
        public void TX(int pos, DataBlock x) {
            double tx = 0.0;
            if (this.data.phi.length > 1) {
                DoubleSeqCursor reader = x.reverseReader();
                for (int i = 1; i < this.data.phi.length; ++i) {
                    tx -= this.data.phi[i] * reader.getAndNext();
                }
            }
            x.bshift(1);
            x.set(this.data.dim - 1, tx);
        }

        @Override
        public void XT(int pos, DataBlock x) {
            double last = -x.get(this.data.dim - 1);
            x.fshift(1);
            x.set(0, 0.0);
            if (last != 0.0) {
                int i = 1;
                int j = this.data.dim - 1;
                while (i < this.data.phi.length) {
                    if (this.data.phi[i] != 0.0) {
                        x.add(j, last * this.data.phi[i]);
                    }
                    ++i;
                    --j;
                }
            }
        }

        @Override
        public boolean isTimeInvariant() {
            return true;
        }

        @Override
        public boolean areInnovationsTimeInvariant() {
            return true;
        }

        @Override
        public int getInnovationsDim() {
            return 1;
        }

        @Override
        public void V(int pos, FastMatrix qm) {
            qm.copy(this.V);
        }

        @Override
        public void S(int pos, FastMatrix sm) {
            sm.column(0).copy(this.data.psi);
            if (this.data.se != 1.0) {
                sm.mul(this.data.se);
            }
        }

        @Override
        public boolean hasInnovations(int pos) {
            return true;
        }

        @Override
        public void addV(int pos, FastMatrix p) {
            p.add(this.V);
        }

        @Override
        public void XS(int pos, DataBlock x, DataBlock sx) {
            double a = x.dot(this.data.psi) * this.data.se;
            sx.set(0, a);
        }

        @Override
        public void addSU(int pos, DataBlock x, DataBlock u) {
            double a = u.get(0) * this.data.se;
            x.addAY(a, this.data.psi);
        }
    }

    static class ArimaData {
        final int dim;
        final double var;
        final double se;
        final double[] phi;
        final DataBlock psi;

        ArimaData(IArimaModel arima) {
            this.var = arima.getInnovationVariance();
            Polynomial ar = arima.getAr().asPolynomial();
            Polynomial ma = arima.getMa().asPolynomial();
            this.phi = ar.toArray();
            this.dim = Math.max(ar.degree(), ma.degree() + 1);
            this.psi = DataBlock.of(RationalFunction.of(ma, ar).coefficients(this.dim));
            this.se = Math.sqrt(this.var);
        }
    }

    static class ArimaInitialization
    implements ISsfInitialization {
        final ArimaData data;
        final double[] dif;
        private final DataBlock stpsi;
        private final DataBlock stacgf;
        private final FastMatrix P0;

        ArimaInitialization(IArimaModel arima) {
            this.data = new ArimaData(arima);
            StationaryTransformation starima = arima.stationaryTransformation();
            this.dif = starima.getUnitRoots().asPolynomial().toArray();
            this.stacgf = DataBlock.of(((IArimaModel)starima.getStationaryModel()).getAutoCovarianceFunction().values(this.data.dim));
            RationalFunction rf = ((IArimaModel)starima.getStationaryModel()).getPsiWeights().getRationalFunction();
            this.stpsi = DataBlock.of(rf.coefficients(this.data.dim));
            FastMatrix stvar = ArmaInitialization.p0(this.data.var, this.stacgf, this.stpsi);
            FastMatrix L = FastMatrix.square(this.data.dim);
            ArimaInitialization.sigma(L, this.dif);
            this.P0 = SymmetricMatrix.XSXt(stvar, L);
        }

        static void B0(FastMatrix b, double[] d) {
            int nd = d.length - 1;
            if (nd == 0) {
                return;
            }
            int nr = b.getRowsCount();
            b.diagonal().set(1.0);
            if (nd == nr) {
                return;
            }
            DataBlock D = DataBlock.of(d, d.length - 1, 0, -1);
            for (int i = 0; i < nd; ++i) {
                DataBlock C = b.column(i);
                DataWindow R = C.window(0, nd);
                C.set(nd, -R.get().dot(D));
                for (int k = nd + 1; k < nr; ++k) {
                    C.set(k, -R.move(1).dot(D));
                }
            }
        }

        static void sigma(FastMatrix X, double[] dif) {
            int n = X.getRowsCount();
            double[] lambda = RationalFunction.of(Polynomial.ONE, Polynomial.of(dif)).coefficients(n);
            for (int j = 0; j < n; ++j) {
                for (int k = 0; k <= j; ++k) {
                    X.set(j, k, lambda[j - k]);
                }
            }
        }

        static void stVar(FastMatrix stV, DataBlock stpsi, DataBlock stacgf, double var) {
            int n = stV.getRowsCount();
            stV.column(0).copy(stacgf);
            for (int j = 0; j < n - 1; ++j) {
                double stpsij = stpsi.get(j);
                stV.set(j + 1, j + 1, stV.get(j, j) - stpsij * stpsij * var);
                for (int k = 0; k < j; ++k) {
                    stV.set(j + 1, k + 1, stV.get(j, k) - stpsij * stpsi.get(k) * var);
                }
            }
            SymmetricMatrix.fromLower(stV);
        }

        @Override
        public boolean isDiffuse() {
            return this.dif.length > 1;
        }

        @Override
        public int getDiffuseDim() {
            return this.dif.length - 1;
        }

        @Override
        public void diffuseConstraints(FastMatrix b) {
            int d = this.dif.length - 1;
            if (d == 0) {
                return;
            }
            ArimaInitialization.B0(b, this.dif);
        }

        @Override
        public void a0(DataBlock a0) {
        }

        @Override
        public void Pf0(FastMatrix pf0) {
            pf0.copy(this.P0);
        }

        @Override
        public void Pi0(FastMatrix pi0) {
            FastMatrix B = FastMatrix.make(this.data.dim, this.dif.length - 1);
            ArimaInitialization.B0(B, this.dif);
            SymmetricMatrix.XXt(B, pi0);
        }

        @Override
        public int getStateDim() {
            return this.data.dim;
        }
    }
}

