package cc.mallet.grmm.test;

import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.DiscreteFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.types.tests.TestSerializable;
import cc.mallet.util.ArrayUtils;
import cc.mallet.util.Maths;
import cc.mallet.util.Randoms;
import java.io.IOException;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:cc/mallet/grmm/test/TestTableFactor.class */
public class TestTableFactor extends TestCase {
    public TestTableFactor(String str) {
        super(str);
    }

    public void testMultiplyMultiplyBy() {
        Variable variable = new Variable(4);
        TableFactor tableFactor = new TableFactor(variable, new double[]{1.0d, 2.0d, 3.0d, 4.0d});
        TableFactor tableFactor2 = new TableFactor(variable, new double[]{2.0d, 4.0d, 6.0d, 8.0d});
        TableFactor tableFactor3 = new TableFactor(variable, new double[]{0.5d, 0.5d, 0.5d, 0.5d});
        Factor multiply = tableFactor2.multiply(tableFactor3);
        tableFactor2.multiplyBy(tableFactor3);
        assertTrue(tableFactor.almostEquals(tableFactor2));
        assertTrue(tableFactor.almostEquals(multiply));
    }

    public void testTblTblPlusEquals() {
        Variable variable = new Variable(4);
        TableFactor tableFactor = new TableFactor(variable, new double[]{2.25d, 4.5d, 6.75d, 9.0d});
        TableFactor tableFactor2 = new TableFactor(variable, new double[]{2.0d, 4.0d, 6.0d, 8.0d});
        tableFactor2.plusEquals(new TableFactor(variable, new double[]{0.25d, 0.5d, 0.75d, 1.0d}));
        assertTrue(tableFactor.almostEquals(tableFactor2));
    }

    public void testEntropy() {
        Variable variable = new Variable(2);
        assertEquals(0.61086d, new TableFactor(variable, new double[]{0.3d, 0.7d}).entropy(), 0.001d);
        assertEquals(0.61086d, LogTableFactor.makeFromValues(variable, new double[]{0.3d, 0.7d}).entropy(), 0.001d);
    }

    public void ignoreTestSerialization() throws IOException, ClassNotFoundException {
        Variable variable = new Variable(2);
        TableFactor tableFactor = new TableFactor(new Variable[]{variable, new Variable(3)}, new double[]{2.0d, 4.0d, 6.0d, 3.0d, 5.0d, 7.0d});
        TableFactor tableFactor2 = (TableFactor) TestSerializable.cloneViaSerialization(tableFactor);
        assertTrue(!tableFactor.varSet().contains(tableFactor2.varSet()));
        comparePotentialValues(tableFactor, tableFactor2);
        comparePotentialValues((TableFactor) tableFactor.marginalize(variable), (TableFactor) tableFactor2.marginalize(tableFactor2.findVariable(variable.getLabel())));
    }

    private void comparePotentialValues(TableFactor tableFactor, TableFactor tableFactor2) {
        AssignmentIterator assignmentIterator = tableFactor.assignmentIterator();
        AssignmentIterator assignmentIterator2 = tableFactor2.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            assertTrue(tableFactor.value(assignmentIterator) == tableFactor.value(assignmentIterator2));
            assignmentIterator.advance();
            assignmentIterator2.advance();
        }
    }

    public void testSample() {
        double[] dArr = {1.0d, 3.0d, 2.0d};
        TableFactor tableFactor = new TableFactor(new Variable(3), dArr);
        int[] iArr = new int[100];
        Randoms randoms = new Randoms(32423);
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = tableFactor.sampleLocation(randoms);
        }
        double sum = MatrixOps.sum(dArr);
        double[] dArr2 = new double[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr2[i2] = ArrayUtils.count(iArr, i2);
        }
        MatrixOps.print(dArr2);
        for (int i3 = 0; i3 < dArr.length; i3++) {
            assertEquals(dArr[i3] / sum, dArr2[i3] / iArr.length, 0.1d);
        }
    }

    public void testMarginalize() {
        Variable[] variableArr = {new Variable(2), new Variable(2)};
        TableFactor tableFactor = (TableFactor) new TableFactor(variableArr, new double[]{1.0d, 2.0d, 3.0d, 4.0d}).marginalize(variableArr[1]);
        assertEquals("FAILURE: Potential has too many vars.\n  " + tableFactor, 1, tableFactor.varSet().size());
        assertTrue("FAILURE: Potential does not contain " + variableArr[1] + ":\n  " + tableFactor, tableFactor.varSet().contains(variableArr[1]));
        double[] dArr = {4.0d, 6.0d};
        assertTrue("FAILURE: Potential has incorrect values.  Expected " + ArrayUtils.toString(dArr) + "was " + tableFactor, Maths.almostEquals(tableFactor.toValueArray(), dArr, 1.0E-5d));
    }

    public void testMarginalizeOut() {
        Variable[] variableArr = {new Variable(2), new Variable(2)};
        TableFactor tableFactor = (TableFactor) new TableFactor(variableArr, new double[]{1.0d, 2.0d, 3.0d, 4.0d}).marginalizeOut(variableArr[0]);
        assertEquals("FAILURE: Potential has too many vars.\n  " + tableFactor, 1, tableFactor.varSet().size());
        assertTrue("FAILURE: Potential does not contain " + variableArr[1] + ":\n  " + tableFactor, tableFactor.varSet().contains(variableArr[1]));
        double[] dArr = {4.0d, 6.0d};
        assertTrue("FAILURE: Potential has incorrect values.  Expected " + ArrayUtils.toString(dArr) + "was " + tableFactor, Maths.almostEquals(tableFactor.toValueArray(), dArr, 1.0E-5d));
    }

    public void testOneVarSlice() {
        Variable variable = new Variable(2);
        Variable variable2 = new Variable(2);
        comparePotentialValues((TableFactor) new TableFactor(new Variable[]{variable, variable2}, new double[]{0.0d, 1.3862943611198906d, 0.6931471805599453d, 1.791759469228055d}).slice(new Assignment(variable, 0)), new TableFactor(variable2, new double[]{1.0d, 4.0d}));
    }

    public void testTwoVarSlice() {
        Variable variable = new Variable(2);
        Variable variable2 = new Variable(2);
        Variable variable3 = new Variable(2);
        comparePotentialValues((TableFactor) new TableFactor(new Variable[]{variable, variable2, variable3}, new double[]{0.0d, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d}).slice(new Assignment(variable3, 0)), new TableFactor(new Variable[]{variable, variable2}, new double[]{0.0d, 2.0d, 4.0d, 6.0d}));
    }

    public void testMultiVarSlice() {
        Variable variable = new Variable(2);
        Variable variable2 = new Variable(2);
        Variable variable3 = new Variable(2);
        Variable variable4 = new Variable(2);
        TableFactor tableFactor = new TableFactor(new Variable[]{variable, variable2, variable3, variable4}, new double[]{0.0d, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d, 10.0d, 11.0d, 12.0d, 13.0d, 14.0d, 15.0d});
        System.out.println(tableFactor);
        TableFactor tableFactor2 = (TableFactor) tableFactor.slice(new Assignment(variable4, 0));
        System.out.println(new TableFactor(tableFactor2));
        comparePotentialValues(tableFactor2, new TableFactor(new Variable[]{variable, variable2, variable3}, new double[]{0.0d, 2.0d, 4.0d, 6.0d, 8.0d, 10.0d, 12.0d, 14.0d}));
    }

    public void testLogMultiVarSlice() {
        Variable variable = new Variable(2);
        Variable variable2 = new Variable(2);
        Variable variable3 = new Variable(2);
        Variable variable4 = new Variable(2);
        LogTableFactor makeFromValues = LogTableFactor.makeFromValues(new Variable[]{variable, variable2, variable3, variable4}, new double[]{0.0d, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d, 10.0d, 11.0d, 12.0d, 13.0d, 14.0d, 15.0d});
        System.out.println(makeFromValues.dumpToString());
        LogTableFactor logTableFactor = (LogTableFactor) makeFromValues.slice(new Assignment(variable4, 0));
        LogTableFactor makeFromValues2 = LogTableFactor.makeFromValues(new Variable[]{variable, variable2, variable3}, new double[]{0.0d, 2.0d, 4.0d, 6.0d, 8.0d, 10.0d, 12.0d, 14.0d});
        assertTrue("Test failed. Expected: " + makeFromValues2.dumpToString() + "\nActual: " + logTableFactor.dumpToString(), makeFromValues2.almostEquals(logTableFactor));
    }

    public void testSparseMultiply() {
        Variable[] variableArr = {new Variable(2), new Variable(2)};
        int[] iArr = {2, 2};
        int[] iArr2 = {0, 1, 3};
        TableFactor tableFactor = new TableFactor(variableArr);
        tableFactor.setValues(new SparseMatrixn(iArr, iArr2, new double[]{2.0d, 4.0d, 8.0d}));
        TableFactor tableFactor2 = new TableFactor(variableArr);
        tableFactor2.setValues(new SparseMatrixn(iArr, new int[]{0, 3}, new double[]{0.5d, 0.5d}));
        TableFactor tableFactor3 = new TableFactor(variableArr);
        tableFactor3.setValues(new SparseMatrixn(iArr, iArr2, new double[]{1.0d, 0.0d, 4.0d}));
        Factor multiply = tableFactor.multiply(tableFactor2);
        assertTrue("Tast failed! Expected: " + tableFactor3 + " Actual: " + multiply, tableFactor3.almostEquals(multiply));
    }

    public void testSparseDivide() {
        Variable[] variableArr = {new Variable(2), new Variable(2)};
        int[] iArr = {2, 2};
        int[] iArr2 = {0, 1, 3};
        TableFactor tableFactor = new TableFactor(variableArr);
        tableFactor.setValues(new SparseMatrixn(iArr, iArr2, new double[]{2.0d, 4.0d, 8.0d}));
        TableFactor tableFactor2 = new TableFactor(variableArr);
        tableFactor2.setValues(new SparseMatrixn(iArr, new int[]{0, 3}, new double[]{0.5d, 0.5d}));
        TableFactor tableFactor3 = new TableFactor(variableArr);
        tableFactor3.setValues(new SparseMatrixn(iArr, iArr2, new double[]{4.0d, 0.0d, 16.0d}));
        tableFactor.divideBy(tableFactor2);
        assertTrue("Tast failed! Expected: " + tableFactor3 + " Actual: " + tableFactor, tableFactor3.almostEquals(tableFactor));
    }

    public void testSparseMarginalize() {
        Variable[] variableArr = {new Variable(2), new Variable(2)};
        TableFactor tableFactor = new TableFactor(variableArr);
        tableFactor.setValues(new SparseMatrixn(new int[]{2, 2}, new int[]{0, 1, 3}, new double[]{2.0d, 4.0d, 8.0d}));
        TableFactor tableFactor2 = new TableFactor(variableArr[0], new double[]{6.0d, 8.0d});
        Factor marginalize = tableFactor.marginalize(variableArr[0]);
        assertTrue("Tast failed! Expected: " + tableFactor2 + " Actual: " + marginalize + " Orig: " + tableFactor, tableFactor2.almostEquals(marginalize));
    }

    public void testSparseExtractMax() {
        Variable[] variableArr = {new Variable(2), new Variable(2)};
        TableFactor tableFactor = new TableFactor(variableArr);
        tableFactor.setValues(new SparseMatrixn(new int[]{2, 2}, new int[]{0, 1, 3}, new double[]{2.0d, 4.0d, 8.0d}));
        TableFactor tableFactor2 = new TableFactor(variableArr[0], new double[]{4.0d, 8.0d});
        Factor extractMax = tableFactor.extractMax(variableArr[0]);
        assertTrue("Tast failed! Expected: " + tableFactor2 + " Actual: " + extractMax + "Orig: " + tableFactor, tableFactor2.almostEquals(extractMax));
    }

    public void testLogSample() {
        assertEquals(1, LogTableFactor.makeFromLogValues(new Variable(2), new double[]{-30.0d, 0.0d}).sampleLocation(new Randoms(43)));
    }

    public void testExp() {
        Variable variable = new Variable(4);
        TableFactor tableFactor = new TableFactor(variable, new double[]{4.0d, 16.0d, 36.0d, 64.0d});
        TableFactor tableFactor2 = new TableFactor(variable, new double[]{2.0d, 4.0d, 6.0d, 8.0d});
        tableFactor2.exponentiate(2.0d);
        assertTrue("Error: expected " + tableFactor.dumpToString() + " but was " + tableFactor2.dumpToString(), tableFactor2.almostEquals(tableFactor));
    }

    public void testPlusEquals() {
        Variable variable = new Variable(4);
        TableFactor tableFactor = new TableFactor(variable, new double[]{2.0d, 4.0d, 6.0d, 8.0d});
        tableFactor.plusEquals(0.1d);
        TableFactor tableFactor2 = new TableFactor(variable, new double[]{2.1d, 4.1d, 6.1d, 8.1d});
        assertTrue("Error: expected " + tableFactor2.dumpToString() + " but was " + tableFactor.dumpToString(), tableFactor.almostEquals(tableFactor2));
    }

    public void testMultiplyAll() {
        for (int i = 0; i < 100; i++) {
            Variable[] variableArr = {new Variable(2), new Variable(2)};
            TableFactor tableFactor = new TableFactor(variableArr, new double[]{1.0d, 2.0d, 3.0d, 4.0d});
            DiscreteFactor multiplyAll = TableFactor.multiplyAll(new Factor[]{new TableFactor(variableArr, new double[]{2.0d, 4.0d, 6.0d, 8.0d}), new TableFactor(variableArr, new double[]{0.5d, 0.5d, 0.5d, 0.5d})});
            VarSet varSet = multiplyAll.varSet();
            for (int i2 = 0; i2 < variableArr.length; i2++) {
                assertEquals(variableArr[i2], varSet.get(i2));
            }
            assertTrue(tableFactor.almostEquals(multiplyAll));
        }
    }

    public void testExpandToContain() {
        Variable variable = new Variable(2);
        Variable variable2 = new Variable(2);
        Variable variable3 = new Variable(2);
        TableFactor tableFactor = new TableFactor(new Variable[]{variable, variable2}, new double[]{2.0d, 4.0d, 6.0d, 8.0d});
        tableFactor.multiplyBy(new TableFactor(variable3, new double[]{0.5d, 0.5d}));
        TableFactor tableFactor2 = new TableFactor(new Variable[]{variable, variable2, variable3}, new double[]{1.0d, 1.0d, 2.0d, 2.0d, 3.0d, 3.0d, 4.0d, 4.0d});
        System.out.println(tableFactor.dumpToString());
        System.out.println(tableFactor2.dumpToString());
        assertTrue(tableFactor2.almostEquals(tableFactor));
    }

    public static Test suite() {
        return new TestSuite((Class<? extends TestCase>) TestTableFactor.class);
    }

    public static void main(String[] strArr) throws Throwable {
        TestSuite testSuite;
        if (strArr.length > 0) {
            testSuite = new TestSuite();
            for (String str : strArr) {
                testSuite.addTest(new TestTableFactor(str));
            }
        } else {
            testSuite = (TestSuite) suite();
        }
        TestRunner.run(testSuite);
    }
}
