/*
 * Decompiled with CFR 0.152.
 */
package de.unijena.bioinf.fingerid;

import de.unijena.bioinf.ChemistryBase.fp.AbstractFingerprint;
import de.unijena.bioinf.ChemistryBase.fp.ArrayFingerprint;
import de.unijena.bioinf.ChemistryBase.fp.CdkFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.FingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.MaskedFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.ChemistryBase.ms.ft.FTree;
import de.unijena.bioinf.ChemistryBase.ms.utils.SimpleSpectrum;
import de.unijena.bioinf.babelms.binary.SpectrumBinaryReader;
import de.unijena.bioinf.babelms.binary.SpectrumBinaryWriter;
import de.unijena.bioinf.babelms.json.FTJsonReader;
import de.unijena.bioinf.babelms.json.FTJsonWriter;
import de.unijena.bioinf.fingerid.KernelCentering;
import de.unijena.bioinf.fingerid.KernelMatrix;
import de.unijena.bioinf.fingerid.NormalizationType;
import de.unijena.bioinf.fingerid.Predictor;
import de.unijena.bioinf.fingerid.utils.FingerIDProperties;
import de.unijena.bioinf.iokr.IOKRModel;
import de.unijena.bioinf.iokr.IOKRPredict;
import de.unijena.bioinf.iokr.IOKRScore;
import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Serializable;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

public class TrainedCSIFingerId
implements Serializable {
    public static final short VERSION_ID = 6;
    protected String version;
    protected String[] names;
    protected String[] inchis;
    protected FTree[] trainingTrees;
    protected SimpleSpectrum[] trainingSpectra;
    protected int[] fingerprintIndizes;
    protected ArrayFingerprint[] trainingFingerprints;
    protected KernelMatrix[] kernels;
    protected double[] kernelNormalizationVector;
    protected Predictor[] predictors;
    protected double[] precursors;
    protected int fingerprintVersion = 2;
    protected MaskedFingerprintVersion maskedFingerprintVersion;
    protected IOKRModel iokr;
    private static final String CHECK = "end of tree block.";
    private static final byte MAGIC_NUMBER_CHECK_ALIGNMENT = -82;
    protected static final byte IOKR_IS_ENABLED = 1;

    public TrainedCSIFingerId(MaskedFingerprintVersion version, int numberOftrainData, int numberOfKernels) {
        this.names = new String[numberOftrainData];
        this.inchis = new String[numberOftrainData];
        this.trainingSpectra = new SimpleSpectrum[numberOftrainData];
        this.trainingTrees = new FTree[numberOftrainData];
        this.kernelNormalizationVector = new double[numberOftrainData];
        this.version = FingerIDProperties.fingeridVersion();
        this.maskedFingerprintVersion = version;
        this.fingerprintIndizes = version.allowedIndizes();
        int numberOfFingerprints = this.fingerprintIndizes.length;
        this.predictors = new Predictor[numberOfFingerprints];
        this.trainingFingerprints = new ArrayFingerprint[numberOftrainData];
        this.kernels = new KernelMatrix[numberOfKernels];
        this.precursors = new double[numberOftrainData];
    }

    private TrainedCSIFingerId() {
    }

    public Predictor getPredictorByRealIndex(int realIndex) {
        int k = Arrays.binarySearch(this.fingerprintIndizes, realIndex);
        if (k < 0) {
            return null;
        }
        return this.predictors[k];
    }

    public PredictionPerformance[] getPredictionPerformances() {
        PredictionPerformance[] perf = new PredictionPerformance[this.numberOfFingerprints()];
        for (int k = 0; k < perf.length; ++k) {
            perf[k] = this.predictors[k].getPerformance();
        }
        return perf;
    }

    public IOKRModel getIOKRModel() {
        return this.iokr;
    }

    public double[] getPrecursorMz() {
        return this.precursors;
    }

    public int numberOfTrainingData() {
        return this.trainingTrees.length;
    }

    public int numberOfFingerprints() {
        return this.fingerprintIndizes.length;
    }

    public String getVersion() {
        return this.version;
    }

    public int numberOfKernels() {
        return this.kernels.length;
    }

    public MaskedFingerprintVersion getMaskedFingerprintVersion() {
        return this.maskedFingerprintVersion;
    }

    public int getFingerprintVersion() {
        return this.fingerprintVersion;
    }

    public void dump(OutputStream outputStream) throws IOException {
        int i;
        DataOutputStream out = new DataOutputStream(outputStream);
        out.writeShort(6);
        out.writeUTF(this.version);
        out.writeInt(this.trainingTrees.length);
        out.writeInt(this.fingerprintIndizes.length);
        out.writeInt(this.kernels.length);
        this.writeStringArray(out, this.names);
        this.writeStringArray(out, this.inchis);
        this.writeTrees(out, this.trainingTrees);
        this.writeDoubleArray(out, this.precursors);
        SpectrumBinaryWriter.writeSpectra((DataOutputStream)out, (SimpleSpectrum[])this.trainingSpectra);
        this.writeIntArray(out, this.fingerprintIndizes);
        for (i = 0; i < this.trainingFingerprints.length; ++i) {
            short[] indizes = this.trainingFingerprints[i].toIndizesArray();
            out.writeInt(indizes.length);
            this.writeShortArray(out, indizes);
        }
        out.writeUTF(NormalizationType.ENABLED_NORMALIZATION_TYPE.name());
        for (i = 0; i < this.kernels.length; ++i) {
            out.writeUTF(this.kernels[i].kernelName);
            out.writeDouble(this.kernels[i].weight);
            if (NormalizationType.ENABLED_NORMALIZATION_TYPE.isNormalizingBeforeCentering()) {
                this.writeDoubleArray(out, this.kernels[i].getNormalizations());
            }
            out.writeDouble(this.kernels[i].kernelCentering.average);
            this.writeDoubleArray(out, this.kernels[i].kernelCentering.averages);
            if (!NormalizationType.ENABLED_NORMALIZATION_TYPE.isNormalizingAfterCentering()) continue;
            this.writeDoubleArray(out, this.kernels[i].kernelCentering.diagonal);
        }
        this.writeDoubleArray(out, this.kernelNormalizationVector);
        for (i = 0; i < this.predictors.length; ++i) {
            Predictor p = this.predictors[i];
            out.writeInt(p.realIndex);
            out.writeDouble(p.rho);
            out.writeDouble(p.probA);
            out.writeDouble(p.probB);
            out.writeInt(p.coefficients.length);
            this.writeDoubleArray(out, p.coefficients);
            this.writeIntArray(out, p.supportVectors);
            out.writeDouble(p.tp);
            out.writeDouble(p.fp);
            out.writeDouble(p.tn);
            out.writeDouble(p.fn);
        }
        out.writeInt(this.fingerprintVersion);
        out.writeLong(((CdkFingerprintVersion)this.maskedFingerprintVersion.getMaskedFingerprintVersion()).getBitsetIdentifier());
        this.writeIOKR(out);
    }

    private void writeIOKR(DataOutputStream out) throws IOException {
        if (this.iokr == null) {
            out.writeInt(0);
            return;
        }
        out.writeInt(this.iokr.getBinaryCompatibilityVersion());
        this.writeMagicNumber(out);
        int[] indizes = this.iokr.getFingerprintVersion().allowedIndizes();
        out.writeInt(indizes.length);
        this.writeIntArray(out, indizes);
        this.writeDoubleArray(out, this.iokr.getPredictor().getModel());
        this.writeDoubleArray(out, this.iokr.getPredictor().getMklWeights());
        out.writeDouble(this.iokr.getScorer().getGamma());
        this.writeDoubleArray(out, this.iokr.getScorer().getRowMean());
        this.writeMagicNumber(out);
    }

    private IOKRModel readIOKR(DataInputStream in) throws IOException {
        int iokrVersion = in.readInt();
        System.out.println("IOKR VERSION: " + iokrVersion);
        if (iokrVersion == 0) {
            return null;
        }
        if (iokrVersion == 1) {
            return this.parseStandardIOKRVersion(in);
        }
        throw new IOException("Unknown IOKR version " + iokrVersion);
    }

    private IOKRModel parseStandardIOKRVersion(DataInputStream in) throws IOException {
        System.out.println("parse IOKR");
        this.checkMagicNumber(in);
        int nf = in.readInt();
        int[] indizes = new int[nf];
        this.readIntArray(in, indizes);
        MaskedFingerprintVersion.Builder mf = MaskedFingerprintVersion.buildMaskFor((FingerprintVersion)this.getMaskedFingerprintVersion().getMaskedFingerprintVersion());
        mf.disableAll();
        for (int index : indizes) {
            mf.enable(index);
        }
        MaskedFingerprintVersion mfv = mf.toMask();
        int N = this.numberOfTrainingData();
        int triangularSize = N * (N + 1) / 2;
        double[] triangularMatrix = new double[triangularSize];
        this.readDoubleArray(in, triangularMatrix);
        double[] weights = new double[this.numberOfKernels()];
        this.readDoubleArray(in, weights);
        double gamma = in.readDouble();
        double[] rowMean = new double[N];
        this.readDoubleArray(in, rowMean);
        this.checkMagicNumber(in);
        ArrayFingerprint[] fps = new ArrayFingerprint[this.numberOfFingerprints()];
        int m = this.numberOfFingerprints();
        for (int i = 0; i < m; ++i) {
            fps[i] = (ArrayFingerprint)mfv.mask((AbstractFingerprint)this.trainingFingerprints[i]);
        }
        System.err.println("IOKR done");
        return new IOKRModel(mfv, new IOKRPredict(triangularMatrix, weights), new IOKRScore(fps, rowMean, gamma));
    }

    private void checkMagicNumber(DataInputStream in) throws IOException {
        if (in.readByte() != -82) {
            throw new IOException("Missalignment happened! Critical error in data serialization of TrainedCSIFingerID!");
        }
    }

    private void writeMagicNumber(DataOutputStream out) throws IOException {
        out.writeByte(-82);
    }

    private void writeTrees(DataOutputStream out, FTree[] trainingTrees) throws IOException {
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream(1024 * trainingTrees.length);
        int[] offsets = new int[trainingTrees.length];
        FTJsonWriter jsonWriter = new FTJsonWriter();
        GZIPOutputStream zippedStream = new GZIPOutputStream(outputStream);
        OutputStreamWriter outw = new OutputStreamWriter((OutputStream)zippedStream, Charset.forName("UTF-8"));
        for (int k = 0; k < trainingTrees.length; ++k) {
            String jsonString = jsonWriter.treeToJsonString(trainingTrees[k]);
            outw.write(jsonString);
            offsets[k] = jsonString.length();
        }
        outw.write(CHECK);
        outw.close();
        byte[] bytes = outputStream.toByteArray();
        out.writeInt(trainingTrees.length);
        for (int index : offsets) {
            out.writeInt(index);
        }
        out.writeInt(bytes.length);
        out.writeByte(42);
        out.write(bytes);
    }

    private FTree[] readTrees(DataInputStream in) throws IOException {
        int numberOfTrees = in.readInt();
        FTree[] trees = new FTree[numberOfTrees];
        int[] lengths = new int[numberOfTrees];
        for (int i = 0; i < lengths.length; ++i) {
            lengths[i] = in.readInt();
        }
        int byteSize = in.readInt();
        byte controlByte = in.readByte();
        if (controlByte != 42) {
            throw new IOException("Alignment problem. File seems to be corrupted");
        }
        byte[] bytes = new byte[byteSize];
        int l = 0;
        while ((l += in.read(bytes, l, bytes.length - l)) < bytes.length) {
        }
        ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes);
        GZIPInputStream zippedInStream = new GZIPInputStream(inputStream);
        InputStreamReader reader = new InputStreamReader((InputStream)zippedInStream, Charset.forName("UTF-8"));
        int maxSize = 0;
        for (int length : lengths) {
            maxSize = Math.max(maxSize, length);
        }
        char[] buffer = new char[maxSize];
        FTJsonReader jsonReader = new FTJsonReader();
        for (int k = 0; k < lengths.length; ++k) {
            reader.read(buffer, 0, lengths[k]);
            String s = new String(buffer, 0, lengths[k]);
            try {
                trees[k] = jsonReader.treeFromJsonString(s, null);
                continue;
            }
            catch (RuntimeException e) {
                e.printStackTrace();
                System.err.println(s);
            }
        }
        reader.read(buffer, 0, CHECK.length());
        if (!new String(buffer, 0, CHECK.length()).equals(CHECK)) {
            throw new IOException("Alignment problem. File seems to be corrupted");
        }
        reader.close();
        return trees;
    }

    private void writeString(DataOutputStream out, String s) throws IOException {
        byte[] ba = s.getBytes(Charset.forName("UTF-8"));
        out.writeInt(ba.length);
        out.write(ba);
    }

    private String readString(DataInputStream in) throws IOException {
        int length = in.readInt();
        byte[] bytes = new byte[length];
        int l = 0;
        while ((l += in.read(bytes, l, bytes.length - l)) < bytes.length) {
        }
        return new String(bytes, Charset.forName("UTF-8"));
    }

    private void writeStringArray(DataOutputStream out, String[] names) throws IOException {
        out.writeInt(names.length);
        for (String name : names) {
            out.writeUTF(name);
        }
    }

    private void readStringArray(DataInputStream in, String[] names) throws IOException {
        int size = in.readInt();
        if (names.length != size) {
            throw new IndexOutOfBoundsException("Wrong ary length: " + size + " expected but " + names.length + " given");
        }
        for (int k = 0; k < names.length; ++k) {
            names[k] = in.readUTF();
        }
    }

    public static TrainedCSIFingerId load(File file) throws IOException {
        return TrainedCSIFingerId.load(file, false);
    }

    public static TrainedCSIFingerId load(File file, boolean loadIOKR) throws IOException {
        try (BufferedInputStream stream = new BufferedInputStream(new FileInputStream(file));){
            TrainedCSIFingerId instance = new TrainedCSIFingerId();
            instance.loadFromStream(stream, loadIOKR);
            TrainedCSIFingerId trainedCSIFingerId = instance;
            return trainedCSIFingerId;
        }
    }

    public static TrainedCSIFingerId load(InputStream in) throws IOException {
        return TrainedCSIFingerId.load(in, false);
    }

    public static TrainedCSIFingerId load(InputStream in, boolean loadIOKR) throws IOException {
        TrainedCSIFingerId instance = new TrainedCSIFingerId();
        instance.loadFromStream(in, loadIOKR);
        return instance;
    }

    private void loadFromStream(InputStream inputStream, boolean loadIOKR) throws IOException {
        int i;
        DataInputStream in = new DataInputStream(inputStream);
        short version_id = in.readShort();
        if (version_id != 6) {
            throw new IOException("Incompatible version ids. Binary format is in v" + version_id + " but script expects version v" + 6);
        }
        this.version = in.readUTF();
        int M = in.readInt();
        this.names = new String[M];
        this.inchis = new String[M];
        this.trainingSpectra = new SimpleSpectrum[M];
        this.trainingTrees = new FTree[M];
        this.precursors = new double[M];
        this.kernelNormalizationVector = new double[M];
        int N = in.readInt();
        this.fingerprintIndizes = new int[N];
        this.predictors = new Predictor[N];
        this.trainingFingerprints = new ArrayFingerprint[M];
        int K = in.readInt();
        this.kernels = new KernelMatrix[K];
        this.readStringArray(in, this.names);
        this.readStringArray(in, this.inchis);
        this.trainingTrees = this.readTrees(in);
        this.readDoubleArray(in, this.precursors);
        this.trainingSpectra = SpectrumBinaryReader.readSpectra((DataInputStream)in);
        this.readIntArray(in, this.fingerprintIndizes);
        short[][] indizes = new short[M][];
        for (int i2 = 0; i2 < this.trainingFingerprints.length; ++i2) {
            int length = in.readInt();
            indizes[i2] = new short[length];
            this.readShortArray(in, indizes[i2]);
        }
        String normType = in.readUTF();
        System.out.println("READ '" + normType + "'");
        NormalizationType normalizationType = NormalizationType.valueOf((String)normType);
        if (normalizationType != NormalizationType.ENABLED_NORMALIZATION_TYPE) {
            throw new RuntimeException("Current version of CSI:FingerID is not compatible with fingerid.data: Normalization Type differs! Current: " + NormalizationType.ENABLED_NORMALIZATION_TYPE.toString() + ", in fingerid.data: " + normalizationType.toString());
        }
        for (i = 0; i < this.kernels.length; ++i) {
            double[] diagonal;
            String name = in.readUTF();
            double weight = in.readDouble();
            double[] normalizations = new double[M];
            if (normalizationType.isNormalizingBeforeCentering()) {
                this.readDoubleArray(in, normalizations);
            }
            double average = in.readDouble();
            double[] averages = new double[M];
            this.readDoubleArray(in, averages);
            if (normalizationType.isNormalizingAfterCentering()) {
                diagonal = new double[M];
                this.readDoubleArray(in, diagonal);
            } else {
                diagonal = null;
            }
            this.kernels[i] = new KernelMatrix(name, new KernelCentering(average, averages, diagonal), normalizations, weight);
        }
        this.readDoubleArray(in, this.kernelNormalizationVector);
        for (i = 0; i < this.predictors.length; ++i) {
            int index = in.readInt();
            double rho = in.readDouble();
            double probA = in.readDouble();
            double probB = in.readDouble();
            int L = in.readInt();
            double[] coeffs = new double[L];
            int[] vectors = new int[L];
            this.readDoubleArray(in, coeffs);
            this.readIntArray(in, vectors);
            double tp = in.readDouble();
            double fp = in.readDouble();
            double tn = in.readDouble();
            double fn = in.readDouble();
            this.predictors[i] = new Predictor(index, rho, probA, probB, coeffs, vectors);
            this.predictors[i].setStatistics(tp, fp, tn, fn);
        }
        this.fingerprintVersion = in.readInt();
        if (this.fingerprintVersion == 1) {
            CdkFingerprintVersion fp = CdkFingerprintVersion.getDefault();
            MaskedFingerprintVersion.Builder b = MaskedFingerprintVersion.buildMaskFor((FingerprintVersion)fp).disableAll();
            for (int index : this.fingerprintIndizes) {
                b.enable(index);
            }
            this.maskedFingerprintVersion = b.toMask();
        } else if (this.fingerprintVersion == 2) {
            long identifier = in.readLong();
            CdkFingerprintVersion version = CdkFingerprintVersion.getFromBitsetIdentifier((long)identifier);
            MaskedFingerprintVersion.Builder b = MaskedFingerprintVersion.buildMaskFor((FingerprintVersion)version).disableAll();
            for (int index : this.fingerprintIndizes) {
                b.enable(index);
            }
            this.maskedFingerprintVersion = b.toMask();
        } else {
            throw new RuntimeException("Unknown fingerprint version: " + this.fingerprintVersion);
        }
        for (int i3 = 0; i3 < M; ++i3) {
            this.trainingFingerprints[i3] = new ArrayFingerprint(this.maskedFingerprintVersion.getMaskedFingerprintVersion(), indizes[i3]);
            indizes[i3] = null;
        }
        if (loadIOKR) {
            this.iokr = this.readIOKR(in);
        }
    }

    private void readShortArray(DataInputStream in, short[] ary) throws IOException {
        for (int i = 0; i < ary.length; ++i) {
            ary[i] = in.readShort();
        }
    }

    private void writeShortArray(DataOutputStream out, short[] ary) throws IOException {
        for (int i = 0; i < ary.length; ++i) {
            out.writeShort(ary[i]);
        }
    }

    private void readDoubleArray(DataInputStream in, double[] ary) throws IOException {
        int size = in.readInt();
        if (ary.length != size) {
            throw new IndexOutOfBoundsException("Wrong ary length: " + size + " expected but " + ary.length + " given");
        }
        for (int i = 0; i < ary.length; ++i) {
            ary[i] = in.readDouble();
        }
    }

    private void readBooleanArray(DataInputStream in, boolean[] ary) throws IOException {
        int k;
        byte vector;
        int size = in.readInt();
        if (ary.length != size) {
            throw new IndexOutOfBoundsException("Wrong ary length: " + size + " expected but " + ary.length + " given");
        }
        int n = ary.length / 8;
        byte[] MASKS = new byte[8];
        for (int i = 0; i < 8; ++i) {
            MASKS[i] = (byte)(1 << i);
        }
        int c = 0;
        for (int j = 0; j < n; ++j) {
            vector = in.readByte();
            for (k = 0; k < 8; ++k) {
                ary[c++] = (MASKS[k] & vector) == MASKS[k];
            }
        }
        int rest = ary.length % 8;
        vector = in.readByte();
        for (k = 0; k < rest; ++k) {
            ary[c++] = (MASKS[k] & vector) == MASKS[k];
        }
    }

    private void readIntArray(DataInputStream in, int[] ary) throws IOException {
        int size = in.readInt();
        if (ary.length != size) {
            throw new IndexOutOfBoundsException("Wrong ary length: " + size + " expected but " + ary.length + " given");
        }
        for (int i = 0; i < ary.length; ++i) {
            ary[i] = in.readInt();
        }
    }

    private void writeObject(ObjectOutputStream outputStream) throws IOException {
        this.dump(outputStream);
    }

    private void readObject(ObjectInputStream in, boolean loadIOKR) throws IOException, ClassNotFoundException {
        this.loadFromStream(in, loadIOKR);
    }

    private void readObjectNoData() throws ObjectStreamException {
    }

    private void writeIntArray(DataOutputStream out, int[] ary) throws IOException {
        out.writeInt(ary.length);
        for (int i = 0; i < ary.length; ++i) {
            out.writeInt(ary[i]);
        }
    }

    private void writeDoubleArray(DataOutputStream out, double[] ary) throws IOException {
        out.writeInt(ary.length);
        for (int i = 0; i < ary.length; ++i) {
            out.writeDouble(ary[i]);
        }
    }

    private void writeBooleanArray(DataOutputStream out, boolean[] ary) throws IOException {
        out.writeInt(ary.length);
        int n = ary.length / 8;
        int c = 0;
        for (int j = 0; j < n; ++j) {
            int vec = 0;
            int x = 1;
            for (int k = 0; k < 8; ++k) {
                if (ary[c++]) {
                    vec = (byte)(vec | x);
                }
                x = (byte)(x << 1);
            }
            out.writeByte(vec);
        }
        int vec = 0;
        int x = 1;
        int rest = ary.length % 8;
        for (int i = 0; i < rest; ++i) {
            if (ary[c++]) {
                vec = (byte)(vec | x);
            }
            x = (byte)(x << 1);
        }
        out.write(vec);
    }

    public String[] getNames() {
        return this.names;
    }

    public String[] getInchis() {
        return this.inchis;
    }

    public FTree[] getTrainingTrees() {
        return this.trainingTrees;
    }

    public SimpleSpectrum[] getTrainingSpectra() {
        return this.trainingSpectra;
    }

    public int[] getFingerprintIndizes() {
        return this.fingerprintIndizes;
    }

    public ArrayFingerprint[] getTrainingFingerprints() {
        return this.trainingFingerprints;
    }

    public KernelMatrix[] getKernels() {
        return this.kernels;
    }

    public Predictor[] getPredictors() {
        return this.predictors;
    }

    public double[] getKernelNormalizationVector() {
        return this.kernelNormalizationVector;
    }

    public IOKRModel getIokr() {
        return this.iokr;
    }

    public void setIokr(IOKRModel iokr) {
        this.iokr = iokr;
    }
}

