/*
 * Decompiled with CFR 0.152.
 */
package de.unijena.bioinf.fingerid;

import com.google.common.collect.Iterables;
import de.unijena.bioinf.ChemistryBase.fp.FingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.ProbabilityFingerprint;
import de.unijena.bioinf.ChemistryBase.ms.ft.FTree;
import de.unijena.bioinf.ChemistryBase.ms.utils.SimpleSpectrum;
import de.unijena.bioinf.fingerid.HighorderKernel;
import de.unijena.bioinf.fingerid.Kernel;
import de.unijena.bioinf.fingerid.KernelMatrix;
import de.unijena.bioinf.fingerid.Kernels;
import de.unijena.bioinf.fingerid.MatrixUtils;
import de.unijena.bioinf.fingerid.MsKernel;
import de.unijena.bioinf.fingerid.NormalizationType;
import de.unijena.bioinf.fingerid.PolynomialKernel;
import de.unijena.bioinf.fingerid.TrainedCSIFingerId;
import de.unijena.bioinf.fingerid.TreeKernel;
import de.unijena.bioinf.jjobs.BasicJJob;
import de.unijena.bioinf.jjobs.JJob;
import gnu.trove.list.array.TIntArrayList;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class Prediction {
    private final TrainedCSIFingerId fingerid;
    private final Kernels kernels;
    private final Kernels.PreprocessedData preprocessedData;
    private String[] kernelNameRefs;
    private int[] highOrderKernelindizes;
    private Kernel[] kernelTypes;

    public Prediction(TrainedCSIFingerId fingerid) {
        this.fingerid = fingerid;
        this.kernels = new Kernels(Runtime.getRuntime().availableProcessors());
        String[] kernelNames = new String[fingerid.getKernels().length];
        for (int k = 0; k < kernelNames.length; ++k) {
            kernelNames[k] = fingerid.getKernels()[k].getKernelName();
        }
        this.kernelNameRefs = kernelNames;
        this.kernelTypes = Kernels.getKernelsByNames((String[])kernelNames);
        this.preprocessedData = this.kernels.preprocessTrainKernels(fingerid.trainingSpectra, fingerid.precursors, fingerid.trainingTrees, Arrays.asList(this.kernelTypes));
        HashMap<String, Integer> amap = new HashMap<String, Integer>();
        for (int k = 0; k < fingerid.getKernels().length; ++k) {
            amap.put(fingerid.getKernels()[k].getKernelName(), k);
        }
        TIntArrayList highOrderKernelIndizes = new TIntArrayList();
        for (int k = 0; k < this.kernelTypes.length; ++k) {
            if (!(this.kernelTypes[k] instanceof HighorderKernel)) continue;
            this.kernelNameRefs[k] = ((HighorderKernel)this.kernelTypes[k]).underlyingKernel().getName();
            highOrderKernelIndizes.add(((Integer)amap.get(this.kernelNameRefs[k])).intValue());
        }
        this.highOrderKernelindizes = highOrderKernelIndizes.toArray();
    }

    public Prediction(TrainedCSIFingerId fingerid, File kernelCache) throws IOException {
        int k;
        this.fingerid = fingerid;
        this.kernels = new Kernels(Runtime.getRuntime().availableProcessors());
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(kernelCache));){
            this.kernelNameRefs = new String[fingerid.getKernels().length];
            for (k = 0; k < this.kernelNameRefs.length; ++k) {
                this.kernelNameRefs[k] = fingerid.getKernels()[k].getKernelName();
            }
            this.kernelTypes = Kernels.getKernelsByNames((String[])this.kernelNameRefs);
            this.preprocessedData = this.kernels.loadFromFileAndRecomputeTheRest((InputStream)bin, fingerid.trainingSpectra, fingerid.precursors, fingerid.trainingTrees, Arrays.asList(this.kernelTypes));
        }
        HashMap<String, Integer> amap = new HashMap<String, Integer>();
        for (int k2 = 0; k2 < fingerid.getKernels().length; ++k2) {
            amap.put(fingerid.getKernels()[k2].getKernelName(), k2);
        }
        TIntArrayList highOrderKernelIndizes = new TIntArrayList();
        for (k = 0; k < this.kernelTypes.length; ++k) {
            if (!(this.kernelTypes[k] instanceof HighorderKernel)) continue;
            this.kernelNameRefs[k] = ((HighorderKernel)this.kernelTypes[k]).underlyingKernel().getName();
            highOrderKernelIndizes.add(((Integer)amap.get(this.kernelNameRefs[k])).intValue());
        }
        this.highOrderKernelindizes = highOrderKernelIndizes.toArray();
    }

    public TrainedCSIFingerId getFingerid() {
        return this.fingerid;
    }

    public Kernels getKernelComputation() {
        return this.kernels;
    }

    public void shutdown() {
        this.kernels.shutdown();
    }

    public static Prediction loadFromFiles(File fingerid, File kernelCache) throws IOException {
        return Prediction.loadFromFiles(fingerid, kernelCache, false);
    }

    public static Prediction loadFromFiles(File fingerid, File kernelCache, boolean loadIOKR) throws IOException {
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fingerid));){
            TrainedCSIFingerId fingeridInstance = TrainedCSIFingerId.load(bin, loadIOKR);
            Prediction prediction = new Prediction(fingeridInstance, kernelCache);
            return prediction;
        }
    }

    public static BasicJJob<Prediction> asyncLoadFromFile(final File fingerid) {
        return new BasicJJob<Prediction>(JJob.JobType.IO){

            protected Prediction compute() throws Exception {
                return Prediction.loadFromFile(fingerid);
            }
        };
    }

    public static Prediction loadFromFile(File fingerid) throws IOException {
        return Prediction.loadFromFile(fingerid, false);
    }

    public static Prediction loadFromFile(File fingerid, boolean loadIOKR) throws IOException {
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fingerid));){
            TrainedCSIFingerId fingeridInstance = TrainedCSIFingerId.load(bin, loadIOKR);
            Prediction prediction = new Prediction(fingeridInstance);
            return prediction;
        }
    }

    public ProbabilityFingerprint predictProbabilityFingerprint(SimpleSpectrum spectrum, FTree tree, double precursor) {
        double[] platts = this.predictPlatts(spectrum, tree, precursor);
        ProbabilityFingerprint fp = new ProbabilityFingerprint((FingerprintVersion)this.fingerid.maskedFingerprintVersion, platts);
        return fp;
    }

    public double[] predictDecisionValues(SimpleSpectrum spectrum, FTree tree, double precursor) {
        double[] mkl = this.computeMKL(spectrum, tree, precursor);
        double[] decisionValues = new double[this.fingerid.numberOfFingerprints()];
        for (int k = 0; k < this.fingerid.numberOfFingerprints(); ++k) {
            decisionValues[k] = this.fingerid.getPredictors()[k].predictValue(mkl);
        }
        return decisionValues;
    }

    public double[] predictPlattsFromCenteredNormalizedKernels(double[][] kernels) {
        double[] mkl = this.computeMKL(kernels);
        double[] platts = new double[this.fingerid.numberOfFingerprints()];
        for (int k = 0; k < this.fingerid.numberOfFingerprints(); ++k) {
            platts[k] = this.fingerid.getPredictors()[k].estimateProbability(mkl);
        }
        return platts;
    }

    public double[] predictPlatts(SimpleSpectrum spectrum, FTree tree, double precursor) {
        double[] mkl = this.computeMKL(spectrum, tree, precursor);
        double[] platts = new double[this.fingerid.numberOfFingerprints()];
        for (int k = 0; k < this.fingerid.numberOfFingerprints(); ++k) {
            platts[k] = this.fingerid.getPredictors()[k].estimateProbability(mkl);
        }
        return platts;
    }

    public double[][] computeKernelValues(SimpleSpectrum spectrum, FTree tree, double precursor) {
        Map kernelValues = this.kernels.computeKernelFor(this.preprocessedData, tree, spectrum, precursor);
        double[][] matrix = new double[this.fingerid.kernels.length][];
        for (int k = 0; k < this.kernelNameRefs.length; ++k) {
            matrix[k] = (double[])kernelValues.get(this.kernelNameRefs[k]);
        }
        return matrix;
    }

    public double[] computeKernelNorms(SimpleSpectrum spectrum, FTree tree, double precursor) {
        Map normalizations = this.kernels.computeNorms(this.preprocessedData, tree, spectrum, precursor);
        double[] ary = new double[this.fingerid.kernels.length];
        for (int j = 0; j < ary.length; ++j) {
            ary[j] = (Double)normalizations.get(this.kernelNameRefs[j]);
        }
        return ary;
    }

    public double[][] computeCenteredNormalizedKernelValues(SimpleSpectrum spectrum, FTree tree, double precursor) {
        double[][] matrix = this.computeKernelValues(spectrum, tree, precursor);
        Map normalizations = this.kernels.computeNorms(this.preprocessedData, tree, spectrum, precursor);
        int k = 0;
        int H = 0;
        for (KernelMatrix M : this.fingerid.kernels) {
            double norm = (Double)normalizations.get(M.kernelName);
            if (this.kernelTypes[k] instanceof HighorderKernel) {
                HighorderKernel hk = (HighorderKernel)this.kernelTypes[k];
                double[] row = matrix[k];
                double[] target = new double[row.length];
                hk.computeRow(this.fingerid.kernels[this.highOrderKernelindizes[H]].getNormalizations(), norm, row, 0, row.length, target);
                matrix[k] = target;
                norm = hk.computeNorm(norm);
                ++H;
            }
            if (NormalizationType.ENABLED_NORMALIZATION_TYPE == NormalizationType.CENTER_NORMALIZE_CENTER) {
                throw new RuntimeException("center -> normalize -> center not implemented yet.");
            }
            if (NormalizationType.ENABLED_NORMALIZATION_TYPE == NormalizationType.CENTER_NORMALIZE) {
                M.kernelCentering.applyToKernelRow(matrix[k], norm);
            } else if (NormalizationType.ENABLED_NORMALIZATION_TYPE == NormalizationType.NORMALIZE_CENTER_NORMALIZE) {
                MatrixUtils.normalizeTest((double[])matrix[k], (double)norm, (double[])M.normalizations);
                M.kernelCentering.applyToKernelRow(matrix[k], 1.0);
            } else if (NormalizationType.ENABLED_NORMALIZATION_TYPE == NormalizationType.NORMALIZE_CENTER) {
                MatrixUtils.normalizeTest((double[])matrix[k], (double)norm, (double[])M.normalizations);
                M.kernelCentering.withoutNormalizing().applyToKernelRow(matrix[k], 1.0);
            } else {
                throw new RuntimeException("Unknown normalization type");
            }
            ++k;
        }
        return matrix;
    }

    public double[] computeMKL(SimpleSpectrum spectrum, FTree tree, double precursor) {
        double[][] values = this.computeCenteredNormalizedKernelValues(spectrum, tree, precursor);
        return this.computeMKL(values);
    }

    public double[] computeMKL(double[][] values) {
        double[] mkl = new double[this.fingerid.numberOfTrainingData()];
        double weight = 0.0;
        for (int i = 0; i < values.length; ++i) {
            KernelMatrix km = this.fingerid.getKernels()[i];
            double[] kernelRow = values[i];
            for (int j = 0; j < kernelRow.length; ++j) {
                int n = j;
                mkl[n] = mkl[n] + kernelRow[j] * km.getWeight();
            }
            weight += km.getWeight();
        }
        return mkl;
    }

    public Kernel[] getKernelMethods() {
        String[] names = new String[this.fingerid.kernels.length];
        for (int k = 0; k < this.fingerid.kernels.length; ++k) {
            names[k] = this.fingerid.kernels[k].getKernelName();
        }
        return Kernels.getKernelsByNames((String[])names);
    }

    public TreeKernel[] getTreeKernelMethods() {
        return (TreeKernel[])Iterables.toArray((Iterable)Iterables.filter(Arrays.asList(this.getKernelMethods()), TreeKernel.class), TreeKernel.class);
    }

    private PolynomialKernel[] getPolynomialKernels() {
        return (PolynomialKernel[])Iterables.toArray((Iterable)Iterables.filter(Arrays.asList(this.getKernelMethods()), PolynomialKernel.class), PolynomialKernel.class);
    }

    public MsKernel[] getMsKernelMethods() {
        return (MsKernel[])Iterables.toArray((Iterable)Iterables.filter(Arrays.asList(this.getKernelMethods()), MsKernel.class), MsKernel.class);
    }

    public Kernels.PreprocessedData getPreprocessedData() {
        return this.preprocessedData;
    }
}

