OpenMS
CrossValidation.h
Go to the documentation of this file.
1 // Copyright (c) 2002-present, OpenMS Inc. -- EKU Tuebingen, ETH Zurich, and FU Berlin
2 // SPDX-License-Identifier: BSD-3-Clause
3 //
4 // --------------------------------------------------------------------------
5 // $Maintainer: Justin Sing $
6 // $Authors: Justin Sing $
7 // --------------------------------------------------------------------------
8 //
9 
10 #pragma once
11 
12 #include <OpenMS/config.h>
13 #include <OpenMS/CONCEPT/Types.h>
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <cstddef>
20 #include <utility>
21 #include <vector>
22 
23 namespace OpenMS
24 {
25 
45 {
46 public:
56  enum class CandidateTieBreak
57  {
60  PreferAny
61  };
62 
76  static std::vector<std::vector<Size>> makeKFolds(Size n, Size K)
77  {
78  if (n == 0)
79  {
80  throw Exception::InvalidValue(__FILE__, __LINE__, OPENMS_PRETTY_FUNCTION,
81  "n", String(n));
82  }
83  if (K == 0)
84  {
85  throw Exception::InvalidValue(__FILE__, __LINE__, OPENMS_PRETTY_FUNCTION,
86  "K", String(K));
87  }
88  if (K > n) K = n;
89 
90  std::vector<std::vector<Size>> folds(K);
91  for (Size i = 0; i < n; ++i) folds[i % K].push_back(i);
92  return folds;
93  }
94 
123  template <typename CandIter, typename TrainEval, typename ScoreFn>
124  static std::pair<typename std::iterator_traits<CandIter>::value_type, double>
125  gridSearch1D(CandIter cbegin, CandIter cend,
126  const std::vector<std::vector<Size>>& folds,
127  TrainEval train_eval,
128  ScoreFn score,
129  double tie_tol = 1e-12,
131  {
132  using CandT = typename std::iterator_traits<CandIter>::value_type;
133 
134  if (cbegin == cend)
135  {
136  throw Exception::InvalidRange(__FILE__, __LINE__, OPENMS_PRETTY_FUNCTION);
137  }
138 
139  CandT best_cand = *cbegin;
140  double best_score = std::numeric_limits<double>::infinity();
141  bool first = true;
142 
143  for (auto it = cbegin; it != cend; ++it)
144  {
145  const CandT cand = *it;
146 
147  std::vector<double> abs_errs;
148  abs_errs.reserve(256); // grows as needed
149  train_eval(cand, folds, abs_errs);
150 
151  const double s = score(abs_errs);
152 
153  // Prefer larger candidate on numerical ties (more stable smoothing, etc.)
154  const bool better = (s < best_score - tie_tol);
155  const bool tie = (std::fabs(s - best_score) <= tie_tol);
156 
157  bool wins_on_tie = false;
158  if (tie)
159  {
160  switch (tie_break)
161  {
162  case CandidateTieBreak::PreferLarger: wins_on_tie = cand > best_cand; break;
163  case CandidateTieBreak::PreferSmaller: wins_on_tie = cand < best_cand; break;
164  case CandidateTieBreak::PreferAny: wins_on_tie = false; break;
165  }
166  }
167 
168  if (first || better || wins_on_tie)
169  {
170  best_cand = cand;
171  best_score = s;
172  first = false;
173  }
174  }
175 
176  return {best_cand, best_score};
177  }
178 };
179 
180 } // namespace OpenMS
Lightweight K-fold / LOO cross-validation utilities and 1-D grid search.
Definition: CrossValidation.h:45
Invalid range exception.
Definition: Exception.h:257
Invalid value exception.
Definition: Exception.h:305
A more convenient string class.
Definition: String.h:34
size_t Size
Size type e.g. used as variable which can hold result of size()
Definition: Types.h:97
CandidateTieBreak
Tie-breaking preference for equal (within tolerance) CV scores.
Definition: CrossValidation.h:57
static std::pair< typename std::iterator_traits< CandIter >::value_type, double > gridSearch1D(CandIter cbegin, CandIter cend, const std::vector< std::vector< Size >> &folds, TrainEval train_eval, ScoreFn score, double tie_tol=1e-12, CandidateTieBreak tie_break=CandidateTieBreak::PreferLarger)
One-dimensional grid search with external cross-validation evaluation.
Definition: CrossValidation.h:125
static std::vector< std::vector< Size > > makeKFolds(Size n, Size K)
Build K folds for indices [0, n).
Definition: CrossValidation.h:76
Main OpenMS namespace.
Definition: openswathalgo/include/OpenMS/OPENSWATHALGO/DATAACCESS/ISpectrumAccess.h:19