package cc.mallet.grmm.inference;

import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.UndirectedGrid;
import cc.mallet.grmm.types.UndirectedModel;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.types.Dirichlet;
import cc.mallet.util.Randoms;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:cc/mallet/grmm/inference/RandomGraphs.class */
public class RandomGraphs {

    /* loaded from: input_file:cc/mallet/grmm/inference/RandomGraphs$FactorGenerator.class */
    public interface FactorGenerator {
        Factor nextFactor(VarSet varSet);
    }

    /* loaded from: input_file:cc/mallet/grmm/inference/RandomGraphs$UniformFactorGenerator.class */
    public static class UniformFactorGenerator implements FactorGenerator {
        @Override // cc.mallet.grmm.inference.RandomGraphs.FactorGenerator
        public Factor nextFactor(VarSet varSet) {
            double[] dArr = new double[varSet.weight()];
            Arrays.fill(dArr, 1.0d);
            return new TableFactor(varSet, dArr);
        }
    }

    public static double[] generateAttractivePotentialValues(Random random, double d) {
        double abs = Math.abs(random.nextGaussian()) * d;
        double exp = Math.exp(abs);
        double exp2 = Math.exp(-abs);
        return new double[]{exp, exp2, exp2, exp};
    }

    public static double[] generateMixedPotentialValues(Random random, double d) {
        double nextGaussian = random.nextGaussian() * d;
        double exp = Math.exp(nextGaussian);
        double exp2 = Math.exp(-nextGaussian);
        return new double[]{exp, exp2, exp2, exp};
    }

    public static UndirectedGrid randomAttractiveGrid(int i, double d, Random random) {
        UndirectedGrid undirectedGrid = new UndirectedGrid(i, i, 2);
        for (int i2 = 0; i2 < i - 1; i2++) {
            for (int i3 = 0; i3 < i - 1; i3++) {
                Variable variable = undirectedGrid.get(i2, i3);
                Variable variable2 = undirectedGrid.get(i2 + 1, i3);
                Variable variable3 = undirectedGrid.get(i2, i3 + 1);
                undirectedGrid.addFactor(variable, variable2, generateAttractivePotentialValues(random, d));
                undirectedGrid.addFactor(variable, variable3, generateAttractivePotentialValues(random, d));
            }
        }
        for (int i4 = 0; i4 < i - 1; i4++) {
            undirectedGrid.addFactor(undirectedGrid.get(i4, i - 1), undirectedGrid.get(i4 + 1, i - 1), generateAttractivePotentialValues(random, d));
        }
        for (int i5 = 0; i5 < i - 1; i5++) {
            undirectedGrid.addFactor(undirectedGrid.get(i - 1, i5), undirectedGrid.get(i - 1, i5 + 1), generateAttractivePotentialValues(random, d));
        }
        for (int i6 = 0; i6 < i; i6++) {
            for (int i7 = 0; i7 < i; i7++) {
                double nextGaussian = random.nextGaussian() * 0.0625d;
                undirectedGrid.addFactor(new TableFactor(undirectedGrid.get(i6, i7), new double[]{Math.exp(nextGaussian), Math.exp(-nextGaussian)}));
            }
        }
        return undirectedGrid;
    }

    public static UndirectedGrid randomRepulsiveGrid(int i, double d, Random random) {
        return randomAttractiveGrid(i, -d, random);
    }

    public static UndirectedGrid randomFrustratedGrid(int i, double d, Random random) {
        UndirectedGrid undirectedGrid = new UndirectedGrid(i, i, 2);
        for (int i2 = 0; i2 < i - 1; i2++) {
            for (int i3 = 0; i3 < i - 1; i3++) {
                Variable variable = undirectedGrid.get(i2, i3);
                Variable variable2 = undirectedGrid.get(i2 + 1, i3);
                Variable variable3 = undirectedGrid.get(i2, i3 + 1);
                undirectedGrid.addFactor(variable, variable2, generateMixedPotentialValues(random, d));
                undirectedGrid.addFactor(variable, variable3, generateMixedPotentialValues(random, d));
            }
        }
        for (int i4 = 0; i4 < i - 1; i4++) {
            undirectedGrid.addFactor(undirectedGrid.get(i4, i - 1), undirectedGrid.get(i4 + 1, i - 1), generateMixedPotentialValues(random, d));
        }
        for (int i5 = 0; i5 < i - 1; i5++) {
            undirectedGrid.addFactor(undirectedGrid.get(i - 1, i5), undirectedGrid.get(i - 1, i5 + 1), generateMixedPotentialValues(random, d));
        }
        addRandomNodePotentials(random, undirectedGrid);
        return undirectedGrid;
    }

    public static UndirectedModel randomFrustratedTree(int i, int i2, double d, Random random) {
        UndirectedModel undirectedModel = new UndirectedModel();
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Variable(2));
        while (undirectedModel.numVariables() < i) {
            Variable variable = (Variable) removeRandomElement(arrayList, random);
            int nextInt = random.nextInt(i2) + 1;
            for (int i3 = 0; i3 < nextInt; i3++) {
                Variable variable2 = new Variable(2);
                undirectedModel.addFactor(variable, variable2, generateMixedPotentialValues(random, d));
                arrayList.add(variable2);
            }
        }
        addRandomNodePotentials(random, undirectedModel);
        return undirectedModel;
    }

    private static Object removeRandomElement(List list, Random random) {
        int nextInt = random.nextInt(list.size());
        Object obj = list.get(nextInt);
        list.remove(nextInt);
        return obj;
    }

    public static void addRandomNodePotentials(Random random, FactorGraph factorGraph) {
        int numVariables = factorGraph.numVariables();
        for (int i = 0; i < numVariables; i++) {
            factorGraph.addFactor(randomNodePotential(random, factorGraph.get(i)));
        }
    }

    public static TableFactor randomNodePotential(Random random, Variable variable) {
        double nextGaussian = random.nextGaussian();
        return new TableFactor(variable, new double[]{Math.exp(nextGaussian), Math.exp(-nextGaussian)});
    }

    public static FactorGraph createUniformChain(int i) {
        Variable[] variableArr = new Variable[i];
        for (int i2 = 0; i2 < i; i2++) {
            variableArr[i2] = new Variable(2);
        }
        UndirectedModel undirectedModel = new UndirectedModel(variableArr);
        for (int i3 = 0; i3 < i - 1; i3++) {
            double[] dArr = new double[4];
            Arrays.fill(dArr, 1.0d);
            undirectedModel.addFactor(variableArr[i3], variableArr[i3 + 1], dArr);
        }
        return undirectedModel;
    }

    public static FactorGraph createUniformGrid(int i) {
        return createGrid(new UniformFactorGenerator(), i);
    }

    public static FactorGraph createRandomChain(Randoms randoms, int i) {
        Variable[] variableArr = new Variable[i];
        for (int i2 = 0; i2 < i; i2++) {
            variableArr[i2] = new Variable(2);
        }
        Dirichlet dirichlet = new Dirichlet(new double[]{1.0d, 1.0d, 1.0d, 1.0d});
        FactorGraph factorGraph = new FactorGraph(variableArr);
        for (int i3 = 0; i3 < i - 1; i3++) {
            factorGraph.addFactor(variableArr[i3], variableArr[i3 + 1], dirichlet.randomMultinomial(randoms).getValues());
        }
        return factorGraph;
    }

    public static UndirectedModel createGrid(FactorGenerator factorGenerator, int i) {
        UndirectedGrid undirectedGrid = new UndirectedGrid(i, i, 2);
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i - 1; i3++) {
                undirectedGrid.addFactor(factorGenerator.nextFactor(new HashVarSet(new Variable[]{undirectedGrid.get(i2, i3), undirectedGrid.get(i2, i3 + 1)})));
            }
        }
        for (int i4 = 0; i4 < i - 1; i4++) {
            for (int i5 = 0; i5 < i; i5++) {
                undirectedGrid.addFactor(factorGenerator.nextFactor(new HashVarSet(new Variable[]{undirectedGrid.get(i4, i5), undirectedGrid.get(i4 + 1, i5)})));
            }
        }
        return undirectedGrid;
    }

    public static FactorGraph createGridWithObs(FactorGenerator factorGenerator, FactorGenerator factorGenerator2, int i) {
        ArrayList arrayList = new ArrayList(2 * i * i);
        Variable[][] variableArr = new Variable[i][i];
        Variable[][] variableArr2 = new Variable[i][i];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                variableArr[i2][i3] = new Variable(2);
                variableArr[i2][i3].setLabel("GRID[" + i2 + "][" + i3 + "]");
                variableArr2[i2][i3] = new Variable(2);
                variableArr2[i2][i3].setLabel("OBS[" + i2 + "][" + i3 + "]");
                arrayList.add(variableArr[i2][i3]);
                arrayList.add(variableArr2[i2][i3]);
            }
        }
        FactorGraph factorGraph = new FactorGraph((Variable[]) arrayList.toArray(new Variable[0]));
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 = 0; i5 < i; i5++) {
                Variable variable = variableArr[i4][i5];
                if (i4 < i - 1) {
                    factorGraph.addFactor(factorGenerator.nextFactor(new HashVarSet(new Variable[]{variable, variableArr[i4 + 1][i5]})));
                }
                if (i5 < i - 1) {
                    factorGraph.addFactor(factorGenerator.nextFactor(new HashVarSet(new Variable[]{variable, variableArr[i4][i5 + 1]})));
                }
            }
        }
        for (int i6 = 0; i6 < i; i6++) {
            for (int i7 = 0; i7 < i; i7++) {
                factorGraph.addFactor(factorGenerator2.nextFactor(new HashVarSet(new Variable[]{variableArr[i6][i7], variableArr2[i6][i7]})));
            }
        }
        return factorGraph;
    }
}
