package weka.filters.supervised.instance;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Vector;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.SupervisedFilter;

/* loaded from: input_file:weka/filters/supervised/instance/CostSensitiveResample.class */
public class CostSensitiveResample extends Filter implements SupervisedFilter, OptionHandler {
    static final long serialVersionUID = 7079064953548300682L;
    protected int m_RandomSeed = 1;
    protected ResampleType m_ResampleType = ResampleType.NORMAL;
    protected String m_CostMatrixFileName = "D:\\eclipse\\workspace\\weka\\datasets\\cost\\glassA1.cost";
    protected boolean m_Oversample = true;
    protected boolean m_SMOTE = false;
    protected boolean m_COSTING = false;

    /* loaded from: input_file:weka/filters/supervised/instance/CostSensitiveResample$ResampleType.class */
    public enum ResampleType {
        NORMAL,
        SMOTE,
        COSTING;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static ResampleType[] valuesCustom() {
            ResampleType[] valuesCustom = values();
            int length = valuesCustom.length;
            ResampleType[] resampleTypeArr = new ResampleType[length];
            System.arraycopy(valuesCustom, 0, resampleTypeArr, 0, length);
            return resampleTypeArr;
        }
    }

    public String globalInfo() {
        return "Blah blah blah blah with replacement or without replacement.\nThe original dataset must fit entirely in memory. The number of instances in the generated dataset may be specified. The dataset must have a nominal class attribute. If not, use the unsupervised version. The filter can be made to maintain the class distribution in the subsample, or to bias the class distribution toward a uniform distribution. When used in batch mode (i.e. in the FilteredClassifier), subsequent batches are NOT resampled.";
    }

    @Override // weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tSpecify the random number seed (default 1)", "S", 1, "-S <num>"));
        vector.addElement(new Option("\tCost matrix which determines how sampling is done.\n\t(default is 1-I)", "C", 1, "-C <String>"));
        vector.addElement(new Option("\tUse SMOTE to resample (default: resampling with replacement).", "SMOTE", 0, "-SMOTE"));
        vector.addElement(new Option("\tUse COSTING to resample (default: resampling with replacement).", "COSTING", 0, "-COSTING"));
        vector.addElement(new Option("\tDo undersampling instead of oversampling.\n\t(default: oversampling)", "undersample", 0, "-undersample"));
        return vector.elements();
    }

    @Override // weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('S', strArr);
        if (option.length() != 0) {
            setRandomSeed(Integer.parseInt(option));
        } else {
            setRandomSeed(1);
        }
        String option2 = Utils.getOption('C', strArr);
        if (option2.length() != 0) {
            setCostMatrixFileName(option2);
        } else {
            setCostMatrixFileName("");
        }
        if (Utils.getFlag("SMOTE", strArr)) {
            this.m_SMOTE = true;
        }
        if (Utils.getFlag("COSTING", strArr)) {
            this.m_COSTING = true;
        }
        if (Utils.getFlag("undersample", strArr)) {
            this.m_Oversample = false;
        }
        if (getInputFormat() != null) {
            setInputFormat(getInputFormat());
        }
    }

    @Override // weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-C");
        vector.add(getCostMatrixFileName());
        vector.add("-S");
        vector.add(new StringBuilder().append(getRandomSeed()).toString());
        if (!getOversample()) {
            vector.add("-undersample");
        }
        if (getSMOTE()) {
            vector.add("-SMOTE");
        }
        if (getCOSTING()) {
            vector.add("-COSTING");
        }
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public String costSensitiveMatrixTipText() {
        return "Sets the filename which contains the cost matrix for the data set. An empty value results in uniform class costs.";
    }

    public String oversampleTipText() {
        return "Sets the filename which contains the cost matrix for the data set. An empty value results in uniform class costs.";
    }

    public void setOversample(boolean z) {
        this.m_Oversample = z;
    }

    public String SMOTETipText() {
        return "Sets the filename which contains the cost matrix for the data set. An empty value results in uniform class costs.";
    }

    public void setSMOTE(boolean z) {
        this.m_SMOTE = z;
    }

    public void setCOSTING(boolean z) {
        this.m_COSTING = z;
    }

    public boolean getSMOTE() {
        return this.m_SMOTE;
    }

    public boolean getCOSTING() {
        return this.m_COSTING;
    }

    public String getCostMatrixFileName() {
        return this.m_CostMatrixFileName;
    }

    public void setCostMatrixFileName(String str) {
        this.m_CostMatrixFileName = str;
    }

    public String randomSeedTipText() {
        return "Sets the random number seed for subsampling.";
    }

    public String costMatrixFileNameTipText() {
        return "Sets the name of the file where the cost matrix is stored.";
    }

    public long getRandomSeed() {
        return this.m_RandomSeed;
    }

    public void setRandomSeed(int i) {
        this.m_RandomSeed = i;
    }

    public String sampleSizePercentTipText() {
        return "The subsample size as a percentage of the original set.";
    }

    public String noReplacementTipText() {
        return "Disables the replacement of instances.";
    }

    public boolean getOversample() {
        return this.m_Oversample;
    }

    public String invertSelectionTipText() {
        return "Inverts the selection (only if instances are drawn WITHOUT replacement).";
    }

    @Override // weka.filters.Filter, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enableAllAttributes();
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        return capabilities;
    }

    @Override // weka.filters.Filter
    public boolean setInputFormat(Instances instances) throws Exception {
        super.setInputFormat(instances);
        setOutputFormat(instances);
        return true;
    }

    @Override // weka.filters.Filter
    public boolean input(Instance instance) {
        if (getInputFormat() == null) {
            throw new IllegalStateException("No input instance format defined");
        }
        if (this.m_NewBatch) {
            resetQueue();
            this.m_NewBatch = false;
        }
        if (isFirstBatchDone()) {
            push(instance);
            return true;
        }
        bufferInput(instance);
        return false;
    }

    @Override // weka.filters.Filter
    public boolean batchFinished() {
        if (getInputFormat() == null) {
            throw new IllegalStateException("No input instance format defined");
        }
        if (!isFirstBatchDone()) {
            createSubsample();
        }
        flushInput();
        this.m_NewBatch = true;
        this.m_FirstBatchDone = true;
        return numPendingOutput() != 0;
    }

    private double[][] readCostMatrix(String str) {
        int numClasses = getInputFormat().numClasses();
        double[][] dArr = new double[numClasses][numClasses];
        if (str.equals("")) {
            for (int i = 0; i < numClasses; i++) {
                for (int i2 = 0; i2 < numClasses; i2++) {
                    if (i == i2) {
                        dArr[i][i2] = 0.0d;
                    } else {
                        dArr[i][i2] = 1.0d;
                    }
                }
            }
        } else {
            try {
                BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(str)));
                String[] split = bufferedReader.readLine().split("\\s+");
                int parseInt = Integer.parseInt(split[0]);
                int parseInt2 = Integer.parseInt(split[1]);
                for (int i3 = 0; i3 < parseInt; i3++) {
                    String readLine = bufferedReader.readLine();
                    if (readLine == null) {
                        break;
                    }
                    String[] split2 = readLine.split("\\s+");
                    for (int i4 = 0; i4 < split2.length && i4 < parseInt2; i4++) {
                        dArr[i3][i4] = Double.parseDouble(split2[i4]);
                    }
                }
            } catch (IOException e) {
            }
        }
        return dArr;
    }

    protected double[] calcClassCost(double[][] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                int i3 = i;
                dArr2[i3] = dArr2[i3] + dArr[i][i2];
            }
        }
        return dArr2;
    }

    protected int[] calcClassSizes(int i, int[] iArr) {
        int[] iArr2 = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr2[i2] = iArr[i2 + 1] - iArr[i2];
        }
        return iArr2;
    }

    protected int calcOversampleLambda(double[] dArr, int i, int[] iArr) {
        int i2 = -1;
        double d = Double.MAX_VALUE;
        for (int i3 = 0; i3 < i; i3++) {
            double d2 = dArr[i3] / iArr[i3];
            if (d2 < d) {
                d = d2;
                i2 = i3;
            }
        }
        return i2;
    }

    protected int[] overSampleNumPerClass(double[] dArr, int i, int[] iArr) {
        int[] iArr2 = new int[i];
        int calcOversampleLambda = calcOversampleLambda(dArr, i, iArr);
        for (int i2 = 0; i2 < i; i2++) {
            iArr2[i2] = (int) Math.floor((dArr[i2] / dArr[calcOversampleLambda]) * iArr[calcOversampleLambda]);
        }
        return iArr2;
    }

    public void overSample(Random random, int i, int[] iArr) {
        double[] calcClassCost = calcClassCost(readCostMatrix(this.m_CostMatrixFileName));
        int[] calcClassSizes = calcClassSizes(i, iArr);
        int[] overSampleNumPerClass = overSampleNumPerClass(calcClassCost, i, calcClassSizes);
        if (getSMOTE()) {
            doSmoteOversample(random, i, calcClassSizes, overSampleNumPerClass);
        } else if (getCOSTING()) {
            doCostingOversample(random, calcClassCost);
        } else {
            doNormalOversample(random, i, iArr, overSampleNumPerClass);
        }
    }

    private void doCostingOversample(Random random, double[] dArr) {
        Instances inputFormat = getInputFormat();
        double d = dArr[Utils.maxIndex(dArr)];
        System.out.println("Hey, I have " + inputFormat.numInstances());
        for (int i = 0; i < inputFormat.numInstances(); i++) {
            Instance instance = inputFormat.instance(i);
            System.out.println("i has class index: " + instance.classIndex());
            if (random.nextDouble() <= dArr[(int) instance.classValue()] / d) {
                push((Instance) instance.copy());
            }
        }
    }

    private void doSmoteOversample(Random random, int i, int[] iArr, int[] iArr2) {
        Instances inputFormat = getInputFormat();
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = ((100.0d * iArr2[i2]) / iArr[i2]) - 100.0d;
            if (dArr[i2] != 100.0d) {
                try {
                    SMOTE smote = new SMOTE();
                    smote.setClassValue(new Integer(i2 + 1).toString());
                    smote.setPercentage(dArr[i2]);
                    if (iArr[i2] <= 5) {
                        smote.setNearestNeighbors(iArr[i2]);
                    }
                    smote.setInputFormat(inputFormat);
                    inputFormat = Filter.useFilter(inputFormat, smote);
                } catch (Exception e) {
                    e.printStackTrace();
                    System.exit(-1);
                }
            }
        }
        for (int i3 = 0; i3 < inputFormat.numInstances(); i3++) {
            push((Instance) inputFormat.instance(i3).copy());
        }
    }

    private ResampleType getOverSampleType() {
        return this.m_ResampleType;
    }

    private void doNormalOversample(Random random, int i, int[] iArr, int[] iArr2) {
        for (int i2 = 0; i2 < i; i2++) {
            int i3 = 0;
            while (i3 < iArr2[i2]) {
                push((Instance) getInputFormat().instance(i3 < iArr[i2 + 1] - iArr[i2] ? iArr[i2] + i3 : iArr[i2 + 1] - iArr[i2] > 0 ? iArr[i2] + random.nextInt(iArr[i2 + 1] - iArr[i2]) : iArr[i2]).copy());
                i3++;
            }
        }
    }

    protected int calcUndersampleLambda(double[] dArr, int i, int[] iArr) {
        int i2 = -1;
        double d = Double.MIN_VALUE;
        for (int i3 = 0; i3 < i; i3++) {
            double d2 = dArr[i3] / iArr[i3];
            if (d2 > d) {
                d = d2;
                i2 = i3;
            }
        }
        return i2;
    }

    protected int[] underSampleNumPerClass(double[] dArr, int i, int[] iArr) {
        int[] iArr2 = new int[i];
        int calcUndersampleLambda = calcUndersampleLambda(dArr, i, iArr);
        for (int i2 = 0; i2 < i; i2++) {
            iArr2[i2] = (int) Math.floor((dArr[i2] / dArr[calcUndersampleLambda]) * iArr[calcUndersampleLambda]);
        }
        return iArr2;
    }

    public void underSample(Random random, int i, int[] iArr) {
        int[] underSampleNumPerClass = underSampleNumPerClass(calcClassCost(readCostMatrix(this.m_CostMatrixFileName)), i, calcClassSizes(i, iArr));
        getInputFormat();
        Instances[] doNN = doNN(partitionInstances(getInputFormat(), i), underSampleNumPerClass);
        Instances[] doTomek = doTomek(doNN, underSampleNumPerClass, random);
        Instances[] doFinal = doFinal(doTomek, underSampleNumPerClass, random);
        for (int i2 = 0; i2 < doFinal.length; i2++) {
            for (int i3 = 0; i3 < doFinal[i2].numInstances(); i3++) {
                push((Instance) doFinal[i2].instance(i3).copy());
            }
        }
    }

    private Instances[] doFinal(Instances[] instancesArr, int[] iArr, Random random) {
        Instances[] instancesArr2 = new Instances[instancesArr.length];
        for (int i = 0; i < instancesArr2.length; i++) {
            instancesArr2[i] = new Instances(instancesArr[i]);
            while (instancesArr2[i].numInstances() > iArr[i]) {
                instancesArr2[i].delete(random.nextInt(instancesArr2[i].numInstances()));
            }
        }
        return instancesArr2;
    }

    private Instances[] doTomek(Instances[] instancesArr, int[] iArr, Random random) {
        Instances[] instancesArr2 = new Instances[instancesArr.length];
        for (int i = 0; i < instancesArr.length; i++) {
            removeTomekInstances(instancesArr, i, iArr[i], random);
        }
        return instancesArr;
    }

    private Instances removeTomekInstances(Instances[] instancesArr, int i, int i2, Random random) {
        if (instancesArr[i].numInstances() <= i2) {
            return instancesArr[i];
        }
        Instances instances = new Instances(instancesArr[i]);
        instances.randomize(random);
        Instances instances2 = new Instances(instancesArr[0], 0);
        for (int i3 = 0; i3 < instancesArr.length; i3++) {
            for (int i4 = 0; i4 < instancesArr[i3].numInstances(); i4++) {
                instances2.add(instancesArr[i3].instance(i4));
            }
        }
        Map vDMMap = getVDMMap(instances2);
        for (int numInstances = instancesArr[i].numInstances() - 1; numInstances >= 0 && instances.numInstances() > i2; numInstances--) {
            if (isTomekMisclassified(instancesArr, instancesArr[i].instance(numInstances), vDMMap)) {
                instances.delete(numInstances);
            }
        }
        return instances;
    }

    private boolean isTomekMisclassified(Instances[] instancesArr, Instance instance, Map map) {
        return instance.equals(getNearestNeighbor(instancesArr, getNearestNeighbor(instancesArr, instance, map), map));
    }

    private Instances[] doNN(Instances[] instancesArr, int[] iArr) {
        Instances[] instancesArr2 = new Instances[instancesArr.length];
        for (int i = 0; i < instancesArr.length; i++) {
            removeNNInstances(instancesArr, i, iArr[i]);
        }
        return instancesArr;
    }

    private void removeNNInstances(Instances[] instancesArr, int i, int i2) {
        if (instancesArr[i].numInstances() <= i2) {
            return;
        }
        Instances instances = new Instances(instancesArr[i]);
        instances.randomize(new Random(this.m_RandomSeed));
        instancesArr[i].delete();
        instancesArr[i] = new Instances(instances, 0, i2 / 2);
        Instances instances2 = new Instances(instancesArr[0], 0);
        for (int i3 = 0; i3 < instancesArr.length; i3++) {
            for (int i4 = 0; i4 < instancesArr[i3].numInstances(); i4++) {
                instances2.add(instancesArr[i3].instance(i4));
            }
        }
        Map vDMMap = getVDMMap(instances2);
        int i5 = (i2 / 2) + 1;
        for (int i6 = i2 / 2; i6 < instances.numInstances(); i6++) {
            if (i5 >= instances.numInstances() - i6 || isNNMisclassified(instancesArr, instances.instance(i6), vDMMap)) {
                instancesArr[i].add((Instance) instances.instance(i6).copy());
                i5--;
            }
        }
    }

    private boolean isNNMisclassified(Instances[] instancesArr, Instance instance, Map map) {
        return instance.classIndex() != getNearestNeighbor(instancesArr, instance, map).classIndex();
    }

    private Instance getNearestNeighbor(Instances[] instancesArr, Instance instance, Map map) {
        double d = Double.MAX_VALUE;
        Instance instance2 = null;
        for (int i = 0; i < instancesArr.length; i++) {
            for (int i2 = 0; i2 < instancesArr[i].numInstances(); i2++) {
                double calcNNDistance = calcNNDistance(instance, instancesArr[i].instance(i2), map);
                if (calcNNDistance < d) {
                    d = calcNNDistance;
                    instance2 = instancesArr[i].instance(i2);
                }
            }
        }
        return instance2;
    }

    private double calcNNDistance(Instance instance, Instance instance2, Map map) {
        double d = 0.0d;
        Enumeration enumerateAttributes = getInputFormat().enumerateAttributes();
        while (enumerateAttributes.hasMoreElements()) {
            Attribute attribute = (Attribute) enumerateAttributes.nextElement();
            if (!attribute.equals(getInputFormat().classAttribute())) {
                double value = instance.value(attribute);
                double value2 = instance2.value(attribute);
                if (attribute.isNumeric()) {
                    double pow = Math.pow(value - value2, 2.0d);
                    if (!Double.isNaN(pow)) {
                        d += pow;
                    }
                } else {
                    double d2 = ((double[][]) map.get(attribute))[(int) value][(int) value2];
                    if (!Double.isNaN(d2)) {
                        d += d2;
                    }
                }
            }
        }
        return Math.pow(d, 0.5d);
    }

    private Instances[] partitionInstances(Instances instances, int i) {
        Instances[] instancesArr = new Instances[i];
        for (int i2 = 0; i2 < i; i2++) {
            instancesArr[i2] = new Instances(instances, 0);
        }
        for (int i3 = 0; i3 < instances.numInstances(); i3++) {
            instancesArr[(int) instances.instance(i3).classValue()].add((Instance) instances.instance(i3).copy());
        }
        return instancesArr;
    }

    private Map getVDMMap(Instances instances) {
        HashMap hashMap = new HashMap();
        Enumeration enumerateAttributes = instances.enumerateAttributes();
        while (enumerateAttributes.hasMoreElements()) {
            Attribute attribute = (Attribute) enumerateAttributes.nextElement();
            if (!attribute.equals(instances.classAttribute()) && (attribute.isNominal() || attribute.isString())) {
                double[][] dArr = new double[attribute.numValues()][attribute.numValues()];
                hashMap.put(attribute, dArr);
                int[] iArr = new int[attribute.numValues()];
                int[][] iArr2 = new int[instances.classAttribute().numValues()][attribute.numValues()];
                Enumeration enumerateInstances = instances.enumerateInstances();
                while (enumerateInstances.hasMoreElements()) {
                    Instance instance = (Instance) enumerateInstances.nextElement();
                    int value = (int) instance.value(attribute);
                    int classValue = (int) instance.classValue();
                    iArr[value] = iArr[value] + 1;
                    int[] iArr3 = iArr2[classValue];
                    iArr3[value] = iArr3[value] + 1;
                }
                for (int i = 0; i < attribute.numValues(); i++) {
                    for (int i2 = 0; i2 < attribute.numValues(); i2++) {
                        double d = 0.0d;
                        for (int i3 = 0; i3 < instances.numClasses(); i3++) {
                            double d2 = iArr2[i3][i];
                            double d3 = iArr2[i3][i2];
                            d += Math.abs((d2 / iArr[i]) - (d3 / iArr[i2]));
                        }
                        dArr[i][i2] = d;
                    }
                }
            }
        }
        return hashMap;
    }

    protected void createSubsample() {
        getInputFormat().sort(getInputFormat().classIndex());
        int[] calcClassIndices = calcClassIndices(getInputFormat());
        int i = 0;
        for (int i2 = 0; i2 < calcClassIndices.length - 1; i2++) {
            if (calcClassIndices[i2] != calcClassIndices[i2 + 1]) {
                i++;
            }
        }
        int[] iArr = new int[i];
        Instances inputFormat = getInputFormat();
        for (int i3 = 0; i3 < inputFormat.numInstances(); i3++) {
            int classValue = (int) inputFormat.instance(i3).classValue();
            iArr[classValue] = iArr[classValue] + 1;
        }
        Random random = new Random(this.m_RandomSeed);
        if (getOversample()) {
            overSample(random, i, calcClassIndices);
        } else {
            underSample(random, i, calcClassIndices);
        }
    }

    private int[] calcClassIndices(Instances instances) {
        int[] iArr = new int[instances.numClasses() + 1];
        int i = 0;
        iArr[0] = 0;
        int i2 = 0;
        while (true) {
            if (i2 >= instances.numInstances()) {
                break;
            }
            Instance instance = instances.instance(i2);
            if (instance.classIsMissing()) {
                for (int i3 = i + 1; i3 < iArr.length; i3++) {
                    iArr[i3] = i2;
                }
            } else {
                if (instance.classValue() != i) {
                    for (int i4 = i + 1; i4 <= instance.classValue(); i4++) {
                        iArr[i4] = i2;
                    }
                    i = (int) instance.classValue();
                }
                i2++;
            }
        }
        if (i <= instances.numClasses()) {
            for (int i5 = i + 1; i5 < iArr.length; i5++) {
                iArr[i5] = instances.numInstances();
            }
        }
        return iArr;
    }

    @Override // weka.filters.Filter, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.13 $");
    }

    public static void main(String[] strArr) {
        runFilter(new CostSensitiveResample(), strArr);
    }

    public void setResampleType(ResampleType resampleType) {
        this.m_ResampleType = resampleType;
    }
}
