package de.unijena.bioinf.fingerid;

import de.unijena.bioinf.ChemistryBase.fp.CdkFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.MaskedFingerprintVersion;
import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.ChemistryBase.ms.Peak;
import de.unijena.bioinf.ChemistryBase.ms.ft.FTree;
import de.unijena.bioinf.ChemistryBase.ms.utils.SimpleSpectrum;
import de.unijena.bioinf.babelms.binary.SpectrumBinaryReader;
import de.unijena.bioinf.babelms.binary.SpectrumBinaryWriter;
import de.unijena.bioinf.babelms.json.FTJsonReader;
import de.unijena.bioinf.babelms.json.FTJsonWriter;
import de.unijena.bioinf.fingerid.utils.PROPERTIES;
import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Serializable;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

/* loaded from: input_file:de/unijena/bioinf/fingerid/TrainedCSIFingerId.class */
public class TrainedCSIFingerId implements Serializable {
    public static final short VERSION_ID = 5;
    protected String version;
    protected String[] names;
    protected String[] inchis;
    protected FTree[] trainingTrees;
    protected SimpleSpectrum[] trainingSpectra;
    protected int[] fingerprintIndizes;
    protected boolean[][] trainingFingerprints;
    protected KernelMatrix[] kernels;
    protected double[] kernelNormalizationVector;
    protected Predictor[] predictors;
    protected double[] precursors;
    protected int fingerprintVersion;
    protected MaskedFingerprintVersion maskedFingerprintVersion;
    private static final String CHECK = "end of tree block.";

    public TrainedCSIFingerId(MaskedFingerprintVersion maskedFingerprintVersion, int i, int i2) {
        this.fingerprintVersion = 2;
        this.names = new String[i];
        this.inchis = new String[i];
        this.trainingSpectra = new SimpleSpectrum[i];
        this.trainingTrees = new FTree[i];
        this.kernelNormalizationVector = new double[i];
        this.version = PROPERTIES.fingeridVersion();
        this.maskedFingerprintVersion = maskedFingerprintVersion;
        this.fingerprintIndizes = maskedFingerprintVersion.allowedIndizes();
        int length = this.fingerprintIndizes.length;
        this.predictors = new Predictor[length];
        this.trainingFingerprints = new boolean[i][length];
        this.kernels = new KernelMatrix[i2];
        this.precursors = new double[i];
    }

    private TrainedCSIFingerId() {
        this.fingerprintVersion = 2;
    }

    public Predictor getPredictorByRealIndex(int i) {
        int binarySearch = Arrays.binarySearch(this.fingerprintIndizes, i);
        if (binarySearch < 0) {
            return null;
        }
        return this.predictors[binarySearch];
    }

    public PredictionPerformance[] getPredictionPerformances() {
        PredictionPerformance[] predictionPerformanceArr = new PredictionPerformance[numberOfFingerprints()];
        for (int i = 0; i < predictionPerformanceArr.length; i++) {
            predictionPerformanceArr[i] = this.predictors[i].getPerformance();
        }
        return predictionPerformanceArr;
    }

    public double[] getPrecursorMz() {
        return this.precursors;
    }

    public int numberOfTrainingData() {
        return this.trainingTrees.length;
    }

    public int numberOfFingerprints() {
        return this.fingerprintIndizes.length;
    }

    public String getVersion() {
        return this.version;
    }

    public int numberOfKernels() {
        return this.kernels.length;
    }

    public MaskedFingerprintVersion getMaskedFingerprintVersion() {
        return this.maskedFingerprintVersion;
    }

    public int getFingerprintVersion() {
        return this.fingerprintVersion;
    }

    public void dump(OutputStream outputStream) throws IOException {
        DataOutputStream dataOutputStream = new DataOutputStream(outputStream);
        dataOutputStream.writeShort(5);
        dataOutputStream.writeUTF(this.version);
        dataOutputStream.writeInt(this.trainingTrees.length);
        dataOutputStream.writeInt(this.fingerprintIndizes.length);
        dataOutputStream.writeInt(this.kernels.length);
        writeStringArray(dataOutputStream, this.names);
        writeStringArray(dataOutputStream, this.inchis);
        writeTrees(dataOutputStream, this.trainingTrees);
        writeDoubleArray(dataOutputStream, this.precursors);
        SpectrumBinaryWriter.writeSpectra(dataOutputStream, this.trainingSpectra);
        writeIntArray(dataOutputStream, this.fingerprintIndizes);
        for (int i = 0; i < this.trainingFingerprints.length; i++) {
            writeBooleanArray(dataOutputStream, this.trainingFingerprints[i]);
        }
        dataOutputStream.writeUTF(NormalizationType.ENABLED_NORMALIZATION_TYPE.name());
        for (int i2 = 0; i2 < this.kernels.length; i2++) {
            dataOutputStream.writeUTF(this.kernels[i2].kernelName);
            dataOutputStream.writeDouble(this.kernels[i2].weight);
            if (NormalizationType.ENABLED_NORMALIZATION_TYPE.isNormalizingBeforeCentering()) {
                writeDoubleArray(dataOutputStream, this.kernels[i2].getNormalizations());
            }
            dataOutputStream.writeDouble(this.kernels[i2].kernelCentering.average);
            writeDoubleArray(dataOutputStream, this.kernels[i2].kernelCentering.averages);
            if (NormalizationType.ENABLED_NORMALIZATION_TYPE.isNormalizingAfterCentering()) {
                writeDoubleArray(dataOutputStream, this.kernels[i2].kernelCentering.diagonal);
            }
        }
        writeDoubleArray(dataOutputStream, this.kernelNormalizationVector);
        for (int i3 = 0; i3 < this.predictors.length; i3++) {
            Predictor predictor = this.predictors[i3];
            dataOutputStream.writeInt(predictor.realIndex);
            dataOutputStream.writeDouble(predictor.rho);
            dataOutputStream.writeDouble(predictor.probA);
            dataOutputStream.writeDouble(predictor.probB);
            dataOutputStream.writeInt(predictor.coefficients.length);
            writeDoubleArray(dataOutputStream, predictor.coefficients);
            writeIntArray(dataOutputStream, predictor.supportVectors);
            dataOutputStream.writeDouble(predictor.tp);
            dataOutputStream.writeDouble(predictor.fp);
            dataOutputStream.writeDouble(predictor.tn);
            dataOutputStream.writeDouble(predictor.fn);
        }
        dataOutputStream.writeInt(this.fingerprintVersion);
        dataOutputStream.writeLong(this.maskedFingerprintVersion.getMaskedFingerprintVersion().getBitsetIdentifier());
    }

    private void writeTrees(DataOutputStream dataOutputStream, FTree[] fTreeArr) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(1024 * fTreeArr.length);
        int[] iArr = new int[fTreeArr.length];
        FTJsonWriter fTJsonWriter = new FTJsonWriter();
        OutputStreamWriter outputStreamWriter = new OutputStreamWriter(new GZIPOutputStream(byteArrayOutputStream), Charset.forName("UTF-8"));
        for (int i = 0; i < fTreeArr.length; i++) {
            String treeToJsonString = fTJsonWriter.treeToJsonString(fTreeArr[i]);
            outputStreamWriter.write(treeToJsonString);
            iArr[i] = treeToJsonString.length();
        }
        outputStreamWriter.write(CHECK);
        outputStreamWriter.close();
        byte[] byteArray = byteArrayOutputStream.toByteArray();
        dataOutputStream.writeInt(fTreeArr.length);
        for (int i2 : iArr) {
            dataOutputStream.writeInt(i2);
        }
        dataOutputStream.writeInt(byteArray.length);
        dataOutputStream.writeByte(42);
        dataOutputStream.write(byteArray);
    }

    private FTree[] readTrees(DataInputStream dataInputStream) throws IOException {
        int readInt = dataInputStream.readInt();
        FTree[] fTreeArr = new FTree[readInt];
        int[] iArr = new int[readInt];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = dataInputStream.readInt();
        }
        int readInt2 = dataInputStream.readInt();
        if (dataInputStream.readByte() != 42) {
            throw new IOException("Alignment problem. File seems to be corrupted");
        }
        byte[] bArr = new byte[readInt2];
        int i2 = 0;
        do {
            i2 += dataInputStream.read(bArr, i2, bArr.length - i2);
        } while (i2 < bArr.length);
        InputStreamReader inputStreamReader = new InputStreamReader(new GZIPInputStream(new ByteArrayInputStream(bArr)), Charset.forName("UTF-8"));
        int i3 = 0;
        for (int i4 : iArr) {
            i3 = Math.max(i3, i4);
        }
        char[] cArr = new char[i3];
        FTJsonReader fTJsonReader = new FTJsonReader();
        for (int i5 = 0; i5 < iArr.length; i5++) {
            inputStreamReader.read(cArr, 0, iArr[i5]);
            fTreeArr[i5] = fTJsonReader.treeFromJsonString(new String(cArr, 0, iArr[i5]), (URL) null);
        }
        inputStreamReader.read(cArr, 0, CHECK.length());
        if (!new String(cArr, 0, CHECK.length()).equals(CHECK)) {
            throw new IOException("Alignment problem. File seems to be corrupted");
        }
        inputStreamReader.close();
        return fTreeArr;
    }

    private void writeString(DataOutputStream dataOutputStream, String str) throws IOException {
        byte[] bytes = str.getBytes(Charset.forName("UTF-8"));
        dataOutputStream.writeInt(bytes.length);
        dataOutputStream.write(bytes);
    }

    private String readString(DataInputStream dataInputStream) throws IOException {
        byte[] bArr = new byte[dataInputStream.readInt()];
        int i = 0;
        do {
            i += dataInputStream.read(bArr, i, bArr.length - i);
        } while (i < bArr.length);
        return new String(bArr, Charset.forName("UTF-8"));
    }

    private void writeStringArray(DataOutputStream dataOutputStream, String[] strArr) throws IOException {
        dataOutputStream.writeInt(strArr.length);
        for (String str : strArr) {
            dataOutputStream.writeUTF(str);
        }
    }

    private void readStringArray(DataInputStream dataInputStream, String[] strArr) throws IOException {
        int readInt = dataInputStream.readInt();
        if (strArr.length != readInt) {
            throw new IndexOutOfBoundsException("Wrong ary length: " + readInt + " expected but " + strArr.length + " given");
        }
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = dataInputStream.readUTF();
        }
    }

    public static TrainedCSIFingerId load(File file) throws IOException {
        BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(file));
        Throwable th = null;
        try {
            TrainedCSIFingerId trainedCSIFingerId = new TrainedCSIFingerId();
            trainedCSIFingerId.loadFromStream(bufferedInputStream);
            if (bufferedInputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            return trainedCSIFingerId;
        } catch (Throwable th3) {
            if (bufferedInputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            throw th3;
        }
    }

    public static TrainedCSIFingerId load(InputStream inputStream) throws IOException {
        TrainedCSIFingerId trainedCSIFingerId = new TrainedCSIFingerId();
        trainedCSIFingerId.loadFromStream(inputStream);
        return trainedCSIFingerId;
    }

    private void loadFromStream(InputStream inputStream) throws IOException {
        double[] dArr;
        DataInputStream dataInputStream = new DataInputStream(inputStream);
        short readShort = dataInputStream.readShort();
        if (readShort > 5) {
            throw new IOException("Model file is using version " + ((int) readShort) + " while this program expects version 5");
        }
        this.version = dataInputStream.readUTF();
        int readInt = dataInputStream.readInt();
        this.names = new String[readInt];
        this.inchis = new String[readInt];
        this.trainingSpectra = new SimpleSpectrum[readInt];
        this.trainingTrees = new FTree[readInt];
        this.precursors = new double[readInt];
        this.kernelNormalizationVector = new double[readInt];
        int readInt2 = dataInputStream.readInt();
        this.fingerprintIndizes = new int[readInt2];
        this.predictors = new Predictor[readInt2];
        this.trainingFingerprints = new boolean[readInt][readInt2];
        this.kernels = new KernelMatrix[dataInputStream.readInt()];
        readStringArray(dataInputStream, this.names);
        readStringArray(dataInputStream, this.inchis);
        this.trainingTrees = readTrees(dataInputStream);
        if (readShort > 1) {
            readDoubleArray(dataInputStream, this.precursors);
        } else {
            for (int i = 0; i < this.trainingTrees.length; i++) {
                FTree fTree = this.trainingTrees[i];
                this.precursors[i] = ((Peak) fTree.getFragmentAnnotationOrThrow(Peak.class).get(fTree.getRoot())).getMass();
            }
        }
        this.trainingSpectra = SpectrumBinaryReader.readSpectra(dataInputStream);
        readIntArray(dataInputStream, this.fingerprintIndizes);
        for (int i2 = 0; i2 < this.trainingFingerprints.length; i2++) {
            readBooleanArray(dataInputStream, this.trainingFingerprints[i2]);
        }
        String readUTF = dataInputStream.readUTF();
        System.out.println("READ '" + readUTF + "'");
        NormalizationType valueOf = NormalizationType.valueOf(readUTF);
        if (valueOf != NormalizationType.ENABLED_NORMALIZATION_TYPE) {
            throw new RuntimeException("Current version of CSI:FingerId is not compatible with fingerid.data: Normalization Type differs! Current: " + NormalizationType.ENABLED_NORMALIZATION_TYPE.toString() + ", in fingerid.data: " + valueOf.toString());
        }
        for (int i3 = 0; i3 < this.kernels.length; i3++) {
            String readUTF2 = dataInputStream.readUTF();
            double readDouble = dataInputStream.readDouble();
            double[] dArr2 = new double[readInt];
            if (valueOf.isNormalizingBeforeCentering()) {
                readDoubleArray(dataInputStream, dArr2);
            }
            double readDouble2 = dataInputStream.readDouble();
            double[] dArr3 = new double[readInt];
            readDoubleArray(dataInputStream, dArr3);
            if (valueOf.isNormalizingAfterCentering()) {
                dArr = new double[readInt];
                readDoubleArray(dataInputStream, dArr);
            } else {
                dArr = null;
            }
            this.kernels[i3] = new KernelMatrix(readUTF2, new KernelCentering(readDouble2, dArr3, dArr), dArr2, readDouble);
        }
        readDoubleArray(dataInputStream, this.kernelNormalizationVector);
        for (int i4 = 0; i4 < this.predictors.length; i4++) {
            int readInt3 = dataInputStream.readInt();
            double readDouble3 = dataInputStream.readDouble();
            double readDouble4 = dataInputStream.readDouble();
            double readDouble5 = dataInputStream.readDouble();
            int readInt4 = dataInputStream.readInt();
            double[] dArr4 = new double[readInt4];
            int[] iArr = new int[readInt4];
            readDoubleArray(dataInputStream, dArr4);
            readIntArray(dataInputStream, iArr);
            double readDouble6 = dataInputStream.readDouble();
            double readDouble7 = dataInputStream.readDouble();
            double readDouble8 = dataInputStream.readDouble();
            double readDouble9 = dataInputStream.readDouble();
            this.predictors[i4] = new Predictor(readInt3, readDouble3, readDouble4, readDouble5, dArr4, iArr);
            this.predictors[i4].setStatistics(readDouble6, readDouble7, readDouble8, readDouble9);
        }
        this.fingerprintVersion = dataInputStream.readInt();
        if (this.fingerprintVersion == 1) {
            MaskedFingerprintVersion.Builder disableAll = MaskedFingerprintVersion.buildMaskFor(CdkFingerprintVersion.getDefault()).disableAll();
            for (int i5 : this.fingerprintIndizes) {
                disableAll.enable(i5);
            }
            this.maskedFingerprintVersion = disableAll.toMask();
            return;
        }
        if (this.fingerprintVersion != 2) {
            throw new RuntimeException("Unknown fingerprint version: " + this.fingerprintVersion);
        }
        MaskedFingerprintVersion.Builder disableAll2 = MaskedFingerprintVersion.buildMaskFor(CdkFingerprintVersion.getFromBitsetIdentifier(dataInputStream.readLong())).disableAll();
        for (int i6 : this.fingerprintIndizes) {
            disableAll2.enable(i6);
        }
        this.maskedFingerprintVersion = disableAll2.toMask();
    }

    private void readDoubleArray(DataInputStream dataInputStream, double[] dArr) throws IOException {
        int readInt = dataInputStream.readInt();
        if (dArr.length != readInt) {
            throw new IndexOutOfBoundsException("Wrong ary length: " + readInt + " expected but " + dArr.length + " given");
        }
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = dataInputStream.readDouble();
        }
    }

    private void readBooleanArray(DataInputStream dataInputStream, boolean[] zArr) throws IOException {
        int readInt = dataInputStream.readInt();
        if (zArr.length != readInt) {
            throw new IndexOutOfBoundsException("Wrong ary length: " + readInt + " expected but " + zArr.length + " given");
        }
        int length = zArr.length / 8;
        byte[] bArr = new byte[8];
        for (int i = 0; i < 8; i++) {
            bArr[i] = (byte) (1 << i);
        }
        int i2 = 0;
        for (int i3 = 0; i3 < length; i3++) {
            byte readByte = dataInputStream.readByte();
            for (int i4 = 0; i4 < 8; i4++) {
                int i5 = i2;
                i2++;
                zArr[i5] = (bArr[i4] & readByte) == bArr[i4];
            }
        }
        int length2 = zArr.length % 8;
        byte readByte2 = dataInputStream.readByte();
        for (int i6 = 0; i6 < length2; i6++) {
            int i7 = i2;
            i2++;
            zArr[i7] = (bArr[i6] & readByte2) == bArr[i6];
        }
    }

    private void readIntArray(DataInputStream dataInputStream, int[] iArr) throws IOException {
        int readInt = dataInputStream.readInt();
        if (iArr.length != readInt) {
            throw new IndexOutOfBoundsException("Wrong ary length: " + readInt + " expected but " + iArr.length + " given");
        }
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = dataInputStream.readInt();
        }
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        dump(objectOutputStream);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        loadFromStream(objectInputStream);
    }

    private void readObjectNoData() throws ObjectStreamException {
    }

    private void writeIntArray(DataOutputStream dataOutputStream, int[] iArr) throws IOException {
        dataOutputStream.writeInt(iArr.length);
        for (int i : iArr) {
            dataOutputStream.writeInt(i);
        }
    }

    private void writeDoubleArray(DataOutputStream dataOutputStream, double[] dArr) throws IOException {
        dataOutputStream.writeInt(dArr.length);
        for (double d : dArr) {
            dataOutputStream.writeDouble(d);
        }
    }

    private void writeBooleanArray(DataOutputStream dataOutputStream, boolean[] zArr) throws IOException {
        dataOutputStream.writeInt(zArr.length);
        int length = zArr.length / 8;
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            byte b = 0;
            byte b2 = 1;
            for (int i3 = 0; i3 < 8; i3++) {
                int i4 = i;
                i++;
                if (zArr[i4]) {
                    b = (byte) (b | b2);
                }
                b2 = (byte) (b2 << 1);
            }
            dataOutputStream.writeByte(b);
        }
        byte b3 = 0;
        byte b4 = 1;
        int length2 = zArr.length % 8;
        for (int i5 = 0; i5 < length2; i5++) {
            int i6 = i;
            i++;
            if (zArr[i6]) {
                b3 = (byte) (b3 | b4);
            }
            b4 = (byte) (b4 << 1);
        }
        dataOutputStream.write(b3);
    }

    public String[] getNames() {
        return this.names;
    }

    public String[] getInchis() {
        return this.inchis;
    }

    public FTree[] getTrainingTrees() {
        return this.trainingTrees;
    }

    public SimpleSpectrum[] getTrainingSpectra() {
        return this.trainingSpectra;
    }

    public int[] getFingerprintIndizes() {
        return this.fingerprintIndizes;
    }

    public boolean[][] getTrainingFingerprints() {
        return this.trainingFingerprints;
    }

    public KernelMatrix[] getKernels() {
        return this.kernels;
    }

    public Predictor[] getPredictors() {
        return this.predictors;
    }

    public double[] getKernelNormalizationVector() {
        return this.kernelNormalizationVector;
    }
}
