package weka.classifiers.trees;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Vector;
import org.apache.commons.io.IOUtils;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.functions.SimpleLinearRegression;
import weka.classifiers.rules.ZeroR;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.TestInstances;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.RemoveUseless;

/* loaded from: input_file:weka/classifiers/trees/AlternatingModelTree.class */
public class AlternatingModelTree extends AbstractClassifier implements WeightedInstancesHandler, IterativeClassifier, Drawable, TechnicalInformationHandler, RevisionHandler {
    static final long serialVersionUID = -7716785668198681288L;
    protected Instances m_Data;
    protected int m_NumberOfIterations = 10;
    protected double m_Shrinkage = 1.0d;
    protected boolean m_BuildDecisionTree;
    protected ArrayList<PredictionNode> m_PredictionNodes;
    private NominalToBinary m_nominalToBinary;
    private RemoveUseless m_removeUseless;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/classifiers/trees/AlternatingModelTree$PredictionNode.class */
    public class PredictionNode implements Serializable {
        protected Classifier m_Model;
        protected List<SplitterNode> m_Successors;
        protected int[] m_Indices;
        protected int m_Size;

        protected String toString(String str) {
            StringBuilder sb = new StringBuilder();
            sb.append(str);
            if (this.m_Model instanceof ZeroR) {
                sb.append("Pred = " + this.m_Model.toString().replaceAll("ZeroR predicts class value: ", "") + " (" + this.m_Size + ")\n");
            } else {
                sb.append("Pred = " + this.m_Model.toString().replaceAll("Linear(.*\\n)", "").replaceAll("Predicting", " :").replaceAll("if attribute value is missing.", "for ?").replaceAll("(\\r|\\n)", "") + " (" + this.m_Size + ")\n");
            }
            Iterator<SplitterNode> it = this.m_Successors.iterator();
            while (it.hasNext()) {
                sb.append(it.next().toString(str));
            }
            return sb.toString();
        }

        protected PredictionNode(Instances instances) throws Exception {
            this.m_Model = null;
            int[] iArr = new int[instances.numInstances()];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = i;
            }
            this.m_Indices = iArr;
            this.m_Model = new ZeroR();
            this.m_Model.buildClassifier(instances);
            AlternatingModelTree.this.calculateResiduals(this.m_Model, iArr, 1.0d);
            this.m_Size = this.m_Indices.length;
            this.m_Successors = new LinkedList();
        }

        protected PredictionNode(int[] iArr) throws Exception {
            this.m_Model = null;
            this.m_Indices = iArr;
            this.m_Model = buildModel(getData(iArr));
            AlternatingModelTree.this.calculateResiduals(this.m_Model, iArr, AlternatingModelTree.this.m_Shrinkage);
            this.m_Size = this.m_Indices.length;
            this.m_Successors = new LinkedList();
        }

        protected Instances getData(int[] iArr) {
            Instances instances = new Instances(AlternatingModelTree.this.m_Data, 0);
            for (int i : iArr) {
                instances.add(AlternatingModelTree.this.m_Data.instance(i));
            }
            return instances;
        }

        protected Classifier buildModel(Instances instances) throws Exception {
            if (instances.numInstances() == 0) {
                ZeroR zeroR = new ZeroR();
                zeroR.buildClassifier(instances);
                return zeroR;
            }
            SimpleLinearRegression simpleLinearRegression = new SimpleLinearRegression();
            simpleLinearRegression.setSuppressErrorMessage(true);
            simpleLinearRegression.setDoNotCheckCapabilities(true);
            simpleLinearRegression.buildClassifier(instances);
            return simpleLinearRegression;
        }

        protected double evaluateModel(Instances instances) throws Exception {
            Classifier buildModel = buildModel(instances);
            double d = 0.0d;
            Iterator<Instance> it = instances.iterator();
            while (it.hasNext()) {
                Instance next = it.next();
                double classValue = next.classValue() - (AlternatingModelTree.this.m_Shrinkage * buildModel.classifyInstance(next));
                d += next.weight() * classValue * classValue;
            }
            return d;
        }

        protected SplitInfo evaluateNodeExpansion() throws Exception {
            if (AlternatingModelTree.this.m_Debug) {
                System.out.println(toString(""));
            }
            if (this.m_Indices.length == 0) {
                return null;
            }
            if (AlternatingModelTree.this.m_BuildDecisionTree && this.m_Successors.size() >= 1) {
                return null;
            }
            SplitInfo splitInfo = new SplitInfo();
            double d = 0.0d;
            for (int i = 0; i < this.m_Indices.length; i++) {
                Instance instance = AlternatingModelTree.this.m_Data.instance(this.m_Indices[i]);
                d += instance.weight() * instance.classValue() * instance.classValue();
            }
            if (AlternatingModelTree.this.m_Debug) {
                System.err.println("Current SSE: " + d);
            }
            for (int i2 = 0; i2 < AlternatingModelTree.this.m_Data.numAttributes(); i2++) {
                if (i2 != AlternatingModelTree.this.m_Data.classIndex()) {
                    if (AlternatingModelTree.this.m_Debug) {
                        System.err.println(AlternatingModelTree.this.m_Data.attribute(i2));
                    }
                    double[] dArr = new double[this.m_Indices.length];
                    for (int i3 = 0; i3 < this.m_Indices.length; i3++) {
                        dArr[i3] = AlternatingModelTree.this.m_Data.instance(this.m_Indices[i3]).value(i2);
                    }
                    double kthSmallestValue = Utils.kthSmallestValue(dArr, dArr.length / 2);
                    if (AlternatingModelTree.this.m_Debug) {
                        System.err.println("median: " + kthSmallestValue);
                    }
                    ArrayList<Integer> arrayList = new ArrayList<>(dArr.length);
                    ArrayList<Integer> arrayList2 = new ArrayList<>(dArr.length);
                    ArrayList<Integer> arrayList3 = new ArrayList<>(dArr.length);
                    for (int i4 = 0; i4 < dArr.length; i4++) {
                        if (Utils.isMissingValue(dArr[i4])) {
                            arrayList3.add(Integer.valueOf(this.m_Indices[i4]));
                        } else if (dArr[i4] <= kthSmallestValue) {
                            arrayList.add(Integer.valueOf(this.m_Indices[i4]));
                        } else {
                            arrayList2.add(Integer.valueOf(this.m_Indices[i4]));
                        }
                    }
                    if ((arrayList.size() != 0 || arrayList3.size() != 0) && ((arrayList2.size() != 0 || arrayList3.size() != 0) && (arrayList2.size() != 0 || arrayList.size() != 0))) {
                        double evaluateModel = evaluateModel(getData(toIntArray(arrayList)));
                        double evaluateModel2 = evaluateModel(getData(toIntArray(arrayList2)));
                        double evaluateModel3 = evaluateModel(getData(toIntArray(arrayList3)));
                        double d2 = d - ((evaluateModel + evaluateModel2) + evaluateModel3);
                        if (AlternatingModelTree.this.m_Debug) {
                            System.err.println("firstSSE " + evaluateModel);
                            System.err.println("secondSSE " + evaluateModel2);
                            System.err.println("missingSSE " + evaluateModel3);
                            System.err.println("errorReduction " + d2);
                        }
                        if (d2 > splitInfo.m_Worth) {
                            splitInfo.m_Worth = d2;
                            splitInfo.m_AttributeIndex = i2;
                            splitInfo.m_Split = kthSmallestValue;
                        }
                    }
                }
            }
            if (splitInfo.m_AttributeIndex < 0) {
                return null;
            }
            if (AlternatingModelTree.this.m_Debug) {
                System.err.println(splitInfo);
            }
            return splitInfo;
        }

        protected SplitterNode expandNode(SplitInfo splitInfo) throws Exception {
            SplitterNode splitterNode = new SplitterNode();
            splitterNode.m_Split = splitInfo.m_Split;
            splitterNode.m_AttributeIndex = splitInfo.m_AttributeIndex;
            ArrayList<Integer> arrayList = new ArrayList<>();
            ArrayList<Integer> arrayList2 = new ArrayList<>();
            ArrayList<Integer> arrayList3 = new ArrayList<>();
            for (int i = 0; i < this.m_Indices.length; i++) {
                if (AlternatingModelTree.this.m_Data.instance(this.m_Indices[i]).isMissing(splitInfo.m_AttributeIndex)) {
                    arrayList3.add(Integer.valueOf(this.m_Indices[i]));
                } else if (AlternatingModelTree.this.m_Data.instance(this.m_Indices[i]).value(splitInfo.m_AttributeIndex) <= splitInfo.m_Split) {
                    arrayList.add(Integer.valueOf(this.m_Indices[i]));
                } else {
                    arrayList2.add(Integer.valueOf(this.m_Indices[i]));
                }
            }
            splitterNode.m_Left = new PredictionNode(toIntArray(arrayList));
            splitterNode.m_Right = new PredictionNode(toIntArray(arrayList2));
            splitterNode.m_Missing = new PredictionNode(toIntArray(arrayList3));
            this.m_Successors.add(splitterNode);
            return splitterNode;
        }

        protected int[] toIntArray(ArrayList<Integer> arrayList) {
            int[] iArr = new int[arrayList.size()];
            int i = 0;
            Iterator<Integer> it = arrayList.iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                iArr[i2] = it.next().intValue();
            }
            return iArr;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/classifiers/trees/AlternatingModelTree$SplitInfo.class */
    public class SplitInfo implements Serializable {
        protected int m_AttributeIndex = -1;
        protected double m_Split = -1.7976931348623157E308d;
        protected double m_Worth = -1.7976931348623157E308d;

        protected SplitInfo() {
        }

        public String toString() {
            return AlternatingModelTree.this.m_Data.attribute(this.m_AttributeIndex).name() + TestInstances.DEFAULT_SEPARATORS + this.m_Split + TestInstances.DEFAULT_SEPARATORS + this.m_Worth;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/classifiers/trees/AlternatingModelTree$SplitterNode.class */
    public class SplitterNode implements Serializable {
        protected int m_AttributeIndex;
        protected double m_Split;
        protected PredictionNode m_Left;
        protected PredictionNode m_Right;
        protected PredictionNode m_Missing;

        protected SplitterNode() {
        }

        protected String toString(String str) {
            StringBuilder sb = new StringBuilder();
            sb.append(str + AlternatingModelTree.this.m_Data.attribute(this.m_AttributeIndex).name() + " <= " + this.m_Split + IOUtils.LINE_SEPARATOR_UNIX);
            sb.append(this.m_Left.toString(str + "  | "));
            sb.append(str + AlternatingModelTree.this.m_Data.attribute(this.m_AttributeIndex).name() + " > " + this.m_Split + IOUtils.LINE_SEPARATOR_UNIX);
            sb.append(this.m_Right.toString(str + "  | "));
            sb.append(str + AlternatingModelTree.this.m_Data.attribute(this.m_AttributeIndex).name() + " = ?\n");
            sb.append(this.m_Missing.toString(str + "  | "));
            return sb.toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:weka/classifiers/trees/AlternatingModelTree$UnsafeInstance.class */
    public class UnsafeInstance extends DenseInstance {
        private static final long serialVersionUID = 3210674215118962869L;

        public UnsafeInstance(Instance instance) {
            super(instance.numAttributes());
            for (int i = 0; i < instance.numAttributes(); i++) {
                this.m_AttValues[i] = instance.value(i);
            }
            this.m_Weight = instance.weight();
        }

        @Override // weka.core.DenseInstance, weka.core.Instance
        public void setValue(int i, double d) {
            this.m_AttValues[i] = d;
        }

        @Override // weka.core.DenseInstance, weka.core.Copyable
        public Object copy() {
            return this;
        }
    }

    public String globalInfo() {
        return "Grows an alternating model tree by minimising squared error. Nominal attributes are converted to binary numeric ones before the tree is built, using the supervised version of NominalToBinary.\n\nFor more information see\n\n" + getTechnicalInformation();
    }

    @Override // weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Eibe Frank, Michael Mayo and Stefan Kramer");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Alternating Model Trees");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2015");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Proceedings of the ACM Symposium on Applied Computing, Data Mining Track");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "ACM Press");
        return technicalInformation;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector(3);
        vector.addElement(new Option("\tSet the number of iterations to perform. (default 10).", "I", 1, "-I <number of iterations>"));
        vector.addElement(new Option("\tSet shrinkage parameter (default 1.0).", "H", 1, "-H <double>"));
        vector.addElement(new Option("\tBuild a decision tree instead of an alternating tree.", "B", 0, "-B"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        Collections.addAll(vector, super.getOptions());
        vector.add("-I");
        vector.add("" + getNumberOfIterations());
        vector.add("-H");
        vector.add("" + getShrinkage());
        if (getBuildDecisionTree()) {
            vector.add("-B");
        }
        return (String[]) vector.toArray(new String[0]);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        super.setOptions(strArr);
        String option = Utils.getOption('I', strArr);
        if (option.length() != 0) {
            this.m_NumberOfIterations = Integer.parseInt(option);
        } else {
            this.m_NumberOfIterations = 10;
        }
        String option2 = Utils.getOption('H', strArr);
        if (option2.length() != 0) {
            setShrinkage(new Double(option2).doubleValue());
        } else {
            setShrinkage(1.0d);
        }
        setBuildDecisionTree(Utils.getFlag('B', strArr));
        Utils.checkForRemainingOptions(strArr);
    }

    public String numberOfIterationsTipText() {
        return "Sets the number of iterations to perform.";
    }

    public int getNumberOfIterations() {
        return this.m_NumberOfIterations;
    }

    public void setNumberOfIterations(int i) {
        this.m_NumberOfIterations = i;
    }

    public String shrinkageTipText() {
        return "The value of the shrinkage parameter.";
    }

    public double getShrinkage() {
        return this.m_Shrinkage;
    }

    public void setShrinkage(double d) {
        this.m_Shrinkage = d;
    }

    public String buildDecisionTreeTipText() {
        return "Set to true if a decision tree is to be built.";
    }

    public boolean getBuildDecisionTree() {
        return this.m_BuildDecisionTree;
    }

    public void setBuildDecisionTree(boolean z) {
        this.m_BuildDecisionTree = z;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.DATE_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return capabilities;
    }

    protected void calculateResiduals(Classifier classifier, int[] iArr, double d) throws Exception {
        for (int i = 0; i < iArr.length; i++) {
            Instance instance = this.m_Data.instance(iArr[i]);
            instance.setClassValue(instance.classValue() - (d * classifier.classifyInstance(this.m_Data.instance(iArr[i]))));
        }
    }

    @Override // weka.core.Drawable
    public int graphType() {
        return 1;
    }

    @Override // weka.core.Drawable
    public String graph() throws Exception {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("digraph AMTree {\n");
        graphTraverse(this.m_PredictionNodes.get(0), stringBuffer, "P");
        return stringBuffer.toString() + "}\n";
    }

    protected void graphTraverse(PredictionNode predictionNode, StringBuffer stringBuffer, String str) throws Exception {
        stringBuffer.append(str + " [label=\"");
        if (predictionNode.m_Model instanceof ZeroR) {
            stringBuffer.append(predictionNode.m_Model.toString().replaceAll("ZeroR predicts class value: ", "") + " (" + predictionNode.m_Size + ")");
        } else {
            stringBuffer.append(predictionNode.m_Model.toString().replaceAll("Linear(.*\\n)", "").replaceAll("Predicting", " :").replaceAll("if attribute value is missing.", "for ?").replaceAll("(\\r|\\n)", "") + " (" + predictionNode.m_Size + ")");
        }
        stringBuffer.append("\" shape=box style=filled");
        stringBuffer.append("]\n");
        int i = 0;
        for (SplitterNode splitterNode : predictionNode.m_Successors) {
            int i2 = i;
            i++;
            String str2 = str + "S" + i2;
            stringBuffer.append(str + "->" + str2 + " [style=dotted]\n");
            stringBuffer.append(str2 + " [label=\"" + Utils.backQuoteChars(this.m_Data.attribute(splitterNode.m_AttributeIndex).name()) + "\"]\n");
            if (splitterNode.m_Left != null) {
                stringBuffer.append(str2 + "->" + str2 + "1P [label=\" <= " + splitterNode.m_Split + "\"]\n");
                graphTraverse(splitterNode.m_Left, stringBuffer, str2 + "1P");
            }
            if (splitterNode.m_Right != null) {
                stringBuffer.append(str2 + "->" + str2 + "2P [label=\" > " + splitterNode.m_Split + "\"]\n");
                graphTraverse(splitterNode.m_Right, stringBuffer, str2 + "2P");
            }
            if (splitterNode.m_Missing != null) {
                stringBuffer.append(str2 + "->" + str2 + "3P [label=\" == ? \"]\n");
                graphTraverse(splitterNode.m_Missing, stringBuffer, str2 + "3P");
            }
        }
    }

    public String toString() {
        return this.m_PredictionNodes == null ? "No model built yet." : this.m_PredictionNodes.get(0).toString("");
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double classifyInstance(Instance instance) throws Exception {
        this.m_nominalToBinary.input(instance);
        this.m_removeUseless.input(this.m_nominalToBinary.output());
        Instance output = this.m_removeUseless.output();
        LinkedList linkedList = new LinkedList();
        linkedList.add(this.m_PredictionNodes.get(0));
        double d = 0.0d;
        double d2 = 1.0d;
        while (true) {
            double d3 = d2;
            PredictionNode predictionNode = (PredictionNode) linkedList.poll();
            if (predictionNode == null) {
                return d;
            }
            d += d3 * predictionNode.m_Model.classifyInstance(output);
            for (SplitterNode splitterNode : predictionNode.m_Successors) {
                if (output.isMissing(splitterNode.m_AttributeIndex)) {
                    linkedList.add(splitterNode.m_Missing);
                } else if (output.value(splitterNode.m_AttributeIndex) <= splitterNode.m_Split) {
                    linkedList.add(splitterNode.m_Left);
                } else {
                    linkedList.add(splitterNode.m_Right);
                }
            }
            d2 = this.m_Shrinkage;
        }
    }

    @Override // weka.classifiers.IterativeClassifier
    public void initializeClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        this.m_nominalToBinary = new NominalToBinary();
        this.m_nominalToBinary.setInputFormat(instances2);
        Instances useFilter = Filter.useFilter(instances2, this.m_nominalToBinary);
        this.m_removeUseless = new RemoveUseless();
        this.m_removeUseless.setInputFormat(useFilter);
        Instances useFilter2 = Filter.useFilter(useFilter, this.m_removeUseless);
        this.m_Data = new Instances(useFilter2, useFilter2.numInstances());
        Iterator<Instance> it = useFilter2.iterator();
        while (it.hasNext()) {
            this.m_Data.add((Instance) new UnsafeInstance(it.next()));
        }
        this.m_PredictionNodes = new ArrayList<>();
        this.m_PredictionNodes.add(new PredictionNode(this.m_Data));
    }

    @Override // weka.classifiers.IterativeClassifier
    public boolean next() throws Exception {
        SplitInfo splitInfo = null;
        PredictionNode predictionNode = null;
        Iterator<PredictionNode> it = this.m_PredictionNodes.iterator();
        while (it.hasNext()) {
            PredictionNode next = it.next();
            SplitInfo evaluateNodeExpansion = next.evaluateNodeExpansion();
            if (evaluateNodeExpansion != null && (splitInfo == null || evaluateNodeExpansion.m_Worth > splitInfo.m_Worth)) {
                splitInfo = evaluateNodeExpansion;
                predictionNode = next;
            }
        }
        if (splitInfo == null) {
            return false;
        }
        SplitterNode expandNode = predictionNode.expandNode(splitInfo);
        this.m_PredictionNodes.add(expandNode.m_Left);
        this.m_PredictionNodes.add(expandNode.m_Right);
        this.m_PredictionNodes.add(expandNode.m_Missing);
        return true;
    }

    @Override // weka.classifiers.IterativeClassifier
    public void done() throws Exception {
        this.m_Data = new Instances(this.m_Data, 0);
        Iterator<PredictionNode> it = this.m_PredictionNodes.iterator();
        while (it.hasNext()) {
            it.next().m_Indices = null;
        }
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        initializeClassifier(instances);
        if (instances.numAttributes() == 1) {
            done();
            return;
        }
        for (int i = 0; i < this.m_NumberOfIterations; i++) {
            next();
        }
        done();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: $");
    }

    public static void main(String[] strArr) {
        runClassifier(new AlternatingModelTree(), strArr);
    }
}
