package cc.mallet.optimize.tests;

import cc.mallet.optimize.Optimizable;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.util.Random;
import java.util.logging.Logger;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:cc/mallet/optimize/tests/TestOptimizable.class */
public class TestOptimizable extends TestCase {
    private static Logger logger;
    private static int numComponents;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:cc/mallet/optimize/tests/TestOptimizable$SimplePoly.class */
    static class SimplePoly implements Optimizable.ByGradientValue {
        double[] params = new double[1];

        SimplePoly() {
        }

        @Override // cc.mallet.optimize.Optimizable
        public void getParameters(double[] dArr) {
            dArr[0] = this.params[0];
        }

        @Override // cc.mallet.optimize.Optimizable
        public int getNumParameters() {
            return 1;
        }

        @Override // cc.mallet.optimize.Optimizable
        public double getParameter(int i) {
            return this.params[0];
        }

        @Override // cc.mallet.optimize.Optimizable
        public void setParameters(double[] dArr) {
            this.params[0] = dArr[0];
        }

        @Override // cc.mallet.optimize.Optimizable
        public void setParameter(int i, double d) {
            this.params[i] = d;
        }

        @Override // cc.mallet.optimize.Optimizable.ByGradientValue
        public double getValue() {
            return (((3.0d * this.params[0]) * this.params[0]) - (5.0d * this.params[0])) + 2.0d;
        }

        @Override // cc.mallet.optimize.Optimizable.ByGradientValue
        public void getValueGradient(double[] dArr) {
            dArr[0] = (3.0d * this.params[0]) - 5.0d;
        }
    }

    /* loaded from: input_file:cc/mallet/optimize/tests/TestOptimizable$WrongSimplePoly.class */
    static class WrongSimplePoly extends SimplePoly {
        WrongSimplePoly() {
        }

        @Override // cc.mallet.optimize.tests.TestOptimizable.SimplePoly, cc.mallet.optimize.Optimizable.ByGradientValue
        public void getValueGradient(double[] dArr) {
            dArr[0] = 3.0d * this.params[0];
        }
    }

    public TestOptimizable(String str) {
        super(str);
    }

    public static void setNumComponents(int i) {
        numComponents = i;
    }

    public static boolean testGetSetParameters(Optimizable optimizable) {
        System.out.println("TestMaximizable testGetSetParameters");
        double[] dArr = new double[optimizable.getNumParameters()];
        optimizable.getParameters(dArr);
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = i;
        }
        optimizable.setParameters(dArr);
        MatrixOps.setAll(dArr, 0.0d);
        optimizable.getParameters(dArr);
        for (int i2 = 0; i2 < dArr.length; i2++) {
            assertTrue(dArr[i2] == ((double) i2));
        }
        return true;
    }

    public static double testValueAndGradientInDirection(Optimizable.ByGradientValue byGradientValue, double[] dArr) {
        int numParameters = byGradientValue.getNumParameters();
        if (!$assertionsDisabled && numParameters != dArr.length) {
            throw new AssertionError();
        }
        double[] dArr2 = new double[numParameters];
        double[] dArr3 = new double[numParameters];
        double[] dArr4 = (double[]) dArr.clone();
        System.arraycopy(dArr, 0, dArr4, 0, numParameters);
        MatrixOps.absNormalize(dArr4);
        double value = byGradientValue.getValue();
        double[] dArr5 = new double[numParameters];
        byGradientValue.getParameters(dArr3);
        byGradientValue.getParameters(dArr2);
        byGradientValue.getValueGradient(dArr5);
        double dotProduct = MatrixOps.dotProduct(dArr5, dArr4);
        double absNorm = 0.1d / MatrixOps.absNorm(dArr5);
        double d = 1.0E-5d * dotProduct;
        System.out.println("epsilon = " + absNorm + " tolerance=" + d);
        MatrixOps.plusEquals(dArr3, dArr4, absNorm);
        byGradientValue.setParameters(dArr3);
        double value2 = byGradientValue.getValue();
        double d2 = (value2 - value) / absNorm;
        System.out.println("value=" + value + " epsilon=" + absNorm + " epsValue=" + value2 + " slope = " + d2 + " gradient=" + dotProduct);
        if (!$assertionsDisabled && Double.isNaN(d2)) {
            throw new AssertionError();
        }
        double abs = Math.abs(d2 - dotProduct);
        logger.info("TestMaximizable : slope tolerance = " + d + ": gradient slope = " + dotProduct + ", value+epsilon slope = " + d2 + ": slope difference = " + abs);
        byGradientValue.setParameters(dArr2);
        if ($assertionsDisabled || Math.abs(abs) < d) {
            return abs;
        }
        throw new AssertionError("Slope difference " + abs + " is greater than tolerance " + d);
    }

    public static double testValueAndGradientCurrentParameters(Optimizable.ByGradientValue byGradientValue) {
        double[] dArr = new double[byGradientValue.getNumParameters()];
        double value = byGradientValue.getValue();
        double[] dArr2 = new double[byGradientValue.getNumParameters()];
        double[] dArr3 = new double[byGradientValue.getNumParameters()];
        byGradientValue.getParameters(dArr);
        byGradientValue.getValueGradient(dArr2);
        byGradientValue.getValueGradient(dArr3);
        double max = 0.1d / Math.max(0.1d, MatrixOps.absNorm(dArr2));
        double d = max * 5.0d;
        System.out.println("epsilon = " + max + " tolerance=" + d);
        int i = -1;
        if (numComponents > 0) {
            i = Math.max(1, dArr.length / numComponents);
            logger.info("Will check every " + i + "-th component.");
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr.length < i || i2 % i == 0) {
                double d2 = dArr[i2];
                dArr[i2] = d2 + max;
                byGradientValue.setParameters(dArr);
                double value2 = byGradientValue.getValue();
                double d3 = (value2 - value) / max;
                System.out.println("value=" + value + " epsValue=" + value2 + " slope[" + i2 + "] = " + d3 + " gradient[]=" + dArr2[i2]);
                if (!$assertionsDisabled && Double.isNaN(d3)) {
                    throw new AssertionError();
                }
                logger.info("TestMaximizable checking singleIndex " + i2 + ": gradient slope = " + dArr2[i2] + ", value+epsilon slope = " + d3 + ": slope difference = " + (d3 - dArr2[i2]));
                dArr3[i2] = d3;
                dArr[i2] = d2;
            }
        }
        System.out.println("analyticGradient.twoNorm = " + MatrixOps.twoNorm(dArr2));
        System.out.println("empiricalGradient.twoNorm = " + MatrixOps.twoNorm(dArr3));
        MatrixOps.timesEquals(dArr2, 1.0d / MatrixOps.twoNorm(dArr2));
        MatrixOps.timesEquals(dArr3, 1.0d / MatrixOps.twoNorm(dArr3));
        double dotProduct = MatrixOps.dotProduct(dArr2, dArr3);
        if (Maths.almostEquals(dotProduct, 1.0d)) {
            logger.info("TestMaximizable angle is zero.");
            return 0.0d;
        }
        double acos = Math.acos(dotProduct);
        logger.info("TestMaximizable angle = " + acos);
        if (Math.abs(acos) > d) {
            throw new IllegalStateException("Gradient/Value mismatch: angle=" + acos + " tol: " + d);
        }
        if (Double.isNaN(acos)) {
            throw new IllegalStateException("Gradient/Value error: angle is NaN!");
        }
        return acos;
    }

    public static boolean testValueAndGradient(Optimizable.ByGradientValue byGradientValue) {
        double[] dArr = new double[byGradientValue.getNumParameters()];
        MatrixOps.setAll(dArr, 0.0d);
        byGradientValue.setParameters(dArr);
        testValueAndGradientCurrentParameters(byGradientValue);
        MatrixOps.setAll(dArr, 0.0d);
        double[] dArr2 = new double[byGradientValue.getNumParameters()];
        byGradientValue.getValueGradient(dArr2);
        logger.info("Gradient two-Norm = " + MatrixOps.twoNorm(dArr2));
        logger.info("  max parameter change = " + (MatrixOps.infinityNorm(dArr2) * (-0.001d)));
        MatrixOps.timesEquals(dArr2, -1.0E-4d);
        MatrixOps.plusEquals(dArr, dArr2);
        byGradientValue.setParameters(dArr);
        testValueAndGradientCurrentParameters(byGradientValue);
        return true;
    }

    public static boolean testValueAndGradientRandomParameters(Optimizable.ByGradientValue byGradientValue, Random random) {
        double[] dArr = new double[byGradientValue.getNumParameters()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = random.nextDouble();
            if (random.nextBoolean()) {
                dArr[i] = -dArr[i];
            }
        }
        byGradientValue.setParameters(dArr);
        testValueAndGradientCurrentParameters(byGradientValue);
        return true;
    }

    public void testTestValueAndGradient() {
        testValueAndGradient(new SimplePoly());
        try {
            testValueAndGradient(new WrongSimplePoly());
            fail("WrongSimplyPoly should fail testMaxmiziable!");
        } catch (Exception e) {
        }
    }

    public static Test suite() {
        return new TestSuite((Class<? extends TestCase>) TestOptimizable.class);
    }

    @Override // junit.framework.TestCase
    protected void setUp() {
    }

    public static void main(String[] strArr) {
        TestRunner.run(suite());
    }

    static {
        $assertionsDisabled = !TestOptimizable.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(TestOptimizable.class.getName());
        numComponents = -1;
    }
}
