package ch.ethz.bicp;

import JSci.maths.ArrayMath;
import java.util.Random;
import org.apache.commons.math.DimensionMismatchException;
import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.BetaDistributionImpl;
import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.NotPositiveDefiniteMatrixException;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.random.CorrelatedRandomVectorGenerator;
import org.apache.commons.math.random.GaussianRandomGenerator;
import org.apache.commons.math.random.JDKRandomGenerator;
import org.apache.commons.math.special.Gamma;

/* loaded from: input_file:ch/ethz/bicp/BetaBinomialModel.class */
public class BetaBinomialModel {
    double[] ks;
    double[] ns;
    double[][] omegas;
    double[] alphas;
    double[] betas;
    double[][] pijs;
    double[] pis;
    boolean includePijs = false;
    boolean includePis = false;

    public double[] get_ks() {
        return this.ks;
    }

    public double[] get_ns() {
        return this.ns;
    }

    public double[][] get_omegas() {
        return this.omegas;
    }

    public double[] get_alphas() {
        return this.alphas;
    }

    public double[] get_betas() {
        return this.betas;
    }

    public double[][] get_pijs() {
        return this.pijs;
    }

    public double[] get_pis() {
        return this.pis;
    }

    public boolean isIncludePijs() {
        return this.includePijs;
    }

    public boolean isIncludePis() {
        return this.includePis;
    }

    public void setIncludePijs(boolean z) {
        this.includePijs = z;
    }

    public void setIncludePis(boolean z) {
        this.includePis = z;
    }

    public BetaBinomialModel(double[] dArr, double[] dArr2) {
        this.ks = dArr;
        this.ns = dArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v13, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v40, types: [double[], double[][]] */
    public void sample(int i, int i2) {
        Random random = new Random();
        BetaDistributionImpl betaDistributionImpl = new BetaDistributionImpl(1.0d, 1.0d);
        GaussianRandomGenerator gaussianRandomGenerator = new GaussianRandomGenerator(new JDKRandomGenerator());
        int length = this.ks.length;
        this.omegas = new double[2][i];
        this.alphas = new double[i];
        this.betas = new double[i];
        this.pijs = new double[length][i];
        this.pis = new double[i];
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix((double[][]) new double[]{new double[]{0.125d, 0.0d}, new double[]{0.0d, 0.125d}});
        double[] orig2trans = orig2trans(Math.abs(random.nextGaussian() * 10.0d), Math.abs(random.nextGaussian() * 10.0d));
        double[][] dArr = new double[length][i];
        for (int i3 = -i2; i3 < i; i3++) {
            try {
                double[] nextVector = new CorrelatedRandomVectorGenerator(orig2trans, createRealMatrix, 1.0E-12d * createRealMatrix.getNorm(), gaussianRandomGenerator).nextVector();
                if (random.nextDouble() < Math.min(1.0d, Math.exp(((loglikelihood(nextVector, this.ks, this.ns) + logprior(nextVector)) - loglikelihood(orig2trans, this.ks, this.ns)) - logprior(orig2trans)))) {
                    orig2trans = nextVector;
                }
                if (i3 >= 0) {
                    this.omegas[0][i3] = orig2trans[0];
                    this.omegas[1][i3] = orig2trans[1];
                }
                double[][] trans2orig = trans2orig(new double[]{new double[]{orig2trans[0]}, new double[]{orig2trans[1]}});
                double d = trans2orig[0][0];
                double d2 = trans2orig[1][0];
                if (isIncludePijs() && i3 >= 0) {
                    for (int i4 = 0; i4 < length; i4++) {
                        betaDistributionImpl.setAlpha(d + this.ks[i4]);
                        betaDistributionImpl.setBeta((d2 + this.ns[i4]) - this.ks[i4]);
                        try {
                            this.pijs[i4][i3] = betaDistributionImpl.sample();
                        } catch (MathException e) {
                            e.printStackTrace();
                        }
                    }
                }
                if (isIncludePis() && i3 >= 0) {
                    betaDistributionImpl.setAlpha(d);
                    betaDistributionImpl.setBeta(d2);
                    try {
                        this.pis[i3] = betaDistributionImpl.sample();
                    } catch (MathException e2) {
                        e2.printStackTrace();
                    }
                }
            } catch (DimensionMismatchException e3) {
                e3.printStackTrace();
                return;
            } catch (NotPositiveDefiniteMatrixException e4) {
                e4.printStackTrace();
                return;
            }
        }
        double[][] trans2orig2 = trans2orig(this.omegas);
        this.alphas = trans2orig2[0];
        this.betas = trans2orig2[1];
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    private double logprior(double[] dArr) {
        double[][] trans2orig = trans2orig(new double[]{new double[]{dArr[0]}, new double[]{dArr[1]}});
        double d = trans2orig[0][0];
        double d2 = trans2orig[1][0];
        return ((-2.5d) * Math.log(d + d2 + 1.0d)) + Math.log(d) + Math.log(d2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    private double loglikelihood(double[] dArr, double[] dArr2, double[] dArr3) {
        double[][] trans2orig = trans2orig(new double[]{new double[]{dArr[0]}, new double[]{dArr[1]}});
        double d = trans2orig[0][0];
        double d2 = trans2orig[1][0];
        return Stats.sum(ArrayMath.subtract(ArrayMath.add(ArrayMath.add(ArrayMath.add(lngamma(ArrayMath.add(dArr2, d)), Gamma.logGamma(d + d2)), lngamma(ArrayMath.add(ArrayMath.subtract(dArr3, dArr2), d2))), (-Gamma.logGamma(d)) - Gamma.logGamma(d2)), lngamma(ArrayMath.add(dArr3, d + d2))));
    }

    private double[] orig2trans(double d, double d2) {
        return new double[]{Math.log(d / d2), Math.log(d + d2)};
    }

    /* JADX WARN: Type inference failed for: r0v12, types: [double[], double[][]] */
    private double[][] trans2orig(double[][] dArr) {
        double[] exp = Stats.exp(dArr[0]);
        double[] exp2 = Stats.exp(dArr[1]);
        return new double[]{Stats.div(Stats.mult(exp2, exp), ArrayMath.add(exp, 1.0d)), Stats.div(exp2, ArrayMath.add(exp, 1.0d))};
    }

    private double[] lngamma(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = Gamma.logGamma(dArr[i]);
        }
        return dArr2;
    }
}
