package weka.classifiers.sklearn;

import java.util.List;
import org.apache.commons.io.IOUtils;
import org.math.plot.plotObjects.Base;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.classifiers.rules.ZeroR;
import weka.core.Attribute;
import weka.core.AttributeStats;
import weka.core.BatchPredictor;
import weka.core.Capabilities;
import weka.core.CapabilitiesHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionMetadata;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.WekaException;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Remove;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.python.PythonSession;

/* loaded from: input_file:weka/classifiers/sklearn/ScikitLearnClassifier.class */
public class ScikitLearnClassifier extends AbstractClassifier implements BatchPredictor, CapabilitiesHandler {
    protected static final String TRAINING_DATA_ID = "scikit_classifier_training";
    protected static final String TEST_DATA_ID = "scikit_classifier_test";
    protected static final String MODEL_ID = "weka_scikit_learner";
    private static final long serialVersionUID = -6212485658537766441L;
    public static final Tag[] TAGS_LEARNER = new Tag[Learner.values().length];
    protected boolean m_useSupervisedNominalToBinary;
    protected Filter m_nominalToBinary;
    protected String m_pickledModel;
    protected boolean m_dontFetchModelFromPython;
    protected String m_modelHash;
    protected boolean[] m_nominalEmptyClassIndexes;
    protected ZeroR m_zeroR;
    protected boolean m_continueOnSysErr;
    protected Learner m_learner = Learner.DecisionTreeClassifier;
    protected String m_learnerOpts = "";
    protected Filter m_replaceMissing = new ReplaceMissingValues();
    protected String m_learnerToString = "";
    protected String m_batchPredictSize = "100";

    /* loaded from: input_file:weka/classifiers/sklearn/ScikitLearnClassifier$Learner.class */
    public enum Learner {
        DecisionTreeClassifier("tree", true, false, true, "\tclass_weight=None, criterion='gini', max_depth=None,\n\tmax_features=None, max_leaf_nodes=None, min_samples_leaf=1,\n\tmin_samples_split=2, min_weight_fraction_leaf=0.0,\n\trandom_state=None, splitter='best'"),
        DecisionTreeRegressor("tree", false, true, false, "\tcriterion='mse', max_depth=None, max_features=None,\n\tmax_leaf_nodes=None, min_samples_leaf=1, min_samples_split=2,\n\tmin_weight_fraction_leaf=0.0, random_state=None,\n\tsplitter='best'"),
        GaussianNB("naive_bayes", true, false, true, ""),
        MultinomialNB("naive_bayes", true, false, true, "alpha=1.0, class_prior=None, fit_prior=True"),
        BernoulliNB("naive_bayes", true, false, true, "alpha=1.0, binarize=0.0, class_prior=None, fit_prior=True"),
        LDA("lda", true, false, true, "\tn_components=None, priors=None, shrinkage=None, solver='svd',\n\tstore_covariance=False, tol=0.0001"),
        QDA("qda", true, false, true, "\tpriors=None, reg_param=0.0"),
        LogisticRegression("linear_model", true, false, true, "\tC=1.0, class_weight=None, dual=False, fit_intercept=True,\n\tintercept_scaling=1, max_iter=100, multi_class='ovr',\n\tpenalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n\tverbose=0"),
        LogisticRegressionCV("linear_model", true, false, true, "\tCs=10, class_weight=None, cv=None, dual=False,\n\tfit_intercept=True, intercept_scaling=1.0, max_iter=100,\n\tmulti_class='ovr', n_jobs=1, penalty='l2', refit=True,\n\tscoring=None, solver='lbfgs', tol=0.0001, verbose=0"),
        LinearRegression("linear_model", false, true, false, "\tcopy_X=True, fit_intercept=True, n_jobs=1, normalize=False"),
        ARDRegression("linear_model", false, true, false, "\talpha_1=1e-06, alpha_2=1e-06, compute_score=False, copy_X=True,\n\tfit_intercept=True, lambda_1=1e-06, lambda_2=1e-06, n_iter=300,\n\tnormalize=False, threshold_lambda=10000.0, tol=0.001, verbose=False"),
        BayesianRidge("linear_model", false, true, false, "\talpha_1=1e-06, alpha_2=1e-06, compute_score=False, copy_X=True,\n\tfit_intercept=True, lambda_1=1e-06, lambda_2=1e-06, n_iter=300,\n\tnormalize=False, tol=0.001, verbose=False"),
        ElasticNet("linear_model", false, true, false, "\talpha=1.0, copy_X=True, fit_intercept=True, l1_ratio=0.5,\n\tmax_iter=1000, normalize=False, positive=False, precompute=False,\n\trandom_state=None, selection='cyclic', tol=0.0001, warm_start=False"),
        Lars("linear_model", false, true, false, "\tcopy_X=True, eps=2.2204460492503131e-16, fit_intercept=True,\n\tfit_path=True, n_nonzero_coefs=500, normalize=True, precompute='auto',\n\tverbose=False"),
        LarsCV("linear_model", false, true, false, "\tcopy_X=True, cv=None, eps=2.2204460492503131e-16, fit_intercept=True,\n\tmax_iter=500, max_n_alphas=1000, n_jobs=1, normalize=True,\n\tprecompute='auto', verbose=False"),
        Lasso("linear_model", false, true, false, "\talpha=1.0, copy_X=True, fit_intercept=True, max_iter=1000,\n\tnormalize=False, positive=False, precompute=False, random_state=None,\n\tselection='cyclic', tol=0.0001, warm_start=False"),
        LassoCV("linear_model", false, true, false, "\talphas=None, copy_X=True, cv=None, eps=0.001, fit_intercept=True,\n\tmax_iter=1000, n_alphas=100, n_jobs=1, normalize=False, positive=False,\n\tprecompute='auto', random_state=None, selection='cyclic', tol=0.0001,\n\tverbose=False"),
        LassoLars("linear_model", false, true, false, "\talpha=1.0, copy_X=True, eps=2.2204460492503131e-16,\n\tfit_intercept=True, fit_path=True, max_iter=500, normalize=True,\n\tprecompute='auto', verbose=False"),
        LassoLarsCV("linear_model", false, true, false, "\tcopy_X=True, cv=None, eps=2.2204460492503131e-16,\n\tfit_intercept=True, max_iter=500, max_n_alphas=1000, n_jobs=1,\n\tnormalize=True, precompute='auto', verbose=False"),
        LassoLarsIC("linear_model", false, true, false, "\tcopy_X=True, criterion='aic', eps=2.2204460492503131e-16,\n\tfit_intercept=True, max_iter=500, normalize=True, precompute='auto',\n\tverbose=False"),
        OrthogonalMatchingPursuit("linear_model", false, true, false, "\tfit_intercept=True, n_nonzero_coefs=None,\n\tnormalize=True, precompute='auto', tol=None"),
        OrthogonalMatchingPursuitCV("linear_model", false, true, false, "\tcopy=True, cv=None, fit_intercept=True,\n\tmax_iter=None, n_jobs=1, normalize=True, verbose=False"),
        PassiveAggressiveClassifier("linear_model", true, false, false, "\tC=1.0, fit_intercept=True, loss='hinge', n_iter=5,\n\tn_jobs=1, random_state=None, shuffle=True, verbose=0,\n\twarm_start=False"),
        PassiveAggressiveRegressor("linear_model", false, true, false, "\tC=1.0, class_weight=None, epsilon=0.1,\n\tfit_intercept=True, loss='epsilon_insensitive', n_iter=5,\n\trandom_state=None, shuffle=True, verbose=0, warm_start=False"),
        Perceptron("linear_model", true, false, false, "\talpha=0.0001, class_weight=None, eta0=1.0, fit_intercept=True,\n\tn_iter=5, n_jobs=1, penalty=None, random_state=0, shuffle=True,\n\tverbose=0, warm_start=False"),
        RANSACRegressor("linear_model", false, true, false, "\tbase_estimator=None, is_data_valid=None, is_model_valid=None,\n\tmax_trials=100, min_samples=None, random_state=None,\n\tresidual_metric=None, residual_threshold=None, stop_n_inliers=inf,\n\tstop_probability=0.99, stop_score=inf"),
        Ridge("linear_model", false, true, false, "\talpha=1.0, copy_X=True, fit_intercept=True, max_iter=None,\n\tnormalize=False, solver='auto', tol=0.001"),
        RidgeClassifier("linear_model", true, false, false, "\talpha=1.0, class_weight=None, copy_X=True, fit_intercept=True,\n\tmax_iter=None, normalize=False, solver='auto', tol=0.001"),
        RidgeClassifierCV("linear_model", true, false, false, "alphas=array([  0.1,   1. ,  10. ]), class_weight=None,\n\tcv=None, fit_intercept=True, normalize=False, scoring=None"),
        RidgeCV("linear_model", false, true, false, "alphas=array([  0.1,   1. ,  10. ]), cv=None, fit_intercept=True,\n\tgcv_mode=None, normalize=False, scoring=None, store_cv_values=False"),
        SGDClassifier("linear_model", true, false, false, "\talpha=0.0001, average=False, class_weight=None, epsilon=0.1,\n\teta0=0.0, fit_intercept=True, l1_ratio=0.15,\n\tlearning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1,\n\tpenalty='l2', power_t=0.5, random_state=None, shuffle=True,\n\tverbose=0, warm_start=False") { // from class: weka.classifiers.sklearn.ScikitLearnClassifier.Learner.1
            @Override // weka.classifiers.sklearn.ScikitLearnClassifier.Learner
            public boolean producesProbabilities(String str) {
                return str.contains(Base.LOGARITHM) || str.contains("modified_huber");
            }
        },
        SGDRegressor("linear_model", false, true, false, "\talpha=0.0001, average=False, epsilon=0.1, eta0=0.01,\n\tfit_intercept=True, l1_ratio=0.15, learning_rate='invscaling',\n\tloss='squared_loss', n_iter=5, penalty='l2', power_t=0.25,\n\trandom_state=None, shuffle=True, verbose=0, warm_start=False"),
        TheilSenRegressor("linear_model", false, true, false, "\tcopy_X=True, fit_intercept=True, max_iter=300,\n\tmax_subpopulation=10000, n_jobs=1, n_subsamples=None,\n\trandom_state=None, tol=0.001, verbose=False"),
        GaussianProcess("gaussian_process", false, true, false, "\tregr='constant', corr='squared_exponential',\n\tbeta0=None, storage_mode='full', verbose=False, theta0=0.1,\n \tthetaL=None, thetaU=None, optimizer='fmin_cobyla', random_start=1,\n \tnormalize=True, nugget=2.2204460492503131e-15, random_state=None"),
        KernelRidge("kernel_ridge", false, true, false, "\talpha=1, coef0=1, degree=3, gamma=None, kernel='linear',\n\tkernel_params=None"),
        KNeighborsClassifier("neighbors", true, false, true, "\talgorithm='auto', leaf_size=30, metric='minkowski',\n\tmetric_params=None, n_neighbors=5, p=2, weights='uniform'"),
        RadiusNeighborsClassifier("neighbors", true, false, false, "\talgorithm='auto', leaf_size=30, metric='minkowski',\n\tmetric_params=None, outlier_label=None, p=2, radius=1.0,\n\tweights='uniform'"),
        KNeighborsRegressor("neighbors", false, true, false, "algorithm='auto', leaf_size=30, metric='minkowski',\n\tmetric_params=None, n_neighbors=5, p=2, weights='uniform'"),
        RadiusNeighborsRegressor("neighbors", false, true, false, ""),
        SVC("svm", true, false, false, "\tC=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,\n\tkernel='rbf', max_iter=-1, probability=False, random_state=None,\n\tshrinking=True, tol=0.001, verbose=False"),
        LinearSVC("svm", true, false, false, "\tC=1.0, class_weight=None, dual=True, fit_intercept=True,\n\tintercept_scaling=1, loss='squared_hinge', max_iter=1000,\n\tmulti_class='ovr', penalty='l2', random_state=None, tol=0.0001,\n\tverbose=0"),
        NuSVC("svm", true, false, false, "\tcache_size=200, coef0=0.0, degree=3, gamma=0.0, kernel='rbf',\n\tmax_iter=-1, nu=0.5, probability=False, random_state=None,\n\tshrinking=True, tol=0.001, verbose=False"),
        SVR("svm", false, true, false, "\tC=1.0, cache_size=200, coef0=0.0, degree=3, epsilon=0.1, gamma=0.0,\n\tkernel='rbf', max_iter=-1, shrinking=True, tol=0.001, verbose=False"),
        NuSVR("svm", false, true, false, "\tC=1.0, cache_size=200, coef0=0.0, degree=3, gamma=0.0, kernel='rbf',\n\tmax_iter=-1, nu=0.5, shrinking=True, tol=0.001, verbose=False"),
        AdaBoostClassifier("ensemble", true, false, true, "\talgorithm='SAMME.R', base_estimator=None,\n\tlearning_rate=1.0, n_estimators=50, random_state=None"),
        AdaBoostRegressor("ensemble", false, true, false, "\tbase_estimator=None, learning_rate=1.0, loss='linear',\n\tn_estimators=50, random_state=None"),
        BaggingClassifier("ensemble", true, false, true, "\tbase_estimator=None, bootstrap=True,\n\tbootstrap_features=False, max_features=1.0, max_samples=1.0,\n\tn_estimators=10, n_jobs=1, oob_score=False, random_state=None,\n\tverbose=0"),
        BaggingRegressor("ensemble", false, true, false, "\tbase_estimator=None, bootstrap=True,\n\tbootstrap_features=False, max_features=1.0, max_samples=1.0,\n\tn_estimators=10, n_jobs=1, oob_score=False, random_state=None,\n\tverbose=0"),
        ExtraTreeClassifier("tree", true, false, true, "\tclass_weight=None, criterion='gini', max_depth=None,\n\tmax_features='auto', max_leaf_nodes=None, min_samples_leaf=1,\n\tmin_samples_split=2, min_weight_fraction_leaf=0.0,\n\trandom_state=None, splitter='random'"),
        ExtraTreeRegressor("tree", false, true, false, "\tcriterion='mse', max_depth=None, max_features='auto',\n\tmax_leaf_nodes=None, min_samples_leaf=1, min_samples_split=2,\n\tmin_weight_fraction_leaf=0.0, random_state=None,\n\tsplitter='random'"),
        GradientBoostingClassifier("ensemble", true, false, true, "\tinit=None, learning_rate=0.1, loss='deviance',\n\tmax_depth=3, max_features=None, max_leaf_nodes=None,\n\tmin_samples_leaf=1, min_samples_split=2,\n\tmin_weight_fraction_leaf=0.0, n_estimators=100,\n\trandom_state=None, subsample=1.0, verbose=0,\n\twarm_start=False"),
        GradientBoostingRegressor("ensemble", false, true, false, "\talpha=0.9, init=None, learning_rate=0.1, loss='ls',\n\tmax_depth=3, max_features=None, max_leaf_nodes=None,\n\tmin_samples_leaf=1, min_samples_split=2,\n\tmin_weight_fraction_leaf=0.0, n_estimators=100,\n\trandom_state=None, subsample=1.0, verbose=0, warm_start=False"),
        RandomForestClassifier("ensemble", true, false, true, "\tbootstrap=True, class_weight=None, criterion='gini',\n\tmax_depth=None, max_features='auto', max_leaf_nodes=None,\n\tmin_samples_leaf=1, min_samples_split=2,\n\tmin_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,\n\toob_score=False, random_state=None, verbose=0,\n\twarm_start=False"),
        RandomForestRegressor("ensemble", false, true, false, "\tbootstrap=True, criterion='mse', max_depth=None,\n\tmax_features='auto', max_leaf_nodes=None, min_samples_leaf=1,\n\tmin_samples_split=2, min_weight_fraction_leaf=0.0,\n\tn_estimators=10, n_jobs=1, oob_score=False, random_state=None,\n\tverbose=0, warm_start=False");

        private String m_module;
        private boolean m_classification;
        private boolean m_regression;
        private boolean m_producesProbabilities;
        private String m_defaultParameters;

        Learner(String str, boolean z, boolean z2, boolean z3, String str2) {
            this.m_module = str;
            this.m_producesProbabilities = z3;
            this.m_classification = z;
            this.m_regression = z2;
            this.m_defaultParameters = str2;
        }

        public String getModule() {
            return this.m_module;
        }

        public boolean producesProbabilities(String str) {
            return this.m_producesProbabilities;
        }

        public boolean isClassifier() {
            return this.m_classification;
        }

        public boolean isRegressor() {
            return this.m_regression;
        }

        public String getDefaultParameters() {
            return this.m_defaultParameters;
        }
    }

    public String globalInfo() {
        StringBuilder sb = new StringBuilder();
        sb.append("A wrapper for classifiers implemented in the scikit-learn python library. The following learners are available:\n\n");
        for (Learner learner : Learner.values()) {
            sb.append(learner.toString()).append(IOUtils.LINE_SEPARATOR_UNIX);
            sb.append("[");
            if (learner.isClassifier()) {
                sb.append(" classification ");
            }
            if (learner.isRegressor()) {
                sb.append(" regression ");
            }
            sb.append("]").append("\nDefault parameters:\n");
            sb.append(learner.getDefaultParameters()).append(IOUtils.LINE_SEPARATOR_UNIX);
        }
        return sb.toString();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        boolean z = true;
        if (!PythonSession.pythonAvailable()) {
            try {
                if (!PythonSession.initSession("python", getDebug())) {
                    z = false;
                }
            } catch (WekaException e) {
                z = false;
            }
        }
        if (z) {
            capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
            capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
            capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
            capabilities.enable(Capabilities.Capability.MISSING_VALUES);
            capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
            if (this.m_learner.isClassifier()) {
                capabilities.enable(Capabilities.Capability.BINARY_CLASS);
                capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
            }
            if (this.m_learner.isRegressor()) {
                capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
            }
        }
        return capabilities;
    }

    @OptionMetadata(displayName = "Use supervised nominal to binary conversion", description = "Use supervised nominal to binary conversion of nominal attributes.", commandLineParamName = "S", commandLineParamSynopsis = "-S", commandLineParamIsFlag = true, displayOrder = 3)
    public boolean getUseSupervisedNominalToBinary() {
        return this.m_useSupervisedNominalToBinary;
    }

    public void setUseSupervisedNominalToBinary(boolean z) {
        this.m_useSupervisedNominalToBinary = z;
    }

    @OptionMetadata(displayName = "Scikit-learn learner", description = "Scikit-learn learner to use.\nAvailable learners:\nDecisionTreeClassifier, DecisionTreeRegressor, GaussianNB, MultinomialNB,BernoulliNB, LDA, QDA, LogisticRegression, LogisticRegressionCV,\nLinearRegression, ARDRegression, BayesianRidge, ElasticNet, Lars,\nLarsCV, Lasso, LassoCV, LassoLars, LassoLarsCV, LassoLarsIC, OrthogonalMatchingPursuit,\nOrthogonalMatchingPursuitCV, PassiveAggressiveClassifier, PassiveAggressiveRegressor, Perceptron, RANSACRegressor,\nRidge, RidgeClassifier, RidgeClassifierCV, RidgeCV, SGDClassifier,\nSGDRegressor,TheilSenRegressor, GaussianProcess, KernelRidge, KNeighborsClassifier, \nRadiusNeighborsClassifier, KNeighborsRegressor, RadiusNeighborsRegressor, SVC,\nLinearSVC, NuSVC, SVR, NuSVR, AdaBoostClassifier, AdaBoostRegressor,BaggingClassifier, BaggingRegressor,\nExtraTreeClassifier, ExtraTreeRegressor,GradientBoostingClassifier, GradientBoostingRegressor,\nRandomForestClassifier, RandomForestRegressor.\n(default = DecisionTreeClassifier)", commandLineParamName = "learner", commandLineParamSynopsis = "-learner <learner name>", displayOrder = 1)
    public SelectedTag getLearner() {
        return new SelectedTag(this.m_learner.ordinal(), TAGS_LEARNER);
    }

    public void setLearner(SelectedTag selectedTag) {
        int id = selectedTag.getSelectedTag().getID();
        for (Learner learner : Learner.values()) {
            if (learner.ordinal() == id) {
                this.m_learner = learner;
                return;
            }
        }
    }

    @OptionMetadata(displayName = "Learner parameters", description = "learner parameters to use", displayOrder = 2, commandLineParamName = "parameters", commandLineParamSynopsis = "-parameters <comma-separated list of name=value pairs>")
    public String getLearnerOpts() {
        return this.m_learnerOpts;
    }

    public void setLearnerOpts(String str) {
        this.m_learnerOpts = str;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    public void setBatchSize(String str) {
        this.m_batchPredictSize = str;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    @OptionMetadata(displayName = "Batch size", description = "The preferred number of instances to transfer into python for prediction\n(if operatingin batch prediction mode). More or fewer instances than this will be accepted.", commandLineParamName = "batch", commandLineParamSynopsis = "-batch <batch size>", displayOrder = 4)
    public String getBatchSize() {
        return this.m_batchPredictSize;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    public boolean implementsMoreEfficientBatchPrediction() {
        return true;
    }

    public void setContinueOnSysErr(boolean z) {
        this.m_continueOnSysErr = z;
    }

    @OptionMetadata(displayName = "Try to continue after sys err output from script", description = "Try to continue after sys err output from script.\nSome schemes report warnings to the system error stream.", displayOrder = 5, commandLineParamName = "continue-on-err", commandLineParamSynopsis = "-continue-on-err", commandLineParamIsFlag = true)
    public boolean getContinueOnSysErr() {
        return this.m_continueOnSysErr;
    }

    public void setDontFetchModelFromPython(boolean z) {
        this.m_dontFetchModelFromPython = z;
    }

    @OptionMetadata(displayName = "Don't retrieve model from python", description = "Don't retrieve the model from python - speeds up cross-validation,\nbut prevents this classifier from being used after deserialization.\nSome models in python (e.g. large random forests) may exceed the maximum size for transfer\n(currently Integer.MAX_VALUE bytes)", displayOrder = 6, commandLineParamName = "dont-fetch-model", commandLineParamSynopsis = "-dont-fetch-model", commandLineParamIsFlag = true)
    public boolean getDontFetchModelFromPython() {
        return this.m_dontFetchModelFromPython;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        this.m_pickledModel = null;
        getCapabilities().testWithFail(instances);
        this.m_zeroR = null;
        if (!PythonSession.pythonAvailable() && !PythonSession.initSession("python", getDebug())) {
            throw new Exception("Was unable to start python environment: " + PythonSession.getPythonEnvCheckResults());
        }
        if (this.m_modelHash == null) {
            this.m_modelHash = "" + hashCode();
        }
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        if (instances2.numInstances() == 0 || instances2.numAttributes() == 1) {
            if (instances2.numInstances() == 0) {
                System.err.println("No instances with non-missing class - using ZeroR model");
            } else {
                System.err.println("Only the class attribute is present in the data - using ZeroR model");
            }
            this.m_zeroR = new ZeroR();
            this.m_zeroR.buildClassifier(instances2);
            return;
        }
        if (instances2.classAttribute().isNominal()) {
            AttributeStats attributeStats = instances2.attributeStats(instances2.classIndex());
            this.m_nominalEmptyClassIndexes = new boolean[instances2.classAttribute().numValues()];
            for (int i = 0; i < attributeStats.nominalWeights.length; i++) {
                if (attributeStats.nominalWeights[i] == KStarConstants.FLOOR) {
                    this.m_nominalEmptyClassIndexes[i] = true;
                }
            }
        }
        this.m_replaceMissing.setInputFormat(instances2);
        Instances useFilter = Filter.useFilter(instances2, this.m_replaceMissing);
        if (getUseSupervisedNominalToBinary()) {
            this.m_nominalToBinary = new NominalToBinary();
        } else {
            this.m_nominalToBinary = new weka.filters.unsupervised.attribute.NominalToBinary();
        }
        this.m_nominalToBinary.setInputFormat(useFilter);
        Instances useFilter2 = Filter.useFilter(useFilter, this.m_nominalToBinary);
        try {
            PythonSession acquireSession = PythonSession.acquireSession(this);
            acquireSession.instancesToPythonAsScikitLearn(useFilter2, TRAINING_DATA_ID, getDebug());
            StringBuilder sb = new StringBuilder();
            sb.append("from sklearn import *").append(IOUtils.LINE_SEPARATOR_UNIX).append("import numpy as np").append(IOUtils.LINE_SEPARATOR_UNIX);
            sb.append(MODEL_ID + this.m_modelHash + " = " + this.m_learner.getModule() + "." + this.m_learner.toString() + "(" + (getLearnerOpts().length() > 0 ? getLearnerOpts() : "") + ")").append(IOUtils.LINE_SEPARATOR_UNIX);
            sb.append(MODEL_ID + this.m_modelHash + ".fit(X,np.ravel(Y))").append(IOUtils.LINE_SEPARATOR_UNIX);
            List<String> executeScript = acquireSession.executeScript(sb.toString(), getDebug());
            if (executeScript.size() == 2 && executeScript.get(1).length() > 0) {
                if (!this.m_continueOnSysErr) {
                    throw new Exception(executeScript.get(1));
                }
                System.err.println(executeScript.get(1));
            }
            this.m_learnerToString = acquireSession.getVariableValueFromPythonAsPlainString(MODEL_ID + this.m_modelHash, getDebug());
            if (!getDontFetchModelFromPython()) {
                this.m_pickledModel = acquireSession.getVariableValueFromPythonAsPickledObject(MODEL_ID + this.m_modelHash, getDebug());
            }
        } finally {
            PythonSession.releaseSession(this);
        }
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [double[], double[][]] */
    private double[][] batchScoreWithZeroR(Instances instances) throws Exception {
        ?? r0 = new double[instances.numInstances()];
        for (int i = 0; i < instances.numInstances(); i++) {
            r0[i] = this.m_zeroR.distributionForInstance(instances.instance(i));
        }
        return r0;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        Instances instances = new Instances(instance.dataset(), 0);
        instances.add(instance);
        return distributionsForInstances(instances)[0];
    }

    /* JADX WARN: Type inference failed for: r0v44, types: [double[], double[][]] */
    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    public double[][] distributionsForInstances(Instances instances) throws Exception {
        if (this.m_zeroR != null) {
            return batchScoreWithZeroR(instances);
        }
        if (!PythonSession.pythonAvailable() && !PythonSession.initSession("python", getDebug())) {
            throw new Exception("Was unable to start python environment: " + PythonSession.getPythonEnvCheckResults());
        }
        Instances useFilter = Filter.useFilter(Filter.useFilter(instances, this.m_replaceMissing), this.m_nominalToBinary);
        Attribute classAttribute = useFilter.classAttribute();
        Remove remove = new Remove();
        remove.setAttributeIndices("" + (useFilter.classIndex() + 1));
        remove.setInputFormat(useFilter);
        Instances useFilter2 = Filter.useFilter(useFilter, remove);
        useFilter2.setClassIndex(-1);
        try {
            PythonSession acquireSession = PythonSession.acquireSession(this);
            acquireSession.instancesToPythonAsScikitLearn(useFilter2, TEST_DATA_ID, getDebug());
            StringBuilder sb = new StringBuilder();
            if (!acquireSession.checkIfPythonVariableIsSet(MODEL_ID + this.m_modelHash, getDebug())) {
                if (this.m_pickledModel == null || this.m_pickledModel.length() == 0) {
                    throw new Exception("There is no model to transfer into Python!");
                }
                acquireSession.setPythonPickledVariableValue(MODEL_ID + this.m_modelHash, this.m_pickledModel, getDebug());
            }
            sb.append("from sklearn." + this.m_learner.getModule() + " import " + this.m_learner.toString()).append(IOUtils.LINE_SEPARATOR_UNIX);
            sb.append("preds = weka_scikit_learner" + this.m_modelHash + ".predict" + (this.m_learner.producesProbabilities(this.m_learnerOpts) ? "_proba" : "") + "(X)").append("\npreds = preds.tolist()\n");
            List<String> executeScript = acquireSession.executeScript(sb.toString(), getDebug());
            if (executeScript.size() == 2 && executeScript.get(1).length() > 0) {
                if (!this.m_continueOnSysErr) {
                    throw new Exception(executeScript.get(1));
                }
                System.err.println(executeScript.get(1));
            }
            List<List> list = (List) acquireSession.getVariableValueFromPythonAsJson("preds", getDebug());
            if (list == null) {
                throw new Exception("Was unable to retrieve predictions from python");
            }
            if (list.size() != useFilter2.numInstances()) {
                throw new Exception("Learner did not return as many predictions as there are test instances");
            }
            ?? r0 = new double[useFilter2.numInstances()];
            if (this.m_learner.producesProbabilities(this.m_learnerOpts) && classAttribute.isNominal()) {
                int i = 0;
                for (List list2 : list) {
                    double[] dArr = new double[classAttribute.numValues()];
                    int i2 = 0;
                    for (int i3 = 0; i3 < dArr.length; i3++) {
                        if (!this.m_nominalEmptyClassIndexes[i3]) {
                            int i4 = i2;
                            i2++;
                            dArr[i3] = ((Number) list2.get(i4)).doubleValue();
                        }
                    }
                    Utils.normalize(dArr);
                    int i5 = i;
                    i++;
                    r0[i5] = dArr;
                }
            } else if (classAttribute.isNominal()) {
                int i6 = 0;
                for (Object obj : list) {
                    double[] dArr2 = new double[classAttribute.numValues()];
                    dArr2[(obj instanceof List ? (Number) ((List) obj).get(0) : (Number) obj).intValue()] = 1.0d;
                    int i7 = i6;
                    i6++;
                    r0[i7] = dArr2;
                }
            } else {
                int i8 = 0;
                for (Object obj2 : list) {
                    double[] dArr3 = new double[1];
                    dArr3[0] = (obj2 instanceof List ? (Number) ((List) obj2).get(0) : (Number) obj2).doubleValue();
                    int i9 = i8;
                    i8++;
                    r0[i9] = dArr3;
                }
            }
            return r0;
        } finally {
            PythonSession.releaseSession(this);
        }
    }

    public String toString() {
        return this.m_zeroR != null ? this.m_zeroR.toString() : (this.m_learnerToString == null || this.m_learnerToString.length() == 0) ? "SckitLearnClassifier: model not built yet!" : this.m_learnerToString;
    }

    public static void main(String[] strArr) {
        runClassifier(new ScikitLearnClassifier(), strArr);
    }

    static {
        for (Learner learner : Learner.values()) {
            TAGS_LEARNER[learner.ordinal()] = new Tag(learner.ordinal(), learner.toString());
        }
    }
}
