package de.unijena.bioinf.fingerid;

import com.google.common.base.Function;
import de.unijena.bioinf.ChemistryBase.chem.InChI;
import de.unijena.bioinf.ChemistryBase.chem.MolecularFormula;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.fingerid.OptimizationStrategy;
import de.unijena.bioinf.fingerid.svm.CSelection;
import de.unijena.bioinf.fingerid.svm.Crossvalidation;
import de.unijena.bioinf.fingerid.svm.FeatureList;
import de.unijena.bioinf.fingerid.svm.Sample;
import de.unijena.bioinf.fingerid.svm.Svm;
import de.unijena.bioinf.fingerid.svm.SvmInstance;
import de.unijena.bioinf.fingerid.svm.SvmModel;
import de.unijena.bioinf.fingerid.svm.linear.LinearSvmWithWeights;
import gnu.trove.list.array.TDoubleArrayList;
import java.io.BufferedReader;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/* loaded from: input_file:de/unijena/bioinf/fingerid/TrainCompoundClassesLinear.class */
public class TrainCompoundClassesLinear implements Closeable {
    private final int[] usedIndizes;
    private final boolean[] usedIndizesMap;
    private final SvmInstance svm;
    private List<Compound> compounds;
    private FeatureList prototype;
    private static final Double ONE = Double.valueOf(1.0d);
    private double cmodifierValue = 0.0d;
    private final ExecutorService service = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:de/unijena/bioinf/fingerid/TrainCompoundClassesLinear$Compound.class */
    public static final class Compound implements Sample {
        private FeatureList featureList;
        private final short[] fingerprints;
        private final String inchikey;
        private byte classification;
        private final MolecularFormula formula;
        private InChI inchi;
        private byte fold;

        public Compound(InChI inChI, short[] sArr, int i) {
            this.inchi = inChI;
            this.inchikey = inChI.key2D();
            this.formula = inChI.extractFormula();
            this.fingerprints = sArr;
            this.classification = (byte) i;
        }

        @Override // de.unijena.bioinf.fingerid.svm.Sample
        public FeatureList getFeatureList() {
            return this.featureList;
        }

        @Override // de.unijena.bioinf.fingerid.svm.Sample
        public void setFeatureList(FeatureList featureList) {
            this.featureList = featureList;
        }

        @Override // de.unijena.bioinf.fingerid.svm.Sample
        public double getLabel() {
            return this.classification;
        }

        @Override // de.unijena.bioinf.fingerid.svm.Sample
        public void setLabel(double d) {
            this.classification = (byte) d;
        }

        @Override // de.unijena.bioinf.fingerid.svm.Sample
        public int getBatchNum() {
            return this.fold;
        }

        @Override // de.unijena.bioinf.fingerid.svm.Sample
        public void setBatchNum(int i) {
            this.fold = (byte) i;
        }

        @Override // de.unijena.bioinf.fingerid.svm.Sample
        public String getGroup() {
            return this.inchikey;
        }
    }

    /* loaded from: input_file:de/unijena/bioinf/fingerid/TrainCompoundClassesLinear$SampleWithScore.class */
    private static class SampleWithScore implements Comparable<SampleWithScore> {
        private Sample sample;
        private double score;

        public SampleWithScore(Sample sample, double d) {
            this.sample = sample;
            this.score = d;
        }

        @Override // java.lang.Comparable
        public int compareTo(SampleWithScore sampleWithScore) {
            return Double.compare(this.score, sampleWithScore.score);
        }
    }

    /* loaded from: input_file:de/unijena/bioinf/fingerid/TrainCompoundClassesLinear$Sampler.class */
    public static class Sampler {
        private int[] usedIndizes;
        private TDoubleArrayList[] positives;
        private TDoubleArrayList[] negatives;
        private boolean dirty = true;
        private double transform = 0.0d;

        public Sampler(int[] iArr) {
            this.usedIndizes = iArr;
            this.positives = new TDoubleArrayList[iArr.length];
            this.negatives = new TDoubleArrayList[iArr.length];
            for (int i = 0; i < iArr.length; i++) {
                this.positives[i] = new TDoubleArrayList();
                this.negatives[i] = new TDoubleArrayList();
            }
        }

        public void setLinearTransform(double d) {
            this.transform = d;
        }

        public double[] sample(short[] sArr, int i) {
            refresh();
            double[] dArr = new double[i];
            int i2 = 0;
            for (short s : sArr) {
                while (this.usedIndizes[i2] < s) {
                    dArr[this.usedIndizes[i2]] = sample(this.negatives[i2]);
                    i2++;
                }
                if (this.usedIndizes[i2] == s) {
                    dArr[this.usedIndizes[i2]] = sample(this.positives[i2]);
                    i2++;
                }
            }
            return dArr;
        }

        public double[] sample(boolean[] zArr) {
            refresh();
            double[] dArr = new double[zArr.length];
            for (int i = 0; i < this.usedIndizes.length; i++) {
                int i2 = this.usedIndizes[i];
                if (zArr[i2]) {
                    dArr[i2] = sample(this.positives[i]);
                } else {
                    dArr[i2] = sample(this.negatives[i]);
                }
            }
            if (this.transform > 0.0d) {
                TrainCompoundClassesLinear.transform(dArr, this.transform);
            }
            return dArr;
        }

        private double sample(TDoubleArrayList tDoubleArrayList) {
            refresh();
            Random random = new Random();
            int i = 1;
            do {
                i++;
            } while (random.nextDouble() >= 0.6d);
            int min = Math.min(i, tDoubleArrayList.size());
            int nextInt = min == tDoubleArrayList.size() ? 0 : random.nextInt(tDoubleArrayList.size() - min);
            double d = 1.0d;
            int i2 = nextInt + min;
            while (nextInt < i2) {
                d *= tDoubleArrayList.getQuick(nextInt);
                nextInt++;
            }
            return Math.pow(d, 1.0d / min);
        }

        protected void refresh() {
            if (this.dirty) {
                for (TDoubleArrayList tDoubleArrayList : this.positives) {
                    tDoubleArrayList.sort();
                }
                for (TDoubleArrayList tDoubleArrayList2 : this.negatives) {
                    tDoubleArrayList2.sort();
                }
                this.dirty = false;
            }
        }

        public void add(boolean[] zArr, double[] dArr) {
            this.dirty = true;
            for (int i = 0; i < this.usedIndizes.length; i++) {
                if (zArr[this.usedIndizes[i]]) {
                    this.positives[i].add(dArr[this.usedIndizes[i]]);
                } else {
                    this.negatives[i].add(dArr[this.usedIndizes[i]]);
                }
            }
        }

        public void readCrossvalidation(File file) throws IOException {
            this.dirty = true;
            BufferedReader newBufferedReader = Files.newBufferedReader(file.toPath(), Charset.forName("UTF-8"));
            Throwable th = null;
            while (true) {
                try {
                    try {
                        String readLine = newBufferedReader.readLine();
                        if (readLine == null) {
                            break;
                        }
                        String[] split = readLine.split("\t");
                        String str = split[3];
                        for (int i = 4; i < split.length; i++) {
                            int i2 = i - 4;
                            boolean z = str.charAt(i2) == '1';
                            double parseDouble = Double.parseDouble(split[i]);
                            if (z) {
                                this.positives[i2].add(parseDouble);
                            } else {
                                this.negatives[i2].add(parseDouble);
                            }
                        }
                    } catch (Throwable th2) {
                        th = th2;
                        throw th2;
                    }
                } catch (Throwable th3) {
                    if (newBufferedReader != null) {
                        if (th != null) {
                            try {
                                newBufferedReader.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            newBufferedReader.close();
                        }
                    }
                    throw th3;
                }
            }
            if (newBufferedReader != null) {
                if (0 == 0) {
                    newBufferedReader.close();
                    return;
                }
                try {
                    newBufferedReader.close();
                } catch (Throwable th5) {
                    th.addSuppressed(th5);
                }
            }
        }
    }

    /* loaded from: input_file:de/unijena/bioinf/fingerid/TrainCompoundClassesLinear$UniqueFingerprint.class */
    private static final class UniqueFingerprint {
        private final Compound compound;
        private final int hashCode;

        public UniqueFingerprint(Compound compound) {
            this.compound = compound;
            this.hashCode = Arrays.hashCode(compound.fingerprints);
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            UniqueFingerprint uniqueFingerprint = (UniqueFingerprint) obj;
            if (this.hashCode != uniqueFingerprint.hashCode) {
                return false;
            }
            return Arrays.equals(this.compound.fingerprints, uniqueFingerprint.compound.fingerprints);
        }

        public int hashCode() {
            return this.hashCode;
        }
    }

    public int numberOfCompounds() {
        return this.compounds.size();
    }

    public static void transform(double[] dArr, double d) {
        double log = Math.log(d);
        double log2 = 0.5d / (Math.log(0.5d) - log);
        double d2 = (-log2) * log;
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] >= 0.5d) {
                dArr[i] = Math.min(1.0d, 1.0d - (d2 + (log2 * Math.log(Math.max(1.0d - dArr[i], d)))));
            } else {
                dArr[i] = Math.max(0.0d, d2 + (log2 * Math.log(Math.max(dArr[i], d))));
            }
        }
    }

    public void autoSamplePlattAndTransform(int i, int[] iArr, double d, File file, double d2) throws IOException {
        int i2 = iArr[iArr.length - 1] + 1;
        this.svm.setParameter(Svm.C, 1.0E-4d);
        Sample[] foldsArray = new Crossvalidation(this.compounds, 5, true, false).getFoldsArray(0, 1, 2, 3, 4);
        this.svm.setSamples(foldsArray);
        SvmModel train = this.svm.train();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Sample sample : foldsArray) {
            double computeDecisionValue = train.computeDecisionValue(sample);
            if (sample.getLabel() > 0.0d) {
                arrayList.add(new SampleWithScore(sample, computeDecisionValue));
            } else {
                arrayList2.add(new SampleWithScore(sample, computeDecisionValue));
            }
        }
        Collections.sort(arrayList);
        Collections.sort(arrayList2, Collections.reverseOrder());
        Sampler sampler = new Sampler(iArr);
        sampler.setLinearTransform(d);
        sampler.readCrossvalidation(file);
        for (int i3 = 0; i3 < Math.min(i, arrayList.size()); i3++) {
            Compound compound = (Compound) ((SampleWithScore) arrayList.get(i3)).sample;
            addCompound(compound.inchikey, compound.inchi.in2D, compound.classification, sampler.sample(compound.fingerprints, i2));
        }
        for (int i4 = 0; i4 < Math.min(i, arrayList2.size()); i4++) {
            Compound compound2 = (Compound) ((SampleWithScore) arrayList2.get(i4)).sample;
            addCompound(compound2.inchikey, compound2.inchi.in2D, compound2.classification, sampler.sample(compound2.fingerprints, i2));
        }
        setCModifier(d2);
    }

    public void setCModifier(double d) {
        this.cmodifierValue = d;
        if (d != 1.0d) {
            ((LinearSvmWithWeights) this.svm).setWeightFunction(new Function<Sample, Double>() { // from class: de.unijena.bioinf.fingerid.TrainCompoundClassesLinear.1
                public Double apply(Sample sample) {
                    return ((Compound) sample).fingerprints == null ? Double.valueOf(TrainCompoundClassesLinear.this.cmodifierValue) : TrainCompoundClassesLinear.ONE;
                }
            });
        }
    }

    public TrainCompoundClassesLinear(int[] iArr) {
        this.usedIndizes = (int[]) iArr.clone();
        Arrays.sort(iArr);
        this.compounds = new ArrayList();
        this.svm = new LinearSvmWithWeights();
        this.svm.disableDebugMode();
        setParameters();
        int i = iArr[iArr.length - 1] + 1;
        this.prototype = this.svm.newFeatureList(i);
        for (int i2 = 0; i2 < i; i2++) {
            this.prototype.add(i2, 1.0d);
        }
        this.usedIndizesMap = new boolean[i];
        for (int i3 : iArr) {
            this.usedIndizesMap[i3] = true;
        }
    }

    public boolean evalCompound(SvmModel svmModel, String str, String str2, int i, short[] sArr) {
        Compound compound = new Compound(new InChI(str, str2), sArr, i);
        makeFeatures(compound);
        return ((int) svmModel.predict(compound)) == i;
    }

    public boolean evalCompound(SvmModel svmModel, String str, String str2, int i, double[] dArr) {
        Compound compound = new Compound(new InChI(str, str2), null, i);
        makeFeatures(compound, filter(this.usedIndizes, dArr));
        return ((int) svmModel.predict(compound)) == i;
    }

    public void writeModel(File file, SvmModel svmModel) throws IOException {
        this.svm.writeModel(file, svmModel);
    }

    private void setParameters() {
    }

    public SvmModel train() {
        Crossvalidation crossvalidation = new Crossvalidation(this.compounds, 5, true, false);
        CSelection cSelection = new CSelection(this.service, Runtime.getRuntime().availableProcessors());
        cSelection.setPossibleCs(Math.pow(2.0d, -15.0d), Math.pow(2.0d, -10.0d), Math.pow(2.0d, -9.0d), Math.pow(2.0d, -8.0d), Math.pow(2.0d, -7.0d), Math.pow(2.0d, -6.0d), Math.pow(2.0d, -5.0d), Math.pow(2.0d, -4.0d), Math.pow(2.0d, -3.0d), Math.pow(2.0d, -2.0d), Math.pow(2.0d, -1.0d), 1.0d, Math.pow(2.0d, 1.0d), Math.pow(2.0d, 2.0d), Math.pow(2.0d, 3.0d), Math.pow(2.0d, 4.0d), Math.pow(2.0d, 5.0d), Math.pow(2.0d, 6.0d));
        try {
            PredictionPerformance predictionPerformance = new PredictionPerformance(0.0d, 0.0d, 0.0d, 0.0d);
            double learnC = cSelection.learnC(this.svm, crossvalidation, new OptimizationStrategy.ByFScore(), predictionPerformance, 0, 1, 2, 3);
            this.svm.setParameter(Svm.C, learnC);
            System.out.println("Final Performance on c selection: " + predictionPerformance.toString() + " with c is " + learnC);
            this.svm.setSamples(crossvalidation.getFoldsArray(0, 1, 2, 3));
            SvmModel train = this.svm.train();
            Sample[] foldsArray = crossvalidation.getFoldsArray(4);
            System.out.println("Performance on evaluation set: " + Svm.evaluateClassificationPerformance(foldsArray, train.predict(foldsArray)));
            this.svm.setSamples((Sample[]) this.compounds.toArray(new Sample[this.compounds.size()]));
            return this.svm.train();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    public void addCompound(String str, String str2, int i, boolean[] zArr) {
        Compound compound = new Compound(new InChI(str, str2), TrainCompoundClasses.transformFingerprintToIntegerArray(this.usedIndizes, zArr), i);
        this.compounds.add(compound);
        makeFeatures(compound);
    }

    public void setBias(double d) {
        this.svm.setParameter(Svm.BIAS, d);
    }

    public void setEpsilon(double d) {
        this.svm.setParameter(Svm.EPS, d);
    }

    public void setDebugMode(boolean z) {
        if (z) {
            this.svm.enableDebugMode();
        } else {
            this.svm.disableDebugMode();
        }
    }

    public void addCompound(String str, String str2, int i, short[] sArr) {
        Compound compound = new Compound(new InChI(str, str2), filter(this.usedIndizes, this.usedIndizesMap, sArr), i);
        this.compounds.add(compound);
        makeFeatures(compound);
    }

    public void addCompound(String str, String str2, int i, double[] dArr) {
        Compound compound = new Compound(new InChI(str, str2), null, i);
        makeFeatures(compound, filter(this.usedIndizes, dArr));
        this.compounds.add(compound);
    }

    private void makeFeatures(Compound compound, double[] dArr) {
        double[] additionalFingerprintsFor = TrainCompoundClasses.getAdditionalFingerprintsFor(compound.formula);
        int size = this.prototype.size();
        FeatureList newFeatureList = this.svm.newFeatureList(size + additionalFingerprintsFor.length);
        for (int i = 0; i < this.usedIndizes.length; i++) {
            newFeatureList.add(this.usedIndizes[i], dArr[i]);
        }
        for (int i2 = 0; i2 < additionalFingerprintsFor.length; i2++) {
            newFeatureList.add(size + i2, additionalFingerprintsFor[i2]);
        }
        compound.setFeatureList(newFeatureList);
    }

    private static short[] filter(int[] iArr, boolean[] zArr, short[] sArr) {
        short[] sArr2 = new short[Math.min(sArr.length, iArr.length)];
        int i = 0;
        for (short s : sArr) {
            if (s < zArr.length && zArr[s]) {
                int i2 = i;
                i++;
                sArr2[i2] = s;
            }
        }
        return i < sArr2.length ? Arrays.copyOf(sArr2, i) : sArr2;
    }

    private static double[] filter(int[] iArr, double[] dArr) {
        double[] dArr2 = new double[iArr.length];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = dArr[iArr[i]];
        }
        return dArr2;
    }

    private void makeFeatures(Compound compound) {
        double[] additionalFingerprintsFor = TrainCompoundClasses.getAdditionalFingerprintsFor(compound.formula);
        int size = this.prototype.size();
        FeatureList newFeatureList = this.svm.newFeatureList(compound.fingerprints.length + additionalFingerprintsFor.length);
        for (short s : compound.fingerprints) {
            newFeatureList.addFeatureFrom(this.prototype, s);
        }
        for (int i = 0; i < additionalFingerprintsFor.length; i++) {
            newFeatureList.add(size + i, additionalFingerprintsFor[i]);
        }
        compound.setFeatureList(newFeatureList);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        this.service.shutdown();
    }

    public void removeDuplicateEntries() {
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        for (Compound compound : this.compounds) {
            UniqueFingerprint uniqueFingerprint = new UniqueFingerprint(compound);
            if (hashMap.containsKey(uniqueFingerprint)) {
                if (compound.classification > 0) {
                    i++;
                } else {
                    i2++;
                }
                if (compound.classification != ((Integer) hashMap.get(uniqueFingerprint)).byteValue()) {
                    i3++;
                }
            } else {
                hashMap.put(uniqueFingerprint, Integer.valueOf(compound.classification));
                arrayList.add(compound);
            }
        }
        this.compounds = arrayList;
        System.out.println("Remove " + i + " positive and " + i2 + " negative duplicate entries. " + i3 + " entries could not be distinguished between positive and negative sets.");
    }
}
