package de.unijena.bioinf.fingerid;

import de.unijena.bioinf.ChemistryBase.chem.Element;
import de.unijena.bioinf.ChemistryBase.chem.MolecularFormula;
import de.unijena.bioinf.ChemistryBase.chem.PeriodicTable;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.fingerid.OptimizationStrategy;
import de.unijena.bioinf.fingerid.svm.Svm;
import gnu.trove.list.array.TShortArrayList;
import gnu.trove.set.hash.TIntHashSet;
import java.io.BufferedWriter;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_print_interface;
import libsvm.svm_problem;

/* loaded from: input_file:de/unijena/bioinf/fingerid/TrainCompoundClasses.class */
public class TrainCompoundClasses implements Closeable {
    private static final boolean DEBUG = true;
    private int numberOfCs;
    private int minDegree;
    private int maxDegree;
    private int minCoef;
    private int maxCoef;
    private static final int FOLDS = 5;
    private ExecutorService service;
    private int[] usedIndizes;
    private PredictionPerformance[] fingerprintPerformance;
    private svm_node[] NODE_POOL;
    private List<Compound> seed;
    private List<Compound> pool;
    private List<Compound> trainingSet;
    private List<Compound> evaluationSet;
    private final double[] WEIGHT;
    private final int[] WEIGHT_LABEL;
    private double[] weighting;
    private File outputDir;
    private final Element[] elements;
    private final int FEATURES;
    private static final String[] ELEMENTS = {"C", "H", "N", "O", "P", "S", "Cl", "Br", "I", "F"};
    private static final double[] ELEM_WEIGHTS = {60.0d, 80.0d, 12.0d, 20.0d, 3.0d, 3.0d, 3.0d, 2.0d, 1.0d, 5.0d};
    private static final Element[] ELEMENT_ARY = new Element[ELEMENTS.length];
    private static final int ADDITIONAL_FINGERPRINTS = ELEMENTS.length + 1;
    static final svm_node FINAL_NODE = new svm_node();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:de/unijena/bioinf/fingerid/TrainCompoundClasses$CFingerprint.class */
    public static class CFingerprint {
        private Compound c;
        private int hashCode;

        private CFingerprint(Compound compound) {
            this.c = compound;
            this.hashCode = Arrays.hashCode(compound.fingerprint);
        }

        public int hashCode() {
            return this.hashCode;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            return Arrays.equals(this.c.fingerprint, ((CFingerprint) obj).c.fingerprint);
        }
    }

    /* loaded from: input_file:de/unijena/bioinf/fingerid/TrainCompoundClasses$Compound.class */
    public static class Compound {
        private final String inchiKey;
        private final byte classification;
        private final short[] fingerprint;
        private final MolecularFormula formula;
        private svm_node[] nodes;
        private byte fold = 0;

        public Compound(String str, MolecularFormula molecularFormula, byte b, short[] sArr) {
            this.formula = molecularFormula;
            this.inchiKey = str;
            this.classification = b;
            this.fingerprint = sArr;
        }
    }

    /* loaded from: input_file:de/unijena/bioinf/fingerid/TrainCompoundClasses$MPredictor.class */
    public static class MPredictor extends Predictor {
        private final Model model;
        private transient svm_model svmModel;

        public MPredictor(Model model, svm_model svm_modelVar) {
            super(0, svm_modelVar.rho[0], 0.0d, 0.0d, svm_modelVar.sv_coef[0], svm_modelVar.sv_indices);
            this.model = model;
            this.svmModel = svm_modelVar;
        }

        public void writeModel(File file) throws IOException {
            svm.svm_save_model(file.getAbsolutePath(), this.svmModel);
        }
    }

    /* loaded from: input_file:de/unijena/bioinf/fingerid/TrainCompoundClasses$Model.class */
    public static class Model implements Comparable<Model> {
        private static Comparator<PredictionPerformance> comp = new OptimizationStrategy.ByFScore().getComparator();
        private final PredictionPerformance performance;
        private final double c;
        private final int degree;
        public double coef0;

        public Model(PredictionPerformance predictionPerformance, double d, int i, double d2) {
            this.performance = predictionPerformance;
            this.c = d;
            this.degree = i;
            this.coef0 = d2;
        }

        @Override // java.lang.Comparable
        public int compareTo(Model model) {
            return comp.compare(this.performance, model.performance);
        }

        public String toString() {
            return "polynomial svm, degree = " + this.degree + ", coefficient = " + this.coef0 + ", c = " + this.c + ", f = " + this.performance.getF();
        }
    }

    private double[] getAdditionalFingerprintsFor(Compound compound) {
        return getAdditionalFingerprintsFor(compound.formula);
    }

    public static double[] getAdditionalFingerprintsFor(MolecularFormula molecularFormula) {
        double[] dArr = new double[ADDITIONAL_FINGERPRINTS];
        if (ELEMENT_ARY[0] == null) {
            PeriodicTable periodicTable = PeriodicTable.getInstance();
            for (int i = 0; i < ELEMENT_ARY.length; i++) {
                ELEMENT_ARY[i] = periodicTable.getByName(ELEMENTS[i]);
            }
        }
        int i2 = 0;
        for (int i3 = 0; i3 < ELEMENT_ARY.length; i3++) {
            int i4 = i2;
            i2++;
            dArr[i4] = molecularFormula.numberOf(ELEMENT_ARY[i3]) / ELEM_WEIGHTS[i3];
        }
        int i5 = i2;
        int i6 = i2 + 1;
        dArr[i5] = molecularFormula.getMass() / 1000.0d;
        return dArr;
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        this.service.shutdown();
    }

    public int getNumberOfCs() {
        return this.numberOfCs;
    }

    public void setNumberOfCs(int i) {
        this.numberOfCs = i;
    }

    public int getMinDegree() {
        return this.minDegree;
    }

    public void setMinDegree(int i) {
        this.minDegree = i;
    }

    public int getMaxDegree() {
        return this.maxDegree;
    }

    public void setMaxDegree(int i) {
        this.maxDegree = i;
    }

    public int getMinCoef() {
        return this.minCoef;
    }

    public void setMinCoef(int i) {
        this.minCoef = i;
    }

    public int getMaxCoef() {
        return this.maxCoef;
    }

    public void setMaxCoef(int i) {
        this.maxCoef = i;
    }

    public void removeDuplicateEntries() {
        removeDuplicates(Arrays.asList(this.seed, this.pool));
    }

    public static void removeDuplicates(List<List<Compound>> list) {
        HashSet hashSet = new HashSet();
        int i = 0;
        Iterator<List<Compound>> it = list.iterator();
        while (it.hasNext()) {
            ListIterator<Compound> listIterator = it.next().listIterator();
            while (listIterator.hasNext()) {
                CFingerprint cFingerprint = new CFingerprint(listIterator.next());
                if (hashSet.contains(cFingerprint)) {
                    listIterator.remove();
                    i++;
                } else {
                    hashSet.add(cFingerprint);
                }
            }
        }
        System.out.println("Remove " + i + " duplicate entries");
    }

    public void setWeighting(double[] dArr) {
        double d;
        this.weighting = dArr;
        int i = 0;
        for (int i2 : this.usedIndizes) {
            svm_node svm_nodeVar = this.NODE_POOL[i2];
            if (dArr == null) {
                d = 1.0d;
            } else {
                int i3 = i;
                i++;
                d = this.weighting[i3];
            }
            svm_nodeVar.value = d;
        }
    }

    public File getOutputDir() {
        return this.outputDir;
    }

    public void setOutputDir(File file) {
        this.outputDir = file;
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [short[], short[][]] */
    public static short[][] transformFingerprintsToIntegerArray(int[] iArr, boolean[][] zArr) {
        ?? r0 = new short[zArr.length];
        TShortArrayList tShortArrayList = new TShortArrayList(524);
        for (int i = 0; i < zArr.length; i++) {
            boolean[] zArr2 = zArr[i];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                if (zArr2[iArr[i2]]) {
                    tShortArrayList.add((short) iArr[i2]);
                }
            }
            r0[i] = tShortArrayList.toArray();
            tShortArrayList.resetQuick();
        }
        return r0;
    }

    public static short[] transformFingerprintToIntegerArray(int[] iArr, boolean[] zArr) {
        TShortArrayList tShortArrayList = new TShortArrayList(128);
        for (int i = 0; i < iArr.length; i++) {
            if (zArr[iArr[i]]) {
                tShortArrayList.add((short) iArr[i]);
            }
        }
        return tShortArrayList.toArray();
    }

    public TrainCompoundClasses(int[] iArr) {
        this(Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()), iArr);
    }

    public TrainCompoundClasses(ExecutorService executorService, int[] iArr) {
        this.numberOfCs = 5;
        this.minDegree = 1;
        this.maxDegree = 2;
        this.minCoef = 0;
        this.maxCoef = 1;
        this.WEIGHT = new double[2];
        this.WEIGHT_LABEL = new int[]{1, -1};
        this.outputDir = new File(".");
        this.elements = new Element[ELEMENTS.length];
        PeriodicTable periodicTable = PeriodicTable.getInstance();
        for (int i = 0; i < ELEMENTS.length; i++) {
            this.elements[i] = periodicTable.getByName(ELEMENTS[i]);
        }
        this.usedIndizes = iArr;
        this.service = executorService;
        Arrays.sort(iArr);
        this.NODE_POOL = new svm_node[iArr[iArr.length - 1] + 1];
        for (int i2 : iArr) {
            this.NODE_POOL[i2] = new svm_node();
            this.NODE_POOL[i2].index = i2 + 1;
            this.NODE_POOL[i2].value = 1.0d;
        }
        this.FEATURES = this.NODE_POOL.length;
        this.seed = new ArrayList();
        this.pool = new ArrayList();
        svm.svm_set_print_string_function(new svm_print_interface() { // from class: de.unijena.bioinf.fingerid.TrainCompoundClasses.1
            public void print(String str) {
            }
        });
    }

    public PredictionPerformance evaluateOnRealFingerprints(svm_model svm_modelVar, MolecularFormula[] molecularFormulaArr, MolecularFormula[] molecularFormulaArr2, double[][] dArr, double[][] dArr2) throws IOException {
        PredictionPerformance evaluateOnRealFingerprints = evaluateOnRealFingerprints(svm_modelVar, molecularFormulaArr, dArr, 1);
        evaluateOnRealFingerprints.merge(evaluateOnRealFingerprints(svm_modelVar, molecularFormulaArr2, dArr2, -1));
        return evaluateOnRealFingerprints;
    }

    public PredictionPerformance evaluateOnRealFingerprints(svm_model svm_modelVar, MolecularFormula[] molecularFormulaArr, double[][] dArr, int i) throws IOException {
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        int length = dArr[0].length;
        svm_node[] svm_nodeVarArr = new svm_node[length + ADDITIONAL_FINGERPRINTS];
        for (int i6 = 0; i6 < svm_nodeVarArr.length; i6++) {
            svm_nodeVarArr[i6] = new svm_node();
        }
        for (int i7 = 0; i7 < dArr.length; i7++) {
            for (int i8 = 0; i8 < length; i8++) {
                svm_nodeVarArr[i8].index = this.usedIndizes[i8] + 1;
                svm_nodeVarArr[i8].value = dArr[i7][i8];
                if (this.weighting != null) {
                    svm_nodeVarArr[i8].value *= this.weighting[i8];
                }
            }
            double[] additionalFingerprintsFor = getAdditionalFingerprintsFor(molecularFormulaArr[i7]);
            for (int i9 = 0; i9 < ADDITIONAL_FINGERPRINTS; i9++) {
                svm_nodeVarArr[length + i9].index = this.FEATURES + i9;
                svm_nodeVarArr[length + i9].value = additionalFingerprintsFor[i9];
            }
            if (svm.svm_predict(svm_modelVar, svm_nodeVarArr) > 0.0d) {
                if (i > 0) {
                    i2++;
                } else {
                    i3++;
                }
            } else if (i > 0) {
                i5++;
            } else {
                i4++;
            }
        }
        return new PredictionPerformance(i2, i3, i4, i5);
    }

    private Compound newCompound(String str, MolecularFormula molecularFormula, int i, boolean[] zArr) {
        Compound compound = new Compound(str, molecularFormula, (byte) i, transformFingerprintToIntegerArray(this.usedIndizes, zArr));
        int length = compound.fingerprint.length;
        compound.nodes = new svm_node[length + ADDITIONAL_FINGERPRINTS];
        for (int i2 = 0; i2 < compound.fingerprint.length; i2++) {
            compound.nodes[i2] = this.NODE_POOL[compound.fingerprint[i2]];
            if (compound.nodes[i2] == null) {
                throw new RuntimeException("node is null: " + str + " at " + i2);
            }
        }
        double[] additionalFingerprintsFor = getAdditionalFingerprintsFor(compound);
        for (int i3 = 0; i3 < additionalFingerprintsFor.length; i3++) {
            compound.nodes[length + i3] = new svm_node();
            compound.nodes[length + i3].index = i3 + this.FEATURES;
            compound.nodes[length + i3].value = additionalFingerprintsFor[i3];
        }
        return compound;
    }

    public void addSeed(String str, MolecularFormula molecularFormula, int i, boolean[] zArr) {
        this.seed.add(newCompound(str, molecularFormula, i, zArr));
    }

    public void addPool(String str, MolecularFormula molecularFormula, int i, boolean[] zArr) {
        this.pool.add(newCompound(str, molecularFormula, i, zArr));
    }

    public void updateWeightLabels(List<Compound> list) {
        double[] dArr = this.WEIGHT;
        this.WEIGHT[1] = 1.0d;
        dArr[0] = 1.0d;
    }

    public MPredictor train(int i, boolean z, double d) {
        Model model = new Model(new PredictionPerformance(), d, i, z ? 1.0d : 0.0d);
        this.trainingSet = new ArrayList();
        this.evaluationSet = new ArrayList();
        pickupTrainAndEval(this.seed, this.trainingSet, this.evaluationSet);
        updateWeightLabels(this.trainingSet);
        MPredictor mPredictor = null;
        int i2 = 0;
        for (int i3 = 0; i3 < 5; i3++) {
            System.out.println("Round " + i3);
            MPredictor trainAndEvaluate = trainAndEvaluate(model, defineProblem(this.trainingSet), this.evaluationSet);
            mPredictor = trainAndEvaluate;
            List<Compound> evaluateAndRememberFalseNegatives = evaluateAndRememberFalseNegatives(trainAndEvaluate.svmModel, this.pool);
            try {
                System.out.println("Write report with " + evaluateAndRememberFalseNegatives.size() + " failed instances and " + trainAndEvaluate.getPerformance().toString());
                i2++;
                writeReport(trainAndEvaluate, evaluateAndRememberFalseNegatives.size(), i3, i2, this.trainingSet);
            } catch (IOException e) {
                e.printStackTrace();
            }
            if (evaluateAndRememberFalseNegatives.size() == 0) {
                break;
            }
            this.trainingSet.size();
            Collections.shuffle(evaluateAndRememberFalseNegatives);
            HashSet hashSet = new HashSet();
            for (int i4 = 0; i4 < Math.min(5000, evaluateAndRememberFalseNegatives.size()); i4++) {
                this.trainingSet.add(evaluateAndRememberFalseNegatives.get(i4));
                hashSet.add(evaluateAndRememberFalseNegatives.get(i4).inchiKey);
            }
            System.out.println("Added " + Math.min(5000, evaluateAndRememberFalseNegatives.size()) + " new negative samples");
            Iterator<Compound> it = this.pool.iterator();
            while (it.hasNext()) {
                if (hashSet.contains(it.next().inchiKey)) {
                    it.remove();
                }
            }
        }
        return mPredictor;
    }

    public MPredictor train() throws InterruptedException {
        System.out.println("Train version: 0.4");
        this.trainingSet = new ArrayList();
        this.evaluationSet = new ArrayList();
        pickupTrainAndEval(this.seed, this.trainingSet, this.evaluationSet);
        int i = 0;
        System.out.println("New outer fold: 1");
        updateWeightLabels(this.trainingSet);
        System.out.println("Start training parameters");
        Model trainParameters = trainParameters(this.trainingSet);
        System.out.println("finally best found parameters: " + trainParameters.toString() + " with " + trainParameters.performance.toString());
        for (int i2 = 0; i2 < 5; i2++) {
            System.out.println("Round " + i2);
            MPredictor trainAndEvaluate = trainAndEvaluate(trainParameters, defineProblem(this.trainingSet), this.evaluationSet);
            ArrayList<Compound> findFarFromMargin = findFarFromMargin(trainAndEvaluate.svmModel, this.trainingSet, 2);
            System.out.println(findFarFromMargin.size() + " vectors are far away from margin");
            List<Compound> evaluateAndRememberFalseNegatives = evaluateAndRememberFalseNegatives(trainAndEvaluate.svmModel, this.pool);
            try {
                System.out.println("Write report with " + evaluateAndRememberFalseNegatives.size() + " failed instances and " + trainAndEvaluate.getPerformance().toString());
                i++;
                writeReport(trainAndEvaluate, evaluateAndRememberFalseNegatives.size(), 1, i, this.trainingSet);
            } catch (IOException e) {
                e.printStackTrace();
            }
            if (evaluateAndRememberFalseNegatives.size() == 0) {
                break;
            }
            this.trainingSet.size();
            if (evaluateAndRememberFalseNegatives.size() + this.trainingSet.size() > 30000) {
                this.trainingSet.removeAll(findFarFromMargin);
                this.pool.addAll(findFarFromMargin);
            }
            Collections.shuffle(evaluateAndRememberFalseNegatives);
            HashSet hashSet = new HashSet();
            int size = evaluateAndRememberFalseNegatives.size() + this.trainingSet.size() > 30000 ? evaluateAndRememberFalseNegatives.size() : 3000;
            for (int i3 = 0; i3 < Math.min(size, evaluateAndRememberFalseNegatives.size()); i3++) {
                this.trainingSet.add(evaluateAndRememberFalseNegatives.get(i3));
                hashSet.add(evaluateAndRememberFalseNegatives.get(i3).inchiKey);
            }
            System.out.println("Added " + Math.min(size, evaluateAndRememberFalseNegatives.size()) + " new negative samples");
            Iterator<Compound> it = this.pool.iterator();
            while (it.hasNext()) {
                if (hashSet.contains(it.next().inchiKey)) {
                    it.remove();
                }
            }
            uniqueList(this.evaluationSet);
            uniqueList(this.pool);
            uniqueList(this.trainingSet);
        }
        System.out.println("FINAL MODEL: " + trainParameters(this.trainingSet).toString());
        for (Compound compound : this.evaluationSet) {
            if (compound.classification > 0) {
                this.trainingSet.add(compound);
            }
        }
        MPredictor trainAndEvaluate2 = trainAndEvaluate(trainParameters, defineProblem(this.trainingSet), this.evaluationSet);
        try {
            trainAndEvaluate2.writeModel(new File(new File(this.outputDir.getParent(), "models"), this.outputDir.getName() + ".model"));
        } catch (IOException e2) {
            e2.printStackTrace();
        }
        return trainAndEvaluate2;
    }

    private void uniqueList(List<Compound> list) {
        HashSet hashSet = new HashSet();
        ListIterator<Compound> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            Compound next = listIterator.next();
            if (hashSet.contains(next.inchiKey)) {
                listIterator.remove();
            } else {
                hashSet.add(next.inchiKey);
            }
        }
    }

    private void removeNonSupportVectorsFromTrainingSet(List<Compound> list, List<Compound> list2, MPredictor mPredictor) {
        ArrayList arrayList = new ArrayList();
        TIntHashSet tIntHashSet = new TIntHashSet();
        for (int i : mPredictor.supportVectors) {
            Compound compound = list.get(i - 1);
            if (compound.classification < 0) {
                arrayList.add(compound);
                tIntHashSet.add(i - 1);
            }
        }
        for (int i2 = 0; i2 < list.size(); i2++) {
            Compound compound2 = list.get(i2);
            if (compound2.classification > 0) {
                arrayList.add(compound2);
            } else if (!tIntHashSet.contains(i2)) {
                list2.add(compound2);
            }
        }
        list.clear();
        list.addAll(arrayList);
    }

    private void writeReport(MPredictor mPredictor, int i, int i2, int i3, List<Compound> list) throws IOException {
        File file = new File(this.outputDir, String.valueOf(i2 + 1) + "_" + String.valueOf(i3) + ".model");
        File file2 = new File(this.outputDir, String.valueOf(i2 + 1) + "_" + String.valueOf(i3) + ".txt");
        mPredictor.writeModel(file);
        BufferedWriter newBufferedWriter = Files.newBufferedWriter(file2.toPath(), Charset.defaultCharset(), new OpenOption[0]);
        Throwable th = null;
        try {
            try {
                newBufferedWriter.write(mPredictor.getPerformance().toString());
                newBufferedWriter.newLine();
                newBufferedWriter.write(String.valueOf(i));
                newBufferedWriter.write("\t");
                newBufferedWriter.write("failed instances in pool\n");
                newBufferedWriter.write(mPredictor.model.toString());
                newBufferedWriter.newLine();
                newBufferedWriter.write("trained on " + list.size() + " compounds\n---\n");
                Iterator<Compound> it = list.iterator();
                while (it.hasNext()) {
                    newBufferedWriter.write(it.next().inchiKey);
                    newBufferedWriter.newLine();
                }
                if (newBufferedWriter != null) {
                    if (0 == 0) {
                        newBufferedWriter.close();
                        return;
                    }
                    try {
                        newBufferedWriter.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (newBufferedWriter != null) {
                if (th != null) {
                    try {
                        newBufferedWriter.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    newBufferedWriter.close();
                }
            }
            throw th4;
        }
    }

    private Model trainParameters(List<Compound> list) throws InterruptedException {
        List<Compound>[] listArr = new List[5];
        final List<Compound>[] listArr2 = new List[5];
        pickupTrainAndEval(list, listArr, listArr2);
        updateWeightLabels(list);
        final svm_problem[] svm_problemVarArr = new svm_problem[5];
        for (int i = 0; i < 5; i++) {
            svm_problemVarArr[i] = defineProblem(listArr[i]);
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = this.minDegree; i2 <= this.maxDegree; i2++) {
            for (int i3 = this.minCoef; i3 <= this.maxCoef; i3++) {
                final int i4 = i2;
                final int i5 = i3;
                for (final int i6 : new int[]{-1, 1}) {
                    arrayList.add(this.service.submit(new Callable<Model>() { // from class: de.unijena.bioinf.fingerid.TrainCompoundClasses.2
                        /* JADX WARN: Can't rename method to resolve collision */
                        @Override // java.util.concurrent.Callable
                        public Model call() throws Exception {
                            Model model = null;
                            int i7 = 0;
                            for (int i8 = i6 < 0 ? 0 : 1; i8 <= TrainCompoundClasses.this.numberOfCs; i8++) {
                                svm_parameter defaultParameters = TrainCompoundClasses.this.defaultParameters();
                                defaultParameters.C = Math.pow(2.0d, i8 * i6);
                                defaultParameters.cache_size = 1024.0d;
                                defaultParameters.weight = TrainCompoundClasses.this.WEIGHT;
                                defaultParameters.coef0 = i5;
                                defaultParameters.degree = i4;
                                defaultParameters.gamma = 1.0d;
                                defaultParameters.weight_label = TrainCompoundClasses.this.WEIGHT_LABEL;
                                Model trainAndEvaluateCrossFolds = TrainCompoundClasses.this.trainAndEvaluateCrossFolds(svm_problemVarArr, defaultParameters, listArr2);
                                if (model != null && trainAndEvaluateCrossFolds.compareTo(model) <= 0) {
                                    i7++;
                                    if (i7 > 1) {
                                        break;
                                    }
                                } else {
                                    model = trainAndEvaluateCrossFolds;
                                }
                            }
                            System.out.println("Best model for " + (i6 > 0 ? "large" : "small") + " c values: " + model.toString() + " with " + model.performance.toString());
                            return model;
                        }
                    }));
                }
            }
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            try {
                arrayList2.add(((Future) it.next()).get());
            } catch (ExecutionException e) {
                throw new RuntimeException(e);
            }
        }
        return (Model) Collections.max(arrayList2);
    }

    private MPredictor trainAndEvaluate(Model model, svm_problem svm_problemVar, List<Compound> list) {
        svm_parameter defaultParameters = defaultParameters();
        defaultParameters.C = model.c;
        defaultParameters.cache_size = 1024.0d;
        defaultParameters.weight = this.WEIGHT;
        defaultParameters.coef0 = model.coef0;
        defaultParameters.degree = model.degree;
        defaultParameters.gamma = 1.0d;
        defaultParameters.weight_label = this.WEIGHT_LABEL;
        svm_model svm_train = svm.svm_train(svm_problemVar, defaultParameters);
        PredictionPerformance evaluate = evaluate(svm_train, list);
        MPredictor mPredictor = new MPredictor(model, svm_train);
        mPredictor.setStatistics(evaluate);
        return mPredictor;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Model trainAndEvaluateCrossFolds(svm_problem[] svm_problemVarArr, svm_parameter svm_parameterVar, List<Compound>[] listArr) {
        StringBuilder sb = new StringBuilder();
        PredictionPerformance predictionPerformance = new PredictionPerformance();
        sb.append("#SV = ");
        for (int i = 0; i < svm_problemVarArr.length; i++) {
            svm_model svm_train = svm.svm_train(svm_problemVarArr[i], svm_parameterVar);
            sb.append("\t").append(svm_train.nSV[0]).append(" ").append(svm_train.nSV.length > 1 ? Integer.valueOf(svm_train.nSV[1]) : "0");
            predictionPerformance.merge(evaluate(svm_train, listArr[i]));
        }
        predictionPerformance.calc();
        Model model = new Model(predictionPerformance, svm_parameterVar.C, svm_parameterVar.degree, svm_parameterVar.coef0);
        System.out.println("found " + model.toString() + ", " + ((Object) sb));
        return model;
    }

    private void pickupTrainAndEval(List<Compound> list, List<Compound> list2, List<Compound> list3) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Compound compound : list) {
            if (compound.classification > 0) {
                arrayList.add(compound);
            } else {
                arrayList2.add(compound);
            }
        }
        int size = arrayList.size() / 5;
        int size2 = arrayList2.size() / 5;
        Collections.shuffle(arrayList);
        Collections.shuffle(arrayList2);
        list3.addAll(arrayList.subList(0, size));
        list2.addAll(arrayList.subList(size, arrayList.size()));
        list3.addAll(arrayList2.subList(0, size2));
        list2.addAll(arrayList2.subList(size2, arrayList2.size()));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void pickupTrainAndEval(List<Compound> list, List<Compound>[] listArr, List<Compound>[] listArr2) {
        int length = listArr.length;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Compound compound : list) {
            if (compound.classification > 0) {
                arrayList.add(compound);
            } else {
                arrayList2.add(compound);
            }
        }
        Collections.shuffle(arrayList);
        Collections.shuffle(arrayList2);
        for (int i = 0; i < length; i++) {
            listArr[i] = new ArrayList();
            listArr2[i] = new ArrayList();
        }
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            int i3 = i2 % length;
            listArr2[i3].add(arrayList.get(i2));
            for (int i4 = 0; i4 < length; i4++) {
                if (i4 != i3) {
                    listArr[i4].add(arrayList.get(i2));
                }
            }
        }
        for (int i5 = 0; i5 < arrayList2.size(); i5++) {
            int i6 = i5 % length;
            listArr2[i6].add(arrayList2.get(i5));
            for (int i7 = 0; i7 < length; i7++) {
                if (i7 != i6) {
                    listArr[i7].add(arrayList2.get(i5));
                }
            }
        }
    }

    private List<Compound> evaluateAndRememberFalseNegatives(svm_model svm_modelVar, List<Compound> list) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            if (svmPredict(svm_modelVar, list.get(i)) != (list.get(i).classification > 0)) {
                arrayList.add(list.get(i));
            }
        }
        return arrayList;
    }

    private PredictionPerformance evaluate(svm_model svm_modelVar, List<Compound> list) {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < list.size(); i5++) {
            Compound compound = list.get(i5);
            boolean svmPredict = svmPredict(svm_modelVar, compound);
            if (compound.classification > 0) {
                if (svmPredict) {
                    i++;
                } else {
                    i2++;
                }
            } else if (svmPredict) {
                i4++;
            } else {
                i3++;
            }
        }
        return new PredictionPerformance(i, i2, i3, i4);
    }

    private ArrayList<Compound> findFarFromMargin(svm_model svm_modelVar, List<Compound> list, int i) {
        ArrayList<Compound> arrayList = new ArrayList<>();
        for (Compound compound : list) {
            double svmPredictValue = svmPredictValue(svm_modelVar, compound);
            if ((svmPredictValue > 0.0d) != svmPredict(svm_modelVar, compound)) {
                throw new RuntimeException("WTF??");
            }
            boolean z = compound.classification > 0;
            if (Math.abs(svmPredictValue) > i) {
                arrayList.add(compound);
            }
        }
        return arrayList;
    }

    private boolean svmPredict(svm_model svm_modelVar, Compound compound) {
        return svm.svm_predict(svm_modelVar, compound.nodes) > 0.0d;
    }

    private double svmPredictValue(svm_model svm_modelVar, Compound compound) {
        svm_node[] svm_nodeVarArr = compound.nodes;
        double d = 0.0d;
        if (svm_modelVar.label[0] < 0) {
            throw new RuntimeException("WTF?");
        }
        for (int i = 0; i < svm_modelVar.l; i++) {
            d += svm_modelVar.sv_coef[0][i] * k_function(svm_nodeVarArr, svm_modelVar.SV[i], svm_modelVar.param);
        }
        return d - svm_modelVar.rho[0];
    }

    /* JADX INFO: Access modifiers changed from: private */
    public svm_parameter defaultParameters() {
        svm_parameter svm_parameterVar = new svm_parameter();
        svm_parameterVar.svm_type = 0;
        svm_parameterVar.kernel_type = 1;
        svm_parameterVar.degree = 2;
        svm_parameterVar.gamma = 1.0d;
        svm_parameterVar.coef0 = 1.0d;
        svm_parameterVar.nu = 0.5d;
        svm_parameterVar.cache_size = 5000.0d;
        svm_parameterVar.C = 1.0d;
        svm_parameterVar.eps = 0.001d;
        svm_parameterVar.p = 0.1d;
        svm_parameterVar.shrinking = 1;
        svm_parameterVar.probability = 0;
        svm_parameterVar.weight_label = new int[]{1, -1};
        svm_parameterVar.weight = new double[]{1.0d, 1.0d};
        svm_parameterVar.nr_weight = svm_parameterVar.weight.length;
        return svm_parameterVar;
    }

    /* JADX WARN: Type inference failed for: r1v5, types: [libsvm.svm_node[], libsvm.svm_node[][]] */
    private svm_problem defineProblem(List<Compound> list) {
        svm_problem svm_problemVar = new svm_problem();
        svm_problemVar.l = list.size();
        svm_problemVar.x = new svm_node[svm_problemVar.l];
        svm_problemVar.y = new double[svm_problemVar.l];
        for (int i = 0; i < list.size(); i++) {
            svm_problemVar.x[i] = list.get(i).nodes;
            svm_problemVar.y[i] = r0.classification;
        }
        return svm_problemVar;
    }

    private static double dot(svm_node[] svm_nodeVarArr, svm_node[] svm_nodeVarArr2) {
        double d = 0.0d;
        int length = svm_nodeVarArr.length;
        int length2 = svm_nodeVarArr2.length;
        int i = 0;
        int i2 = 0;
        while (i < length && i2 < length2) {
            if (svm_nodeVarArr[i].index == svm_nodeVarArr2[i2].index) {
                int i3 = i;
                i++;
                int i4 = i2;
                i2++;
                d += svm_nodeVarArr[i3].value * svm_nodeVarArr2[i4].value;
            } else if (svm_nodeVarArr[i].index > svm_nodeVarArr2[i2].index) {
                i2++;
            } else {
                i++;
            }
        }
        return d;
    }

    private static double k_function(svm_node[] svm_nodeVarArr, svm_node[] svm_nodeVarArr2, svm_parameter svm_parameterVar) {
        switch (svm_parameterVar.kernel_type) {
            case Svm.LINEAR /* 0 */:
                return dot(svm_nodeVarArr, svm_nodeVarArr2);
            case 1:
                return powi((svm_parameterVar.gamma * dot(svm_nodeVarArr, svm_nodeVarArr2)) + svm_parameterVar.coef0, svm_parameterVar.degree);
            case Svm.RBF /* 2 */:
                double d = 0.0d;
                int length = svm_nodeVarArr.length;
                int length2 = svm_nodeVarArr2.length;
                int i = 0;
                int i2 = 0;
                while (i < length && i2 < length2) {
                    if (svm_nodeVarArr[i].index == svm_nodeVarArr2[i2].index) {
                        int i3 = i;
                        i++;
                        int i4 = i2;
                        i2++;
                        double d2 = svm_nodeVarArr[i3].value - svm_nodeVarArr2[i4].value;
                        d += d2 * d2;
                    } else if (svm_nodeVarArr[i].index > svm_nodeVarArr2[i2].index) {
                        d += svm_nodeVarArr2[i2].value * svm_nodeVarArr2[i2].value;
                        i2++;
                    } else {
                        d += svm_nodeVarArr[i].value * svm_nodeVarArr[i].value;
                        i++;
                    }
                }
                while (i < length) {
                    d += svm_nodeVarArr[i].value * svm_nodeVarArr[i].value;
                    i++;
                }
                while (i2 < length2) {
                    d += svm_nodeVarArr2[i2].value * svm_nodeVarArr2[i2].value;
                    i2++;
                }
                return Math.exp((-svm_parameterVar.gamma) * d);
            case 3:
                return Math.tanh((svm_parameterVar.gamma * dot(svm_nodeVarArr, svm_nodeVarArr2)) + svm_parameterVar.coef0);
            case Svm.PRECOMPUTED /* 4 */:
                return svm_nodeVarArr[(int) svm_nodeVarArr2[0].value].value;
            default:
                return 0.0d;
        }
    }

    private static double powi(double d, int i) {
        double d2 = d;
        double d3 = 1.0d;
        int i2 = i;
        while (true) {
            int i3 = i2;
            if (i3 <= 0) {
                return d3;
            }
            if (i3 % 2 == 1) {
                d3 *= d2;
            }
            d2 *= d2;
            i2 = i3 / 2;
        }
    }

    static {
        FINAL_NODE.index = -1;
    }
}
