package weka.classifiers.functions;

import groovy.text.markup.DelegatingIndentWriter;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.io.IOUtils;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;

/* loaded from: input_file:weka/classifiers/functions/NonNegativeLogisticRegression.class */
public class NonNegativeLogisticRegression extends RandomizableClassifier {
    static final long serialVersionUID = -1223158323933117974L;
    protected double[] m_weights;
    protected Instances m_data;
    protected double[][] m_matrix;
    protected int m_numThreads = 1;
    protected int m_poolSize = 1;
    protected transient ExecutorService m_Pool = null;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/classifiers/functions/NonNegativeLogisticRegression$OptEng.class */
    public class OptEng extends Optimization {
        protected OptEng() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // weka.core.Optimization
        public double objectiveFunction(double[] dArr) throws Exception {
            NonNegativeLogisticRegression.this.m_weights = dArr;
            int length = NonNegativeLogisticRegression.this.m_matrix.length / NonNegativeLogisticRegression.this.m_numThreads;
            HashSet hashSet = new HashSet();
            int i = 0;
            while (i < NonNegativeLogisticRegression.this.m_numThreads) {
                final int i2 = i * length;
                final int length2 = i < NonNegativeLogisticRegression.this.m_numThreads - 1 ? i2 + length : NonNegativeLogisticRegression.this.m_matrix.length;
                hashSet.add(NonNegativeLogisticRegression.this.m_Pool.submit(new Callable<Double>() { // from class: weka.classifiers.functions.NonNegativeLogisticRegression.OptEng.1
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public Double call() {
                        double d = 0.0d;
                        for (int i3 = i2; i3 < length2; i3++) {
                            double d2 = 0.0d;
                            for (int i4 = 0; i4 < NonNegativeLogisticRegression.this.m_matrix[i3].length; i4++) {
                                d2 += NonNegativeLogisticRegression.this.m_weights[i4] * NonNegativeLogisticRegression.this.m_matrix[i3][i4];
                            }
                            d -= ((-d2) * NonNegativeLogisticRegression.this.m_matrix[i3][NonNegativeLogisticRegression.this.m_data.classIndex()]) - Math.log(1.0d + Math.exp(-d2));
                        }
                        return Double.valueOf(d);
                    }
                }));
                i++;
            }
            double d = 0.0d;
            try {
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    d += ((Double) ((Future) it.next()).get()).doubleValue();
                }
            } catch (Exception e) {
                System.out.println("NLL could not be calculated.");
            }
            return d;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // weka.core.Optimization
        public double[] evaluateGradient(double[] dArr) throws Exception {
            NonNegativeLogisticRegression.this.m_weights = dArr;
            int length = NonNegativeLogisticRegression.this.m_matrix.length / NonNegativeLogisticRegression.this.m_numThreads;
            HashSet hashSet = new HashSet();
            int i = 0;
            while (i < NonNegativeLogisticRegression.this.m_numThreads) {
                final int i2 = i * length;
                final int length2 = i < NonNegativeLogisticRegression.this.m_numThreads - 1 ? i2 + length : NonNegativeLogisticRegression.this.m_matrix.length;
                hashSet.add(NonNegativeLogisticRegression.this.m_Pool.submit(new Callable<double[]>() { // from class: weka.classifiers.functions.NonNegativeLogisticRegression.OptEng.2
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public double[] call() {
                        double[] dArr2 = new double[NonNegativeLogisticRegression.this.m_data.numAttributes()];
                        for (int i3 = i2; i3 < length2; i3++) {
                            double d = 0.0d;
                            for (int i4 = 0; i4 < NonNegativeLogisticRegression.this.m_matrix[i3].length; i4++) {
                                d += NonNegativeLogisticRegression.this.m_weights[i4] * NonNegativeLogisticRegression.this.m_matrix[i3][i4];
                            }
                            double d2 = NonNegativeLogisticRegression.this.m_matrix[i3][NonNegativeLogisticRegression.this.m_data.classIndex()];
                            double exp = Math.exp(-d);
                            for (int i5 = 0; i5 < NonNegativeLogisticRegression.this.m_matrix[i3].length; i5++) {
                                if (i5 != NonNegativeLogisticRegression.this.m_data.classIndex()) {
                                    if (d2 == KStarConstants.FLOOR) {
                                        int i6 = i5;
                                        dArr2[i6] = dArr2[i6] - ((exp / (1.0d + exp)) * NonNegativeLogisticRegression.this.m_matrix[i3][i5]);
                                    } else {
                                        int i7 = i5;
                                        dArr2[i7] = dArr2[i7] + ((1.0d / (1.0d + exp)) * NonNegativeLogisticRegression.this.m_matrix[i3][i5]);
                                    }
                                }
                            }
                        }
                        return dArr2;
                    }
                }));
                i++;
            }
            double[] dArr2 = new double[NonNegativeLogisticRegression.this.m_data.numAttributes()];
            try {
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    double[] dArr3 = (double[]) ((Future) it.next()).get();
                    for (int i3 = 0; i3 < dArr3.length; i3++) {
                        int i4 = i3;
                        dArr2[i4] = dArr2[i4] + dArr3[i3];
                    }
                }
            } catch (Exception e) {
                System.out.println("Gradient could not be calculated.");
            }
            return dArr2;
        }

        @Override // weka.core.RevisionHandler
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 1111 $");
        }
    }

    public String globalInfo() {
        return "Class for learning a logistic regression model that has non-negative coefficients. The first class value is assumed to be the positive class value (i.e. 1.0).";
    }

    public double[] getCoefficients() {
        return Arrays.copyOf(this.m_weights, this.m_weights.length);
    }

    public String numThreadsTipText() {
        return "The number of threads to use, which should be >= size of thread pool.";
    }

    public int getNumThreads() {
        return this.m_numThreads;
    }

    public void setNumThreads(int i) {
        this.m_numThreads = i;
    }

    public String poolSizeTipText() {
        return "The size of the thread pool, for example, the number of cores in the CPU.";
    }

    public int getPoolSize() {
        return this.m_poolSize;
    }

    public void setPoolSize(int i) {
        this.m_poolSize = i;
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector(2);
        vector.addElement(new Option(DelegatingIndentWriter.TAB + poolSizeTipText() + " (default 1)\n", "P", 1, "-P <int>"));
        vector.addElement(new Option(DelegatingIndentWriter.TAB + numThreadsTipText() + " (default 1)\n", "E", 1, "-E <int>"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('P', strArr);
        if (option.length() != 0) {
            setPoolSize(Integer.parseInt(option));
        } else {
            setPoolSize(1);
        }
        String option2 = Utils.getOption('E', strArr);
        if (option2.length() != 0) {
            setNumThreads(Integer.parseInt(option2));
        } else {
            setNumThreads(1);
        }
        super.setOptions(strArr);
        Utils.checkForRemainingOptions(strArr);
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-P");
        vector.add("" + getPoolSize());
        vector.add("-E");
        vector.add("" + getNumThreads());
        Collections.addAll(vector, super.getOptions());
        return (String[]) vector.toArray(new String[0]);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return capabilities;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        this.m_data = new Instances(instances);
        this.m_data.deleteWithMissingClass();
        this.m_data.randomize(this.m_data.getRandomNumberGenerator(getSeed()));
        this.m_matrix = new double[instances.numInstances()][instances.numAttributes()];
        double[] dArr = new double[this.m_data.numAttributes()];
        for (int i = 0; i < this.m_data.numInstances(); i++) {
            Instance instance = this.m_data.instance(i);
            for (int i2 = 0; i2 < this.m_data.numAttributes(); i2++) {
                this.m_matrix[i][i2] = instance.value(i2);
                double abs = Math.abs(this.m_matrix[i][i2]);
                if (i == 0 || abs > dArr[i2]) {
                    dArr[i2] = abs;
                }
            }
        }
        for (int i3 = 0; i3 < this.m_data.numInstances(); i3++) {
            for (int i4 = 0; i4 < this.m_data.numAttributes(); i4++) {
                if (dArr[i4] > KStarConstants.FLOOR && i4 != this.m_data.classIndex()) {
                    double[] dArr2 = this.m_matrix[i3];
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] / dArr[i4];
                }
            }
        }
        this.m_data = new Instances(this.m_data, 0);
        this.m_weights = new double[this.m_data.numAttributes()];
        for (int i6 = 0; i6 < this.m_weights.length; i6++) {
            if (i6 != this.m_data.classIndex() && dArr[i6] > KStarConstants.FLOOR) {
                this.m_weights[i6] = 1.0d / (this.m_weights.length - 1);
            }
        }
        double[][] dArr3 = new double[2][this.m_weights.length];
        for (int i7 = 0; i7 < this.m_weights.length; i7++) {
            if (i7 == this.m_data.classIndex()) {
                dArr3[0][i7] = Double.NaN;
            } else {
                dArr3[0][i7] = 0.0d;
            }
            dArr3[1][i7] = Double.NaN;
        }
        this.m_Pool = Executors.newFixedThreadPool(this.m_poolSize);
        OptEng optEng = new OptEng();
        optEng.setDebug(this.m_Debug);
        this.m_weights = optEng.findArgmin(this.m_weights, dArr3);
        while (this.m_weights == null) {
            this.m_weights = optEng.getVarbValues();
            if (this.m_Debug) {
                System.out.println("First set of iterations finished, not enough!");
            }
            this.m_weights = optEng.findArgmin(this.m_weights, dArr3);
        }
        this.m_Pool.shutdown();
        for (int i8 = 0; i8 < this.m_weights.length; i8++) {
            if (i8 != this.m_data.classIndex() && dArr[i8] > KStarConstants.FLOOR) {
                double[] dArr4 = this.m_weights;
                int i9 = i8;
                dArr4[i9] = dArr4[i9] / dArr[i8];
            }
        }
        this.m_matrix = (double[][]) null;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        double d = 0.0d;
        for (int i = 0; i < instance.numAttributes(); i++) {
            if (i != this.m_data.classIndex()) {
                d += this.m_weights[i] * instance.value(i);
            }
        }
        double[] dArr = {1.0d / (1.0d + Math.exp(-d)), 1.0d - dArr[0]};
        return dArr;
    }

    public String toString() {
        if (this.m_data == null) {
            return "Classifier not built yet.";
        }
        String str = "\nlog(x / (1 - x))\t=\n";
        int i = 0;
        while (i < this.m_data.numAttributes()) {
            if (i != this.m_data.classIndex()) {
                str = (i > 0 ? str + "\t+  " : str + "\t   ") + this.m_data.attribute(i).name() + "   \t* " + Utils.doubleToString(this.m_weights[i], 6) + IOUtils.LINE_SEPARATOR_UNIX;
            }
            i++;
        }
        return str;
    }

    public static void main(String[] strArr) {
        runClassifier(new NonNegativeLogisticRegression(), strArr);
    }
}
