package weka.classifiers.functions;

import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.TestInstances;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

/* loaded from: input_file:weka/classifiers/functions/SPegasos.class */
public class SPegasos extends AbstractClassifier implements TechnicalInformationHandler, UpdateableClassifier, OptionHandler {
    private static final long serialVersionUID = -3732968666673530290L;
    protected ReplaceMissingValues m_replaceMissing;
    protected NominalToBinary m_nominalToBinary;
    protected Normalize m_normalize;
    protected double[] m_weights;
    protected double m_t;
    protected Instances m_data;
    protected static final int HINGE = 0;
    protected static final int LOGLOSS = 1;
    public static final Tag[] TAGS_SELECTION = {new Tag(0, "Hinge loss (SVM)"), new Tag(1, "Log loss (logistic regression)")};
    protected double m_lambda = 1.0E-4d;
    protected int m_epochs = 500;
    protected boolean m_dontNormalize = false;
    protected boolean m_dontReplaceMissing = false;
    protected int m_loss = 0;

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.setMinimumNumberInstances(0);
        return capabilities;
    }

    public String lambdaTipText() {
        return "The regularization constant. (default = 0.0001)";
    }

    public void setLambda(double d) {
        this.m_lambda = d;
    }

    public double getLambda() {
        return this.m_lambda;
    }

    public String epochsTipText() {
        return "The number of epochs to perform (batch learning). The total number of iterations is epochs * num instances.";
    }

    public void setEpochs(int i) {
        this.m_epochs = i;
    }

    public int getEpochs() {
        return this.m_epochs;
    }

    public void setDontNormalize(boolean z) {
        this.m_dontNormalize = z;
    }

    public boolean getDontNormalize() {
        return this.m_dontNormalize;
    }

    public String dontNormalizeTipText() {
        return "Turn normalization off";
    }

    public void setDontReplaceMissing(boolean z) {
        this.m_dontReplaceMissing = z;
    }

    public boolean getDontReplaceMissing() {
        return this.m_dontReplaceMissing;
    }

    public String dontReplaceMissingTipText() {
        return "Turn off global replacement of missing values";
    }

    public void setLossFunction(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_SELECTION) {
            this.m_loss = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getLossFunction() {
        return new SelectedTag(this.m_loss, TAGS_SELECTION);
    }

    public String lossFunctionTipText() {
        return "The loss function to use. Hinge loss (SVM) or log loss (logistic regression).";
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector();
        vector.add(new Option("\tSet the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression).\n\t(default = 0)", "F", 1, "-F"));
        vector.add(new Option("\tThe lambda regularization constant (default = 0.0001)", "L", 1, "-L <double>"));
        vector.add(new Option("\tThe number of epochs to perform (batch learning only, default = 500)", "E", 1, "-E <integer>"));
        vector.add(new Option("\tDon't normalize the data", "N", 0, "-N"));
        vector.add(new Option("\tDon't replace missing values", "M", 0, "-M"));
        return vector.elements();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        reset();
        String option = Utils.getOption('F', strArr);
        if (option.length() != 0) {
            setLossFunction(new SelectedTag(Integer.parseInt(option), TAGS_SELECTION));
        } else {
            setLossFunction(new SelectedTag(0, TAGS_SELECTION));
        }
        String option2 = Utils.getOption('L', strArr);
        if (option2.length() > 0) {
            setLambda(Double.parseDouble(option2));
        }
        String option3 = Utils.getOption("E", strArr);
        if (option3.length() > 0) {
            setEpochs(Integer.parseInt(option3));
        }
        setDontNormalize(Utils.getFlag("N", strArr));
        setDontReplaceMissing(Utils.getFlag('M', strArr));
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        arrayList.add("-F");
        arrayList.add(new StringBuilder().append(getLossFunction().getSelectedTag().getID()).toString());
        arrayList.add("-L");
        arrayList.add(new StringBuilder().append(getLambda()).toString());
        arrayList.add("-E");
        arrayList.add(new StringBuilder().append(getEpochs()).toString());
        if (getDontNormalize()) {
            arrayList.add("-N");
        }
        if (getDontReplaceMissing()) {
            arrayList.add("-M");
        }
        return (String[]) arrayList.toArray(new String[1]);
    }

    public String globalInfo() {
        return "Implements the stochastic variant of the Pegasos (Primal Estimated sub-GrAdient SOlver for SVM) method of Shalev-Shwartz et al. (2007). This implementation globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data. Can either minimize the hinge loss (SVM) or log loss (logistic regression). For more information, see\n\n" + getTechnicalInformation().toString();
    }

    @Override // weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "S. Shalev-Shwartz and Y. Singer and N. Srebro");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2007");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Pegasos: Primal Estimated sub-GrAdient SOlver for SVM");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "24th International Conference on MachineLearning");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "807-814");
        return technicalInformation;
    }

    public void reset() {
        this.m_t = 1.0d;
        this.m_weights = null;
        this.m_normalize = null;
        this.m_replaceMissing = null;
        this.m_nominalToBinary = null;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        reset();
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        if (instances2.numInstances() > 0 && !this.m_dontReplaceMissing) {
            this.m_replaceMissing = new ReplaceMissingValues();
            this.m_replaceMissing.setInputFormat(instances2);
            instances2 = Filter.useFilter(instances2, this.m_replaceMissing);
        }
        boolean z = true;
        int i = 0;
        while (true) {
            if (i >= instances2.numAttributes()) {
                break;
            }
            if (i != instances2.classIndex() && !instances2.attribute(i).isNumeric()) {
                z = false;
                break;
            }
            i++;
        }
        if (!z) {
            this.m_nominalToBinary = new NominalToBinary();
            this.m_nominalToBinary.setInputFormat(instances2);
            instances2 = Filter.useFilter(instances2, this.m_nominalToBinary);
        }
        if (!this.m_dontNormalize && instances2.numInstances() > 0) {
            this.m_normalize = new Normalize();
            this.m_normalize.setInputFormat(instances2);
            instances2 = Filter.useFilter(instances2, this.m_normalize);
        }
        this.m_weights = new double[instances2.numAttributes() + 1];
        this.m_data = new Instances(instances2, 0);
        if (instances2.numInstances() > 0) {
            train(instances2);
        }
    }

    protected double dloss(double d) {
        if (this.m_loss == 0) {
            return d < 1.0d ? 1 : 0;
        }
        if (d < KStarConstants.FLOOR) {
            return 1.0d / (Math.exp(d) + 1.0d);
        }
        double exp = Math.exp(-d);
        return exp / (exp + 1.0d);
    }

    private void train(Instances instances) {
        for (int i = 0; i < this.m_epochs; i++) {
            for (int i2 = 0; i2 < instances.numInstances(); i2++) {
                Instance instance = instances.instance(i2);
                double d = 1.0d / (this.m_lambda * this.m_t);
                double d2 = 1.0d - (1.0d / this.m_t);
                double d3 = instance.classValue() == KStarConstants.FLOOR ? -1 : 1;
                double dotProd = d3 * (dotProd(instance, this.m_weights, instance.classIndex()) + this.m_weights[this.m_weights.length - 1]);
                if (this.m_loss == 1 || dotProd < 1.0d) {
                    double dloss = d * dloss(dotProd);
                    int numValues = instance.numValues();
                    int numAttributes = instances.numAttributes();
                    int i3 = 0;
                    for (int i4 = 0; i4 < numAttributes; i4++) {
                        int index = i3 < numValues ? instance.index(i3) : 0;
                        int i5 = i4;
                        if (i5 != instances.classIndex()) {
                            double[] dArr = this.m_weights;
                            dArr[i5] = dArr[i5] * d2;
                        }
                        if (index == i5) {
                            if (index != instances.classIndex() && !instance.isMissingSparse(i3)) {
                                double valueSparse = dloss * instance.valueSparse(i3) * d3;
                                double[] dArr2 = this.m_weights;
                                dArr2[index] = dArr2[index] + valueSparse;
                            }
                            i3++;
                        }
                    }
                    double[] dArr3 = this.m_weights;
                    int length = this.m_weights.length - 1;
                    dArr3[length] = dArr3[length] + (dloss * d3);
                    double d4 = 0.0d;
                    for (int i6 = 0; i6 < this.m_weights.length; i6++) {
                        if (i6 != instances.classIndex()) {
                            d4 += this.m_weights[i6] * this.m_weights[i6];
                        }
                    }
                    double min = Math.min(1.0d, 1.0d / (Math.sqrt(this.m_lambda) * Math.sqrt(d4)));
                    if (min < 1.0d) {
                        for (int i7 = 0; i7 < this.m_weights.length; i7++) {
                            double[] dArr4 = this.m_weights;
                            int i8 = i7;
                            dArr4[i8] = dArr4[i8] * min;
                        }
                    }
                }
                this.m_t += 1.0d;
            }
        }
    }

    protected static double dotProd(Instance instance, double[] dArr, int i) {
        double d = 0.0d;
        int numValues = instance.numValues();
        int length = dArr.length - 1;
        int i2 = 0;
        int i3 = 0;
        while (i2 < numValues && i3 < length) {
            int index = instance.index(i2);
            int i4 = i3;
            if (index == i4) {
                if (index != i && !instance.isMissingSparse(i2)) {
                    d += instance.valueSparse(i2) * dArr[i3];
                }
                i2++;
                i3++;
            } else if (index > i4) {
                i3++;
            } else {
                i2++;
            }
        }
        return d;
    }

    @Override // weka.classifiers.UpdateableClassifier
    public void updateClassifier(Instance instance) throws Exception {
        if (instance.classIsMissing()) {
            return;
        }
        double d = 1.0d / (this.m_lambda * this.m_t);
        double d2 = 1.0d - (1.0d / this.m_t);
        double d3 = instance.classValue() == KStarConstants.FLOOR ? -1 : 1;
        double dotProd = d3 * (dotProd(instance, this.m_weights, instance.classIndex()) + this.m_weights[this.m_weights.length - 1]);
        for (int i = 0; i < this.m_weights.length; i++) {
            double[] dArr = this.m_weights;
            int i2 = i;
            dArr[i2] = dArr[i2] * d2;
        }
        if (this.m_loss == 1 || dotProd < 1.0d) {
            double dloss = d * dloss(dotProd);
            int numValues = instance.numValues();
            int numAttributes = instance.numAttributes();
            int i3 = 0;
            for (int i4 = 0; i4 < numAttributes; i4++) {
                int index = i3 < numValues ? instance.index(i3) : 0;
                int i5 = i4;
                if (i5 != instance.classIndex()) {
                    double[] dArr2 = this.m_weights;
                    dArr2[i5] = dArr2[i5] * d2;
                }
                if (index == i5) {
                    if (index != instance.classIndex() && !instance.isMissingSparse(i3)) {
                        double valueSparse = dloss * instance.valueSparse(i3) * d3;
                        double[] dArr3 = this.m_weights;
                        dArr3[index] = dArr3[index] + valueSparse;
                    }
                    i3++;
                }
            }
            double[] dArr4 = this.m_weights;
            int length = this.m_weights.length - 1;
            dArr4[length] = dArr4[length] + (dloss * d3);
            double d4 = 0.0d;
            for (int i6 = 0; i6 < this.m_weights.length; i6++) {
                if (i6 != instance.classIndex()) {
                    d4 += this.m_weights[i6] * this.m_weights[i6];
                }
            }
            double min = Math.min(1.0d, 1.0d / (Math.sqrt(this.m_lambda) * Math.sqrt(d4)));
            if (min < 1.0d) {
                for (int i7 = 0; i7 < this.m_weights.length; i7++) {
                    double[] dArr5 = this.m_weights;
                    int i8 = i7;
                    dArr5[i8] = dArr5[i8] * min;
                }
            }
        }
        this.m_t += 1.0d;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] dArr = new double[2];
        if (this.m_replaceMissing != null) {
            this.m_replaceMissing.input(instance);
            instance = this.m_replaceMissing.output();
        }
        if (this.m_nominalToBinary != null) {
            this.m_nominalToBinary.input(instance);
            instance = this.m_nominalToBinary.output();
        }
        if (this.m_normalize != null) {
            this.m_normalize.input(instance);
            instance = this.m_normalize.output();
        }
        double dotProd = dotProd(instance, this.m_weights, instance.classIndex()) + this.m_weights[this.m_weights.length - 1];
        if (dotProd <= KStarConstants.FLOOR) {
            if (this.m_loss == 1) {
                dArr[0] = 1.0d / (1.0d + Math.exp(dotProd));
                dArr[1] = 1.0d - dArr[0];
            } else {
                dArr[0] = 1.0d;
            }
        } else if (this.m_loss == 1) {
            dArr[1] = 1.0d / (1.0d + Math.exp(-dotProd));
            dArr[0] = 1.0d - dArr[1];
        } else {
            dArr[1] = 1.0d;
        }
        return dArr;
    }

    public String toString() {
        if (this.m_weights == null) {
            return "SPegasos: No model built yet.\n";
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Loss function: ");
        if (this.m_loss == 0) {
            stringBuffer.append("Hinge loss (SVM)\n\n");
        } else {
            stringBuffer.append("Log loss (logistic regression)\n\n");
        }
        int i = 0;
        for (int i2 = 0; i2 < this.m_weights.length - 1; i2++) {
            if (i2 != this.m_data.classIndex()) {
                if (i > 0) {
                    stringBuffer.append(" + ");
                } else {
                    stringBuffer.append("   ");
                }
                stringBuffer.append(String.valueOf(Utils.doubleToString(this.m_weights[i2], 12, 4)) + TestInstances.DEFAULT_SEPARATORS + (this.m_normalize != null ? "(normalized) " : "") + this.m_data.attribute(i2).name() + "\n");
                i++;
            }
        }
        if (this.m_weights[this.m_weights.length - 1] > KStarConstants.FLOOR) {
            stringBuffer.append(" + " + Utils.doubleToString(this.m_weights[this.m_weights.length - 1], 12, 4));
        } else {
            stringBuffer.append(" - " + Utils.doubleToString(-this.m_weights[this.m_weights.length - 1], 12, 4));
        }
        return stringBuffer.toString();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 6105 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new SPegasos(), strArr);
    }
}
