[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest_common.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2014-2015 by Ullrich Koethe and Philip Schill */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 #ifndef VIGRA_RF3_COMMON_HXX
36 #define VIGRA_RF3_COMMON_HXX
37 
38 #include <iterator>
39 #include <type_traits>
40 #include <cmath>
41 #include <numeric>
42 
43 #include "../multi_array.hxx"
44 #include "../mathutil.hxx"
45 
46 namespace vigra
47 {
48 
49 namespace rf3
50 {
51 
52 /** \addtogroup MachineLearning
53 **/
54 //@{
55 
56 template <typename T>
57 struct LessEqualSplitTest
58 {
59 public:
60  LessEqualSplitTest(size_t dim = 0, T const & val = 0)
61  :
62  dim_(dim),
63  val_(val)
64  {}
65 
66  template<typename FEATURES>
67  size_t operator()(FEATURES const & features) const
68  {
69  return features(dim_) <= val_ ? 0 : 1;
70  }
71 
72  size_t dim_;
73  T val_;
74 };
75 
76 
77 
78 struct ArgMaxAcc
79 {
80 public:
81  typedef size_t input_type;
82 
83  template <typename ITER, typename OUTITER>
84  void operator()(ITER begin, ITER end, OUTITER out)
85  {
86  std::fill(buffer_.begin(), buffer_.end(), 0);
87  size_t max_v = 0;
88  size_t n = 0;
89  for (ITER it = begin; it != end; ++it)
90  {
91  size_t const v = *it;
92  if (v >= buffer_.size())
93  {
94  buffer_.resize(v+1, 0);
95  }
96  ++buffer_[v];
97  ++n;
98  max_v = std::max(max_v, v);
99  }
100  for (size_t i = 0; i <= max_v; ++i)
101  {
102  *out = buffer_[i] / static_cast<double>(n);
103  ++out;
104  }
105  }
106 private:
107  std::vector<size_t> buffer_;
108 };
109 
110 
111 
112 template <typename VALUETYPE>
113 struct ArgMaxVectorAcc
114 {
115 public:
116  typedef VALUETYPE value_type;
117  typedef std::vector<value_type> input_type;
118  template <typename ITER, typename OUTITER>
119  void operator()(ITER begin, ITER end, OUTITER out)
120  {
121  std::fill(buffer_.begin(), buffer_.end(), 0);
122  size_t max_v = 0;
123  for (ITER it = begin; it != end; ++it)
124  {
125  input_type const & vec = *it;
126  if (vec.size() >= buffer_.size())
127  {
128  buffer_.resize(vec.size(), 0);
129  }
130  value_type const n = std::accumulate(vec.begin(), vec.end(), static_cast<value_type>(0));
131  for (size_t i = 0; i < vec.size(); ++i)
132  {
133  buffer_[i] += vec[i] / static_cast<double>(n);
134  }
135  max_v = std::max(vec.size()-1, max_v);
136  }
137  for (size_t i = 0; i <= max_v; ++i)
138  {
139  *out = buffer_[i];
140  ++out;
141  }
142  }
143  private:
144  std::vector<double> buffer_;
145 };
146 
147 
148 
149 // struct LargestSumAcc
150 // {
151 // public:
152 // typedef std::vector<size_t> input_type;
153 // template <typename ITER>
154 // size_t operator()(ITER begin, ITER end)
155 // {
156 // std::fill(buffer_.begin(), buffer_.end(), 0);
157 // for (ITER it = begin; it != end; ++it)
158 // {
159 // auto const & v = *it;
160 // if (v.size() > buffer_.size())
161 // {
162 // buffer_.resize(v.size(), 0);
163 // }
164 // for (size_t i = 0; i < v.size(); ++i)
165 // {
166 // buffer_[i] += v[i];
167 // }
168 // }
169 // size_t max_label = 0;
170 // size_t max_count = 0;
171 // for (size_t i = 0; i < buffer_.size(); ++i)
172 // {
173 // if (buffer_[i] > max_count)
174 // {
175 // max_count = buffer_[i];
176 // max_label = i;
177 // }
178 // }
179 // return max_label;
180 // }
181 // private:
182 // std::vector<size_t> buffer_;
183 // };
184 
185 
186 
187 // struct ForestGarroteAcc
188 // {
189 // public:
190 // typedef double input_type;
191 // template <typename ITER, typename OUTITER>
192 // void operator()(ITER begin, ITER end, OUTITER out)
193 // {
194 // double s = 0.0;
195 // for (ITER it = begin; it != end; ++it)
196 // {
197 // s += *it;
198 // }
199 // if (s < 0.0)
200 // s = 0.0;
201 // else if (s > 1.0)
202 // s = 1.0;
203 // *out = 1.0-s;
204 // ++out;
205 // *out = s;
206 // }
207 // };
208 
209 
210 
211 namespace detail
212 {
213 
214  /// Abstract scorer that iterates over all split candidates, uses FUNCTOR to compute a score,
215  /// and saves the split with the minimum score.
216  template <typename FUNCTOR>
218  {
219  public:
220 
221  typedef FUNCTOR Functor;
222 
223  GeneralScorer(std::vector<double> const & priors)
224  :
225  split_found_(false),
226  best_split_(0),
227  best_dim_(0),
228  best_score_(std::numeric_limits<double>::max()),
229  priors_(priors),
230  n_total_(std::accumulate(priors.begin(), priors.end(), 0.0))
231  {}
232 
233  template <typename FEATURES, typename LABELS, typename WEIGHTS, typename ITER>
234  void operator()(
235  FEATURES const & features,
236  LABELS const & labels,
237  WEIGHTS const & weights,
238  ITER begin,
239  ITER end,
240  size_t dim
241  ){
242  if (begin == end)
243  return;
244 
245  Functor score;
246 
247  std::vector<double> counts(priors_.size(), 0.0);
248  double n_left = 0;
249  ITER next = begin;
250  ++next;
251  for (; next != end; ++begin, ++next)
252  {
253  // Move the label from the right side to the left side.
254  size_t const left_index = *begin;
255  size_t const right_index = *next;
256  size_t const label = static_cast<size_t>(labels(left_index));
257  counts[label] += weights[left_index];
258  n_left += weights[left_index];
259 
260  // Skip if there is no new split.
261  auto const left = features(left_index, dim);
262  auto const right = features(right_index, dim);
263  if (left == right)
264  continue;
265 
266  // Update the score.
267  split_found_ = true;
268  double const s = score(priors_, counts, n_total_, n_left);
269  bool const better_score = s < best_score_;
270  if (better_score)
271  {
272  best_score_ = s;
273  best_split_ = 0.5*(left+right);
274  best_dim_ = dim;
275  }
276  }
277  }
278 
279  bool split_found_; // whether a split was found at all
280  double best_split_; // the threshold of the best split
281  size_t best_dim_; // the dimension of the best split
282  double best_score_; // the score of the best split
283 
284  private:
285 
286  std::vector<double> const priors_; // the weighted number of datapoints per class
287  double const n_total_; // the weighted number of datapoints
288  };
289 
290 } // namespace detail
291 
292 /// \brief Functor that computes the gini score.
293 ///
294 /// This functor is typically selected indirectly by passing the value <tt>RF_GINI</tt>
295 /// to vigra::rf3::RandomForestOptions::split().
297 {
298 public:
299  double operator()(std::vector<double> const & priors,
300  std::vector<double> const & counts, double n_total, double n_left) const
301  {
302  double const n_right = n_total - n_left;
303  double gini_left = 1.0;
304  double gini_right = 1.0;
305  for (size_t i = 0; i < counts.size(); ++i)
306  {
307  double const p_left = counts[i] / n_left;
308  double const p_right = (priors[i] - counts[i]) / n_right;
309  gini_left -= (p_left*p_left);
310  gini_right -= (p_right*p_right);
311  }
312  return n_left*gini_left + n_right*gini_right;
313  }
314 
315  // needed for Gini-based variable importance calculation
316  template <typename LABELS, typename WEIGHTS, typename ITER>
317  static double region_score(LABELS const & labels, WEIGHTS const & weights, ITER begin, ITER end)
318  {
319  // Count the occurences.
320  std::vector<double> counts;
321  double total = 0.0;
322  for (auto it = begin; it != end; ++it)
323  {
324  auto const d = *it;
325  auto const lbl = labels[d];
326  if (counts.size() <= lbl)
327  {
328  counts.resize(lbl+1, 0.0);
329  }
330  counts[lbl] += weights[d];
331  total += weights[d];
332  }
333 
334  // Compute the gini.
335  double gini = total;
336  for (auto x : counts)
337  {
338  gini -= x*x/total;
339  }
340  return gini;
341  }
342 };
343 
344 /// \brief Functor that computes the entropy score.
345 ///
346 /// This functor is typically selected indirectly by passing the value <tt>RF_ENTROPY</tt>
347 /// to vigra::rf3::RandomForestOptions::split().
349 {
350 public:
351  double operator()(std::vector<double> const & priors, std::vector<double> const & counts, double n_total, double n_left) const
352  {
353  double const n_right = n_total - n_left;
354  double ig = 0;
355  for (size_t i = 0; i < counts.size(); ++i)
356  {
357  double c = counts[i];
358  if (c != 0)
359  ig -= c * std::log(c / n_left);
360 
361  c = priors[i] - c;
362  if (c != 0)
363  ig -= c * std::log(c / n_right);
364  }
365  return ig;
366  }
367 
368  template <typename LABELS, typename WEIGHTS, typename ITER>
369  double region_score(LABELS const & /*labels*/, WEIGHTS const & /*weights*/, ITER /*begin*/, ITER /*end*/) const
370  {
371  vigra_fail("EntropyScore::region_score(): Not implemented yet.");
372  return 0.0; // FIXME
373  }
374 };
375 
376 /// \brief Functor that computes the Kolmogorov-Smirnov score.
377 ///
378 /// Actually, it reutrns the negated KSD score, because we want to minimize.
379 ///
380 /// This functor is typically selected indirectly by passing the value <tt>RF_KSD</tt>
381 /// to vigra::rf3::RandomForestOptions::split().
383 {
384 public:
385  double operator()(std::vector<double> const & priors, std::vector<double> const & counts, double /*n_total*/, double /*n_left*/) const // Fix unused parameter warning, but leave in to not break compatibility with overall API
386  {
387  double const eps = 1e-10;
388  double nnz = 0;
389  std::vector<double> norm_counts(counts.size(), 0.0);
390  for (size_t i = 0; i < counts.size(); ++i)
391  {
392  if (priors[i] > eps)
393  {
394  norm_counts[i] = counts[i] / priors[i];
395  ++nnz;
396  }
397  }
398  if (nnz < eps)
399  return 0.0;
400 
401  // NOTE to future self:
402  // In std::accumulate, it makes a huge difference whether you use 0 or 0.0 as init. Think about that before making changes.
403  double const mean = std::accumulate(norm_counts.begin(), norm_counts.end(), 0.0) / nnz;
404 
405  // Compute the sum of the squared distances.
406  double ksd = 0.0;
407  for (size_t i = 0; i < norm_counts.size(); ++i)
408  {
409  if (priors[i] != 0)
410  {
411  double const v = (mean-norm_counts[i]);
412  ksd += v*v;
413  }
414  }
415  return -ksd;
416  }
417 
418  template <typename LABELS, typename WEIGHTS, typename ITER>
419  double region_score(LABELS const & /*labels*/, WEIGHTS const & /*weights*/, ITER /*begin*/, ITER /*end*/) const
420  {
421  vigra_fail("KolmogorovSmirnovScore::region_score(): Region score not available for the Kolmogorov-Smirnov split.");
422  return 0.0;
423  }
424 };
425 
426 // This struct holds the depth and the weighted number of datapoints per class of a single node.
427 template <typename ARR>
428 struct RFNodeDescription
429 {
430 public:
431  RFNodeDescription(size_t depth, ARR const & priors)
432  :
433  depth_(depth),
434  priors_(priors)
435  {}
436  size_t depth_;
437  ARR const & priors_;
438 };
439 
440 
441 
442 // Return true if the given node is pure.
443 template <typename LABELS, typename ITER>
444 bool is_pure(LABELS const & /*labels*/, RFNodeDescription<ITER> const & desc)
445 {
446  bool found = false;
447  for (auto n : desc.priors_)
448  {
449  if (n > 0)
450  {
451  if (found)
452  return false;
453  else
454  found = true;
455  }
456  }
457  return true;
458 }
459 
460 /// @brief Random forest 'node purity' stop criterion.
461 ///
462 /// Stop splitting a node when it contains only instanes of a single class.
464 {
465 public:
466  template <typename LABELS, typename ITER>
467  bool operator()(LABELS const & labels, RFNodeDescription<ITER> const & desc) const
468  {
469  return is_pure(labels, desc);
470  }
471 };
472 
473 /// @brief Random forest 'maximum depth' stop criterion.
474 ///
475 /// Stop splitting a node when the its depth reaches a given value or when it is pure.
477 {
478 public:
479  /// @brief Constructor: terminate tree construction at \a max_depth.
480  DepthStop(size_t max_depth)
481  :
482  max_depth_(max_depth)
483  {}
484 
485  template <typename LABELS, typename ITER>
486  bool operator()(LABELS const & labels, RFNodeDescription<ITER> const & desc) const
487  {
488  if (desc.depth_ >= max_depth_)
489  return true;
490  else
491  return is_pure(labels, desc);
492  }
493  size_t max_depth_;
494 };
495 
496 /// @brief Random forest 'number of datapoints' stop criterion.
497 ///
498 /// Stop splitting a node when it contains too few instances or when it is pure.
500 {
501 public:
502  /// @brief Constructor: terminate tree construction when node contains less than \a min_n instances.
503  NumInstancesStop(size_t min_n)
504  :
505  min_n_(min_n)
506  {}
507 
508  template <typename LABELS, typename ARR>
509  bool operator()(LABELS const & labels, RFNodeDescription<ARR> const & desc) const
510  {
511  typedef typename ARR::value_type value_type;
512  if (std::accumulate(desc.priors_.begin(), desc.priors_.end(), static_cast<value_type>(0)) <= min_n_)
513  return true;
514  else
515  return is_pure(labels, desc);
516  }
517  size_t min_n_;
518 };
519 
520 /// @brief Random forest 'node complexity' stop criterion.
521 ///
522 /// Stop splitting a node when it allows for too few different data arrangements.
523 /// This includes purity, which offers only a sinlge data arrangement.
525 {
526 public:
527  /// @brief Constructor: stop when fewer than <tt>1/tau</tt> label arrangements are possible.
528  NodeComplexityStop(double tau = 0.001)
529  :
530  logtau_(std::log(tau))
531  {
532  vigra_precondition(tau > 0 && tau < 1, "NodeComplexityStop(): Tau must be in the open interval (0, 1).");
533  }
534 
535  template <typename LABELS, typename ARR>
536  bool operator()(LABELS const & /*labels*/, RFNodeDescription<ARR> const & desc) // Fix unused parameter, but leave in for API compatability
537  {
538  typedef typename ARR::value_type value_type;
539 
540  // Count the labels.
541  size_t const total = std::accumulate(desc.priors_.begin(), desc.priors_.end(), static_cast<value_type>(0));
542 
543  // Compute log(prod_k(n_k!)).
544  size_t nnz = 0;
545  double lg = 0.0;
546  for (auto v : desc.priors_)
547  {
548  if (v > 0)
549  {
550  ++nnz;
551  lg += loggamma(static_cast<double>(v+1));
552  }
553  }
554  lg += loggamma(static_cast<double>(nnz+1));
555  lg -= loggamma(static_cast<double>(total+1));
556  if (nnz <= 1)
557  return true;
558 
559  return lg >= logtau_;
560  }
561 
562  double logtau_;
563 };
564 
565 enum RandomForestOptionTags
566 {
567  RF_SQRT,
568  RF_LOG,
569  RF_CONST,
570  RF_ALL,
571  RF_GINI,
572  RF_ENTROPY,
573  RF_KSD
574 };
575 
576 
577 /** \brief Options class for \ref vigra::rf3::RandomForest version 3.
578 
579  <b>\#include</b> <vigra/random_forest_3.hxx><br/>
580  Namespace: vigra::rf3
581 */
583 {
584 public:
585 
587  :
588  tree_count_(255),
589  features_per_node_(0),
590  features_per_node_switch_(RF_SQRT),
591  bootstrap_sampling_(true),
592  resample_count_(0),
593  split_(RF_GINI),
594  max_depth_(0),
595  node_complexity_tau_(-1),
596  min_num_instances_(1),
597  use_stratification_(false),
598  n_threads_(-1),
599  class_weights_()
600  {}
601 
602  /**
603  * @brief The number of trees.
604  *
605  * Default: 255
606  */
607  RandomForestOptions & tree_count(int p_tree_count)
608  {
609  tree_count_ = p_tree_count;
610  return *this;
611  }
612 
613  /**
614  * @brief The number of features that are considered when computing the split.
615  *
616  * @param p_features_per_node the number of features
617  *
618  * Default: use sqrt of the total number of features.
619  */
620  RandomForestOptions & features_per_node(int p_features_per_node)
621  {
622  features_per_node_switch_ = RF_CONST;
623  features_per_node_ = p_features_per_node;
624  return *this;
625  }
626 
627  /**
628  * @brief The number of features that are considered when computing the split.
629  *
630  * @param p_features_per_node_switch possible values: <br/>
631  <tt>vigra::rf3::RF_SQRT</tt> (use square root of total number of features, recommended for classification), <br/>
632  <tt>vigra::rf3::RF_LOG</tt> (use logarithm of total number of features, recommended for regression), <br/>
633  <tt>vigra::rf3::RF_ALL</tt> (use all features).
634  *
635  * Default: <tt>vigra::rf3::RF_SQRT</tt>
636  */
637  RandomForestOptions & features_per_node(RandomForestOptionTags p_features_per_node_switch)
638  {
639  vigra_precondition(p_features_per_node_switch == RF_SQRT ||
640  p_features_per_node_switch == RF_LOG ||
641  p_features_per_node_switch == RF_ALL,
642  "RandomForestOptions::features_per_node(): Input must be RF_SQRT, RF_LOG or RF_ALL.");
643  features_per_node_switch_ = p_features_per_node_switch;
644  return *this;
645  }
646 
647  /**
648  * @brief Use bootstrap sampling.
649  *
650  * Default: true
651  */
653  {
654  bootstrap_sampling_ = b;
655  return *this;
656  }
657 
658  /**
659  * @brief If resample_count is greater than zero, the split in each node is computed using only resample_count data points.
660  *
661  * Default: \a n = 0 (don't resample in every node)
662  */
664  {
665  resample_count_ = n;
666  bootstrap_sampling_ = false;
667  return *this;
668  }
669 
670  /**
671  * @brief The split criterion.
672  *
673  * @param p_split possible values: <br/>
674  <tt>vigra::rf3::RF_GINI</tt> (use Gini criterion, \ref vigra::rf3::GiniScorer), <br/>
675  <tt>vigra::rf3::RF_ENTROPY</tt> (use entropy criterion, \ref vigra::rf3::EntropyScorer), <br/>
676  <tt>vigra::rf3::RF_KSD</tt> (use Kolmogorov-Smirnov criterion, \ref vigra::rf3::KSDScorer).
677  *
678  * Default: <tt>vigra::rf3::RF_GINI</tt>
679  */
680  RandomForestOptions & split(RandomForestOptionTags p_split)
681  {
682  vigra_precondition(p_split == RF_GINI ||
683  p_split == RF_ENTROPY ||
684  p_split == RF_KSD,
685  "RandomForestOptions::split(): Input must be RF_GINI, RF_ENTROPY or RF_KSD.");
686  split_ = p_split;
687  return *this;
688  }
689 
690  /**
691  * @brief Do not split a node if its depth is greater or equal to max_depth.
692  *
693  * Default: \a d = 0 (don't use depth as a termination criterion)
694  */
696  {
697  max_depth_ = d;
698  return *this;
699  }
700 
701  /**
702  * @brief Value of the node complexity termination criterion.
703  *
704  * Default: \a tau = -1 (don't use complexity as a termination criterion)
705  */
707  {
708  node_complexity_tau_ = tau;
709  return *this;
710  }
711 
712  /**
713  * @brief Do not split a node if it contains less than min_num_instances data points.
714  *
715  * Default: \a n = 1 (don't use instance count as a termination criterion)
716  */
718  {
719  min_num_instances_ = n;
720  return *this;
721  }
722 
723  /**
724  * @brief Use stratification when creating the bootstrap samples.
725  *
726  * That is, preserve the proportion between the number of class instances exactly
727  * rather than on average.
728  *
729  * Default: false
730  */
732  {
733  use_stratification_ = b;
734  return *this;
735  }
736 
737  /**
738  * @brief The number of threads that are used in training.
739  *
740  * \a n = -1 means use number of cores, \a n = 0 means single-threaded training.
741  *
742  * Default: \a n = -1 (use as many threads as there are cores in the machine).
743  */
745  {
746  n_threads_ = n;
747  return *this;
748  }
749 
750  /**
751  * @brief Each datapoint is weighted by its class weight. By default, each class has weight 1.
752  * @details
753  * The classes in the random forest training have to follow a strict ordering. The weights must be given in that order.
754  * Example:
755  * You have the classes 3, 8 and 5 and use the vector {0.2, 0.3, 0.4} for the class weights.
756  * The ordering of the classes is 3, 5, 8, so class 3 will get weight 0.2, class 5 will get weight 0.3
757  * and class 8 will get weight 0.4.
758  */
759  RandomForestOptions & class_weights(std::vector<double> const & v)
760  {
761  class_weights_ = v;
762  return *this;
763  }
764 
765  /**
766  * @brief Get the actual number of features per node.
767  *
768  * @param total the total number of features
769  *
770  * This function is normally only called internally before training is started.
771  */
772  size_t get_features_per_node(size_t total) const
773  {
774  if (features_per_node_switch_ == RF_SQRT)
775  return std::ceil(std::sqrt(total));
776  else if (features_per_node_switch_ == RF_LOG)
777  return std::ceil(std::log(total));
778  else if (features_per_node_switch_ == RF_CONST)
779  return features_per_node_;
780  else if (features_per_node_switch_ == RF_ALL)
781  return total;
782  vigra_fail("RandomForestOptions::get_features_per_node(): Unknown switch.");
783  return 0;
784  }
785 
786  int tree_count_;
787  int features_per_node_;
788  RandomForestOptionTags features_per_node_switch_;
789  bool bootstrap_sampling_;
790  size_t resample_count_;
791  RandomForestOptionTags split_;
792  size_t max_depth_;
793  double node_complexity_tau_;
794  size_t min_num_instances_;
795  bool use_stratification_;
796  int n_threads_;
797  std::vector<double> class_weights_;
798 
799 };
800 
801 
802 
803 template <typename LabelType>
804 class ProblemSpec
805 {
806 public:
807 
808  ProblemSpec()
809  :
810  num_features_(0),
811  num_instances_(0),
812  num_classes_(0),
813  distinct_classes_(),
814  actual_mtry_(0),
815  actual_msample_(0)
816  {}
817 
818  ProblemSpec & num_features(size_t n)
819  {
820  num_features_ = n;
821  return *this;
822  }
823 
824  ProblemSpec & num_instances(size_t n)
825  {
826  num_instances_ = n;
827  return *this;
828  }
829 
830  ProblemSpec & num_classes(size_t n)
831  {
832  num_classes_ = n;
833  return *this;
834  }
835 
836  ProblemSpec & distinct_classes(std::vector<LabelType> v)
837  {
838  distinct_classes_ = v;
839  num_classes_ = v.size();
840  return *this;
841  }
842 
843  ProblemSpec & actual_mtry(size_t m)
844  {
845  actual_mtry_ = m;
846  return *this;
847  }
848 
849  ProblemSpec & actual_msample(size_t m)
850  {
851  actual_msample_ = m;
852  return *this;
853  }
854 
855  bool operator==(ProblemSpec const & other) const
856  {
857  #define COMPARE(field) if (field != other.field) return false;
858  COMPARE(num_features_);
859  COMPARE(num_instances_);
860  COMPARE(num_classes_);
861  COMPARE(distinct_classes_);
862  COMPARE(actual_mtry_);
863  COMPARE(actual_msample_);
864  #undef COMPARE
865  return true;
866  }
867 
868  size_t num_features_;
869  size_t num_instances_;
870  size_t num_classes_;
871  std::vector<LabelType> distinct_classes_;
872  size_t actual_mtry_;
873  size_t actual_msample_;
874 
875 };
876 
877 //@}
878 
879 } // namespace rf3
880 
881 } // namespace vigra
882 
883 #endif
884 
RandomForestOptions & min_num_instances(size_t n)
Do not split a node if it contains less than min_num_instances data points.
Definition: random_forest_common.hxx:717
RandomForestOptions & split(RandomForestOptionTags p_split)
The split criterion.
Definition: random_forest_common.hxx:680
Random forest 'maximum depth' stop criterion.
Definition: random_forest_common.hxx:476
size_t get_features_per_node(size_t total) const
Get the actual number of features per node.
Definition: random_forest_common.hxx:772
DepthStop(size_t max_depth)
Constructor: terminate tree construction at max_depth.
Definition: random_forest_common.hxx:480
RandomForestOptions & features_per_node(RandomForestOptionTags p_features_per_node_switch)
The number of features that are considered when computing the split.
Definition: random_forest_common.hxx:637
RandomForestOptions & max_depth(size_t d)
Do not split a node if its depth is greater or equal to max_depth.
Definition: random_forest_common.hxx:695
RandomForestOptions & bootstrap_sampling(bool b)
Use bootstrap sampling.
Definition: random_forest_common.hxx:652
problem specification class for the random forest.
Definition: rf_common.hxx:538
Definition: random_forest_common.hxx:217
RandomForestOptions & class_weights(std::vector< double > const &v)
Each datapoint is weighted by its class weight. By default, each class has weight 1...
Definition: random_forest_common.hxx:759
RandomForestOptions & use_stratification(bool b)
Use stratification when creating the bootstrap samples.
Definition: random_forest_common.hxx:731
RandomForestOptions & resample_count(size_t n)
If resample_count is greater than zero, the split in each node is computed using only resample_count ...
Definition: random_forest_common.hxx:663
Functor that computes the entropy score.
Definition: random_forest_common.hxx:348
RandomForestOptions & n_threads(int n)
The number of threads that are used in training.
Definition: random_forest_common.hxx:744
bool operator==(FFTWComplex< R > const &a, const FFTWComplex< R > &b)
equal
Definition: fftw3.hxx:825
Random forest 'node purity' stop criterion.
Definition: random_forest_common.hxx:463
NodeComplexityStop(double tau=0.001)
Constructor: stop when fewer than 1/tau label arrangements are possible.
Definition: random_forest_common.hxx:528
RandomForestOptions & node_complexity_tau(double tau)
Value of the node complexity termination criterion.
Definition: random_forest_common.hxx:706
linalg::TemporaryMatrix< T > log(MultiArrayView< 2, T, C > const &v)
double loggamma(double x)
The natural logarithm of the gamma function.
Definition: mathutil.hxx:1603
Functor that computes the gini score.
Definition: random_forest_common.hxx:296
Random forest 'node complexity' stop criterion.
Definition: random_forest_common.hxx:524
Options class for vigra::rf3::RandomForest version 3.
Definition: random_forest_common.hxx:582
RandomForestOptions & features_per_node(int p_features_per_node)
The number of features that are considered when computing the split.
Definition: random_forest_common.hxx:620
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up.
Definition: fixedpoint.hxx:675
RandomForestOptions & tree_count(int p_tree_count)
The number of trees.
Definition: random_forest_common.hxx:607
NumInstancesStop(size_t min_n)
Constructor: terminate tree construction when node contains less than min_n instances.
Definition: random_forest_common.hxx:503
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root.
Definition: fixedpoint.hxx:616
Functor that computes the Kolmogorov-Smirnov score.
Definition: random_forest_common.hxx:382
Random forest 'number of datapoints' stop criterion.
Definition: random_forest_common.hxx:499

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.1 (Fri May 19 2017)