/*
 * Decompiled with CFR 0.152.
 */
package de.unijena.bioinf.fingerid.svm;

import de.unijena.bioinf.ChemistryBase.fp.PredictionPerformance;
import de.unijena.bioinf.fingerid.OptimizationStrategy;
import de.unijena.bioinf.fingerid.svm.Crossvalidation;
import de.unijena.bioinf.fingerid.svm.Sample;
import de.unijena.bioinf.fingerid.svm.Svm;
import de.unijena.bioinf.fingerid.svm.SvmInstance;
import de.unijena.bioinf.fingerid.svm.SvmModel;
import gnu.trove.list.array.TDoubleArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;

public class CSelection {
    private final ExecutorService service;
    private final int nThreads;
    private double[] cvalues;

    public CSelection(ExecutorService service, int numberOfThreads) {
        this.service = service;
        this.nThreads = numberOfThreads;
    }

    public void setPossibleCs(double ... cvalues) {
        this.cvalues = (double[])cvalues.clone();
        Arrays.sort(this.cvalues);
    }

    public <T extends Sample> double learnC(SvmInstance instance, Crossvalidation<T> crossvalidation, OptimizationStrategy strategy, PredictionPerformance bestPerformance, int ... allowedFolds) throws InterruptedException {
        try {
            if (this.cvalues.length <= 1) {
                throw new IllegalArgumentException("Too few c parameters");
            }
            int middle = this.cvalues.length <= 2 ? 0 : this.cvalues.length / 2;
            PredictionPerformance bestLargeP = null;
            PredictionPerformance bestSmallP = null;
            double bestLargeC = 0.0;
            double bestSmallC = 0.0;
            int betterLarge = 0;
            int betterSmall = 0;
            int i = middle;
            int j = middle + 1;
            ArrayList<FutureTask<PredictionPerformance>> waitForLeft = new ArrayList<FutureTask<PredictionPerformance>>();
            TDoubleArrayList waitForLeftC = new TDoubleArrayList();
            ArrayList<FutureTask<PredictionPerformance>> waitForRight = new ArrayList<FutureTask<PredictionPerformance>>();
            TDoubleArrayList waitForRightC = new TDoubleArrayList();
            while (i >= 0 || j < this.cvalues.length) {
                PredictionPerformance pf;
                FutureTask wait;
                int l;
                for (int c = 0; c < 2; ++c) {
                    if (i >= 0) {
                        waitForLeft.add(this.tryC(instance, crossvalidation, allowedFolds, this.cvalues[i]));
                        waitForLeftC.add(this.cvalues[i]);
                        --i;
                    }
                    if (j >= this.cvalues.length) continue;
                    waitForRight.add(this.tryC(instance, crossvalidation, allowedFolds, this.cvalues[j]));
                    waitForRightC.add(this.cvalues[j]);
                    ++j;
                }
                for (l = 0; l < waitForLeft.size(); ++l) {
                    wait = (FutureTask)waitForLeft.get(l);
                    wait.run();
                    pf = (PredictionPerformance)wait.get();
                    System.out.println("For c = " + waitForLeftC.get(l) + " performance is " + pf);
                    if (bestSmallP == null || strategy.getComparator().compare(pf, bestSmallP) > 0) {
                        bestSmallC = waitForLeftC.get(l);
                        bestSmallP = pf;
                        betterSmall = 0;
                        continue;
                    }
                    if (++betterSmall < 2) continue;
                    i = -1;
                }
                for (l = 0; l < waitForRight.size(); ++l) {
                    wait = (FutureTask)waitForRight.get(l);
                    wait.run();
                    pf = (PredictionPerformance)wait.get();
                    System.out.println("For c = " + waitForRightC.get(l) + " performance is " + pf);
                    if (bestLargeP == null || strategy.getComparator().compare(pf, bestLargeP) > 0) {
                        bestLargeC = waitForRightC.get(l);
                        bestLargeP = pf;
                        betterLarge = 0;
                        continue;
                    }
                    if (++betterLarge < 2) continue;
                    j = this.cvalues.length + 1;
                }
                waitForLeft.clear();
                waitForLeftC.clear();
                waitForRight.clear();
                waitForRightC.clear();
            }
            if (strategy.getComparator().compare(bestSmallP, bestLargeP) >= 1) {
                bestPerformance.set(bestSmallP);
                return bestSmallC;
            }
            bestPerformance.set(bestLargeP);
            return bestLargeC;
        }
        catch (ExecutionException e) {
            throw new RuntimeException(e.getCause());
        }
    }

    private <T extends Sample> FutureTask<PredictionPerformance> tryC(final SvmInstance instance, final Crossvalidation<T> crossvalidation, final int[] allowedFolds, final double C) {
        final Future[] futures = new Future[allowedFolds.length];
        for (int k = 0; k < allowedFolds.length; ++k) {
            final int K = k;
            futures[k] = this.service.submit(new Callable<PredictionPerformance>(){

                @Override
                public PredictionPerformance call() throws Exception {
                    int[] folds = new int[allowedFolds.length - 1];
                    int j = 0;
                    for (int f = 0; f < allowedFolds.length; ++f) {
                        if (f == K) continue;
                        folds[j++] = allowedFolds[f];
                    }
                    Sample[] samples = crossvalidation.getFoldsArray(folds);
                    SvmInstance myInstance = instance.getCopy();
                    myInstance.setSamples(samples);
                    myInstance.setParameter("c", C);
                    SvmModel model = myInstance.train();
                    Sample[] evaluation = crossvalidation.getFoldsArray(allowedFolds[K]);
                    double[] predictions = model.predict(evaluation);
                    return Svm.evaluateClassificationPerformance(evaluation, predictions);
                }
            });
        }
        return new FutureTask<PredictionPerformance>(new Callable<PredictionPerformance>(){

            @Override
            public PredictionPerformance call() throws Exception {
                PredictionPerformance pf = (PredictionPerformance)futures[0].get();
                for (int k = 1; k < futures.length; ++k) {
                    pf.merge((PredictionPerformance)futures[k].get());
                }
                pf.calc();
                return pf;
            }
        });
    }
}

