/*
 * Decompiled with CFR 0.152.
 */
package de.unijena.bioinf.fingerid;

import com.google.common.base.Function;
import de.unijena.bioinf.ChemistryBase.chem.InChI;
import de.unijena.bioinf.ChemistryBase.chem.MolecularFormula;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.fingerid.OptimizationStrategy;
import de.unijena.bioinf.fingerid.TrainCompoundClasses;
import de.unijena.bioinf.fingerid.svm.CSelection;
import de.unijena.bioinf.fingerid.svm.Crossvalidation;
import de.unijena.bioinf.fingerid.svm.FeatureList;
import de.unijena.bioinf.fingerid.svm.Sample;
import de.unijena.bioinf.fingerid.svm.Svm;
import de.unijena.bioinf.fingerid.svm.SvmInstance;
import de.unijena.bioinf.fingerid.svm.SvmModel;
import de.unijena.bioinf.fingerid.svm.linear.LinearSvmWithWeights;
import gnu.trove.list.array.TDoubleArrayList;
import java.io.BufferedReader;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class TrainCompoundClassesLinear
implements Closeable {
    private final ExecutorService service = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
    private final int[] usedIndizes;
    private final boolean[] usedIndizesMap;
    private final SvmInstance svm;
    private List<Compound> compounds;
    private FeatureList prototype;
    private double cmodifierValue = 0.0;
    private static final Double ONE = 1.0;

    public int numberOfCompounds() {
        return this.compounds.size();
    }

    public static void transform(double[] platt, double minimalPlatt) {
        double minimum = Math.log(minimalPlatt);
        double maximum = Math.log(0.5);
        double a = 0.5 / (maximum - minimum);
        double b = -a * minimum;
        for (int k = 0; k < platt.length; ++k) {
            platt[k] = platt[k] >= 0.5 ? Math.min(1.0, 1.0 - (b + a * Math.log(Math.max(1.0 - platt[k], minimalPlatt)))) : Math.max(0.0, b + a * Math.log(Math.max(platt[k], minimalPlatt)));
        }
    }

    public void autoSamplePlattAndTransform(int count, int[] usedIndizes, double linearTransform, File cvFile, double cmodifier) throws IOException {
        int k;
        int N = usedIndizes[usedIndizes.length - 1] + 1;
        this.svm.setParameter("c", 1.0E-4);
        Crossvalidation<Compound> crossvalidation = new Crossvalidation<Compound>(this.compounds, 5, true, false);
        Sample[] trainSamples = crossvalidation.getFoldsArray(0, 1, 2, 3, 4);
        this.svm.setSamples(trainSamples);
        SvmModel model = this.svm.train();
        ArrayList<SampleWithScore> positiveSamples = new ArrayList<SampleWithScore>();
        ArrayList<SampleWithScore> negativeSamples = new ArrayList<SampleWithScore>();
        for (Sample s : trainSamples) {
            double score = model.computeDecisionValue(s);
            if (s.getLabel() > 0.0) {
                positiveSamples.add(new SampleWithScore(s, score));
                continue;
            }
            negativeSamples.add(new SampleWithScore(s, score));
        }
        Collections.sort(positiveSamples);
        Collections.sort(negativeSamples, Collections.reverseOrder());
        Sampler sampler = new Sampler(usedIndizes);
        sampler.setLinearTransform(linearTransform);
        sampler.readCrossvalidation(cvFile);
        for (k = 0; k < Math.min(count, positiveSamples.size()); ++k) {
            Compound c = (Compound)((SampleWithScore)positiveSamples.get(k)).sample;
            this.addCompound(c.inchikey, ((Compound)c).inchi.in2D, (int)c.classification, sampler.sample(c.fingerprints, N));
        }
        for (k = 0; k < Math.min(count, negativeSamples.size()); ++k) {
            Compound c = (Compound)((SampleWithScore)negativeSamples.get(k)).sample;
            this.addCompound(c.inchikey, ((Compound)c).inchi.in2D, (int)c.classification, sampler.sample(c.fingerprints, N));
        }
        this.setCModifier(cmodifier);
    }

    public void setCModifier(double cmodifier) {
        this.cmodifierValue = cmodifier;
        if (cmodifier != 1.0) {
            ((LinearSvmWithWeights)this.svm).setWeightFunction(new Function<Sample, Double>(){

                public Double apply(Sample input) {
                    Compound c = (Compound)input;
                    if (c.fingerprints == null) {
                        return TrainCompoundClassesLinear.this.cmodifierValue;
                    }
                    return ONE;
                }
            });
        }
    }

    public TrainCompoundClassesLinear(int[] usedIndizes) {
        this.usedIndizes = (int[])usedIndizes.clone();
        Arrays.sort(usedIndizes);
        this.compounds = new ArrayList<Compound>();
        this.svm = new LinearSvmWithWeights();
        this.svm.disableDebugMode();
        this.setParameters();
        int F = usedIndizes[usedIndizes.length - 1] + 1;
        this.prototype = this.svm.newFeatureList(F);
        for (int i = 0; i < F; ++i) {
            this.prototype.add(i, 1.0);
        }
        this.usedIndizesMap = new boolean[F];
        for (int index : usedIndizes) {
            this.usedIndizesMap[index] = true;
        }
    }

    public boolean evalCompound(SvmModel model, String inchiKey, String inchi, int classification, short[] fingerprints) {
        Compound c = new Compound(new InChI(inchiKey, inchi), fingerprints, classification);
        this.makeFeatures(c);
        double classif = model.predict(c);
        return (int)classif == classification;
    }

    public boolean evalCompound(SvmModel model, String inchiKey, String inchi, int classification, double[] plattScores) {
        Compound c = new Compound(new InChI(inchiKey, inchi), null, classification);
        this.makeFeatures(c, TrainCompoundClassesLinear.filter(this.usedIndizes, plattScores));
        double classif = model.predict(c);
        return (int)classif == classification;
    }

    public void writeModel(File file, SvmModel model) throws IOException {
        this.svm.writeModel(file, model);
    }

    private void setParameters() {
    }

    public SvmModel train() {
        Crossvalidation<Compound> crossvalidation = new Crossvalidation<Compound>(this.compounds, 5, true, false);
        CSelection cs = new CSelection(this.service, Runtime.getRuntime().availableProcessors());
        cs.setPossibleCs(Math.pow(2.0, -15.0), Math.pow(2.0, -10.0), Math.pow(2.0, -9.0), Math.pow(2.0, -8.0), Math.pow(2.0, -7.0), Math.pow(2.0, -6.0), Math.pow(2.0, -5.0), Math.pow(2.0, -4.0), Math.pow(2.0, -3.0), Math.pow(2.0, -2.0), Math.pow(2.0, -1.0), 1.0, Math.pow(2.0, 1.0), Math.pow(2.0, 2.0), Math.pow(2.0, 3.0), Math.pow(2.0, 4.0), Math.pow(2.0, 5.0), Math.pow(2.0, 6.0));
        try {
            PredictionPerformance cperformance = new PredictionPerformance(0.0, 0.0, 0.0, 0.0);
            double bestC = cs.learnC(this.svm, crossvalidation, new OptimizationStrategy.ByFScore(), cperformance, 0, 1, 2, 3);
            this.svm.setParameter("c", bestC);
            System.out.println("Final Performance on c selection: " + cperformance.toString() + " with c is " + bestC);
            this.svm.setSamples(crossvalidation.getFoldsArray(0, 1, 2, 3));
            SvmModel model = this.svm.train();
            Sample[] eval = crossvalidation.getFoldsArray(4);
            double[] results = model.predict(eval);
            PredictionPerformance performance = Svm.evaluateClassificationPerformance(eval, results);
            System.out.println("Performance on evaluation set: " + performance);
            this.svm.setSamples(this.compounds.toArray(new Sample[this.compounds.size()]));
            SvmModel finalModel = this.svm.train();
            return finalModel;
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    public void addCompound(String inchiKey, String inchi, int classification, boolean[] fingerprint) {
        Compound c = new Compound(new InChI(inchiKey, inchi), TrainCompoundClasses.transformFingerprintToIntegerArray(this.usedIndizes, fingerprint), classification);
        this.compounds.add(c);
        this.makeFeatures(c);
    }

    public void setBias(double bias) {
        this.svm.setParameter("bias", bias);
    }

    public void setEpsilon(double epsilon) {
        this.svm.setParameter("epsilon", epsilon);
    }

    public void setDebugMode(boolean flag) {
        if (flag) {
            this.svm.enableDebugMode();
        } else {
            this.svm.disableDebugMode();
        }
    }

    public void addCompound(String inchiKey, String inchi, int classification, short[] fingerprints) {
        Compound c = new Compound(new InChI(inchiKey, inchi), TrainCompoundClassesLinear.filter(this.usedIndizes, this.usedIndizesMap, fingerprints), classification);
        this.compounds.add(c);
        this.makeFeatures(c);
    }

    public void addCompound(String inchiKey, String inchi, int classification, double[] probabilities) {
        Compound c = new Compound(new InChI(inchiKey, inchi), null, classification);
        probabilities = TrainCompoundClassesLinear.filter(this.usedIndizes, probabilities);
        this.makeFeatures(c, probabilities);
        this.compounds.add(c);
    }

    private void makeFeatures(Compound c, double[] probabilities) {
        int k;
        double[] additionalFeatures = TrainCompoundClasses.getAdditionalFingerprintsFor(c.formula);
        int F = this.prototype.size();
        FeatureList fl = this.svm.newFeatureList(F + additionalFeatures.length);
        for (k = 0; k < this.usedIndizes.length; ++k) {
            fl.add(this.usedIndizes[k], probabilities[k]);
        }
        for (k = 0; k < additionalFeatures.length; ++k) {
            fl.add(F + k, additionalFeatures[k]);
        }
        c.setFeatureList(fl);
    }

    private static short[] filter(int[] usedIndizes, boolean[] usedIndizesMap, short[] fingerprints) {
        short[] list = new short[Math.min(fingerprints.length, usedIndizes.length)];
        int j = 0;
        for (short fingerprint : fingerprints) {
            if (fingerprint >= usedIndizesMap.length || !usedIndizesMap[fingerprint]) continue;
            list[j++] = fingerprint;
        }
        if (j < list.length) {
            return Arrays.copyOf(list, j);
        }
        return list;
    }

    private static double[] filter(int[] usedIndizes, double[] fingerprints) {
        double[] platts = new double[usedIndizes.length];
        for (int k = 0; k < platts.length; ++k) {
            platts[k] = fingerprints[usedIndizes[k]];
        }
        return platts;
    }

    private void makeFeatures(Compound c) {
        double[] additionalFeatures = TrainCompoundClasses.getAdditionalFingerprintsFor(c.formula);
        int F = this.prototype.size();
        FeatureList fl = this.svm.newFeatureList(c.fingerprints.length + additionalFeatures.length);
        for (short index : c.fingerprints) {
            fl.addFeatureFrom(this.prototype, index);
        }
        for (int k = 0; k < additionalFeatures.length; ++k) {
            fl.add(F + k, additionalFeatures[k]);
        }
        c.setFeatureList(fl);
    }

    @Override
    public void close() throws IOException {
        this.service.shutdown();
    }

    public void removeDuplicateEntries() {
        HashMap<UniqueFingerprint, Integer> map = new HashMap<UniqueFingerprint, Integer>();
        ArrayList<Compound> compoundList = new ArrayList<Compound>();
        int posC = 0;
        int negC = 0;
        int missc = 0;
        for (Compound c : this.compounds) {
            UniqueFingerprint fc = new UniqueFingerprint(c);
            if (map.containsKey(fc)) {
                if (c.classification > 0) {
                    ++posC;
                } else {
                    ++negC;
                }
                if (c.classification == ((Integer)map.get(fc)).byteValue()) continue;
                ++missc;
                continue;
            }
            map.put(fc, Integer.valueOf(c.classification));
            compoundList.add(c);
        }
        this.compounds = compoundList;
        System.out.println("Remove " + posC + " positive and " + negC + " negative duplicate entries. " + missc + " entries could not be distinguished between positive and negative sets.");
    }

    public static class Sampler {
        private int[] usedIndizes;
        private TDoubleArrayList[] positives;
        private TDoubleArrayList[] negatives;
        private boolean dirty = true;
        private double transform = 0.0;

        public Sampler(int[] usedIndizes) {
            this.usedIndizes = usedIndizes;
            this.positives = new TDoubleArrayList[usedIndizes.length];
            this.negatives = new TDoubleArrayList[usedIndizes.length];
            for (int i = 0; i < usedIndizes.length; ++i) {
                this.positives[i] = new TDoubleArrayList();
                this.negatives[i] = new TDoubleArrayList();
            }
        }

        public void setLinearTransform(double minPlatt) {
            this.transform = minPlatt;
        }

        public double[] sample(short[] fingerprint, int length) {
            this.refresh();
            double[] platts = new double[length];
            int i = 0;
            for (short index : fingerprint) {
                while (this.usedIndizes[i] < index) {
                    platts[this.usedIndizes[i]] = this.sample(this.negatives[i]);
                    ++i;
                }
                if (this.usedIndizes[i] != index) continue;
                platts[this.usedIndizes[i]] = this.sample(this.positives[i]);
                ++i;
            }
            return platts;
        }

        public double[] sample(boolean[] fingerprint) {
            this.refresh();
            double[] platts = new double[fingerprint.length];
            for (int j = 0; j < this.usedIndizes.length; ++j) {
                int i = this.usedIndizes[j];
                platts[i] = fingerprint[i] ? this.sample(this.positives[j]) : this.sample(this.negatives[j]);
            }
            if (this.transform > 0.0) {
                TrainCompoundClassesLinear.transform(platts, this.transform);
            }
            return platts;
        }

        private double sample(TDoubleArrayList distribution) {
            int j;
            this.refresh();
            Random r = new Random();
            int k = 1;
            do {
                ++k;
            } while (!(r.nextDouble() < 0.6));
            double platt = 1.0;
            int n = j + k;
            for (j = (k = Math.min(k, distribution.size())) == distribution.size() ? 0 : r.nextInt(distribution.size() - k); j < n; ++j) {
                platt *= distribution.getQuick(j);
            }
            return Math.pow(platt, 1.0 / (double)k);
        }

        protected void refresh() {
            if (this.dirty) {
                for (TDoubleArrayList distr : this.positives) {
                    distr.sort();
                }
                for (TDoubleArrayList distr : this.negatives) {
                    distr.sort();
                }
                this.dirty = false;
            }
        }

        public void add(boolean[] real, double[] platts) {
            this.dirty = true;
            for (int i = 0; i < this.usedIndizes.length; ++i) {
                if (real[this.usedIndizes[i]]) {
                    this.positives[i].add(platts[this.usedIndizes[i]]);
                    continue;
                }
                this.negatives[i].add(platts[this.usedIndizes[i]]);
            }
        }

        public void readCrossvalidation(File crossvalidationFile) throws IOException {
            this.dirty = true;
            try (BufferedReader br = Files.newBufferedReader(crossvalidationFile.toPath(), Charset.forName("UTF-8"));){
                String line;
                while ((line = br.readLine()) != null) {
                    String[] columns = line.split("\t");
                    String fingerprint = columns[3];
                    for (int k = 4; k < columns.length; ++k) {
                        int index = k - 4;
                        boolean real = fingerprint.charAt(index) == '1';
                        double prediction = Double.parseDouble(columns[k]);
                        if (real) {
                            this.positives[index].add(prediction);
                            continue;
                        }
                        this.negatives[index].add(prediction);
                    }
                }
            }
        }
    }

    private static final class Compound
    implements Sample {
        private FeatureList featureList;
        private final short[] fingerprints;
        private final String inchikey;
        private byte classification;
        private final MolecularFormula formula;
        private InChI inchi;
        private byte fold;

        public Compound(InChI inchi, short[] fingerprint, int classification) {
            this.inchi = inchi;
            this.inchikey = inchi.key2D();
            this.formula = inchi.extractFormula();
            this.fingerprints = fingerprint;
            this.classification = (byte)classification;
        }

        @Override
        public FeatureList getFeatureList() {
            return this.featureList;
        }

        @Override
        public void setFeatureList(FeatureList list) {
            this.featureList = list;
        }

        @Override
        public double getLabel() {
            return this.classification;
        }

        @Override
        public void setLabel(double label) {
            this.classification = (byte)label;
        }

        @Override
        public int getBatchNum() {
            return this.fold;
        }

        @Override
        public void setBatchNum(int num) {
            this.fold = (byte)num;
        }

        @Override
        public String getGroup() {
            return this.inchikey;
        }
    }

    private static final class UniqueFingerprint {
        private final Compound compound;
        private final int hashCode;

        public UniqueFingerprint(Compound compound) {
            this.compound = compound;
            this.hashCode = Arrays.hashCode(compound.fingerprints);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            UniqueFingerprint that = (UniqueFingerprint)o;
            if (this.hashCode != that.hashCode) {
                return false;
            }
            return Arrays.equals(this.compound.fingerprints, that.compound.fingerprints);
        }

        public int hashCode() {
            return this.hashCode;
        }
    }

    private static class SampleWithScore
    implements Comparable<SampleWithScore> {
        private Sample sample;
        private double score;

        public SampleWithScore(Sample sample, double score) {
            this.sample = sample;
            this.score = score;
        }

        @Override
        public int compareTo(SampleWithScore o) {
            return Double.compare(this.score, o.score);
        }
    }
}

