package cc.mallet.fst;

import cc.mallet.fst.Transducer;
import cc.mallet.types.ArraySequence;
import cc.mallet.types.Sequence;
import cc.mallet.types.SequencePairAlignment;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.search.AStar;
import cc.mallet.util.search.AStarNode;
import cc.mallet.util.search.AStarState;
import cc.mallet.util.search.SearchState;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.log4j.spi.Configurator;

/* loaded from: input_file:cc/mallet/fst/MaxLatticeDefault.class */
public class MaxLatticeDefault implements MaxLattice {
    private static Logger logger;
    private Transducer t;
    private Sequence<Object> input;
    private Sequence<Object> providedOutput;
    private int latticeLength;
    private ViterbiNode[][] lattice;
    private WeightCache first;
    private WeightCache last;
    private WeightCache[] caches;
    private int numCaches;
    private int maxCaches;
    private List<SequencePairAlignment<Object, ViterbiNode>> viterbiNodeAlignmentCache;
    private List<SequencePairAlignment<Object, Transducer.State>> stateAlignmentCache;
    private List<SequencePairAlignment<Object, Object>> outputAlignmentCache;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:cc/mallet/fst/MaxLatticeDefault$Factory.class */
    public static class Factory extends MaxLatticeFactory implements Serializable {
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 1;

        @Override // cc.mallet.fst.MaxLatticeFactory
        public MaxLattice newMaxLattice(Transducer transducer, Sequence sequence, Sequence sequence2) {
            return new MaxLatticeDefault(transducer, sequence, sequence2);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.writeInt(1);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.readInt();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/mallet/fst/MaxLatticeDefault$ViterbiNode.class */
    public class ViterbiNode implements AStarState {
        int inputPosition;
        Transducer.State state;
        Object output;
        double delta = Double.NEGATIVE_INFINITY;
        ViterbiNode maxWeightPredecessor = null;

        /* loaded from: input_file:cc/mallet/fst/MaxLatticeDefault$ViterbiNode$PreviousStateIterator.class */
        private class PreviousStateIterator extends SearchState.NextStateIterator {
            private int prev;
            private boolean found;
            private double weight;
            private double[] weights;

            private PreviousStateIterator() {
                this.prev = 0;
                if (ViterbiNode.this.inputPosition > 0) {
                    int index = ViterbiNode.this.state.getIndex();
                    this.weights = new double[MaxLatticeDefault.this.t.numStates()];
                    WeightCache cache = MaxLatticeDefault.this.getCache(ViterbiNode.this.inputPosition - 1);
                    for (int i = 0; i < MaxLatticeDefault.this.t.numStates(); i++) {
                        this.weights[i] = cache.weight[i][index];
                    }
                }
            }

            private void lookAhead() {
                if (this.weights == null || this.found) {
                    return;
                }
                while (this.prev < MaxLatticeDefault.this.t.numStates()) {
                    if (this.weights[this.prev] > Double.NEGATIVE_INFINITY) {
                        this.found = true;
                        return;
                    }
                    this.prev++;
                }
            }

            @Override // cc.mallet.util.search.SearchState.NextStateIterator, java.util.Iterator
            public boolean hasNext() {
                lookAhead();
                return this.weights != null && this.prev < MaxLatticeDefault.this.t.numStates();
            }

            @Override // cc.mallet.util.search.SearchState.NextStateIterator
            public SearchState nextState() {
                lookAhead();
                double[] dArr = this.weights;
                int i = this.prev;
                this.prev = i + 1;
                this.weight = dArr[i];
                this.found = false;
                return MaxLatticeDefault.this.getViterbiNode(ViterbiNode.this.inputPosition - 1, this.prev - 1);
            }

            @Override // cc.mallet.util.search.SearchState.NextStateIterator
            public double cost() {
                return -this.weight;
            }

            public double weight() {
                return this.weight;
            }
        }

        ViterbiNode(int i, Transducer.State state) {
            this.inputPosition = i;
            this.state = state;
        }

        @Override // cc.mallet.util.search.AStarState
        public double completionCost() {
            return -this.delta;
        }

        @Override // cc.mallet.util.search.SearchState
        public boolean isFinal() {
            return this.inputPosition == 0 && this.state.getInitialWeight() > Double.NEGATIVE_INFINITY;
        }

        @Override // cc.mallet.util.search.SearchState
        public SearchState.NextStateIterator getNextStates() {
            return new PreviousStateIterator();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/mallet/fst/MaxLatticeDefault$WeightCache.class */
    public class WeightCache {
        private WeightCache prev;
        private WeightCache next;
        private double[][] weight;
        private int position;

        private WeightCache(int i) {
            this.weight = new double[MaxLatticeDefault.this.t.numStates()][MaxLatticeDefault.this.t.numStates()];
            init(i);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void init(int i) {
            this.position = i;
            for (int i2 = 0; i2 < MaxLatticeDefault.this.t.numStates(); i2++) {
                for (int i3 = 0; i3 < MaxLatticeDefault.this.t.numStates(); i3++) {
                    this.weight[i2][i3] = Double.NEGATIVE_INFINITY;
                }
            }
        }
    }

    @Override // cc.mallet.fst.MaxLattice
    public Transducer getTransducer() {
        return this.t;
    }

    public Sequence getInput() {
        return this.input;
    }

    public Sequence getProvidedOutput() {
        return this.providedOutput;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public WeightCache getCache(int i) {
        WeightCache weightCache = this.caches[i];
        if (weightCache == null) {
            if (this.numCaches < this.maxCaches) {
                weightCache = new WeightCache(i);
                int i2 = this.numCaches;
                this.numCaches = i2 + 1;
                if (i2 == 0) {
                    this.last = weightCache;
                    this.first = weightCache;
                }
            } else {
                weightCache = this.last;
                this.caches[weightCache.position] = null;
                weightCache.init(i);
            }
            for (int i3 = 0; i3 < this.t.numStates(); i3++) {
                if (this.lattice[i][i3] != null && this.lattice[i][i3].delta != Double.NEGATIVE_INFINITY) {
                    Transducer.TransitionIterator transitionIterator = this.t.getState(i3).transitionIterator(this.input, i, this.providedOutput, i);
                    while (transitionIterator.hasNext()) {
                        weightCache.weight[i3][transitionIterator.next().getIndex()] = transitionIterator.getWeight();
                    }
                }
            }
            this.caches[i] = weightCache;
        }
        if (weightCache != this.first) {
            if (weightCache == this.last) {
                this.last = weightCache.prev;
            }
            if (weightCache.prev != null) {
                weightCache.prev.next = weightCache.next;
            }
            weightCache.next = this.first;
            weightCache.prev = null;
            this.first.prev = weightCache;
            this.first = weightCache;
        }
        return weightCache;
    }

    protected ViterbiNode getViterbiNode(int i, int i2) {
        if (this.lattice[i][i2] == null) {
            this.lattice[i][i2] = new ViterbiNode(i, this.t.getState(i2));
        }
        return this.lattice[i][i2];
    }

    public MaxLatticeDefault(Transducer transducer, Sequence sequence) {
        this(transducer, sequence, null, 100000);
    }

    public MaxLatticeDefault(Transducer transducer, Sequence sequence, Sequence sequence2) {
        this(transducer, sequence, sequence2, 100000);
    }

    public MaxLatticeDefault(Transducer transducer, Sequence sequence, Sequence sequence2, int i) {
        this.viterbiNodeAlignmentCache = null;
        this.stateAlignmentCache = null;
        this.outputAlignmentCache = null;
        this.t = transducer;
        this.maxCaches = i < 1 ? 1 : i;
        if (!$assertionsDisabled && sequence == null) {
            throw new AssertionError();
        }
        if (logger.isLoggable(Level.FINE)) {
            logger.fine("Starting ViterbiLattice");
            logger.fine("Input: ");
            for (int i2 = 0; i2 < sequence.size(); i2++) {
                logger.fine(" " + sequence.get(i2));
            }
            logger.fine("\nOutput: ");
            if (sequence2 == null) {
                logger.fine(Configurator.NULL);
            } else {
                for (int i3 = 0; i3 < sequence2.size(); i3++) {
                    logger.fine(" " + sequence2.get(i3));
                }
            }
            logger.fine("\n");
        }
        this.input = sequence;
        this.providedOutput = sequence2;
        this.latticeLength = this.input.size() + 1;
        int numStates = transducer.numStates();
        this.lattice = new ViterbiNode[this.latticeLength][numStates];
        this.caches = new WeightCache[this.latticeLength - 1];
        logger.fine("Starting Viterbi");
        boolean z = false;
        for (int i4 = 0; i4 < numStates; i4++) {
            double initialWeight = transducer.getState(i4).getInitialWeight();
            if (initialWeight > Double.NEGATIVE_INFINITY) {
                getViterbiNode(0, i4).delta = initialWeight;
                z = true;
            }
        }
        if (!z) {
            logger.warning("Viterbi: No initial states!");
        }
        int i5 = 0;
        while (i5 < this.latticeLength - 1) {
            for (int i6 = 0; i6 < numStates; i6++) {
                if (this.lattice[i5][i6] != null && this.lattice[i5][i6].delta != Double.NEGATIVE_INFINITY) {
                    Transducer.State state = transducer.getState(i6);
                    Transducer.TransitionIterator transitionIterator = state.transitionIterator(this.input, i5, this.providedOutput, i5);
                    if (logger.isLoggable(Level.FINE)) {
                        logger.fine(" Starting Viterbi transition iteration from state " + state.getName() + " on input " + this.input.get(i5));
                    }
                    while (transitionIterator.hasNext()) {
                        Transducer.State next = transitionIterator.next();
                        if (logger.isLoggable(Level.FINE)) {
                            logger.fine("Viterbi[inputPos=" + i5 + "][source=" + state.getName() + "][dest=" + next.getName() + "]");
                        }
                        ViterbiNode viterbiNode = getViterbiNode(i5 + 1, next.getIndex());
                        viterbiNode.output = transitionIterator.getOutput();
                        double weight = this.lattice[i5][i6].delta + transitionIterator.getWeight();
                        weight = i5 == this.latticeLength - 2 ? weight + next.getFinalWeight() : weight;
                        if (weight > viterbiNode.delta) {
                            if (logger.isLoggable(Level.FINE)) {
                                logger.fine("Viterbi[inputPos=" + i5 + "][source][dest=" + next.getName() + "] weight increased to " + weight + " by source=" + state.getName());
                            }
                            viterbiNode.delta = weight;
                            viterbiNode.maxWeightPredecessor = this.lattice[i5][i6];
                        }
                    }
                }
            }
            i5++;
        }
    }

    @Override // cc.mallet.fst.MaxLattice
    public double getDelta(int i, int i2) {
        if (this.lattice != null) {
            return getViterbiNode(i, i2).delta;
        }
        throw new RuntimeException("Attempt to called getDelta() when lattice not stored.");
    }

    public List<SequencePairAlignment<Object, ViterbiNode>> bestViterbiNodeSequences(int i) {
        if (this.viterbiNodeAlignmentCache != null && this.viterbiNodeAlignmentCache.size() >= i) {
            return this.viterbiNodeAlignmentCache;
        }
        int i2 = 0;
        for (int i3 = 0; i3 < this.t.numStates(); i3++) {
            if (this.lattice[this.latticeLength - 1][i3] != null && this.lattice[this.latticeLength - 1][i3].delta > Double.NEGATIVE_INFINITY) {
                i2++;
            }
        }
        ViterbiNode[] viterbiNodeArr = new ViterbiNode[i2];
        int i4 = 0;
        for (int i5 = 0; i5 < this.t.numStates(); i5++) {
            if (this.lattice[this.latticeLength - 1][i5] != null && this.lattice[this.latticeLength - 1][i5].delta > Double.NEGATIVE_INFINITY) {
                int i6 = i4;
                i4++;
                viterbiNodeArr[i6] = this.lattice[this.latticeLength - 1][i5];
            }
        }
        AStar aStar = new AStar(viterbiNodeArr, this.latticeLength * this.t.numStates());
        ArrayList arrayList = new ArrayList(i);
        for (int i7 = 0; i7 < i && aStar.hasNext(); i7++) {
            AStarNode next = aStar.next();
            double d = -next.getCost();
            ViterbiNode[] viterbiNodeArr2 = new ViterbiNode[this.latticeLength];
            for (int i8 = 0; i8 < this.latticeLength; i8++) {
                ViterbiNode viterbiNode = (ViterbiNode) next.getState();
                if (!$assertionsDisabled && viterbiNode.inputPosition != i8) {
                    throw new AssertionError();
                }
                viterbiNodeArr2[i8] = viterbiNode;
                next = next.getParent();
            }
            arrayList.add(new SequencePairAlignment(this.input, new ArraySequence(viterbiNodeArr2), d));
        }
        this.viterbiNodeAlignmentCache = arrayList;
        return arrayList;
    }

    public List<SequencePairAlignment<Object, Transducer.State>> bestStateAlignments(int i) {
        if (this.stateAlignmentCache != null && this.stateAlignmentCache.size() >= i) {
            return this.stateAlignmentCache;
        }
        bestViterbiNodeSequences(i);
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            Transducer.State[] stateArr = new Transducer.State[this.latticeLength];
            Sequence<ViterbiNode> output = this.viterbiNodeAlignmentCache.get(i2).output();
            for (int i3 = 0; i3 < this.latticeLength; i3++) {
                stateArr[i3] = output.get(i3).state;
            }
            arrayList.add(new SequencePairAlignment(this.input, new ArraySequence(stateArr), this.viterbiNodeAlignmentCache.get(i2).getWeight()));
        }
        this.stateAlignmentCache = arrayList;
        return arrayList;
    }

    public SequencePairAlignment<Object, Transducer.State> bestStateAlignment() {
        return bestStateAlignments(1).get(0);
    }

    @Override // cc.mallet.fst.MaxLattice
    public List<Sequence<Transducer.State>> bestStateSequences(int i) {
        List<SequencePairAlignment<Object, Transducer.State>> bestStateAlignments = bestStateAlignments(i);
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(bestStateAlignments.get(i2).output());
        }
        return arrayList;
    }

    @Override // cc.mallet.fst.MaxLattice
    public Sequence<Transducer.State> bestStateSequence() {
        return bestStateAlignments(1).get(0).output();
    }

    public List<SequencePairAlignment<Object, Object>> bestOutputAlignments(int i) {
        if (this.outputAlignmentCache != null && this.outputAlignmentCache.size() >= i) {
            return this.outputAlignmentCache;
        }
        bestViterbiNodeSequences(i);
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            Object[] objArr = new Object[this.latticeLength - 1];
            Sequence<ViterbiNode> output = this.viterbiNodeAlignmentCache.get(i2).output();
            for (int i3 = 0; i3 < this.latticeLength - 1; i3++) {
                objArr[i3] = output.get(i3 + 1).output;
            }
            arrayList.add(new SequencePairAlignment(this.input, new ArraySequence(objArr), this.viterbiNodeAlignmentCache.get(i2).getWeight()));
        }
        this.outputAlignmentCache = arrayList;
        return arrayList;
    }

    public SequencePairAlignment<Object, Object> bestOutputAlignment() {
        return bestOutputAlignments(1).get(0);
    }

    @Override // cc.mallet.fst.MaxLattice
    public List<Sequence<Object>> bestOutputSequences(int i) {
        bestOutputAlignments(i);
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(this.outputAlignmentCache.get(i2).output());
        }
        return arrayList;
    }

    @Override // cc.mallet.fst.MaxLattice
    public Sequence<Object> bestOutputSequence() {
        return bestOutputAlignments(1).get(0).output();
    }

    public double bestWeight() {
        return bestOutputAlignments(1).get(0).getWeight();
    }

    public void incrementTransducer(Transducer.Incrementor incrementor) {
        SequencePairAlignment<Object, ViterbiNode> sequencePairAlignment = bestViterbiNodeSequences(1).get(0);
        int size = sequencePairAlignment.output().size();
        if (!$assertionsDisabled && size != sequencePairAlignment.input().size()) {
            throw new AssertionError();
        }
        incrementor.incrementInitialState(sequencePairAlignment.output().get(0).state, 1.0d);
        incrementor.incrementFinalState(sequencePairAlignment.output().get(size - 1).state, 1.0d);
        for (int i = 0; i < sequencePairAlignment.input().size() - 1; i++) {
            Transducer.TransitionIterator transitionIterator = sequencePairAlignment.output().get(i).state.transitionIterator(this.input, i, this.providedOutput, i);
            int i2 = 0;
            while (transitionIterator.hasNext()) {
                if (transitionIterator.next().equals(sequencePairAlignment.output().get(i + 1).state) && transitionIterator.getOutput().equals(sequencePairAlignment.output().get(i).output)) {
                    incrementor.incrementTransition(transitionIterator, 1.0d);
                    i2++;
                }
            }
            if (i2 > 1) {
                throw new IllegalStateException("More than one satisfying transition found.");
            }
            if (i2 == 0) {
                throw new IllegalStateException("No satisfying transition found.");
            }
        }
    }

    @Override // cc.mallet.fst.MaxLattice
    public double elementwiseAccuracy(Sequence sequence) {
        int i = 0;
        Sequence<Object> bestOutputSequence = bestOutputSequence();
        if (!$assertionsDisabled && sequence.size() != bestOutputSequence.size()) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < bestOutputSequence.size(); i2++) {
            if (sequence.get(i2).toString().equals(bestOutputSequence.get(i2).toString())) {
                i++;
            }
        }
        logger.info("Number correct: " + i + " out of " + bestOutputSequence.size());
        return i / bestOutputSequence.size();
    }

    public double tokenAccuracy(Sequence sequence, PrintWriter printWriter) {
        Sequence<Object> bestOutputSequence = bestOutputSequence();
        int i = 0;
        if (!$assertionsDisabled && sequence.size() != bestOutputSequence.size()) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < bestOutputSequence.size(); i2++) {
            String obj = bestOutputSequence.get(i2).toString();
            if (printWriter != null) {
                printWriter.println(obj);
            }
            if (sequence.get(i2).toString().equals(obj)) {
                i++;
            }
        }
        logger.info("Number correct: " + i + " out of " + bestOutputSequence.size());
        return i / bestOutputSequence.size();
    }

    static {
        $assertionsDisabled = !MaxLatticeDefault.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(MaxLatticeDefault.class.getName());
    }
}
