/*
 * Decompiled with CFR 0.152.
 */
package beast.core;

import beast.core.BEASTObject;
import beast.core.Description;
import beast.core.Input;
import beast.core.Operator;
import beast.core.util.Log;
import beast.util.Randomizer;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Formatter;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;

@Description(value="Specify operator selection and optimisation schedule")
public class OperatorSchedule
extends BEASTObject {
    public final Input<OptimisationTransform> transformInput = new Input<OptimisationTransform>("transform", "transform optimisation schedule (default none) This can be " + Arrays.toString((Object[])OptimisationTransform.values()) + " (default 'none')", OptimisationTransform.none, OptimisationTransform.values());
    public final Input<Boolean> autoOptimiseInput = new Input<Boolean>("autoOptimize", "whether to automatically optimise operator settings", true);
    public final Input<Boolean> detailedRejectionInput = new Input<Boolean>("detailedRejection", "true if detailed rejection statistics should be included. (default=false)", false);
    public final Input<Integer> autoOptimizeDelayInput = new Input<Integer>("autoOptimizeDelay", "number of samples to skip before auto optimisation kicks in (default=10000)", 10000);
    public final Input<List<Operator>> operatorsInput = new Input("operator", "operator that the schedule can choose from. Any operators added by other classes (e.g. MCMC) will be added if there are no duplicates.", new ArrayList());
    public final Input<List<OperatorSchedule>> subschedulesInput = new Input("subschedule", "operator schedule representing a subset ofthe weight of the operators it contains.", new ArrayList());
    public final Input<Double> weightInput = new Input<Double>("weight", "weight with which this operator schedule is selected. Only used when this operator schedule is nested inside other schedules. This weight is relative to other operators and operator schedules of the parent schedule.", 100.0);
    public final Input<Boolean> weightIsPercentageInput = new Input<Boolean>("weightIsPercentage", "indicates weight is a percentage of total weight instead of a relative weight", false);
    public final Input<String> operatorPatternInput = new Input("operatorPattern", "Regular expression matching operator IDs of operators of parent schedule");
    public List<Operator> operators = new ArrayList<Operator>();
    double totalWeight = 0.0;
    double[] cumulativeProbs;
    String stateFileName;
    protected int autoOptimizeDelay = 10000;
    protected int autoOptimizeDelayCount = 0;
    OptimisationTransform transform = OptimisationTransform.none;
    boolean autoOptimise = true;
    boolean detailedRejection = false;
    private boolean reweighted = false;
    private static final String TUNING = "Tuning";
    private static final String NUM_ACCEPT = "#accept";
    private static final String NUM_REJECT = "#reject";
    private static final String PR_M = "Pr(m)";
    private static final String PR_ACCEPT = "Pr(acc|m)";

    @Override
    public void initAndValidate() {
        this.transform = this.transformInput.get();
        this.autoOptimise = this.autoOptimiseInput.get();
        this.autoOptimizeDelay = this.autoOptimizeDelayInput.get();
        this.detailedRejection = this.detailedRejectionInput.get();
        this.operators.addAll((Collection<Operator>)this.operatorsInput.get());
        for (Operator operator : this.operators) {
            operator.setOperatorSchedule(this);
        }
        double d = 0.0;
        for (OperatorSchedule object : this.subschedulesInput.get()) {
            if (!object.weightIsPercentageInput.get().booleanValue()) continue;
            d += object.weightInput.get().doubleValue();
        }
        if (d > 100.0) {
            throw new IllegalArgumentException("Sum of percentages of subschedules should not exceed 100%. Reduce the weight of subschedules.");
        }
        if (Math.abs(d - 100.0) < 1.0E-6 && this.operators.size() > 0) {
            throw new IllegalArgumentException("Sum of percentages of subschedules add to 100%, so operators in main schedule will be ignored. Reduce the weight of subschedules.");
        }
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        linkedHashSet.addAll(this.operators);
        for (OperatorSchedule operatorSchedule : this.subschedulesInput.get()) {
            for (Operator operator : operatorSchedule.operators) {
                if (linkedHashSet.contains(operator)) {
                    Log.warning("WARNING: Operator " + operator.getID() + " is contained in multiple operator schedules.\n" + "Operator weighting may not work as expected.");
                }
                linkedHashSet.add(operator);
            }
        }
    }

    public void setStateFileName(String string) {
        this.stateFileName = string;
    }

    public void addOperator(Operator operator) {
        for (Operator operator2 : this.operators) {
            if (operator2 != operator) continue;
            return;
        }
        this.operators.add(operator);
        operator.setOperatorSchedule(this);
        this.reweighted = false;
        this.totalWeight += operator.getWeight();
    }

    protected void addOperators(Collection<Operator> collection) {
        if (this.operatorPatternInput.get() == null || this.operatorPatternInput.get().trim().equals("")) {
            return;
        }
        String string = this.operatorPatternInput.get();
        for (Operator operator : collection) {
            if (operator.getID() == null || !operator.getID().matches(string)) continue;
            for (Operator operator2 : this.operators) {
                if (operator2 != operator) continue;
                return;
            }
            this.operators.add(operator);
        }
        this.reweighted = false;
    }

    public Operator selectOperator() {
        if (!this.reweighted) {
            this.reweightOperators();
            this.reweighted = true;
        }
        int n = Randomizer.randomChoice(this.cumulativeProbs);
        return this.operators.get(n);
    }

    public void showOperatorRates(PrintStream printStream) {
        Formatter formatter = new Formatter(printStream);
        int n = 0;
        for (Operator object2 : this.operators) {
            if (object2.getName().length() <= n) continue;
            n = object2.getName().length();
        }
        formatter.format("%-" + n + "s", "Operator");
        int n2 = 10;
        String string = " %" + n2 + "s";
        formatter.format(string, TUNING);
        formatter.format(string, NUM_ACCEPT);
        formatter.format(string, NUM_REJECT);
        if (this.detailedRejection) {
            formatter.format(string, "rej.inv");
            formatter.format(string, "rej.op");
        }
        formatter.format(string, PR_M);
        formatter.format(string, PR_ACCEPT);
        printStream.println();
        for (Operator operator : this.operators) {
            printStream.println(OperatorSchedule.prettyPrintOperator(operator, n, n2, 4, this.totalWeight, this.detailedRejection));
        }
        printStream.println();
        formatter.format(string, TUNING);
        printStream.println(": The value of the operator's tuning parameter, or '-' if the operator can't be optimized.");
        formatter.format(string, NUM_ACCEPT);
        printStream.println(": The total number of times a proposal by this operator has been accepted.");
        formatter.format(string, NUM_REJECT);
        printStream.println(": The total number of times a proposal by this operator has been rejected.");
        formatter.format(string, PR_M);
        printStream.println(": The probability this operator is chosen in a step of the MCMC (i.e. the normalized weight).");
        formatter.format(string, PR_ACCEPT);
        printStream.println(": The acceptance probability (#accept as a fraction of the total proposals for this operator).");
        printStream.println();
    }

    protected static String prettyPrintOperator(Operator operator, int n, int n2, int n3, double d, boolean bl) {
        double d2 = operator.getCoercableParameterValue();
        double d3 = (double)operator.m_nNrAccepted / (double)(operator.m_nNrAccepted + operator.m_nNrRejected);
        StringBuilder stringBuilder = new StringBuilder();
        Formatter formatter = new Formatter(stringBuilder);
        String string = " %" + n2 + "d";
        String string2 = " %" + n2 + "." + n3 + "f";
        formatter.format("%-" + n + "s", operator.getName());
        if (!Double.isNaN(d2)) {
            formatter.format(string2, d2);
        } else {
            formatter.format(" %" + n2 + "s", "-");
        }
        formatter.format(string, operator.m_nNrAccepted);
        formatter.format(string, operator.m_nNrRejected);
        if (bl) {
            formatter.format(string2, (double)operator.m_nNrRejectedInvalid / (double)operator.m_nNrRejected);
            formatter.format(string2, (double)operator.m_nNrRejectedOperator / (double)operator.m_nNrRejected);
        }
        if (d > 0.0) {
            formatter.format(string2, operator.getWeight() / d);
        }
        formatter.format(string2, d3);
        stringBuilder.append(" " + operator.getPerformanceSuggestion());
        formatter.close();
        return stringBuilder.toString();
    }

    public void storeToFile() throws IOException {
        File file = new File(this.stateFileName);
        PrintWriter printWriter = new PrintWriter(new FileWriter(file, true));
        printWriter.println("<!--");
        printWriter.println("{\"operators\":[");
        int n = 0;
        for (Operator operator : this.operators) {
            operator.storeToFile(printWriter);
            if (n++ >= this.operators.size() - 1) continue;
            printWriter.println(",");
        }
        printWriter.println("\n]}");
        printWriter.println("-->");
        printWriter.flush();
        printWriter.close();
    }

    public void restoreFromFile() throws IOException {
        String string = "";
        BufferedReader bufferedReader = new BufferedReader(new FileReader(this.stateFileName));
        while (bufferedReader.ready()) {
            string = string + bufferedReader.readLine() + "\n";
        }
        bufferedReader.close();
        int n = string.indexOf("</itsabeastystatewerein>") + 25 + 5;
        if (n >= string.length() - 4) {
            return;
        }
        string = string.substring(string.indexOf("</itsabeastystatewerein>") + 25 + 5, string.length() - 4);
        try {
            JSONObject jSONObject = new JSONObject(string);
            JSONArray jSONArray = jSONObject.getJSONArray("operators");
            this.autoOptimizeDelayCount = 0;
            for (int i = 0; i < jSONArray.length(); ++i) {
                JSONObject object = jSONArray.getJSONObject(i);
                String string2 = object.getString("id");
                boolean bl = false;
                if (!string2.equals("null")) {
                    for (Operator operator : this.operators) {
                        if (!string2.equals(operator.getID())) continue;
                        operator.restoreFromFile(object);
                        this.autoOptimizeDelayCount += operator.m_nNrAccepted + operator.m_nNrRejected;
                        bl = true;
                        break;
                    }
                }
                if (bl) continue;
                Log.warning.println("Operator (" + string2 + ") found in state file that is not in operator list any more");
            }
            for (Operator operator : this.operators) {
                if (operator.getID() != null) continue;
                Log.warning.println("Operator (" + operator.getClass() + ") found in BEAST file that could not be restored because it has not ID");
            }
        }
        catch (JSONException jSONException) {
            String[] stringArray = string.split("\n");
            this.autoOptimizeDelayCount = 0;
            for (int i = 0; i < this.operators.size() && i + 2 < stringArray.length; ++i) {
                String[] stringArray2 = stringArray[i + 1].split(" ");
                Operator operator = this.operators.get(i);
                if (operator.getID() == null && stringArray2[0].equals("null") || operator.getID().equals(stringArray2[0])) {
                    this.cumulativeProbs[i] = Double.parseDouble(stringArray2[1]);
                    if (!stringArray2[2].equals("NaN")) {
                        operator.setCoercableParameterValue(Double.parseDouble(stringArray2[2]));
                    }
                    operator.m_nNrAccepted = Integer.parseInt(stringArray2[3]);
                    operator.m_nNrRejected = Integer.parseInt(stringArray2[4]);
                    this.autoOptimizeDelayCount += operator.m_nNrAccepted + operator.m_nNrRejected;
                } else {
                    throw new RuntimeException("Cannot resume: operator order or set changed from previous run");
                }
                operator.m_nNrAcceptedForCorrection = Integer.parseInt(stringArray2[5]);
                operator.m_nNrRejectedForCorrection = Integer.parseInt(stringArray2[6]);
            }
        }
        this.showOperatorRates(System.err);
    }

    public double calcDelta(Operator operator, double d) {
        if (this.autoOptimizeDelayCount < this.autoOptimizeDelay || !this.autoOptimise) {
            ++this.autoOptimizeDelayCount;
            return 0.0;
        }
        double d2 = operator.getTargetAcceptanceProbability();
        double d3 = (double)(operator.m_nNrRejectedForCorrection + operator.m_nNrAcceptedForCorrection) + 1.0;
        switch (this.transform) {
            case log: {
                d3 = Math.log(d3 + 1.0);
                break;
            }
            case sqrt: {
                d3 = Math.sqrt(d3);
                break;
            }
            case none: {
                break;
            }
        }
        double d4 = 1.0 / d3 * (Math.exp(Math.min(d, 0.0)) - d2);
        if (d4 > -1.7976931348623157E308 && d4 < Double.MAX_VALUE) {
            return d4;
        }
        return 0.0;
    }

    private void reweightOperators() {
        LinkedHashSet<Operator> linkedHashSet = new LinkedHashSet<Operator>();
        LinkedHashSet<Operator> linkedHashSet2 = new LinkedHashSet<Operator>();
        linkedHashSet.addAll(this.operators);
        for (OperatorSchedule iterator22 : this.subschedulesInput.get()) {
            linkedHashSet.addAll(iterator22.operators);
        }
        for (OperatorSchedule operatorSchedule : this.subschedulesInput.get()) {
            operatorSchedule.addOperators(linkedHashSet);
            linkedHashSet2.addAll(operatorSchedule.operators);
        }
        linkedHashSet.addAll(linkedHashSet2);
        LinkedHashSet linkedHashSet3 = new LinkedHashSet();
        linkedHashSet3.addAll(linkedHashSet);
        linkedHashSet3.removeAll(linkedHashSet2);
        this.operators.clear();
        Iterator iterator = linkedHashSet3.iterator();
        while (iterator.hasNext()) {
            Operator operator = (Operator)iterator.next();
            this.operators.add(operator);
        }
        for (OperatorSchedule operatorSchedule : this.subschedulesInput.get()) {
            for (Operator operator : operatorSchedule.operators) {
                this.operators.add(operator);
            }
        }
        int n = this.operators.size();
        double[] dArray = new double[n];
        int n2 = 0;
        Iterator iterator2 = linkedHashSet3.iterator();
        while (iterator2.hasNext()) {
            Operator operator = (Operator)iterator2.next();
            dArray[n2++] = operator.getWeight();
        }
        for (OperatorSchedule operatorSchedule : this.subschedulesInput.get()) {
            for (Operator operator : operatorSchedule.operators) {
                dArray[n2++] = operator.getWeight();
            }
        }
        double d = 0.0;
        Iterator<Object> iterator3 = linkedHashSet3.iterator();
        while (iterator3.hasNext()) {
            Operator operator;
            operator = (Operator)iterator3.next();
            d += operator.getWeight();
        }
        double d2 = 0.0;
        double d3 = 0.0;
        for (OperatorSchedule operatorSchedule : this.subschedulesInput.get()) {
            if (operatorSchedule.weightIsPercentageInput.get().booleanValue()) {
                d2 += operatorSchedule.weightInput.get().doubleValue();
                continue;
            }
            d3 += operatorSchedule.weightInput.get().doubleValue();
        }
        double d4 = d2 >= 100.0 ? 100.0 : (d + d3) * 100.0 / (100.0 - d2);
        double d5 = 1.0 / d4;
        n2 = 0;
        Iterator<Object> iterator4 = linkedHashSet3.iterator();
        while (iterator4.hasNext()) {
            Operator operator = (Operator)iterator4.next();
            int n3 = n2++;
            dArray[n3] = dArray[n3] * d5;
        }
        for (OperatorSchedule operatorSchedule : this.subschedulesInput.get()) {
            d = 0.0;
            for (Operator operator : operatorSchedule.operators) {
                d += operator.getWeight();
            }
            double d6 = operatorSchedule.weightIsPercentageInput.get() == false ? operatorSchedule.weightInput.get() / d * (1.0 / d4) : operatorSchedule.weightInput.get() / 100.0 * 1.0 / d;
            for (Operator operator : operatorSchedule.operators) {
                int n4 = n2++;
                dArray[n4] = dArray[n4] * d6;
            }
        }
        this.cumulativeProbs = new double[dArray.length];
        this.cumulativeProbs[0] = dArray[0];
        for (n2 = 1; n2 < this.operators.size(); ++n2) {
            this.cumulativeProbs[n2] = dArray[n2] + this.cumulativeProbs[n2 - 1];
        }
    }

    public double[] getCummulativeProbs() {
        return (double[])this.cumulativeProbs.clone();
    }

    static enum OptimisationTransform {
        none,
        log,
        sqrt;

    }
}

