package cc.mallet.classify;

import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletProgressMessageLogger;
import cc.mallet.util.Maths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/classify/MaxEntOptimizableByGE.class */
public class MaxEntOptimizableByGE implements Optimizable.ByGradientValue {
    private static Logger progressLogger;
    private int defaultFeatureIndex;
    private double cachedValue;
    private double gaussianPriorVariance;
    private double[] cachedGradient;
    private double[] parameters;
    private InstanceList trainingList;
    private MaxEnt classifier;
    private HashMap<Integer, double[]> refEx;
    private HashMap<Integer, Integer> mapping;
    static final /* synthetic */ boolean $assertionsDisabled;
    private boolean cacheStale = true;
    private boolean useValues = false;
    private double temperature = 1.0d;
    private double objWeight = 1.0d;

    public MaxEntOptimizableByGE(InstanceList instanceList, HashMap<Integer, double[]> hashMap, MaxEnt maxEnt) {
        this.trainingList = instanceList;
        int size = instanceList.getDataAlphabet().size();
        this.defaultFeatureIndex = size;
        int size2 = instanceList.getTargetAlphabet().size();
        this.parameters = new double[(size + 1) * size2];
        this.cachedGradient = new double[(size + 1) * size2];
        this.cachedValue = 0.0d;
        if (maxEnt != null) {
            this.classifier = maxEnt;
        } else {
            this.classifier = new MaxEnt(instanceList.getPipe(), this.parameters);
        }
        this.refEx = hashMap;
    }

    public void setGaussianPriorVariance(double d) {
        this.gaussianPriorVariance = d;
    }

    public void setTemperature(double d) {
        this.temperature = d;
    }

    public void setWeight(double d) {
        this.objWeight = d;
    }

    public MaxEnt getClassifier() {
        return this.classifier;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        if (!this.cacheStale) {
            return this.cachedValue;
        }
        if (this.objWeight == 0.0d) {
            return 0.0d;
        }
        Arrays.fill(this.cachedGradient, 0.0d);
        int size = this.refEx.size();
        int size2 = this.trainingList.getDataAlphabet().size() + 1;
        int size3 = this.trainingList.getTargetAlphabet().size();
        double d = this.objWeight;
        if (this.mapping == null) {
            setMapping();
        }
        double[][] dArr = new double[size][size3];
        double[][] dArr2 = new double[size][size3];
        double[][] dArr3 = new double[size][size3];
        double[] dArr4 = new double[size];
        double[][] dArr5 = new double[this.trainingList.size()][size3];
        Iterator<Instance> it = this.trainingList.iterator();
        int i = 0;
        while (it.hasNext()) {
            Instance next = it.next();
            double instanceWeight = this.trainingList.getInstanceWeight(next);
            if (next.getTarget() != null) {
                i++;
            } else {
                FeatureVector featureVector = (FeatureVector) next.getData();
                this.classifier.getClassificationScoresWithTemperature(next, this.temperature, dArr5[i]);
                for (int i2 = 0; i2 < featureVector.numLocations(); i2++) {
                    int indexAtLocation = featureVector.indexAtLocation(i2);
                    if (this.refEx.containsKey(Integer.valueOf(indexAtLocation))) {
                        int intValue = this.mapping.get(Integer.valueOf(indexAtLocation)).intValue();
                        double valueAtLocation = !this.useValues ? 1.0d : featureVector.valueAtLocation(i2);
                        dArr4[intValue] = dArr4[intValue] + valueAtLocation;
                        for (int i3 = 0; i3 < size3; i3++) {
                            double[] dArr6 = dArr[intValue];
                            int i4 = i3;
                            dArr6[i4] = dArr6[i4] + (dArr5[i][i3] * valueAtLocation * instanceWeight);
                        }
                    }
                }
                if (this.refEx.containsKey(Integer.valueOf(this.defaultFeatureIndex))) {
                    int intValue2 = this.mapping.get(Integer.valueOf(this.defaultFeatureIndex)).intValue();
                    dArr4[intValue2] = dArr4[intValue2] + 1.0d;
                    for (int i5 = 0; i5 < size3; i5++) {
                        double[] dArr7 = dArr[intValue2];
                        int i6 = i5;
                        dArr7[i6] = dArr7[i6] + (dArr5[i][i5] * instanceWeight);
                    }
                }
                i++;
            }
        }
        Iterator<Integer> it2 = this.refEx.keySet().iterator();
        while (it2.hasNext()) {
            int intValue3 = it2.next().intValue();
            int intValue4 = this.mapping.get(Integer.valueOf(intValue3)).intValue();
            if (dArr4[intValue4] > 0.0d) {
                for (int i7 = 0; i7 < size3; i7++) {
                    dArr2[intValue4][i7] = dArr[intValue4][i7] / dArr4[intValue4];
                    dArr3[intValue4][i7] = this.refEx.get(Integer.valueOf(intValue3))[i7] / dArr[intValue4][i7];
                }
                if (!$assertionsDisabled && !Maths.almostEquals(MatrixOps.sum(dArr2[intValue4]), 1.0d)) {
                    throw new AssertionError();
                }
            }
        }
        Iterator<Instance> it3 = this.trainingList.iterator();
        int i8 = 0;
        while (it3.hasNext()) {
            Instance next2 = it3.next();
            if (next2.getTarget() != null) {
                i8++;
            } else {
                double instanceWeight2 = this.trainingList.getInstanceWeight(next2);
                FeatureVector featureVector2 = (FeatureVector) next2.getData();
                int i9 = 0;
                while (i9 < featureVector2.numLocations() + 1) {
                    int indexAtLocation2 = i9 == featureVector2.numLocations() ? this.defaultFeatureIndex : featureVector2.indexAtLocation(i9);
                    if (this.refEx.containsKey(Integer.valueOf(indexAtLocation2))) {
                        int intValue5 = this.mapping.get(Integer.valueOf(indexAtLocation2)).intValue();
                        if (MatrixOps.sum(dArr2[intValue5]) != 0.0d) {
                            double valueAtLocation2 = (indexAtLocation2 == this.defaultFeatureIndex || !this.useValues) ? 1.0d : featureVector2.valueAtLocation(i9);
                            double d2 = 0.0d;
                            for (int i10 = 0; i10 < size3; i10++) {
                                d2 += dArr3[intValue5][i10] * dArr5[i8][i10];
                            }
                            for (int i11 = 0; i11 < size3; i11++) {
                                if (dArr5[i8][i11] != 0.0d) {
                                    if (!$assertionsDisabled && Double.isInfinite(dArr5[i8][i11])) {
                                        throw new AssertionError();
                                    }
                                    double d3 = d * instanceWeight2 * this.temperature * valueAtLocation2 * dArr5[i8][i11] * (dArr3[intValue5][i11] - d2);
                                    MatrixOps.rowPlusEquals(this.cachedGradient, size2, i11, featureVector2, d3);
                                    double[] dArr8 = this.cachedGradient;
                                    int i12 = (size2 * i11) + this.defaultFeatureIndex;
                                    dArr8[i12] = dArr8[i12] + d3;
                                }
                            }
                        } else {
                            continue;
                        }
                    }
                    i9++;
                }
                i8++;
            }
        }
        double d4 = 0.0d;
        Iterator<Integer> it4 = this.refEx.keySet().iterator();
        while (it4.hasNext()) {
            int intValue6 = it4.next().intValue();
            int intValue7 = this.mapping.get(Integer.valueOf(intValue6)).intValue();
            if (MatrixOps.sum(dArr2[intValue7]) != 0.0d) {
                double d5 = 0.0d;
                for (int i13 = 0; i13 < size3; i13++) {
                    d5 -= (d * this.refEx.get(Integer.valueOf(intValue6))[i13]) * Math.log(dArr2[intValue7][i13]);
                }
                for (int i14 = 0; i14 < size3; i14++) {
                    d5 += d * this.refEx.get(Integer.valueOf(intValue6))[i14] * Math.log(this.refEx.get(Integer.valueOf(intValue6))[i14]);
                }
                d4 -= d5;
            }
        }
        this.cachedValue = d4;
        this.cacheStale = false;
        progressLogger.info("Value (GE=" + d4 + " Gaussian prior= " + getRegularization() + ") = " + this.cachedValue);
        return d4;
    }

    public double getRegularization() {
        double log = !Double.isInfinite(this.gaussianPriorVariance) ? Math.log(this.gaussianPriorVariance * Math.sqrt(6.283185307179586d)) : 0.0d;
        for (int i = 0; i < this.parameters.length; i++) {
            double d = this.parameters[i];
            log -= (d * d) / (2.0d * this.gaussianPriorVariance);
            double[] dArr = this.cachedGradient;
            int i2 = i;
            dArr[i2] = dArr[i2] - (d / this.gaussianPriorVariance);
        }
        this.cachedValue += log;
        return log;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        if (this.cacheStale) {
            getValue();
        }
        if (!$assertionsDisabled && dArr.length != this.cachedGradient.length) {
            throw new AssertionError();
        }
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.cachedGradient[i];
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.parameters.length;
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.parameters[i];
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        if (!$assertionsDisabled && dArr.length != this.parameters.length) {
            throw new AssertionError();
        }
        System.arraycopy(this.parameters, 0, dArr, 0, dArr.length);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.cacheStale = true;
        this.parameters[i] = d;
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        if (!$assertionsDisabled && dArr.length != this.parameters.length) {
            throw new AssertionError();
        }
        this.cacheStale = true;
        System.arraycopy(dArr, 0, this.parameters, 0, this.parameters.length);
    }

    private void setMapping() {
        int i = 0;
        this.mapping = new HashMap<>();
        Iterator<Integer> it = this.refEx.keySet().iterator();
        while (it.hasNext()) {
            this.mapping.put(Integer.valueOf(it.next().intValue()), Integer.valueOf(i));
            i++;
        }
    }

    static {
        $assertionsDisabled = !MaxEntOptimizableByGE.class.desiredAssertionStatus();
        progressLogger = MalletProgressMessageLogger.getLogger(MaxEntOptimizableByLabelLikelihood.class.getName() + "-pl");
    }
}
