44 #include "multi_array.hxx"
45 #include "sampling.hxx"
46 #include "threading.hxx"
47 #include "threadpool.hxx"
48 #include "random_forest_3/random_forest.hxx"
49 #include "random_forest_3/random_forest_common.hxx"
50 #include "random_forest_3/random_forest_visitors.hxx"
68 template <
typename FEATURES,
typename LABELS>
71 typedef RandomForest<FEATURES,
73 LessEqualSplitTest<typename FEATURES::value_type>,
74 ArgMaxVectorAcc<double> > type;
85 template <
typename ACC>
88 template <
typename A,
typename B>
89 void operator()(A & a, B
const & b)
const
98 struct RFMapUpdater<ArgMaxAcc>
100 template <
typename A,
typename B>
101 void operator()(A & a, B
const & b)
const
103 auto it = std::max_element(b.begin(), b.end());
104 a = std::distance(b.begin(), it);
111 template <
typename FEATURES,
typename LABELS,
typename SAMPLER,
typename SCORER>
113 FEATURES
const & features,
114 LABELS
const & labels,
115 std::vector<double>
const & instance_weights,
116 std::vector<size_t>
const & instances,
117 SAMPLER
const & dim_sampler,
120 typedef typename FEATURES::value_type FeatureType;
122 auto feats = std::vector<FeatureType>(instances.size());
123 auto sorted_indices = std::vector<size_t>(feats.size());
124 auto tosort_instances = std::vector<size_t>(feats.size());
126 for (
int i = 0; i < dim_sampler.sampleSize(); ++i)
128 size_t const d = dim_sampler[i];
131 for (
size_t kk = 0; kk < instances.size(); ++kk)
132 feats[kk] = features(instances[kk], d);
135 indexSort(feats.begin(), feats.end(), sorted_indices.begin());
136 std::copy(instances.begin(), instances.end(), tosort_instances.begin());
137 applyPermutation(sorted_indices.begin(), sorted_indices.end(), instances.begin(), tosort_instances.begin());
140 score(features, labels, instance_weights, tosort_instances.begin(), tosort_instances.end(), d);
149 template <
typename RF,
typename SCORER,
typename VISITOR,
typename STOP,
typename RANDENGINE>
150 void random_forest_single_tree(
151 typename RF::Features
const & features,
152 MultiArray<1, size_t>
const & labels,
153 RandomForestOptions
const & options,
157 RANDENGINE
const & randengine
159 typedef typename RF::Features Features;
160 typedef typename Features::value_type FeatureType;
161 typedef LessEqualSplitTest<FeatureType> SplitTests;
162 typedef typename RF::Node Node;
163 typedef typename RF::ACC ACC;
164 typedef typename ACC::input_type ACCInputType;
166 static_assert(std::is_same<SplitTests, typename RF::SplitTests>::value,
167 "random_forest_single_tree(): Wrong Random Forest class.");
170 int const num_instances = features.shape()[0];
171 size_t const num_features = features.shape()[1];
172 auto const & spec = tree.problem_spec_;
174 vigra_precondition(num_instances == labels.size(),
175 "random_forest_single_tree(): Shape mismatch between features and labels.");
176 vigra_precondition(num_features == spec.num_features_,
177 "random_forest_single_tree(): Wrong number of features.");
180 std::vector<size_t> instance_indices(num_instances);
181 std::iota(instance_indices.begin(), instance_indices.end(), 0);
182 typedef std::vector<size_t>::iterator InstanceIter;
185 std::vector<double> instance_weights(num_instances, 1.0);
186 if (options.bootstrap_sampling_)
188 std::fill(instance_weights.begin(), instance_weights.end(), 0.0);
189 Sampler<MersenneTwister> sampler(num_instances,
190 SamplerOptions().withReplacement().stratified(options.use_stratification_),
193 for (
int i = 0; i < sampler.sampleSize(); ++i)
195 int const index = sampler[i];
196 ++instance_weights[index];
201 if (options.class_weights_.size() > 0)
203 for (
size_t i = 0; i < instance_weights.size(); ++i)
204 instance_weights[i] *= options.class_weights_.at(labels(i));
208 auto const mtry = spec.actual_mtry_;
209 Sampler<MersenneTwister> dim_sampler(num_features, SamplerOptions().withoutReplacement().sampleSize(mtry), &randengine);
212 std::stack<Node> node_stack;
213 typedef std::pair<InstanceIter, InstanceIter> IterPair;
214 PropertyMap<Node, IterPair> instance_range;
215 PropertyMap<Node, std::vector<double> > node_distributions;
216 PropertyMap<Node, size_t> node_depths;
218 auto const rootnode = tree.graph_.addNode();
219 node_stack.push(rootnode);
221 instance_range.insert(rootnode, IterPair(instance_indices.begin(), instance_indices.end()));
223 std::vector<double> priors(spec.num_classes_, 0.0);
224 for (
auto i : instance_indices)
225 priors[labels(i)] += instance_weights[i];
226 node_distributions.insert(rootnode, priors);
228 node_depths.insert(rootnode, 0);
232 visitor.visit_before_tree(tree, features, labels, instance_weights);
235 detail::RFMapUpdater<ACC> node_map_updater;
236 while (!node_stack.empty())
239 auto const node = node_stack.top();
241 auto const begin = instance_range.at(node).first;
242 auto const end = instance_range.at(node).second;
243 auto const & priors = node_distributions.at(node);
244 auto const depth = node_depths.at(node);
247 std::vector<size_t> used_instances;
248 for (
auto it = begin; it != end; ++it)
249 if (instance_weights[*it] > 1e-10)
250 used_instances.push_back(*it);
253 dim_sampler.sample();
254 SCORER score(priors);
255 if (options.resample_count_ == 0 || used_instances.size() <= options.resample_count_)
270 Sampler<MersenneTwister> resampler(used_instances.begin(), used_instances.end(), SamplerOptions().withoutReplacement().sampleSize(options.resample_count_), &randengine);
272 auto indices = std::vector<size_t>(options.resample_count_);
273 for (
size_t i = 0; i < options.resample_count_; ++i)
274 indices[i] = used_instances[resampler[i]];
288 if (!score.split_found_)
290 tree.node_responses_.insert(node, ACCInputType());
291 node_map_updater(tree.node_responses_.at(node), node_distributions.at(node));
296 auto const n_left = tree.graph_.addNode();
297 auto const n_right = tree.graph_.addNode();
298 tree.graph_.addArc(node, n_left);
299 tree.graph_.addArc(node, n_right);
300 auto const best_split = score.best_split_;
301 auto const best_dim = score.best_dim_;
302 auto const split_iter = std::partition(begin, end,
305 return features(i, best_dim) <= best_split;
310 visitor.visit_after_split(tree, features, labels, instance_weights, score, begin, split_iter, end);
312 instance_range.insert(n_left, IterPair(begin, split_iter));
313 instance_range.insert(n_right, IterPair(split_iter, end));
314 tree.split_tests_.insert(node, SplitTests(best_dim, best_split));
315 node_depths.insert(n_left, depth+1);
316 node_depths.insert(n_right, depth+1);
319 auto priors_left = std::vector<double>(spec.num_classes_, 0.0);
320 for (
auto it = begin; it != split_iter; ++it)
321 priors_left[labels(*it)] += instance_weights[*it];
322 node_distributions.insert(n_left, priors_left);
325 if (stop(labels, RFNodeDescription<decltype(priors_left)>(depth+1, priors_left)))
327 tree.node_responses_.insert(n_left, ACCInputType());
328 node_map_updater(tree.node_responses_.at(n_left), node_distributions.at(n_left));
332 node_stack.push(n_left);
336 auto priors_right = std::vector<double>(spec.num_classes_, 0.0);
337 for (
auto it = split_iter; it != end; ++it)
338 priors_right[labels(*it)] += instance_weights[*it];
339 node_distributions.insert(n_right, priors_right);
342 if (stop(labels, RFNodeDescription<decltype(priors_right)>(depth+1, priors_right)))
344 tree.node_responses_.insert(n_right, ACCInputType());
345 node_map_updater(tree.node_responses_.at(n_right), node_distributions.at(n_right));
349 node_stack.push(n_right);
354 visitor.visit_after_tree(tree, features, labels, instance_weights);
360 template <
typename FEATURES,
366 RandomForest<FEATURES, LABELS>
368 FEATURES
const & features,
369 LABELS
const & labels,
370 RandomForestOptions
const & options,
373 RANDENGINE & randengine
376 typedef LABELS Labels;
378 typedef typename Labels::value_type LabelType;
379 typedef RandomForest<FEATURES, LABELS> RF;
381 ProblemSpec<LabelType> pspec;
382 pspec.num_instances(features.shape()[0])
383 .num_features(features.shape()[1])
384 .actual_mtry(options.get_features_per_node(features.shape()[1]))
385 .actual_msample(labels.size());
388 size_t const tree_count = options.tree_count_;
389 vigra_precondition(tree_count > 0,
"random_forest_impl(): tree_count must not be zero.");
390 std::vector<RF> trees(tree_count);
393 std::set<LabelType>
const dlabels(labels.begin(), labels.end());
394 std::vector<LabelType>
const distinct_labels(dlabels.begin(), dlabels.end());
395 pspec.distinct_classes(distinct_labels);
396 std::map<LabelType, size_t> label_map;
397 for (
size_t i = 0; i < distinct_labels.size(); ++i)
399 label_map[distinct_labels[i]] = i;
402 MultiArray<1, LabelType> transformed_labels(Shape1(labels.size()));
403 for (
size_t i = 0; i < (size_t)labels.size(); ++i)
405 transformed_labels(i) = label_map[labels(i)];
409 vigra_precondition(options.class_weights_.size() == 0 || options.class_weights_.size() == distinct_labels.size(),
410 "random_forest_impl(): The number of class weights must be 0 or equal to the number of classes.");
413 for (
auto & t : trees)
414 t.problem_spec_ = pspec;
417 size_t n_threads = 1;
418 if (options.n_threads_ >= 1)
419 n_threads = options.n_threads_;
420 else if (options.n_threads_ == -1)
421 n_threads = std::thread::hardware_concurrency();
424 UniformIntRandomFunctor<RANDENGINE> rand_functor(randengine);
425 std::set<UInt32> seeds;
426 while (seeds.size() < n_threads)
428 seeds.insert(rand_functor());
430 vigra_assert(seeds.size() == n_threads,
"random_forest_impl(): Could not create random seeds.");
433 std::vector<RANDENGINE> rand_engines;
434 for (
auto seed : seeds)
436 rand_engines.push_back(RANDENGINE(seed));
440 visitor.visit_before_training();
444 typedef typename VisitorCopy<VISITOR>::type VisitorCopyType;
445 std::vector<VisitorCopyType> tree_visitors;
446 for (
size_t i = 0; i < tree_count; ++i)
448 tree_visitors.emplace_back(visitor);
452 ThreadPool pool((
size_t)n_threads);
453 std::vector<threading::future<void> > futures;
454 for (
size_t i = 0; i < tree_count; ++i)
456 futures.emplace_back(
457 pool.enqueue([&features, &transformed_labels, &options, &tree_visitors, &stop, &trees, i, &rand_engines](
size_t thread_id)
459 random_forest_single_tree<RF, SCORER, VisitorCopyType, STOP>(features, transformed_labels, options, tree_visitors[i], stop, trees[i], rand_engines[thread_id]);
464 for (
auto & fut : futures)
469 rf.options_ = options;
470 for (
size_t i = 1; i < trees.size(); ++i)
476 visitor.visit_after_training(tree_visitors, rf, features, labels);
484 template <
typename FEATURES,
typename LABELS,
typename VISITOR,
typename SCORER,
typename RANDENGINE>
486 RandomForest<FEATURES, LABELS>
488 FEATURES
const & features,
489 LABELS
const & labels,
490 RandomForestOptions
const & options,
492 RANDENGINE & randengine
494 if (options.max_depth_ > 0)
495 return random_forest_impl<FEATURES, LABELS, VISITOR, SCORER, DepthStop, RANDENGINE>(features, labels, options, visitor, DepthStop(options.max_depth_), randengine);
496 else if (options.min_num_instances_ > 1)
497 return random_forest_impl<FEATURES, LABELS, VISITOR, SCORER, NumInstancesStop, RANDENGINE>(features, labels, options, visitor, NumInstancesStop(options.min_num_instances_), randengine);
498 else if (options.node_complexity_tau_ > 0)
499 return random_forest_impl<FEATURES, LABELS, VISITOR, SCORER, NodeComplexityStop, RANDENGINE>(features, labels, options, visitor, NodeComplexityStop(options.node_complexity_tau_), randengine);
501 return random_forest_impl<FEATURES, LABELS, VISITOR, SCORER, PurityStop, RANDENGINE>(features, labels, options, visitor, PurityStop(), randengine);
579 template <
typename FEATURES,
typename LABELS,
typename VISITOR,
typename RANDENGINE>
581 RandomForest<FEATURES, LABELS>
583 FEATURES
const & features,
584 LABELS
const & labels,
585 RandomForestOptions
const & options,
587 RANDENGINE & randengine
589 typedef detail::GeneralScorer<GiniScore> GiniScorer;
590 typedef detail::GeneralScorer<EntropyScore> EntropyScorer;
591 typedef detail::GeneralScorer<KolmogorovSmirnovScore> KSDScorer;
592 if (options.split_ == RF_GINI)
593 return detail::random_forest_impl0<FEATURES, LABELS, VISITOR, GiniScorer, RANDENGINE>(features, labels, options, visitor, randengine);
594 else if (options.split_ == RF_ENTROPY)
595 return detail::random_forest_impl0<FEATURES, LABELS, VISITOR, EntropyScorer, RANDENGINE>(features, labels, options, visitor, randengine);
596 else if (options.split_ == RF_KSD)
597 return detail::random_forest_impl0<FEATURES, LABELS, VISITOR, KSDScorer, RANDENGINE>(features, labels, options, visitor, randengine);
599 throw std::runtime_error(
"random_forest(): Unknown split criterion.");
602 template <
typename FEATURES,
typename LABELS,
typename VISITOR>
604 RandomForest<FEATURES, LABELS>
606 FEATURES
const & features,
607 LABELS
const & labels,
608 RandomForestOptions
const & options,
612 return random_forest(features, labels, options, visitor, randengine);
615 template <
typename FEATURES,
typename LABELS>
617 RandomForest<FEATURES, LABELS>
619 FEATURES
const & features,
620 LABELS
const & labels,
621 RandomForestOptions
const & options
627 template <
typename FEATURES,
typename LABELS>
629 RandomForest<FEATURES, LABELS>
631 FEATURES
const & features,
632 LABELS
const & labels
634 return random_forest(features, labels, RandomForestOptions());
void applyPermutation(IndexIterator index_first, IndexIterator index_last, InIterator in, OutIterator out)
Sort an array according to the given index permutation.
Definition: algorithm.hxx:456
void indexSort(Iterator first, Iterator last, IndexIterator index_first, Compare c)
Return the index permutation that would sort the input array.
Definition: algorithm.hxx:414
static RandomNumberGenerator & global()
Definition: random.hxx:566
doxygen_overloaded_function(template<...> void separableConvolveBlockwise) template< unsigned int N
Separated convolution on ChunkedArrays.
void random_forest(...)
Train a vigra::rf3::RandomForest classifier.