/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.fsm.TransducerGraph;
import edu.stanford.nlp.parser.lexparser.BinaryGrammar;
import edu.stanford.nlp.parser.lexparser.BinaryRule;
import edu.stanford.nlp.parser.lexparser.Rule;
import edu.stanford.nlp.parser.lexparser.Train;
import edu.stanford.nlp.parser.lexparser.UnaryGrammar;
import edu.stanford.nlp.parser.lexparser.UnaryRule;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Distribution;
import edu.stanford.nlp.util.Numberer;
import edu.stanford.nlp.util.Pair;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

public abstract class GrammarCompactor {
    Set compactedGraphs;
    public static final Object RAW_COUNTS = new Object();
    public static final Object NORMALIZED_LOG_PROBABILITIES = new Object();
    public Object outputType = RAW_COUNTS;
    protected Numberer stateNumberer;
    protected Numberer newStateNumberer;
    protected String stateSpace;
    String rawBaseDir = "raw";
    String compactedBaseDir = "compacted";
    boolean writeToFile = false;
    protected Distribution inputPrior;
    private static Object END = "END";
    private static Object EPSILON = "EPSILON";
    protected boolean verbose = false;

    protected abstract TransducerGraph doCompaction(TransducerGraph var1, List var2, List var3);

    public Pair compactGrammar(Pair grammar) {
        return this.compactGrammar(grammar, new HashMap(), new HashMap());
    }

    public Pair compactGrammar(Pair grammar, Map allTrainPaths, Map allTestPaths) {
        this.inputPrior = this.computeInputPrior(allTrainPaths);
        BinaryGrammar bg = (BinaryGrammar)grammar.second;
        this.stateSpace = bg.stateSpace();
        this.stateNumberer = Numberer.getGlobalNumberer(this.stateSpace);
        HashSet unaryRules = new HashSet();
        HashSet binaryRules = new HashSet();
        Map graphs = this.convertGrammarToGraphs(grammar, unaryRules, binaryRules);
        this.compactedGraphs = new HashSet();
        if (this.verbose) {
            System.out.println("There are " + graphs.size() + " categories to compact.");
        }
        int i = 0;
        Iterator graphIter = graphs.entrySet().iterator();
        while (graphIter.hasNext()) {
            ArrayList testPaths;
            ArrayList trainPaths;
            Map.Entry entry = graphIter.next();
            String cat = (String)entry.getKey();
            TransducerGraph graph = (TransducerGraph)entry.getValue();
            if (this.verbose) {
                System.out.println("About to compact grammar for " + cat + " with numNodes=" + graph.getNodes().size());
            }
            if ((trainPaths = (ArrayList)allTrainPaths.remove(cat)) == null) {
                trainPaths = new ArrayList();
            }
            if ((testPaths = (ArrayList)allTestPaths.remove(cat)) == null) {
                testPaths = new ArrayList();
            }
            TransducerGraph compactedGraph = this.doCompaction(graph, trainPaths, testPaths);
            ++i;
            if (this.verbose) {
                System.out.println(i + ". Compacted grammar for " + cat + " from " + graph.getArcs().size() + " arcs to " + compactedGraph.getArcs().size() + " arcs.");
            }
            graphIter.remove();
            this.compactedGraphs.add(compactedGraph);
        }
        Pair result = this.convertGraphsToGrammar(this.compactedGraphs, unaryRules, binaryRules);
        return result;
    }

    protected Distribution computeInputPrior(Map trainPathMap) {
        Counter result = new Counter();
        for (List pathList : trainPathMap.values()) {
            for (List path : pathList) {
                for (Object input : path) {
                    result.incrementCount(input);
                }
            }
        }
        return Distribution.laplaceSmoothedDistribution(result, result.size() * 2, 0.5);
    }

    private double smartNegate(double output) {
        if (this.outputType == NORMALIZED_LOG_PROBABILITIES) {
            return -output;
        }
        return output;
    }

    public static boolean writeFile(TransducerGraph graph, String dir, String name) {
        try {
            File baseDir = new File(dir);
            if (baseDir.exists() ? !baseDir.isDirectory() : !baseDir.mkdirs()) {
                return false;
            }
            File file = new File(baseDir, name + ".dot");
            try {
                PrintWriter w = new PrintWriter(new FileWriter(file));
                String dotString = graph.asDOTString();
                w.print(dotString);
                w.flush();
                w.close();
            }
            catch (FileNotFoundException e) {
                System.err.println("Failed to open file in writeToDOTfile: " + file);
                return false;
            }
            catch (IOException e) {
                System.err.println("Failed to open file in writeToDOTfile: " + file);
                return false;
            }
            return true;
        }
        catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    public Map convertGrammarToGraphs(Pair grammar, Set unaryRules, Set binaryRules) {
        boolean wasAdded;
        int numRules = 0;
        UnaryGrammar ug = (UnaryGrammar)grammar.first;
        BinaryGrammar bg = (BinaryGrammar)grammar.second;
        HashMap graphs = new HashMap();
        for (Rule rule : bg) {
            ++numRules;
            wasAdded = this.addOneBinaryRule((BinaryRule)rule, graphs);
            if (wasAdded) continue;
            binaryRules.add(rule);
        }
        for (Rule rule : ug) {
            ++numRules;
            wasAdded = this.addOneUnaryRule((UnaryRule)rule, graphs);
            if (wasAdded) continue;
            unaryRules.add(rule);
        }
        if (this.verbose) {
            System.out.println("Number of raw rules: " + numRules);
            System.out.println("Number of raw states: " + this.stateNumberer.total());
        }
        return graphs;
    }

    protected TransducerGraph getGraphFromMap(Map m, Object o) {
        TransducerGraph graph = (TransducerGraph)m.get(o);
        if (graph == null) {
            graph = new TransducerGraph();
            graph.setEndNode(o);
            m.put(o, graph);
        }
        return graph;
    }

    protected String getTopCategoryOfSyntheticState(String s) {
        if (s.charAt(0) != '@') {
            return null;
        }
        int bar = s.indexOf(124);
        if (bar < 0) {
            throw new RuntimeException("Grammar format error. Expected bar in state name: " + s);
        }
        String topcat = s.substring(1, bar);
        return topcat;
    }

    protected boolean addOneUnaryRule(UnaryRule rule, Map graphs) {
        String parentString = (String)this.stateNumberer.object(rule.parent);
        String childString = (String)this.stateNumberer.object(rule.child);
        if (this.isSyntheticState(parentString)) {
            String topcat = this.getTopCategoryOfSyntheticState(parentString);
            TransducerGraph graph = this.getGraphFromMap(graphs, topcat);
            Double output = new Double(this.smartNegate(rule.score()));
            graph.addArc(graph.getStartNode(), parentString, childString, output);
            return true;
        }
        if (this.isSyntheticState(childString)) {
            TransducerGraph graph = this.getGraphFromMap(graphs, parentString);
            Double output = new Double(this.smartNegate(rule.score()));
            graph.addArc(childString, parentString, END, output);
            graph.setEndNode(parentString);
            return true;
        }
        return false;
    }

    protected boolean addOneBinaryRule(BinaryRule rule, Map graphs) {
        String input;
        String source;
        String parentString = (String)this.stateNumberer.object(rule.parent);
        String leftString = (String)this.stateNumberer.object(rule.leftChild);
        String rightString = (String)this.stateNumberer.object(rule.rightChild);
        String bracket = null;
        if (Train.markFinalStates) {
            bracket = parentString.substring(parentString.length() - 1, parentString.length());
        }
        if (this.isSyntheticState(leftString)) {
            source = leftString;
            input = rightString + (bracket == null ? ">" : bracket);
        } else if (this.isSyntheticState(rightString)) {
            source = rightString;
            input = leftString + (bracket == null ? "<" : bracket);
        } else {
            return false;
        }
        String target = parentString;
        Double output = new Double(this.smartNegate(rule.score()));
        String topcat = this.getTopCategoryOfSyntheticState(source);
        if (topcat == null) {
            throw new RuntimeException("can't have null topcat");
        }
        TransducerGraph graph = this.getGraphFromMap(graphs, topcat);
        graph.addArc(source, target, input, output);
        return true;
    }

    protected boolean isSyntheticState(String state) {
        return state.charAt(0) == '@';
    }

    public Pair convertGraphsToGrammar(Set graphs, Set unaryRules, Set binaryRules) {
        Object parent;
        this.newStateNumberer = new Numberer();
        for (Rule rule : unaryRules) {
            parent = this.stateNumberer.object(rule.parent);
            rule.parent = this.newStateNumberer.number(parent);
            Object child = this.stateNumberer.object(rule.child);
            rule.child = this.newStateNumberer.number(child);
        }
        for (Rule rule : binaryRules) {
            parent = this.stateNumberer.object(((BinaryRule)rule).parent);
            ((BinaryRule)rule).parent = this.newStateNumberer.number(parent);
            Object leftChild = this.stateNumberer.object(((BinaryRule)rule).leftChild);
            ((BinaryRule)rule).leftChild = this.newStateNumberer.number(leftChild);
            Object rightChild = this.stateNumberer.object(((BinaryRule)rule).rightChild);
            ((BinaryRule)rule).rightChild = this.newStateNumberer.number(rightChild);
        }
        Map numbs = Numberer.getNumberers();
        numbs.put(this.stateSpace, this.newStateNumberer);
        for (TransducerGraph graph : graphs) {
            Object startNode = graph.getStartNode();
            for (TransducerGraph.Arc arc : graph.getArcs()) {
                Object source = arc.getSourceNode();
                Object target = arc.getTargetNode();
                Object input = arc.getInput();
                String inputString = input.toString();
                double output = (Double)arc.getOutput();
                if (source.equals(startNode)) {
                    UnaryRule ur = new UnaryRule(this.newStateNumberer.number(target), this.newStateNumberer.number(inputString), this.smartNegate(output));
                    unaryRules.add(ur);
                    continue;
                }
                if (inputString.equals(END) || inputString.equals(EPSILON)) {
                    UnaryRule ur = new UnaryRule(this.newStateNumberer.number(target), this.newStateNumberer.number(source), this.smartNegate(output));
                    unaryRules.add(ur);
                    continue;
                }
                int length = inputString.length();
                char leftOrRight = inputString.charAt(length - 1);
                inputString = inputString.substring(0, length - 1);
                BinaryRule br = null;
                if (leftOrRight == '<' || leftOrRight == '[') {
                    br = new BinaryRule(this.newStateNumberer.number(target), this.newStateNumberer.number(inputString), this.newStateNumberer.number(source), this.smartNegate(output));
                } else if (leftOrRight == '>' || leftOrRight == ']') {
                    br = new BinaryRule(this.newStateNumberer.number(target), this.newStateNumberer.number(source), this.newStateNumberer.number(inputString), this.smartNegate(output));
                } else {
                    throw new RuntimeException("Arc input is in unexpected format: " + arc);
                }
                binaryRules.add(br);
            }
        }
        Counter<Object> symbolCounter = new Counter<Object>();
        if (this.outputType == RAW_COUNTS) {
            for (Rule rule : unaryRules) {
                symbolCounter.incrementCount(this.newStateNumberer.object(rule.parent), rule.score);
            }
            for (Rule rule : binaryRules) {
                symbolCounter.incrementCount(this.newStateNumberer.object(((BinaryRule)rule).parent), ((BinaryRule)rule).score);
            }
        }
        int numStates = this.newStateNumberer.total();
        int numRules = 0;
        UnaryGrammar ug = new UnaryGrammar(numStates);
        BinaryGrammar bg = new BinaryGrammar(numStates);
        for (Rule rule : unaryRules) {
            if (this.outputType == RAW_COUNTS) {
                double count = symbolCounter.getCount(this.newStateNumberer.object(rule.parent));
                rule.score = (float)Math.log((double)rule.score / count);
            }
            ug.addRule((UnaryRule)rule);
            ++numRules;
        }
        for (Rule rule : binaryRules) {
            if (this.outputType == RAW_COUNTS) {
                double count = symbolCounter.getCount(this.newStateNumberer.object(((BinaryRule)rule).parent));
                ((BinaryRule)rule).score = (float)Math.log(((double)((BinaryRule)rule).score - Train.ruleDiscount) / count);
            }
            bg.addRule((BinaryRule)rule);
            ++numRules;
        }
        if (this.verbose) {
            System.out.println("Number of minimized rules: " + numRules);
            System.out.println("Number of minimized states: " + this.newStateNumberer.total());
        }
        ug.purgeRules();
        bg.splitRules();
        return new Pair<UnaryGrammar, BinaryGrammar>(ug, bg);
    }
}

