35 #ifndef VIGRA_RF3_VISITORS_HXX
36 #define VIGRA_RF3_VISITORS_HXX
40 #include "../multi_array.hxx"
41 #include "../multi_shape.hxx"
89 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
98 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
105 template <
typename RF,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
115 template <
typename TREE,
179 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
186 double const EPS = 1e-20;
190 is_in_bag_.resize(weights.size(),
true);
191 for (
size_t i = 0; i < weights.size(); ++i)
195 is_in_bag_[i] =
false;
201 throw std::runtime_error(
"OOBError::visit_before_tree(): The tree has no out-of-bags.");
207 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
211 const FEATURES & features,
212 const LABELS & labels
215 vigra_precondition(rf.num_trees() > 0,
"OOBError::visit_after_training(): Number of trees must be greater than zero after training.");
216 vigra_precondition(visitors.size() == rf.num_trees(),
"OOBError::visit_after_training(): Number of visitors must be equal to number of trees.");
217 size_t const num_instances = features.shape()[0];
218 auto const num_features = features.shape()[1];
219 for (
auto vptr : visitors)
220 vigra_precondition(vptr->is_in_bag_.size() == num_instances,
"OOBError::visit_after_training(): Some visitors have the wrong number of data points.");
223 typedef typename std::remove_const<LABELS>::type Labels;
226 for (
size_t i = 0; i < (size_t)num_instances; ++i)
229 std::vector<size_t> tree_indices;
230 for (
size_t k = 0; k < visitors.size(); ++k)
231 if (!visitors[k]->is_in_bag_[i])
232 tree_indices.push_back(k);
235 auto const sub_features = features.subarray(Shape2(i, 0), Shape2(i+1, num_features));
236 rf.predict(sub_features, pred, 1, tree_indices);
237 if (pred(0) != labels(i))
249 std::vector<bool> is_in_bag_;
269 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
279 auto const num_features = features.shape()[1];
283 double const EPS = 1e-20;
285 is_in_bag_.resize(weights.size(),
true);
286 for (
size_t i = 0; i < weights.size(); ++i)
290 is_in_bag_[i] =
false;
295 throw std::runtime_error(
"VariableImportance::visit_before_tree(): The tree has no out-of-bags.");
301 template <
typename TREE,
317 typename SCORER::Functor functor;
318 auto const region_impurity = functor.region_score(labels, weights, begin, end);
319 auto const split_impurity = scorer.best_score_;
320 variable_importance_(scorer.best_dim_, tree.num_classes()+1) += region_impurity - split_impurity;
326 template <
typename RF,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
328 const FEATURES & features,
329 const LABELS & labels,
333 typedef typename std::remove_const<FEATURES>::type Features;
334 typedef typename std::remove_const<LABELS>::type Labels;
336 typedef typename Features::value_type FeatureType;
338 auto const num_features = features.shape()[1];
345 copy_out_of_bags(features, labels, feats, labs);
346 auto const num_oobs = feats.shape()[0];
351 rf.predict(feats, pred, 1);
352 for (
size_t i = 0; i < (size_t)labs.size(); ++i)
354 if (labs(i) == pred(i))
356 oob_right(labs(i)) += 1.0;
357 oob_right(rf.num_classes()) += 1.0;
363 for (
size_t j = 0; j < (size_t)num_features; ++j)
366 backup = feats.template bind<1>(j);
372 for (
int ii = num_oobs-1; ii >= 1; --ii)
373 std::swap(feats(ii, j), feats(randint(ii+1), j));
376 rf.predict(feats, pred, 1);
377 for (
size_t i = 0; i < (size_t)labs.size(); ++i)
379 if (labs(i) == pred(i))
381 perm_oob_right(0, labs(i)) += 1.0;
382 perm_oob_right(0, rf.num_classes()) += 1.0;
389 perm_oob_right.bind<0>(0) -= oob_right;
390 perm_oob_right *= -1;
391 perm_oob_right /= num_oobs;
395 feats.template bind<1>(j) = backup;
402 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
406 const FEATURES & features,
409 vigra_precondition(rf.num_trees() > 0,
"VariableImportance::visit_after_training(): Number of trees must be greater than zero after training.");
410 vigra_precondition(visitors.size() == rf.num_trees(),
"VariableImportance::visit_after_training(): Number of visitors must be equal to number of trees.");
413 auto const num_features = features.shape()[1];
415 for (
auto vptr : visitors)
418 "VariableImportance::visit_after_training(): Shape mismatch.");
464 template <
typename F0,
typename L0,
typename F1,
typename L1>
465 void copy_out_of_bags(
466 F0
const & features_in,
467 L0
const & labels_in,
471 auto const num_instances = features_in.shape()[0];
472 auto const num_features = features_in.shape()[1];
476 for (
auto x : is_in_bag_)
481 features_out.reshape(Shape2(num_oobs, num_features));
482 labels_out.reshape(
Shape1(num_oobs));
484 for (
size_t i = 0; i < (size_t)num_instances; ++i)
488 auto const src = features_in.template bind<0>(i);
489 auto out = features_out.template bind<0>(current);
491 labels_out(current) = labels_in(i);
497 std::vector<bool> is_in_bag_;
518 template <
typename VISITOR,
typename NEXT = RFStopVisiting,
bool CPY = false>
523 typedef VISITOR Visitor;
526 typename std::conditional<CPY, Visitor, Visitor &>::type visitor_;
543 visitor_(other.visitor_),
549 visitor_(other.visitor_),
553 void visit_before_training()
555 if (visitor_.is_active())
556 visitor_.visit_before_training();
557 next_.visit_before_training();
560 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
561 void visit_after_training(VISITORS & v, RF & rf,
const FEATURES & features,
const LABELS & labels)
563 typedef typename VISITORS::value_type VisitorNodeType;
564 typedef typename VisitorNodeType::Visitor VisitorType;
565 typedef typename VisitorNodeType::Next NextType;
571 if (visitor_.is_active())
573 std::vector<VisitorType*> visitors;
575 visitors.push_back(&x.visitor_);
576 visitor_.visit_after_training(visitors, rf, features, labels);
580 std::vector<NextType> nexts;
582 nexts.push_back(x.next_);
585 next_.visit_after_training(nexts, rf, features, labels);
588 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
589 void visit_before_tree(TREE & tree, FEATURES & features, LABELS & labels, WEIGHTS & weights)
591 if (visitor_.is_active())
592 visitor_.visit_before_tree(tree, features, labels, weights);
593 next_.visit_before_tree(tree, features, labels, weights);
596 template <
typename RF,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
597 void visit_after_tree(RF & rf,
602 if (visitor_.is_active())
603 visitor_.visit_after_tree(rf, features, labels, weights);
604 next_.visit_after_tree(rf, features, labels, weights);
607 template <
typename TREE,
613 void visit_after_split(TREE & tree,
622 if (visitor_.is_active())
623 visitor_.visit_after_split(tree, features, labels, weights, scorer, begin, split, end);
624 next_.visit_after_split(tree, features, labels, weights, scorer, begin, split, end);
634 template <
typename VISITOR>
654 detail::RFVisitorNode<A>
655 create_visitor(A & a)
657 typedef detail::RFVisitorNode<A> _0_t;
662 template<
typename A,
typename B>
663 detail::RFVisitorNode<A, detail::RFVisitorNode<B> >
664 create_visitor(A & a, B & b)
666 typedef detail::RFVisitorNode<B> _1_t;
668 typedef detail::RFVisitorNode<A, _1_t> _0_t;
673 template<
typename A,
typename B,
typename C>
674 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C> > >
677 typedef detail::RFVisitorNode<C> _2_t;
679 typedef detail::RFVisitorNode<B, _2_t> _1_t;
681 typedef detail::RFVisitorNode<A, _1_t> _0_t;
686 template<
typename A,
typename B,
typename C,
typename D>
687 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
688 detail::RFVisitorNode<D> > > >
691 typedef detail::RFVisitorNode<D> _3_t;
693 typedef detail::RFVisitorNode<C, _3_t> _2_t;
695 typedef detail::RFVisitorNode<B, _2_t> _1_t;
697 typedef detail::RFVisitorNode<A, _1_t> _0_t;
702 template<
typename A,
typename B,
typename C,
typename D,
typename E>
703 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
704 detail::RFVisitorNode<D, detail::RFVisitorNode<E> > > > >
707 typedef detail::RFVisitorNode<E> _4_t;
709 typedef detail::RFVisitorNode<D, _4_t> _3_t;
711 typedef detail::RFVisitorNode<C, _3_t> _2_t;
713 typedef detail::RFVisitorNode<B, _2_t> _1_t;
715 typedef detail::RFVisitorNode<A, _1_t> _0_t;
720 template<
typename A,
typename B,
typename C,
typename D,
typename E,
722 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
723 detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F> > > > > >
726 typedef detail::RFVisitorNode<F> _5_t;
728 typedef detail::RFVisitorNode<E, _5_t> _4_t;
730 typedef detail::RFVisitorNode<D, _4_t> _3_t;
732 typedef detail::RFVisitorNode<C, _3_t> _2_t;
734 typedef detail::RFVisitorNode<B, _2_t> _1_t;
736 typedef detail::RFVisitorNode<A, _1_t> _0_t;
741 template<
typename A,
typename B,
typename C,
typename D,
typename E,
742 typename F,
typename G>
743 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
744 detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
745 detail::RFVisitorNode<G> > > > > > >
748 typedef detail::RFVisitorNode<G> _6_t;
750 typedef detail::RFVisitorNode<F, _6_t> _5_t;
752 typedef detail::RFVisitorNode<E, _5_t> _4_t;
754 typedef detail::RFVisitorNode<D, _4_t> _3_t;
756 typedef detail::RFVisitorNode<C, _3_t> _2_t;
758 typedef detail::RFVisitorNode<B, _2_t> _1_t;
760 typedef detail::RFVisitorNode<A, _1_t> _0_t;
765 template<
typename A,
typename B,
typename C,
typename D,
typename E,
766 typename F,
typename G,
typename H>
767 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
768 detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
769 detail::RFVisitorNode<G, detail::RFVisitorNode<H> > > > > > > >
770 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h)
772 typedef detail::RFVisitorNode<H> _7_t;
774 typedef detail::RFVisitorNode<G, _7_t> _6_t;
776 typedef detail::RFVisitorNode<F, _6_t> _5_t;
778 typedef detail::RFVisitorNode<E, _5_t> _4_t;
780 typedef detail::RFVisitorNode<D, _4_t> _3_t;
782 typedef detail::RFVisitorNode<C, _3_t> _2_t;
784 typedef detail::RFVisitorNode<B, _2_t> _1_t;
786 typedef detail::RFVisitorNode<A, _1_t> _0_t;
791 template<
typename A,
typename B,
typename C,
typename D,
typename E,
792 typename F,
typename G,
typename H,
typename I>
793 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
794 detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
795 detail::RFVisitorNode<G, detail::RFVisitorNode<H, detail::RFVisitorNode<I> > > > > > > > >
796 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h, I & i)
798 typedef detail::RFVisitorNode<I> _8_t;
800 typedef detail::RFVisitorNode<H, _8_t> _7_t;
802 typedef detail::RFVisitorNode<G, _7_t> _6_t;
804 typedef detail::RFVisitorNode<F, _6_t> _5_t;
806 typedef detail::RFVisitorNode<E, _5_t> _4_t;
808 typedef detail::RFVisitorNode<D, _4_t> _3_t;
810 typedef detail::RFVisitorNode<C, _3_t> _2_t;
812 typedef detail::RFVisitorNode<B, _2_t> _1_t;
814 typedef detail::RFVisitorNode<A, _1_t> _0_t;
819 template<
typename A,
typename B,
typename C,
typename D,
typename E,
820 typename F,
typename G,
typename H,
typename I,
typename J>
821 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
822 detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
823 detail::RFVisitorNode<G, detail::RFVisitorNode<H, detail::RFVisitorNode<I,
824 detail::RFVisitorNode<J> > > > > > > > > >
825 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h, I & i,
828 typedef detail::RFVisitorNode<J> _9_t;
830 typedef detail::RFVisitorNode<I, _9_t> _8_t;
832 typedef detail::RFVisitorNode<H, _8_t> _7_t;
834 typedef detail::RFVisitorNode<G, _7_t> _6_t;
836 typedef detail::RFVisitorNode<F, _6_t> _5_t;
838 typedef detail::RFVisitorNode<E, _5_t> _4_t;
840 typedef detail::RFVisitorNode<D, _4_t> _3_t;
842 typedef detail::RFVisitorNode<C, _3_t> _2_t;
844 typedef detail::RFVisitorNode<B, _2_t> _1_t;
846 typedef detail::RFVisitorNode<A, _1_t> _0_t;
void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &)
Do something before a tree has been learned.
Definition: random_forest_visitors.hxx:99
void visit_before_tree(TREE &tree, FEATURES &features, LABELS &, WEIGHTS &weights)
Definition: random_forest_visitors.hxx:270
void visit_after_split(TREE &, FEATURES &, LABELS &, WEIGHTS &, SCORER &, ITER, ITER, ITER)
Do something after the split was made.
Definition: random_forest_visitors.hxx:121
const difference_type & shape() const
Definition: multi_array.hxx:1648
void visit_after_tree(RF &, FEATURES &, LABELS &, WEIGHTS &)
Do something after a tree has been learned.
Definition: random_forest_visitors.hxx:106
void deactivate()
Deactivate the visitor.
Definition: random_forest_visitors.hxx:150
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2861
Base class from which all random forest visitors derive.
Definition: random_forest_visitors.hxx:68
size_t repetition_count_
Definition: random_forest_visitors.hxx:457
The default visitor node (= "do nothing").
Definition: random_forest_visitors.hxx:509
void visit_after_training(VISITORS &visitors, RF &rf, const FEATURES &features, const LABELS &labels)
Definition: random_forest_visitors.hxx:208
Compute the variable importance.
Definition: random_forest_visitors.hxx:257
Compute the out of bag error.
Definition: random_forest_visitors.hxx:172
double oob_err_
Definition: random_forest_visitors.hxx:246
void visit_after_tree(RF &rf, const FEATURES &features, const LABELS &labels, WEIGHTS &)
Definition: random_forest_visitors.hxx:327
Definition: random_forest_visitors.hxx:635
void activate()
Activate the visitor.
Definition: random_forest_visitors.hxx:142
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:940
void visit_after_split(TREE &tree, FEATURES &, LABELS &labels, WEIGHTS &weights, SCORER &scorer, ITER begin, ITER, ITER end)
Definition: random_forest_visitors.hxx:307
void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &weights)
Definition: random_forest_visitors.hxx:180
Container elements of the statically linked visitor list. Use the create_visitor() functions to creat...
Definition: random_forest_visitors.hxx:519
MultiArray< 2, double > variable_importance_
Definition: random_forest_visitors.hxx:452
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition: fftw3.hxx:1002
bool is_active() const
Return whether the visitor is active or not.
Definition: random_forest_visitors.hxx:134
void visit_before_training()
Do something before training starts.
Definition: random_forest_visitors.hxx:80
MultiArrayView subarray(difference_type p, difference_type q) const
Definition: multi_array.hxx:1528
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:344
void visit_after_training(VISITORS &, RF &, const FEATURES &, const LABELS &)
Do something after all trees have been learned.
Definition: random_forest_visitors.hxx:90
void visit_after_training(VISITORS &visitors, RF &rf, const FEATURES &features, const LABELS &)
Definition: random_forest_visitors.hxx:403