package de.unijena.bioinf.fingerid;

import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;

/* loaded from: input_file:de/unijena/bioinf/fingerid/Predictor.class */
public class Predictor {
    protected int realIndex;
    protected double rho;
    protected double probB;
    protected double probA;
    protected double[] coefficients;
    protected int[] supportVectors;
    protected double tp;
    protected double fp;
    protected double tn;
    protected double fn;
    protected ParameterC c;
    static final /* synthetic */ boolean $assertionsDisabled;

    public Predictor(int i, double d, double d2, double d3, double[] dArr, int[] iArr) {
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException("Numbers of support vectors and coefficients differ");
        }
        this.rho = d;
        this.probA = d2;
        this.probB = d3;
        this.coefficients = dArr;
        this.supportVectors = iArr;
        this.realIndex = i;
    }

    public double getProbB() {
        return this.probB;
    }

    public void setProbB(double d) {
        this.probB = d;
    }

    public double getProbA() {
        return this.probA;
    }

    public void setProbA(double d) {
        this.probA = d;
    }

    public boolean predict(double[] dArr) {
        return predictValue(dArr) > 0.0d;
    }

    public double estimateProbability(double[] dArr) {
        return sigmoid_predict(predictValue(dArr), this.probA, this.probB);
    }

    public void setStatistics(PredictionPerformance predictionPerformance) {
        this.tp = predictionPerformance.getTp();
        this.fp = predictionPerformance.getFp();
        this.tn = predictionPerformance.getTn();
        this.fn = predictionPerformance.getFn();
    }

    public PredictionPerformance getPerformance() {
        return new PredictionPerformance(this.tp, this.fp, this.tn, this.fn);
    }

    public int getRealIndex() {
        return this.realIndex;
    }

    public void setStatistics(double d, double d2, double d3, double d4) {
        this.tp = d;
        this.fp = d2;
        this.tn = d3;
        this.fn = d4;
    }

    public void setParameterC(ParameterC parameterC) {
        this.c = parameterC;
    }

    public ParameterC getParameterC() {
        return this.c;
    }

    public void setRealIndex(int i) {
        this.realIndex = i;
    }

    public double getTp() {
        return this.tp;
    }

    public double getFp() {
        return this.fp;
    }

    public double getTn() {
        return this.tn;
    }

    public double getFn() {
        return this.fn;
    }

    public double predictValue(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < this.supportVectors.length; i++) {
            d += this.coefficients[i] * dArr[this.supportVectors[i] - 1];
        }
        return d - this.rho;
    }

    public static Predictor parseModelFile(File file) throws IOException {
        String name = file.getName();
        return parseModelFile(file, Integer.parseInt(name.substring(0, name.lastIndexOf(46))));
    }

    public static Predictor parseModelFile(File file, int i) throws IOException {
        return parseModel(new FileReader(file), i);
    }

    public static Predictor parseModel(Reader reader) throws IOException {
        return parseModel(reader, 0);
    }

    public static Predictor parseModel(Reader reader, int i) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(reader);
        boolean z = false;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double[] dArr = null;
        int[] iArr = null;
        int i2 = 0;
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                return new Predictor(i, d, d3, d2, dArr, iArr);
            }
            if (z) {
                String[] split = readLine.split("\\s+|:");
                if (split.length >= 3) {
                    dArr[i2] = Double.parseDouble(split[0]);
                    iArr[i2] = Integer.parseInt(split[2]);
                    i2++;
                }
            } else if (readLine.startsWith("total_sv")) {
                int parseInt = Integer.parseInt(readLine.split("\\s+")[1]);
                dArr = new double[parseInt];
                iArr = new int[parseInt];
            } else if (readLine.startsWith("label")) {
                if (!readLine.startsWith("label 1 -1")) {
                    throw new IOException("strange order of labels");
                }
            } else if (readLine.startsWith("probA")) {
                d3 = Double.parseDouble(readLine.split("\\s+")[1]);
            } else if (readLine.startsWith("probB")) {
                d2 = Double.parseDouble(readLine.split("\\s+")[1]);
            } else if (readLine.startsWith("rho")) {
                d = Double.parseDouble(readLine.split("\\s+")[1]);
            } else if (readLine.startsWith("SV")) {
                z = true;
            }
        }
    }

    public void writeModel(BufferedWriter bufferedWriter) throws IOException {
        bufferedWriter.write("svm_type c_svc");
        bufferedWriter.newLine();
        bufferedWriter.write("kernel_type precomputed");
        bufferedWriter.newLine();
        bufferedWriter.write("nr_class 2");
        bufferedWriter.newLine();
        bufferedWriter.write("total_sv ");
        bufferedWriter.write(String.valueOf(this.supportVectors.length));
        bufferedWriter.newLine();
        bufferedWriter.write("rho ");
        bufferedWriter.write(String.valueOf(this.rho));
        bufferedWriter.newLine();
        bufferedWriter.write("label 1 -1");
        bufferedWriter.newLine();
        bufferedWriter.write("probA ");
        bufferedWriter.write(String.valueOf(this.probA));
        bufferedWriter.newLine();
        bufferedWriter.write("probB ");
        bufferedWriter.write(String.valueOf(this.probB));
        bufferedWriter.newLine();
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < this.coefficients.length; i3++) {
            if (this.coefficients[i3] > 0.0d) {
                i++;
            } else {
                i2++;
            }
        }
        if (!$assertionsDisabled && i + i2 != this.supportVectors.length) {
            throw new AssertionError();
        }
        bufferedWriter.write("nr_svg ");
        bufferedWriter.write(String.valueOf(i));
        bufferedWriter.write(" ");
        bufferedWriter.write(String.valueOf(i2));
        bufferedWriter.newLine();
        bufferedWriter.write("SV");
        bufferedWriter.newLine();
        for (int i4 = 0; i4 < this.supportVectors.length; i4++) {
            bufferedWriter.write(String.valueOf(this.coefficients[i4]));
            bufferedWriter.write(" 0:");
            bufferedWriter.write(String.valueOf(this.supportVectors[i4]));
            bufferedWriter.newLine();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double sigmoid_predict(double d, double d2, double d3) {
        double d4 = (d * d2) + d3;
        return d4 >= 0.0d ? Math.exp(-d4) / (1.0d + Math.exp(-d4)) : 1.0d / (1.0d + Math.exp(d4));
    }

    static {
        $assertionsDisabled = !Predictor.class.desiredAssertionStatus();
    }
}
