package de.unijena.bioinf.GibbsSampling.model;

import de.unijena.bioinf.ChemistryBase.algorithm.scoring.Scored;
import de.unijena.bioinf.ChemistryBase.ms.CompoundQuality;
import de.unijena.bioinf.GibbsSampling.model.Candidate;
import de.unijena.bioinf.jjobs.BasicMasterJJob;
import de.unijena.bioinf.jjobs.JJob;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.set.hash.TIntHashSet;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.ExecutionException;

/* loaded from: input_file:de/unijena/bioinf/GibbsSampling/model/TwoPhaseGibbsSampling.class */
public class TwoPhaseGibbsSampling<C extends Candidate<?>> extends BasicMasterJJob<ZodiacResult<C>> {
    private String[] ids;
    private C[][] possibleFormulas;
    private NodeScorer<C>[] nodeScorers;
    private EdgeScorer<C>[] edgeScorers;
    private EdgeFilter edgeFilter;
    private int repetitions;
    private Class<C> cClass;
    private CompoundResult<C>[] results1;
    private CompoundResult<C>[] results2;
    private CompoundResult<C>[] combinedResult;
    private String[] usedIds;
    private Graph<C> graph;
    private GibbsParallel<C> gibbsParallel;
    private String[] firstRoundIds;
    private TIntArrayList firstRoundCompoundsIdx;
    private int maxSteps;
    private int burnIn;

    public TwoPhaseGibbsSampling(String[] strArr, C[][] cArr, NodeScorer[] nodeScorerArr, EdgeScorer<C>[] edgeScorerArr, EdgeFilter edgeFilter, int i, Class<C> cls) {
        super(JJob.JobType.CPU);
        this.maxSteps = -1;
        this.burnIn = -1;
        this.ids = strArr;
        this.possibleFormulas = cArr;
        this.nodeScorers = nodeScorerArr;
        this.edgeScorers = edgeScorerArr;
        this.edgeFilter = edgeFilter;
        this.repetitions = i;
        this.cClass = cls;
    }

    private void init() throws ExecutionException {
        Candidate[][] candidateArr;
        this.firstRoundCompoundsIdx = selectCompoundsForFirstRoundGibbsSampling();
        int i = 0;
        while (true) {
            if (i >= this.possibleFormulas.length) {
                break;
            }
            C[] cArr = this.possibleFormulas[i];
            if (this.cClass == null && cArr.length > 0) {
                this.cClass = (Class<C>) cArr[0].getClass();
                break;
            }
            i++;
        }
        if (this.firstRoundCompoundsIdx.size() == this.possibleFormulas.length) {
            candidateArr = this.possibleFormulas;
            this.firstRoundIds = this.ids;
        } else {
            candidateArr = (Candidate[][]) Array.newInstance((Class<?>) this.cClass, this.firstRoundCompoundsIdx.size(), 1);
            this.firstRoundIds = new String[this.firstRoundCompoundsIdx.size()];
            for (int i2 = 0; i2 < this.firstRoundCompoundsIdx.size(); i2++) {
                candidateArr[i2] = this.possibleFormulas[this.firstRoundCompoundsIdx.get(i2)];
                this.firstRoundIds[i2] = this.ids[this.firstRoundCompoundsIdx.get(i2)];
            }
        }
        logInfo("Start first round with " + this.firstRoundCompoundsIdx.size() + " of " + this.possibleFormulas.length + " compounds.");
        logInfo("ZODIAC: Graph building");
        long currentTimeMillis = System.currentTimeMillis();
        this.graph = (Graph) submitSubJob(GraphBuilder.createGraphBuilder(this.firstRoundIds, candidateArr, this.nodeScorers, this.edgeScorers, this.edgeFilter, this.cClass)).awaitResult();
        logInfo("finished building graph after: " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
    }

    private TIntArrayList selectCompoundsForFirstRoundGibbsSampling() {
        TIntArrayList tIntArrayList = new TIntArrayList();
        int length = this.possibleFormulas.length;
        long count = Arrays.stream(this.possibleFormulas).filter(candidateArr -> {
            return candidateArr.length > 0 && candidateArr[0].getExperiment().getAnnotation(CompoundQuality.class, CompoundQuality::new).isNotBadQuality();
        }).count();
        boolean z = count >= 300 && (1.0d * ((double) count)) / ((double) length) >= 0.33d;
        for (int i = 0; i < this.possibleFormulas.length; i++) {
            C[] cArr = this.possibleFormulas[i];
            if (z) {
                if (cArr.length > 0 && cArr[0].getExperiment().getAnnotation(CompoundQuality.class, CompoundQuality::new).isNotBadQuality()) {
                    tIntArrayList.add(i);
                }
            } else if (cArr.length > 0) {
                CompoundQuality annotation = cArr[0].getExperiment().getAnnotation(CompoundQuality.class, CompoundQuality::new);
                if (annotation.isNot(CompoundQuality.CompoundQualityFlag.FewPeaks) && annotation.isNot(CompoundQuality.CompoundQualityFlag.Chimeric) && annotation.isNot(CompoundQuality.CompoundQualityFlag.PoorlyExplained)) {
                    tIntArrayList.add(i);
                }
            }
            if (this.cClass == null && cArr.length > 0) {
                this.cClass = (Class<C>) cArr[0].getClass();
            }
        }
        return tIntArrayList;
    }

    public void setIterationSteps(int i, int i2) {
        this.maxSteps = i;
        this.burnIn = i2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public ZodiacResult<C> m12compute() throws Exception {
        if (this.maxSteps < 0 || this.burnIn < 0) {
            throw new IllegalArgumentException("number of iterations steps not set.");
        }
        checkForInterruption();
        init();
        logInfo("Running ZODIAC with " + this.firstRoundIds.length + " of " + this.ids.length + " compounds.");
        Graph.validateAndThrowError(this.graph, this::logWarn);
        this.gibbsParallel = new GibbsParallel<>(this.graph, this.repetitions);
        this.gibbsParallel.setIterationSteps(this.maxSteps, this.burnIn);
        long currentTimeMillis = System.currentTimeMillis();
        submitSubJob(this.gibbsParallel);
        this.results1 = (CompoundResult[]) this.gibbsParallel.awaitResult();
        logDebug("finished running " + this.repetitions + " repetitions in parallel: " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
        checkForInterruption();
        this.firstRoundIds = this.gibbsParallel.getGraph().getIds();
        if (this.firstRoundIds.length == this.possibleFormulas.length) {
            this.combinedResult = this.results1;
            this.usedIds = this.firstRoundIds;
        } else {
            logInfo("Running second round: Score " + (this.ids.length - this.results1.length) + " low quality compounds. " + this.ids.length + " compounds overall.");
            C[][] combineNewAndOldAndSetFixedProbabilities = combineNewAndOldAndSetFixedProbabilities(this.results1, this.firstRoundCompoundsIdx);
            TIntHashSet tIntHashSet = new TIntHashSet(this.firstRoundCompoundsIdx);
            this.graph = (Graph) submitSubJob(GraphBuilder.createGraphBuilder(this.ids, combineNewAndOldAndSetFixedProbabilities, this.nodeScorers, this.edgeScorers, this.edgeFilter, tIntHashSet, this.cClass)).awaitResult();
            checkForInterruption();
            Graph.validateAndThrowError(this.graph, this::logWarn);
            this.gibbsParallel = new GibbsParallel<>(this.graph, this.repetitions, tIntHashSet);
            this.gibbsParallel.setIterationSteps(this.maxSteps, this.burnIn);
            submitSubJob(this.gibbsParallel);
            this.results2 = (CompoundResult[]) this.gibbsParallel.awaitResult();
            checkForInterruption();
            this.usedIds = this.gibbsParallel.getGraph().ids;
            this.combinedResult = combineResults(this.results1, this.firstRoundIds, this.results2, this.usedIds);
        }
        return new ZodiacResult<>(this.ids, this.graph, this.combinedResult);
    }

    private void addConnectivityInfo(CompoundResult<C>[] compoundResultArr, Graph<C> graph, boolean z) {
        for (int i = 0; i < compoundResultArr.length; i++) {
            CompoundResult<C> compoundResult = compoundResultArr[i];
            if (!z || !compoundResult.hasAnnotation(Connectivity.class)) {
                compoundResult.addAnnotation(Connectivity.class, new Connectivity(graph.getMaxNumberOfConnectedCompounds(i)));
            }
        }
    }

    private CompoundResult<C>[] combineResults(CompoundResult<C>[] compoundResultArr, String[] strArr, CompoundResult<C>[] compoundResultArr2, String[] strArr2) {
        TObjectIntHashMap tObjectIntHashMap = new TObjectIntHashMap();
        for (int i = 0; i < strArr.length; i++) {
            tObjectIntHashMap.put(strArr[i], i);
        }
        CompoundResult<C>[] compoundResultArr3 = new CompoundResult[compoundResultArr2.length];
        for (int i2 = 0; i2 < strArr2.length; i2++) {
            String str = strArr2[i2];
            if (tObjectIntHashMap.containsKey(str)) {
                compoundResultArr3[i2] = compoundResultArr[tObjectIntHashMap.get(str)];
            } else {
                compoundResultArr3[i2] = compoundResultArr2[i2];
            }
        }
        return compoundResultArr3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private C[][] combineNewAndOld(Scored<C>[][] scoredArr, String[] strArr) {
        if (scoredArr.length == 0) {
            return this.possibleFormulas;
        }
        TObjectIntHashMap tObjectIntHashMap = new TObjectIntHashMap();
        for (int i = 0; i < strArr.length; i++) {
            tObjectIntHashMap.put(strArr[i], i);
        }
        C[][] cArr = (C[][]) ((Candidate[][]) Array.newInstance((Class<?>) this.cClass, this.possibleFormulas.length, 1));
        for (int i2 = 0; i2 < this.possibleFormulas.length; i2++) {
            if (tObjectIntHashMap.containsKey(this.ids[i2])) {
                Scored<C>[] scoredArr2 = scoredArr[tObjectIntHashMap.get(this.ids[i2])];
                ArrayList arrayList = new ArrayList();
                double d = 0.0d;
                for (Scored<C> scored : scoredArr2) {
                    arrayList.add((Candidate) scored.getCandidate());
                    d += scored.getScore();
                    if (d >= 0.99d) {
                        break;
                    }
                }
                cArr[i2] = (Candidate[]) arrayList.toArray((Candidate[]) Array.newInstance((Class<?>) this.cClass, 0));
            } else {
                cArr[i2] = this.possibleFormulas[i2];
            }
        }
        return cArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private C[][] combineNewAndOldAndSetFixedProbabilities(CompoundResult<C>[] compoundResultArr, TIntArrayList tIntArrayList) {
        if (compoundResultArr.length == 0) {
            return this.possibleFormulas;
        }
        TIntIntHashMap tIntIntHashMap = new TIntIntHashMap(compoundResultArr.length, 0.75f, -1, -1);
        for (int i = 0; i < tIntArrayList.size(); i++) {
            tIntIntHashMap.put(tIntArrayList.get(i), i);
        }
        C[][] cArr = (C[][]) ((Candidate[][]) Array.newInstance((Class<?>) this.cClass, this.possibleFormulas.length, 1));
        for (int i2 = 0; i2 < this.possibleFormulas.length; i2++) {
            try {
                if (tIntIntHashMap.containsKey(i2)) {
                    Scored<C>[] candidates = compoundResultArr[tIntIntHashMap.get(i2)].getCandidates();
                    ArrayList arrayList = new ArrayList();
                    for (Scored<C> scored : candidates) {
                        Candidate candidate = (Candidate) scored.getCandidate();
                        candidate.clearNodeScores();
                        candidate.addNodeProbabilityScore(scored.getScore());
                        arrayList.add(candidate);
                    }
                    cArr[i2] = (Candidate[]) arrayList.toArray((Candidate[]) Array.newInstance((Class<?>) this.cClass, 0));
                } else {
                    cArr[i2] = this.possibleFormulas[i2];
                }
            } catch (Exception e) {
                System.out.println("Error: " + e.getMessage());
                System.out.println(tIntIntHashMap.containsKey(i2));
                Scored<C>[] candidates2 = compoundResultArr[tIntIntHashMap.get(i2)].getCandidates();
                System.out.println(Arrays.toString(candidates2));
                for (int i3 = 0; i3 < candidates2.length; i3++) {
                    Scored<C> scored2 = candidates2[i3];
                    System.out.println(i3);
                    System.out.println(scored2);
                    System.out.println(scored2.getCandidate());
                    System.out.println("isScored " + (scored2 instanceof Scored));
                    System.out.println("isFragmentCandidate " + (scored2.getCandidate() instanceof FragmentsCandidate));
                }
            }
        }
        return cArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Scored<C>[][] getChosenFormulas() {
        Scored<C>[][] scoredArr = (Scored<C>[][]) new Scored[this.combinedResult.length];
        for (int i = 0; i < this.combinedResult.length; i++) {
            scoredArr[i] = this.combinedResult[i].getCandidates();
        }
        return scoredArr;
    }

    public Graph<C> getGraph() {
        return this.graph;
    }

    public String[] getIds() {
        return this.usedIds;
    }
}
