OpenMS
Loading...
Searching...
No Matches
CrossValidation.h
Go to the documentation of this file.
1// Copyright (c) 2002-present, OpenMS Inc. -- EKU Tuebingen, ETH Zurich, and FU Berlin
2// SPDX-License-Identifier: BSD-3-Clause
3//
4// --------------------------------------------------------------------------
5// $Maintainer: Justin Sing $
6// $Authors: Justin Sing $
7// --------------------------------------------------------------------------
8//
9
10#pragma once
11
12#include <OpenMS/config.h>
16
17#include <algorithm>
18#include <cmath>
19#include <cstddef>
20#include <utility>
21#include <vector>
22
23namespace OpenMS
24{
25
45{
46public:
57 {
61 };
62
76 static std::vector<std::vector<Size>> makeKFolds(Size n, Size K)
77 {
78 if (n == 0)
79 {
80 throw Exception::InvalidValue(__FILE__, __LINE__, OPENMS_PRETTY_FUNCTION,
81 "n", String(n));
82 }
83 if (K == 0)
84 {
85 throw Exception::InvalidValue(__FILE__, __LINE__, OPENMS_PRETTY_FUNCTION,
86 "K", String(K));
87 }
88 if (K > n) K = n;
89
90 std::vector<std::vector<Size>> folds(K);
91 for (Size i = 0; i < n; ++i) folds[i % K].push_back(i);
92 return folds;
93 }
94
123 template <typename CandIter, typename TrainEval, typename ScoreFn>
124 static std::pair<typename std::iterator_traits<CandIter>::value_type, double>
125 gridSearch1D(CandIter cbegin, CandIter cend,
126 const std::vector<std::vector<Size>>& folds,
127 TrainEval train_eval,
128 ScoreFn score,
129 double tie_tol = 1e-12,
131 {
132 using CandT = typename std::iterator_traits<CandIter>::value_type;
133
134 if (cbegin == cend)
135 {
136 throw Exception::InvalidRange(__FILE__, __LINE__, OPENMS_PRETTY_FUNCTION);
137 }
138
139 CandT best_cand = *cbegin;
140 double best_score = std::numeric_limits<double>::infinity();
141 bool first = true;
142
143 for (auto it = cbegin; it != cend; ++it)
144 {
145 const CandT cand = *it;
146
147 std::vector<double> abs_errs;
148 abs_errs.reserve(256); // grows as needed
149 train_eval(cand, folds, abs_errs);
150
151 const double s = score(abs_errs);
152
153 // Prefer larger candidate on numerical ties (more stable smoothing, etc.)
154 const bool better = (s < best_score - tie_tol);
155 const bool tie = (std::fabs(s - best_score) <= tie_tol);
156
157 bool wins_on_tie = false;
158 if (tie)
159 {
160 switch (tie_break)
161 {
162 case CandidateTieBreak::PreferLarger: wins_on_tie = cand > best_cand; break;
163 case CandidateTieBreak::PreferSmaller: wins_on_tie = cand < best_cand; break;
164 case CandidateTieBreak::PreferAny: wins_on_tie = false; break;
165 }
166 }
167
168 if (first || better || wins_on_tie)
169 {
170 best_cand = cand;
171 best_score = s;
172 first = false;
173 }
174 }
175
176 return {best_cand, best_score};
177 }
178};
179
180} // namespace OpenMS
Lightweight K-fold / LOO cross-validation utilities and 1-D grid search.
Definition CrossValidation.h:45
Invalid range exception.
Definition Exception.h:257
Invalid value exception.
Definition Exception.h:306
A more convenient string class.
Definition String.h:34
size_t Size
Size type e.g. used as variable which can hold result of size()
Definition Types.h:97
CandidateTieBreak
Tie-breaking preference for equal (within tolerance) CV scores.
Definition CrossValidation.h:57
static std::pair< typename std::iterator_traits< CandIter >::value_type, double > gridSearch1D(CandIter cbegin, CandIter cend, const std::vector< std::vector< Size > > &folds, TrainEval train_eval, ScoreFn score, double tie_tol=1e-12, CandidateTieBreak tie_break=CandidateTieBreak::PreferLarger)
One-dimensional grid search with external cross-validation evaluation.
Definition CrossValidation.h:125
static std::vector< std::vector< Size > > makeKFolds(Size n, Size K)
Build K folds for indices [0, n).
Definition CrossValidation.h:76
Main OpenMS namespace.
Definition openswathalgo/include/OpenMS/OPENSWATHALGO/DATAACCESS/ISpectrumAccess.h:19