package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import org.apache.commons.io.IOUtils;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TestInstances;
import weka.core.Utils;
import weka.core.matrix.Matrix;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.RemoveUseless;

/* loaded from: input_file:weka/classifiers/functions/FLDA.class */
public class FLDA extends AbstractClassifier {
    static final long serialVersionUID = -9212385698193681291L;
    protected Instances m_Data;
    protected Matrix m_Weights;
    protected double m_Threshold;
    protected double m_Ridge = 1.0E-6d;
    protected RemoveUseless m_RemoveUseless;

    public String globalInfo() {
        return "Builds Fisher's Linear Discriminant function. The threshold is selected so that the separator is half-way between centroids. The class must be binary and all other attributes must be numeric. Missing values are not permitted. Constant attributes are removed using RemoveUseless. No standardization or normalization of attributes is performed.";
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.setMinimumNumberInstances(0);
        return capabilities;
    }

    protected Matrix[] getClassMeans(Instances instances, int[] iArr) {
        double[][] dArr = new double[2][instances.numAttributes() - 1];
        for (int i = 0; i < instances.numInstances(); i++) {
            Instance instance = instances.instance(i);
            int i2 = 0;
            for (int i3 = 0; i3 < instances.numAttributes(); i3++) {
                if (i3 != instances.classIndex()) {
                    double[] dArr2 = dArr[(int) instance.classValue()];
                    int i4 = i2;
                    i2++;
                    dArr2[i4] = dArr2[i4] + instance.value(i3);
                }
            }
            int classValue = (int) instance.classValue();
            iArr[classValue] = iArr[classValue] + 1;
        }
        Matrix[] matrixArr = new Matrix[2];
        for (int i5 = 0; i5 < 2; i5++) {
            matrixArr[i5] = new Matrix(dArr[i5], 1);
            matrixArr[i5].timesEquals(1.0d / iArr[i5]);
        }
        if (this.m_Debug) {
            System.out.println("Count for class 0: " + iArr[0]);
            System.out.println("Centroid 0:" + matrixArr[0]);
            System.out.println("Count for class 11: " + iArr[1]);
            System.out.println("Centroid 1:" + matrixArr[1]);
        }
        return matrixArr;
    }

    protected Matrix[] getCenteredData(Instances instances, int[] iArr, Matrix[] matrixArr) {
        Matrix[] matrixArr2 = new Matrix[2];
        for (int i = 0; i < 2; i++) {
            matrixArr2[i] = new Matrix(iArr[i], instances.numAttributes() - 1);
        }
        int[] iArr2 = new int[2];
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            Instance instance = instances.instance(i2);
            int classValue = (int) instance.classValue();
            int i3 = 0;
            for (int i4 = 0; i4 < instances.numAttributes(); i4++) {
                if (i4 != instances.classIndex()) {
                    matrixArr2[classValue].set(iArr2[classValue], i3, instance.value(i4) - matrixArr[classValue].get(0, i3));
                    i3++;
                }
            }
            iArr2[classValue] = iArr2[classValue] + 1;
        }
        if (this.m_Debug) {
            System.out.println("Centered data for class 0:\n" + matrixArr2[0]);
            System.out.println("Centered data for class 1:\n" + matrixArr2[1]);
        }
        return matrixArr2;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        this.m_RemoveUseless = new RemoveUseless();
        this.m_RemoveUseless.setInputFormat(instances);
        Instances useFilter = Filter.useFilter(instances, this.m_RemoveUseless);
        useFilter.deleteWithMissingClass();
        int[] iArr = new int[2];
        Matrix[] classMeans = getClassMeans(useFilter, iArr);
        Matrix minus = classMeans[0].minus(classMeans[1]);
        Matrix[] centeredData = getCenteredData(useFilter, iArr, classMeans);
        Matrix plus = centeredData[0].transpose().times(centeredData[0]).plus(centeredData[1].transpose().times(centeredData[1]));
        plus.plusEquals(Matrix.identity(useFilter.numAttributes() - 1, useFilter.numAttributes() - 1).timesEquals(this.m_Ridge));
        if (this.m_Debug) {
            System.out.println("Scatter:\n" + plus);
        }
        this.m_Weights = plus.inverse().times(minus.transpose());
        this.m_Weights.timesEquals(1.0d / Math.sqrt(this.m_Weights.transpose().times(this.m_Weights).get(0, 0)));
        this.m_Threshold = 0.5d * this.m_Weights.transpose().times(classMeans[0].transpose().plus(classMeans[1].transpose())).get(0, 0);
        this.m_Data = new Instances(useFilter, 0);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        this.m_RemoveUseless.input(instance);
        Instance output = this.m_RemoveUseless.output();
        Matrix matrix = new Matrix(1, output.numAttributes() - 1);
        int i = 0;
        for (int i2 = 0; i2 < output.numAttributes(); i2++) {
            if (i2 != this.m_Data.classIndex()) {
                int i3 = i;
                i++;
                matrix.set(0, i3, output.value(i2));
            }
        }
        double[] dArr = {1.0d - dArr[1], 1.0d / (1.0d + Math.exp(matrix.times(this.m_Weights).get(0, 0) - this.m_Threshold))};
        return dArr;
    }

    public String toString() {
        if (this.m_Weights == null) {
            return "No model has been built yet.";
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Fisher's Linear Discriminant Analysis\n\n");
        stringBuffer.append("Threshold: " + this.m_Threshold + "\n\n");
        stringBuffer.append("Weights:\n\n");
        int i = 0;
        for (int i2 = 0; i2 < this.m_Data.numAttributes(); i2++) {
            if (i2 != this.m_Data.classIndex()) {
                stringBuffer.append(this.m_Data.attribute(i2).name() + ": \t");
                int i3 = i;
                i++;
                double d = this.m_Weights.get(i3, 0);
                if (d >= KStarConstants.FLOOR) {
                    stringBuffer.append(TestInstances.DEFAULT_SEPARATORS);
                }
                stringBuffer.append(d + IOUtils.LINE_SEPARATOR_UNIX);
            }
        }
        return stringBuffer.toString();
    }

    public String ridgeTipText() {
        return "The value of the ridge parameter.";
    }

    public double getRidge() {
        return this.m_Ridge;
    }

    public void setRidge(double d) {
        this.m_Ridge = d;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector(7);
        vector.addElement(new Option("\tThe ridge parameter.\n\t(default is 1e-6)", "R", 0, "-R"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('R', strArr);
        if (option.length() != 0) {
            setRidge(Double.parseDouble(option));
        } else {
            setRidge(1.0E-6d);
        }
        super.setOptions(strArr);
        Utils.checkForRemainingOptions(strArr);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-R");
        vector.add("" + getRidge());
        Collections.addAll(vector, super.getOptions());
        return (String[]) vector.toArray(new String[0]);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10382 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new FLDA(), strArr);
    }
}
