/*
 * Decompiled with CFR 0.152.
 */
package de.unijena.bioinf.fingerid;

import de.unijena.bioinf.ChemistryBase.chem.InChI;
import de.unijena.bioinf.ChemistryBase.chem.MolecularFormula;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.fingerid.OptimizationStrategy;
import de.unijena.bioinf.fingerid.TrainCompoundClasses;
import de.unijena.bioinf.fingerid.svm.CSelection;
import de.unijena.bioinf.fingerid.svm.Crossvalidation;
import de.unijena.bioinf.fingerid.svm.FeatureList;
import de.unijena.bioinf.fingerid.svm.Sample;
import de.unijena.bioinf.fingerid.svm.Svm;
import de.unijena.bioinf.fingerid.svm.SvmInstance;
import de.unijena.bioinf.fingerid.svm.SvmModel;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class TrainCompoundClassesKernel
implements Closeable {
    private final ExecutorService service = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
    private final int[] usedIndizes;
    private final boolean[] usedIndizesMap;
    private final SvmInstance svm;
    private List<Compound> compounds;
    private FeatureList prototype;

    public TrainCompoundClassesKernel(int[] usedIndizes) {
        this.usedIndizes = (int[])usedIndizes.clone();
        Arrays.sort(usedIndizes);
        this.compounds = new ArrayList<Compound>();
        this.svm = Svm.getKernelSvm();
        this.svm.disableDebugMode();
        this.setParameters();
        int F = usedIndizes[usedIndizes.length - 1] + 1;
        this.prototype = this.svm.newFeatureList(F);
        for (int i = 0; i < F; ++i) {
            this.prototype.add(i, 1.0);
        }
        this.usedIndizesMap = new boolean[F];
        for (int index : usedIndizes) {
            this.usedIndizesMap[index] = true;
        }
    }

    public boolean evalCompound(SvmModel model, String inchiKey, String inchi, int classification, short[] fingerprints) {
        Compound c = new Compound(new InChI(inchiKey, inchi), fingerprints, classification);
        this.makeFeatures(c);
        double classif = model.predict(c);
        return (int)classif == classification;
    }

    public boolean evalCompound(SvmModel model, String inchiKey, String inchi, int classification, double[] plattScores) {
        Compound c = new Compound(new InChI(inchiKey, inchi), null, classification);
        this.makeFeatures(c, TrainCompoundClassesKernel.filter(this.usedIndizes, plattScores));
        double classif = model.predict(c);
        return (int)classif == classification;
    }

    public void writeModel(File file, SvmModel model) throws IOException {
        this.svm.writeModel(file, model);
    }

    private void setParameters() {
        this.svm.setParameter("kernel_type", 2.0);
        this.svm.setParameter("cache_size", 5096.0);
    }

    public SvmModel train() {
        Crossvalidation<Compound> crossvalidation = new Crossvalidation<Compound>(this.compounds, 5, true, false);
        CSelection cs = new CSelection(this.service, Runtime.getRuntime().availableProcessors());
        cs.setPossibleCs(Math.pow(2.0, -3.0), Math.pow(2.0, -2.0), Math.pow(2.0, -1.0), 1.0, Math.pow(2.0, 1.0), Math.pow(2.0, 2.0), Math.pow(2.0, 3.0));
        try {
            double bestGamma = -1.0;
            double bestC = 0.0;
            PredictionPerformance gammaPerformance = null;
            Comparator<PredictionPerformance> pcomp = new OptimizationStrategy.ByFScore().getComparator();
            for (double gamma : new double[]{0.001, 0.01, 0.1, 1.0, 10.0, 100.0}) {
                PredictionPerformance cperformance = new PredictionPerformance(0.0, 0.0, 0.0, 0.0);
                this.svm.setParameter("gamma", gamma);
                double bestC2 = cs.learnC(this.svm, crossvalidation, new OptimizationStrategy.ByFScore(), cperformance, 0, 1, 2, 3);
                System.out.printf(Locale.US, "Performance of c=%.4f and gamma=%.4f is F1=%.2f", bestC2, gamma, cperformance.getF());
                if (gammaPerformance != null && pcomp.compare(cperformance, gammaPerformance) <= 0) continue;
                gammaPerformance = cperformance;
                bestC = bestC2;
                bestGamma = gamma;
            }
            this.svm.setParameter("c", bestC);
            this.svm.setParameter("gamma", bestGamma);
            System.out.println("Final Performance on c selection: " + gammaPerformance.toString() + " with c is " + bestC);
            this.svm.setSamples(crossvalidation.getFoldsArray(0, 1, 2, 3));
            SvmModel model = this.svm.train();
            Sample[] eval = crossvalidation.getFoldsArray(4);
            double[] results = model.predict(eval);
            PredictionPerformance performance = Svm.evaluateClassificationPerformance(eval, results);
            System.out.println("Performance on evaluation set: " + performance);
            this.svm.setSamples(this.compounds.toArray(new Sample[this.compounds.size()]));
            SvmModel finalModel = this.svm.train();
            return finalModel;
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    public void addCompound(String inchiKey, String inchi, int classification, boolean[] fingerprint) {
        Compound c = new Compound(new InChI(inchiKey, inchi), TrainCompoundClasses.transformFingerprintToIntegerArray(this.usedIndizes, fingerprint), classification);
        this.compounds.add(c);
        this.makeFeatures(c);
    }

    public void setBias(double bias) {
        this.svm.setParameter("bias", bias);
    }

    public void setEpsilon(double epsilon) {
        this.svm.setParameter("epsilon", epsilon);
    }

    public void setDebugMode(boolean flag) {
        if (flag) {
            this.svm.enableDebugMode();
        } else {
            this.svm.disableDebugMode();
        }
    }

    public void addCompound(String inchiKey, String inchi, int classification, short[] fingerprints) {
        Compound c = new Compound(new InChI(inchiKey, inchi), TrainCompoundClassesKernel.filter(this.usedIndizes, this.usedIndizesMap, fingerprints), classification);
        this.compounds.add(c);
        this.makeFeatures(c);
    }

    public void addCompound(String inchiKey, String inchi, int classification, double[] probabilities) {
        Compound c = new Compound(new InChI(inchiKey, inchi), null, classification);
        probabilities = TrainCompoundClassesKernel.filter(this.usedIndizes, probabilities);
        this.makeFeatures(c, probabilities);
        this.compounds.add(c);
    }

    private void makeFeatures(Compound c, double[] probabilities) {
        int k;
        double[] additionalFeatures = TrainCompoundClasses.getAdditionalFingerprintsFor(c.formula);
        int F = this.prototype.size();
        FeatureList fl = this.svm.newFeatureList(F + additionalFeatures.length);
        for (k = 0; k < this.usedIndizes.length; ++k) {
            fl.add(this.usedIndizes[k], probabilities[k]);
        }
        for (k = 0; k < additionalFeatures.length; ++k) {
            fl.add(F + k, additionalFeatures[k]);
        }
        c.setFeatureList(fl);
    }

    private static short[] filter(int[] usedIndizes, boolean[] usedIndizesMap, short[] fingerprints) {
        short[] list = new short[Math.min(fingerprints.length, usedIndizes.length)];
        int j = 0;
        for (short fingerprint : fingerprints) {
            if (fingerprint >= usedIndizesMap.length || !usedIndizesMap[fingerprint]) continue;
            list[j++] = fingerprint;
        }
        if (j < list.length) {
            return Arrays.copyOf(list, j);
        }
        return list;
    }

    private static double[] filter(int[] usedIndizes, double[] fingerprints) {
        double[] platts = new double[usedIndizes.length];
        for (int k = 0; k < platts.length; ++k) {
            platts[k] = fingerprints[usedIndizes[k]];
        }
        return platts;
    }

    private void makeFeatures(Compound c) {
        double[] additionalFeatures = TrainCompoundClasses.getAdditionalFingerprintsFor(c.formula);
        int F = this.prototype.size();
        FeatureList fl = this.svm.newFeatureList(c.fingerprints.length + additionalFeatures.length);
        for (short index : c.fingerprints) {
            fl.addFeatureFrom(this.prototype, index);
        }
        for (int k = 0; k < additionalFeatures.length; ++k) {
            fl.add(F + k, additionalFeatures[k]);
        }
        c.setFeatureList(fl);
    }

    @Override
    public void close() throws IOException {
        this.service.shutdown();
    }

    public void removeDuplicateEntries() {
        HashMap<UniqueFingerprint, Integer> map = new HashMap<UniqueFingerprint, Integer>();
        ArrayList<Compound> compoundList = new ArrayList<Compound>();
        int posC = 0;
        int negC = 0;
        int missc = 0;
        for (Compound c : this.compounds) {
            UniqueFingerprint fc = new UniqueFingerprint(c);
            if (map.containsKey(fc)) {
                if (c.classification > 0) {
                    ++posC;
                } else {
                    ++negC;
                }
                if (c.classification == ((Integer)map.get(fc)).byteValue()) continue;
                ++missc;
                continue;
            }
            map.put(fc, Integer.valueOf(c.classification));
            compoundList.add(c);
        }
        this.compounds = compoundList;
        System.out.println("Remove " + posC + " positive and " + negC + " negative duplicate entries. " + missc + " entries could not be distinguished between positive and negative sets.");
    }

    private static final class Compound
    implements Sample {
        private FeatureList featureList;
        private final short[] fingerprints;
        private final String inchikey;
        private byte classification;
        private final MolecularFormula formula;
        private byte fold;

        public Compound(InChI inchi, short[] fingerprint, int classification) {
            this.inchikey = inchi.key2D();
            this.formula = inchi.extractFormula();
            this.fingerprints = fingerprint;
            this.classification = (byte)classification;
        }

        @Override
        public FeatureList getFeatureList() {
            return this.featureList;
        }

        @Override
        public void setFeatureList(FeatureList list) {
            this.featureList = list;
        }

        @Override
        public double getLabel() {
            return this.classification;
        }

        @Override
        public void setLabel(double label) {
            this.classification = (byte)label;
        }

        @Override
        public int getBatchNum() {
            return this.fold;
        }

        @Override
        public void setBatchNum(int num) {
            this.fold = (byte)num;
        }

        @Override
        public String getGroup() {
            return this.inchikey;
        }
    }

    private static final class UniqueFingerprint {
        private final Compound compound;
        private final int hashCode;

        public UniqueFingerprint(Compound compound) {
            this.compound = compound;
            this.hashCode = Arrays.hashCode(compound.fingerprints);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            UniqueFingerprint that = (UniqueFingerprint)o;
            if (this.hashCode != that.hashCode) {
                return false;
            }
            return Arrays.equals(this.compound.fingerprints, that.compound.fingerprints);
        }

        public int hashCode() {
            return this.hashCode;
        }
    }
}

