[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_visitors.hxx | ![]() |
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 #ifndef RF_VISITORS_HXX 00036 #define RF_VISITORS_HXX 00037 00038 #ifdef HasHDF5 00039 # include "vigra/hdf5impex.hxx" 00040 #endif // HasHDF5 00041 #include <vigra/windows.h> 00042 #include <iostream> 00043 #include <iomanip> 00044 #include <vigra/timing.hxx> 00045 00046 namespace vigra 00047 { 00048 namespace rf 00049 { 00050 /** \addtogroup MachineLearning Machine Learning 00051 **/ 00052 //@{ 00053 00054 /** 00055 This namespace contains all classes and methods related to extracting information during 00056 learning of the random forest. All Visitors share the same interface defined in 00057 visitors::VisitorBase. The member methods are invoked at certain points of the main code in 00058 the order they were supplied. 00059 00060 For the Random Forest the Visitor concept is implemented as a statically linked list 00061 (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The 00062 VisitorNode object calls the Next Visitor after one of its visit() methods have terminated. 00063 00064 To simplify usage create_visitor() factory methods are supplied. 00065 Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method. 00066 It is possible to supply more than one visitor. They will then be invoked in serial order. 00067 00068 The calculated information are stored as public data members of the class. - see documentation 00069 of the individual visitors 00070 00071 While creating a new visitor the new class should therefore publicly inherit from this class 00072 (i.e.: see visitors::OOB_Error). 00073 00074 \code 00075 00076 typedef xxx feature_t \\ replace xxx with whichever type 00077 typedef yyy label_t \\ meme chose. 00078 MultiArrayView<2, feature_t> f = get_some_features(); 00079 MultiArrayView<2, label_t> l = get_some_labels(); 00080 RandomForest<> rf() 00081 00082 //calculate OOB Error 00083 visitors::OOB_Error oob_v; 00084 //calculate Variable Importance 00085 visitors::VariableImportanceVisitor varimp_v; 00086 00087 double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v); 00088 //the data can be found in the attributes of oob_v and varimp_v now 00089 00090 \endcode 00091 */ 00092 namespace visitors 00093 { 00094 00095 00096 /** Base Class from which all Visitors derive. Can be used as a template to create new 00097 * Visitors. 00098 */ 00099 class VisitorBase 00100 { 00101 public: 00102 bool active_; 00103 bool is_active() 00104 { 00105 return active_; 00106 } 00107 00108 bool has_value() 00109 { 00110 return false; 00111 } 00112 00113 VisitorBase() 00114 : active_(true) 00115 {} 00116 00117 void deactivate() 00118 { 00119 active_ = false; 00120 } 00121 void activate() 00122 { 00123 active_ = true; 00124 } 00125 00126 /** do something after the the Split has decided how to process the Region 00127 * (Stack entry) 00128 * 00129 * \param tree reference to the tree that is currently being learned 00130 * \param split reference to the split object 00131 * \param parent current stack entry which was used to decide the split 00132 * \param leftChild left stack entry that will be pushed 00133 * \param rightChild 00134 * right stack entry that will be pushed. 00135 * \param features features matrix 00136 * \param labels label matrix 00137 * \sa RF_Traits::StackEntry_t 00138 */ 00139 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00140 void visit_after_split( Tree & tree, 00141 Split & split, 00142 Region & parent, 00143 Region & leftChild, 00144 Region & rightChild, 00145 Feature_t & features, 00146 Label_t & labels) 00147 {} 00148 00149 /** do something after each tree has been learned 00150 * 00151 * \param rf reference to the random forest object that called this 00152 * visitor 00153 * \param pr reference to the preprocessor that processed the input 00154 * \param sm reference to the sampler object 00155 * \param st reference to the first stack entry 00156 * \param index index of current tree 00157 */ 00158 template<class RF, class PR, class SM, class ST> 00159 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00160 {} 00161 00162 /** do something after all trees have been learned 00163 * 00164 * \param rf reference to the random forest object that called this 00165 * visitor 00166 * \param pr reference to the preprocessor that processed the input 00167 */ 00168 template<class RF, class PR> 00169 void visit_at_end(RF const & rf, PR const & pr) 00170 {} 00171 00172 /** do something before learning starts 00173 * 00174 * \param rf reference to the random forest object that called this 00175 * visitor 00176 * \param pr reference to the Processor class used. 00177 */ 00178 template<class RF, class PR> 00179 void visit_at_beginning(RF const & rf, PR const & pr) 00180 {} 00181 /** do some thing while traversing tree after it has been learned 00182 * (external nodes) 00183 * 00184 * \param tr reference to the tree object that called this visitor 00185 * \param index index in the topology_ array we currently are at 00186 * \param node_t type of node we have (will be e_.... - ) 00187 * \param weight Node weight of current node. 00188 * \sa NodeTags; 00189 * 00190 * you can create the node by using a switch on node_tag and using the 00191 * corresponding Node objects. Or - if you do not care about the type 00192 * use the Nodebase class. 00193 */ 00194 template<class TR, class IntT, class TopT,class Feat> 00195 void visit_external_node(TR & tr, IntT index, TopT node_t,Feat & features) 00196 {} 00197 00198 /** do something when visiting a internal node after it has been learned 00199 * 00200 * \sa visit_external_node 00201 */ 00202 template<class TR, class IntT, class TopT,class Feat> 00203 void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features) 00204 {} 00205 00206 /** return a double value. The value of the first 00207 * visitor encountered that has a return value is returned with the 00208 * RandomForest::learn() method - or -1.0 if no return value visitor 00209 * existed. This functionality basically only exists so that the 00210 * OOB - visitor can return the oob error rate like in the old version 00211 * of the random forest. 00212 */ 00213 double return_val() 00214 { 00215 return -1.0; 00216 } 00217 }; 00218 00219 00220 /** Last Visitor that should be called to stop the recursion. 00221 */ 00222 class StopVisiting: public VisitorBase 00223 { 00224 public: 00225 bool has_value() 00226 { 00227 return true; 00228 } 00229 double return_val() 00230 { 00231 return -1.0; 00232 } 00233 }; 00234 namespace detail 00235 { 00236 /** Container elements of the statically linked Visitor list. 00237 * 00238 * use the create_visitor() factory functions to create visitors up to size 10; 00239 * 00240 */ 00241 template <class Visitor, class Next = StopVisiting> 00242 class VisitorNode 00243 { 00244 public: 00245 00246 StopVisiting stop_; 00247 Next next_; 00248 Visitor & visitor_; 00249 VisitorNode(Visitor & visitor, Next & next) 00250 : 00251 next_(next), visitor_(visitor) 00252 {} 00253 00254 VisitorNode(Visitor & visitor) 00255 : 00256 next_(stop_), visitor_(visitor) 00257 {} 00258 00259 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00260 void visit_after_split( Tree & tree, 00261 Split & split, 00262 Region & parent, 00263 Region & leftChild, 00264 Region & rightChild, 00265 Feature_t & features, 00266 Label_t & labels) 00267 { 00268 if(visitor_.is_active()) 00269 visitor_.visit_after_split(tree, split, 00270 parent, leftChild, rightChild, 00271 features, labels); 00272 next_.visit_after_split(tree, split, parent, leftChild, rightChild, 00273 features, labels); 00274 } 00275 00276 template<class RF, class PR, class SM, class ST> 00277 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00278 { 00279 if(visitor_.is_active()) 00280 visitor_.visit_after_tree(rf, pr, sm, st, index); 00281 next_.visit_after_tree(rf, pr, sm, st, index); 00282 } 00283 00284 template<class RF, class PR> 00285 void visit_at_beginning(RF & rf, PR & pr) 00286 { 00287 if(visitor_.is_active()) 00288 visitor_.visit_at_beginning(rf, pr); 00289 next_.visit_at_beginning(rf, pr); 00290 } 00291 template<class RF, class PR> 00292 void visit_at_end(RF & rf, PR & pr) 00293 { 00294 if(visitor_.is_active()) 00295 visitor_.visit_at_end(rf, pr); 00296 next_.visit_at_end(rf, pr); 00297 } 00298 00299 template<class TR, class IntT, class TopT,class Feat> 00300 void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features) 00301 { 00302 if(visitor_.is_active()) 00303 visitor_.visit_external_node(tr, index, node_t,features); 00304 next_.visit_external_node(tr, index, node_t,features); 00305 } 00306 template<class TR, class IntT, class TopT,class Feat> 00307 void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features) 00308 { 00309 if(visitor_.is_active()) 00310 visitor_.visit_internal_node(tr, index, node_t,features); 00311 next_.visit_internal_node(tr, index, node_t,features); 00312 } 00313 00314 double return_val() 00315 { 00316 if(visitor_.is_active() && visitor_.has_value()) 00317 return visitor_.return_val(); 00318 return next_.return_val(); 00319 } 00320 }; 00321 00322 } //namespace detail 00323 00324 ////////////////////////////////////////////////////////////////////////////// 00325 // Visitor Factory function up to 10 visitors // 00326 ////////////////////////////////////////////////////////////////////////////// 00327 00328 /** factory method to to be used with RandomForest::learn() 00329 */ 00330 template<class A> 00331 detail::VisitorNode<A> 00332 create_visitor(A & a) 00333 { 00334 typedef detail::VisitorNode<A> _0_t; 00335 _0_t _0(a); 00336 return _0; 00337 } 00338 00339 00340 /** factory method to to be used with RandomForest::learn() 00341 */ 00342 template<class A, class B> 00343 detail::VisitorNode<A, detail::VisitorNode<B> > 00344 create_visitor(A & a, B & b) 00345 { 00346 typedef detail::VisitorNode<B> _1_t; 00347 _1_t _1(b); 00348 typedef detail::VisitorNode<A, _1_t> _0_t; 00349 _0_t _0(a, _1); 00350 return _0; 00351 } 00352 00353 00354 /** factory method to to be used with RandomForest::learn() 00355 */ 00356 template<class A, class B, class C> 00357 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > > 00358 create_visitor(A & a, B & b, C & c) 00359 { 00360 typedef detail::VisitorNode<C> _2_t; 00361 _2_t _2(c); 00362 typedef detail::VisitorNode<B, _2_t> _1_t; 00363 _1_t _1(b, _2); 00364 typedef detail::VisitorNode<A, _1_t> _0_t; 00365 _0_t _0(a, _1); 00366 return _0; 00367 } 00368 00369 00370 /** factory method to to be used with RandomForest::learn() 00371 */ 00372 template<class A, class B, class C, class D> 00373 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00374 detail::VisitorNode<D> > > > 00375 create_visitor(A & a, B & b, C & c, D & d) 00376 { 00377 typedef detail::VisitorNode<D> _3_t; 00378 _3_t _3(d); 00379 typedef detail::VisitorNode<C, _3_t> _2_t; 00380 _2_t _2(c, _3); 00381 typedef detail::VisitorNode<B, _2_t> _1_t; 00382 _1_t _1(b, _2); 00383 typedef detail::VisitorNode<A, _1_t> _0_t; 00384 _0_t _0(a, _1); 00385 return _0; 00386 } 00387 00388 00389 /** factory method to to be used with RandomForest::learn() 00390 */ 00391 template<class A, class B, class C, class D, class E> 00392 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00393 detail::VisitorNode<D, detail::VisitorNode<E> > > > > 00394 create_visitor(A & a, B & b, C & c, 00395 D & d, E & e) 00396 { 00397 typedef detail::VisitorNode<E> _4_t; 00398 _4_t _4(e); 00399 typedef detail::VisitorNode<D, _4_t> _3_t; 00400 _3_t _3(d, _4); 00401 typedef detail::VisitorNode<C, _3_t> _2_t; 00402 _2_t _2(c, _3); 00403 typedef detail::VisitorNode<B, _2_t> _1_t; 00404 _1_t _1(b, _2); 00405 typedef detail::VisitorNode<A, _1_t> _0_t; 00406 _0_t _0(a, _1); 00407 return _0; 00408 } 00409 00410 00411 /** factory method to to be used with RandomForest::learn() 00412 */ 00413 template<class A, class B, class C, class D, class E, 00414 class F> 00415 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00416 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > > 00417 create_visitor(A & a, B & b, C & c, 00418 D & d, E & e, F & f) 00419 { 00420 typedef detail::VisitorNode<F> _5_t; 00421 _5_t _5(f); 00422 typedef detail::VisitorNode<E, _5_t> _4_t; 00423 _4_t _4(e, _5); 00424 typedef detail::VisitorNode<D, _4_t> _3_t; 00425 _3_t _3(d, _4); 00426 typedef detail::VisitorNode<C, _3_t> _2_t; 00427 _2_t _2(c, _3); 00428 typedef detail::VisitorNode<B, _2_t> _1_t; 00429 _1_t _1(b, _2); 00430 typedef detail::VisitorNode<A, _1_t> _0_t; 00431 _0_t _0(a, _1); 00432 return _0; 00433 } 00434 00435 00436 /** factory method to to be used with RandomForest::learn() 00437 */ 00438 template<class A, class B, class C, class D, class E, 00439 class F, class G> 00440 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00441 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 00442 detail::VisitorNode<G> > > > > > > 00443 create_visitor(A & a, B & b, C & c, 00444 D & d, E & e, F & f, G & g) 00445 { 00446 typedef detail::VisitorNode<G> _6_t; 00447 _6_t _6(g); 00448 typedef detail::VisitorNode<F, _6_t> _5_t; 00449 _5_t _5(f, _6); 00450 typedef detail::VisitorNode<E, _5_t> _4_t; 00451 _4_t _4(e, _5); 00452 typedef detail::VisitorNode<D, _4_t> _3_t; 00453 _3_t _3(d, _4); 00454 typedef detail::VisitorNode<C, _3_t> _2_t; 00455 _2_t _2(c, _3); 00456 typedef detail::VisitorNode<B, _2_t> _1_t; 00457 _1_t _1(b, _2); 00458 typedef detail::VisitorNode<A, _1_t> _0_t; 00459 _0_t _0(a, _1); 00460 return _0; 00461 } 00462 00463 00464 /** factory method to to be used with RandomForest::learn() 00465 */ 00466 template<class A, class B, class C, class D, class E, 00467 class F, class G, class H> 00468 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00469 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 00470 detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > > 00471 create_visitor(A & a, B & b, C & c, 00472 D & d, E & e, F & f, 00473 G & g, H & h) 00474 { 00475 typedef detail::VisitorNode<H> _7_t; 00476 _7_t _7(h); 00477 typedef detail::VisitorNode<G, _7_t> _6_t; 00478 _6_t _6(g, _7); 00479 typedef detail::VisitorNode<F, _6_t> _5_t; 00480 _5_t _5(f, _6); 00481 typedef detail::VisitorNode<E, _5_t> _4_t; 00482 _4_t _4(e, _5); 00483 typedef detail::VisitorNode<D, _4_t> _3_t; 00484 _3_t _3(d, _4); 00485 typedef detail::VisitorNode<C, _3_t> _2_t; 00486 _2_t _2(c, _3); 00487 typedef detail::VisitorNode<B, _2_t> _1_t; 00488 _1_t _1(b, _2); 00489 typedef detail::VisitorNode<A, _1_t> _0_t; 00490 _0_t _0(a, _1); 00491 return _0; 00492 } 00493 00494 00495 /** factory method to to be used with RandomForest::learn() 00496 */ 00497 template<class A, class B, class C, class D, class E, 00498 class F, class G, class H, class I> 00499 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00500 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 00501 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > > 00502 create_visitor(A & a, B & b, C & c, 00503 D & d, E & e, F & f, 00504 G & g, H & h, I & i) 00505 { 00506 typedef detail::VisitorNode<I> _8_t; 00507 _8_t _8(i); 00508 typedef detail::VisitorNode<H, _8_t> _7_t; 00509 _7_t _7(h, _8); 00510 typedef detail::VisitorNode<G, _7_t> _6_t; 00511 _6_t _6(g, _7); 00512 typedef detail::VisitorNode<F, _6_t> _5_t; 00513 _5_t _5(f, _6); 00514 typedef detail::VisitorNode<E, _5_t> _4_t; 00515 _4_t _4(e, _5); 00516 typedef detail::VisitorNode<D, _4_t> _3_t; 00517 _3_t _3(d, _4); 00518 typedef detail::VisitorNode<C, _3_t> _2_t; 00519 _2_t _2(c, _3); 00520 typedef detail::VisitorNode<B, _2_t> _1_t; 00521 _1_t _1(b, _2); 00522 typedef detail::VisitorNode<A, _1_t> _0_t; 00523 _0_t _0(a, _1); 00524 return _0; 00525 } 00526 00527 /** factory method to to be used with RandomForest::learn() 00528 */ 00529 template<class A, class B, class C, class D, class E, 00530 class F, class G, class H, class I, class J> 00531 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00532 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 00533 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I, 00534 detail::VisitorNode<J> > > > > > > > > > 00535 create_visitor(A & a, B & b, C & c, 00536 D & d, E & e, F & f, 00537 G & g, H & h, I & i, 00538 J & j) 00539 { 00540 typedef detail::VisitorNode<J> _9_t; 00541 _9_t _9(j); 00542 typedef detail::VisitorNode<I, _9_t> _8_t; 00543 _8_t _8(i, _9); 00544 typedef detail::VisitorNode<H, _8_t> _7_t; 00545 _7_t _7(h, _8); 00546 typedef detail::VisitorNode<G, _7_t> _6_t; 00547 _6_t _6(g, _7); 00548 typedef detail::VisitorNode<F, _6_t> _5_t; 00549 _5_t _5(f, _6); 00550 typedef detail::VisitorNode<E, _5_t> _4_t; 00551 _4_t _4(e, _5); 00552 typedef detail::VisitorNode<D, _4_t> _3_t; 00553 _3_t _3(d, _4); 00554 typedef detail::VisitorNode<C, _3_t> _2_t; 00555 _2_t _2(c, _3); 00556 typedef detail::VisitorNode<B, _2_t> _1_t; 00557 _1_t _1(b, _2); 00558 typedef detail::VisitorNode<A, _1_t> _0_t; 00559 _0_t _0(a, _1); 00560 return _0; 00561 } 00562 00563 ////////////////////////////////////////////////////////////////////////////// 00564 // Visitors of communal interest. // 00565 ////////////////////////////////////////////////////////////////////////////// 00566 00567 00568 /** Visitor to gain information, later needed for online learning. 00569 */ 00570 00571 class OnlineLearnVisitor: public VisitorBase 00572 { 00573 public: 00574 //Set if we adjust thresholds 00575 bool adjust_thresholds; 00576 //Current tree id 00577 int tree_id; 00578 //Last node id for finding parent 00579 int last_node_id; 00580 //Need to now the label for interior node visiting 00581 vigra::Int32 current_label; 00582 //marginal distribution for interior nodes 00583 struct MarginalDistribution 00584 { 00585 ArrayVector<Int32> leftCounts; 00586 Int32 leftTotalCounts; 00587 ArrayVector<Int32> rightCounts; 00588 Int32 rightTotalCounts; 00589 double gap_left; 00590 double gap_right; 00591 }; 00592 typedef ArrayVector<vigra::Int32> IndexList; 00593 00594 //All information for one tree 00595 struct TreeOnlineInformation 00596 { 00597 std::vector<MarginalDistribution> mag_distributions; 00598 std::vector<IndexList> index_lists; 00599 //map for linear index of mag_distiributions 00600 std::map<int,int> interior_to_index; 00601 //map for linear index of index_lists 00602 std::map<int,int> exterior_to_index; 00603 }; 00604 00605 //All trees 00606 std::vector<TreeOnlineInformation> trees_online_information; 00607 00608 /** Initilize, set the number of trees 00609 */ 00610 template<class RF,class PR> 00611 void visit_at_beginning(RF & rf,const PR & pr) 00612 { 00613 tree_id=0; 00614 trees_online_information.resize(rf.options_.tree_count_); 00615 } 00616 00617 /** Reset a tree 00618 */ 00619 void reset_tree(int tree_id) 00620 { 00621 trees_online_information[tree_id].mag_distributions.clear(); 00622 trees_online_information[tree_id].index_lists.clear(); 00623 trees_online_information[tree_id].interior_to_index.clear(); 00624 trees_online_information[tree_id].exterior_to_index.clear(); 00625 } 00626 00627 /** simply increase the tree count 00628 */ 00629 template<class RF, class PR, class SM, class ST> 00630 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00631 { 00632 tree_id++; 00633 } 00634 00635 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00636 void visit_after_split( Tree & tree, 00637 Split & split, 00638 Region & parent, 00639 Region & leftChild, 00640 Region & rightChild, 00641 Feature_t & features, 00642 Label_t & labels) 00643 { 00644 int linear_index; 00645 int addr=tree.topology_.size(); 00646 if(split.createNode().typeID() == i_ThresholdNode) 00647 { 00648 if(adjust_thresholds) 00649 { 00650 //Store marginal distribution 00651 linear_index=trees_online_information[tree_id].mag_distributions.size(); 00652 trees_online_information[tree_id].interior_to_index[addr]=linear_index; 00653 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution()); 00654 00655 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_; 00656 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_; 00657 00658 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_; 00659 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_; 00660 //Store the gap 00661 double gap_left,gap_right; 00662 int i; 00663 gap_left=features(leftChild[0],split.bestSplitColumn()); 00664 for(i=1;i<leftChild.size();++i) 00665 if(features(leftChild[i],split.bestSplitColumn())>gap_left) 00666 gap_left=features(leftChild[i],split.bestSplitColumn()); 00667 gap_right=features(rightChild[0],split.bestSplitColumn()); 00668 for(i=1;i<rightChild.size();++i) 00669 if(features(rightChild[i],split.bestSplitColumn())<gap_right) 00670 gap_right=features(rightChild[i],split.bestSplitColumn()); 00671 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left; 00672 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right; 00673 } 00674 } 00675 else 00676 { 00677 //Store index list 00678 linear_index=trees_online_information[tree_id].index_lists.size(); 00679 trees_online_information[tree_id].exterior_to_index[addr]=linear_index; 00680 00681 trees_online_information[tree_id].index_lists.push_back(IndexList()); 00682 00683 trees_online_information[tree_id].index_lists.back().resize(parent.size_,0); 00684 std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin()); 00685 } 00686 } 00687 void add_to_index_list(int tree,int node,int index) 00688 { 00689 if(!this->active_) 00690 return; 00691 TreeOnlineInformation &ti=trees_online_information[tree]; 00692 ti.index_lists[ti.exterior_to_index[node]].push_back(index); 00693 } 00694 void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index) 00695 { 00696 if(!this->active_) 00697 return; 00698 trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index]; 00699 trees_online_information[src_tree].exterior_to_index.erase(src_index); 00700 } 00701 /** do something when visiting a internal node during getToLeaf 00702 * 00703 * remember as last node id, for finding the parent of the last external node 00704 * also: adjust class counts and borders 00705 */ 00706 template<class TR, class IntT, class TopT,class Feat> 00707 void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features) 00708 { 00709 last_node_id=index; 00710 if(adjust_thresholds) 00711 { 00712 vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes"); 00713 //Check if we are in the gap 00714 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column()); 00715 TreeOnlineInformation &ti=trees_online_information[tree_id]; 00716 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]]; 00717 if(value>m.gap_left && value<m.gap_right) 00718 { 00719 //Check which site we want to go 00720 if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts)) 00721 { 00722 //We want to go left 00723 m.gap_left=value; 00724 } 00725 else 00726 { 00727 //We want to go right 00728 m.gap_right=value; 00729 } 00730 Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0; 00731 } 00732 //Adjust class counts 00733 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()) 00734 { 00735 ++m.rightTotalCounts; 00736 ++m.rightCounts[current_label]; 00737 } 00738 else 00739 { 00740 ++m.leftTotalCounts; 00741 ++m.rightCounts[current_label]; 00742 } 00743 } 00744 } 00745 /** do something when visiting a extern node during getToLeaf 00746 * 00747 * Store the new index! 00748 */ 00749 }; 00750 00751 ////////////////////////////////////////////////////////////////////////////// 00752 // Out of Bag Error estimates // 00753 ////////////////////////////////////////////////////////////////////////////// 00754 00755 00756 /** Visitor that calculates the oob error of each individual randomized 00757 * decision tree. 00758 * 00759 * After training a tree, all those samples that are OOB for this particular tree 00760 * are put down the tree and the error estimated. 00761 * the per tree oob error is the average of the individual error estimates. 00762 * (oobError = average error of one randomized tree) 00763 * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error 00764 * visitor) 00765 */ 00766 class OOB_PerTreeError:public VisitorBase 00767 { 00768 public: 00769 /** Average error of one randomized decision tree 00770 */ 00771 double oobError; 00772 00773 int totalOobCount; 00774 ArrayVector<int> oobCount,oobErrorCount; 00775 00776 OOB_PerTreeError() 00777 : oobError(0.0), 00778 totalOobCount(0) 00779 {} 00780 00781 00782 bool has_value() 00783 { 00784 return true; 00785 } 00786 00787 00788 /** does the basic calculation per tree*/ 00789 template<class RF, class PR, class SM, class ST> 00790 void visit_after_tree( RF& rf, PR & pr, SM & sm, ST & st, int index) 00791 { 00792 //do the first time called. 00793 if(int(oobCount.size()) != rf.ext_param_.row_count_) 00794 { 00795 oobCount.resize(rf.ext_param_.row_count_, 0); 00796 oobErrorCount.resize(rf.ext_param_.row_count_, 0); 00797 } 00798 // go through the samples 00799 for(int l = 0; l < rf.ext_param_.row_count_; ++l) 00800 { 00801 // if the lth sample is oob... 00802 if(!sm.is_used()[l]) 00803 { 00804 ++oobCount[l]; 00805 if( rf.tree(index) 00806 .predictLabel(rowVector(pr.features(), l)) 00807 != pr.response()(l,0)) 00808 { 00809 ++oobErrorCount[l]; 00810 } 00811 } 00812 00813 } 00814 } 00815 00816 /** Does the normalisation 00817 */ 00818 template<class RF, class PR> 00819 void visit_at_end(RF & rf, PR & pr) 00820 { 00821 // do some normalisation 00822 for(int l=0; l < (int)rf.ext_param_.row_count_; ++l) 00823 { 00824 if(oobCount[l]) 00825 { 00826 oobError += double(oobErrorCount[l]) / oobCount[l]; 00827 ++totalOobCount; 00828 } 00829 } 00830 oobError/=totalOobCount; 00831 } 00832 00833 }; 00834 00835 /** Visitor that calculates the oob error of the ensemble 00836 * This rate should be used to estimate the crossvalidation 00837 * error rate. 00838 * Here each sample is put down those trees, for which this sample 00839 * is OOB i.e. if sample #1 is OOB for trees 1, 3 and 5 we calculate 00840 * the output using the ensemble consisting only of trees 1 3 and 5. 00841 * 00842 * Using normal bagged sampling each sample is OOB for approx. 33% of trees 00843 * The error rate obtained as such therefore corresponds to crossvalidation 00844 * rate obtained using a ensemble containing 33% of the trees. 00845 */ 00846 class OOB_Error : public VisitorBase 00847 { 00848 typedef MultiArrayShape<2>::type Shp; 00849 int class_count; 00850 bool is_weighted; 00851 MultiArray<2,double> tmp_prob; 00852 public: 00853 00854 MultiArray<2, double> prob_oob; 00855 /** Ensemble oob error rate 00856 */ 00857 double oob_breiman; 00858 00859 MultiArray<2, double> oobCount; 00860 ArrayVector< int> indices; 00861 OOB_Error() : VisitorBase(), oob_breiman(0.0) {} 00862 00863 void save(std::string filen, std::string pathn) 00864 { 00865 if(*(pathn.end()-1) != '/') 00866 pathn += "/"; 00867 const char* filename = filen.c_str(); 00868 MultiArray<2, double> temp(Shp(1,1), 0.0); 00869 temp[0] = oob_breiman; 00870 writeHDF5(filename, (pathn + "breiman_error").c_str(), temp); 00871 } 00872 // negative value if sample was ib, number indicates how often. 00873 // value >=0 if sample was oob, 0 means fail 1, corrrect 00874 00875 template<class RF, class PR> 00876 void visit_at_beginning(RF & rf, PR & pr) 00877 { 00878 class_count = rf.class_count(); 00879 tmp_prob.reshape(Shp(1, class_count), 0); 00880 prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0); 00881 is_weighted = rf.options().predict_weighted_; 00882 indices.resize(rf.ext_param().row_count_); 00883 if(int(oobCount.size()) != rf.ext_param_.row_count_) 00884 { 00885 oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0); 00886 } 00887 for(int ii = 0; ii < rf.ext_param().row_count_; ++ii) 00888 { 00889 indices[ii] = ii; 00890 } 00891 } 00892 00893 template<class RF, class PR, class SM, class ST> 00894 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00895 { 00896 // go through the samples 00897 int total_oob =0; 00898 int wrong_oob =0; 00899 // FIXME: magic number 10000: invoke special treatment when when msample << sample_count 00900 // (i.e. the OOB sample ist very large) 00901 // 40000: use at most 40000 OOB samples per class for OOB error estimate 00902 if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000) 00903 { 00904 ArrayVector<int> oob_indices; 00905 ArrayVector<int> cts(class_count, 0); 00906 std::random_shuffle(indices.begin(), indices.end()); 00907 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii) 00908 { 00909 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000) 00910 { 00911 oob_indices.push_back(indices[ii]); 00912 ++cts[pr.response()(indices[ii], 0)]; 00913 } 00914 } 00915 for(unsigned int ll = 0; ll < oob_indices.size(); ++ll) 00916 { 00917 // update number of trees in which current sample is oob 00918 ++oobCount[oob_indices[ll]]; 00919 00920 // update number of oob samples in this tree. 00921 ++total_oob; 00922 // get the predicted votes ---> tmp_prob; 00923 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll])); 00924 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 00925 rf.tree(index).parameters_, 00926 pos); 00927 tmp_prob.init(0); 00928 for(int ii = 0; ii < class_count; ++ii) 00929 { 00930 tmp_prob[ii] = node.prob_begin()[ii]; 00931 } 00932 if(is_weighted) 00933 { 00934 for(int ii = 0; ii < class_count; ++ii) 00935 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1)); 00936 } 00937 rowVector(prob_oob, oob_indices[ll]) += tmp_prob; 00938 int label = argMax(tmp_prob); 00939 00940 } 00941 }else 00942 { 00943 for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll) 00944 { 00945 // if the lth sample is oob... 00946 if(!sm.is_used()[ll]) 00947 { 00948 // update number of trees in which current sample is oob 00949 ++oobCount[ll]; 00950 00951 // update number of oob samples in this tree. 00952 ++total_oob; 00953 // get the predicted votes ---> tmp_prob; 00954 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll)); 00955 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 00956 rf.tree(index).parameters_, 00957 pos); 00958 tmp_prob.init(0); 00959 for(int ii = 0; ii < class_count; ++ii) 00960 { 00961 tmp_prob[ii] = node.prob_begin()[ii]; 00962 } 00963 if(is_weighted) 00964 { 00965 for(int ii = 0; ii < class_count; ++ii) 00966 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1)); 00967 } 00968 rowVector(prob_oob, ll) += tmp_prob; 00969 int label = argMax(tmp_prob); 00970 00971 } 00972 } 00973 } 00974 // go through the ib samples; 00975 } 00976 00977 /** Normalise variable importance after the number of trees is known. 00978 */ 00979 template<class RF, class PR> 00980 void visit_at_end(RF & rf, PR & pr) 00981 { 00982 // ullis original metric and breiman style stuff 00983 int totalOobCount =0; 00984 int breimanstyle = 0; 00985 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll) 00986 { 00987 if(oobCount[ll]) 00988 { 00989 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0)) 00990 ++breimanstyle; 00991 ++totalOobCount; 00992 } 00993 } 00994 oob_breiman = double(breimanstyle)/totalOobCount; 00995 } 00996 }; 00997 00998 00999 /** Visitor that calculates different OOB error statistics 01000 */ 01001 class CompleteOOBInfo : public VisitorBase 01002 { 01003 typedef MultiArrayShape<2>::type Shp; 01004 int class_count; 01005 bool is_weighted; 01006 MultiArray<2,double> tmp_prob; 01007 public: 01008 01009 /** OOB Error rate of each individual tree 01010 */ 01011 MultiArray<2, double> oob_per_tree; 01012 /** Mean of oob_per_tree 01013 */ 01014 double oob_mean; 01015 /**Standard deviation of oob_per_tree 01016 */ 01017 double oob_std; 01018 01019 MultiArray<2, double> prob_oob; 01020 /** Ensemble OOB error 01021 * 01022 * \sa OOB_Error 01023 */ 01024 double oob_breiman; 01025 01026 MultiArray<2, double> oobCount; 01027 MultiArray<2, double> oobErrorCount; 01028 /** Per Tree OOB error calculated as in OOB_PerTreeError 01029 * (Ulli's version) 01030 */ 01031 double oob_per_tree2; 01032 01033 /**Column containing the development of the Ensemble 01034 * error rate with increasing number of trees 01035 */ 01036 MultiArray<2, double> breiman_per_tree; 01037 /** 4 dimensional array containing the development of confusion matrices 01038 * with number of trees - can be used to estimate ROC curves etc. 01039 * 01040 * oobroc_per_tree(ii,jj,kk,ll) 01041 * corresponds true label = ii 01042 * predicted label = jj 01043 * confusion matrix after ll trees 01044 * 01045 * explaination of third index: 01046 * 01047 * Two class case: 01048 * kk = 0 - (treeCount-1) 01049 * Threshold is on Probability for class 0 is kk/(treeCount-1); 01050 * More classes: 01051 * kk = 0. Threshold on probability set by argMax of the probability array. 01052 */ 01053 MultiArray<4, double> oobroc_per_tree; 01054 01055 CompleteOOBInfo() : VisitorBase(), oob_mean(0), oob_std(0), oob_per_tree2(0) {} 01056 01057 /** save to HDF5 file 01058 */ 01059 void save(std::string filen, std::string pathn) 01060 { 01061 if(*(pathn.end()-1) != '/') 01062 pathn += "/"; 01063 const char* filename = filen.c_str(); 01064 MultiArray<2, double> temp(Shp(1,1), 0.0); 01065 writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree); 01066 writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree); 01067 writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree); 01068 temp[0] = oob_mean; 01069 writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp); 01070 temp[0] = oob_std; 01071 writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp); 01072 temp[0] = oob_breiman; 01073 writeHDF5(filename, (pathn + "breiman_error").c_str(), temp); 01074 temp[0] = oob_per_tree2; 01075 writeHDF5(filename, (pathn + "ulli_error").c_str(), temp); 01076 } 01077 // negative value if sample was ib, number indicates how often. 01078 // value >=0 if sample was oob, 0 means fail 1, corrrect 01079 01080 template<class RF, class PR> 01081 void visit_at_beginning(RF & rf, PR & pr) 01082 { 01083 class_count = rf.class_count(); 01084 if(class_count == 2) 01085 oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count())); 01086 else 01087 oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count())); 01088 tmp_prob.reshape(Shp(1, class_count), 0); 01089 prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0); 01090 is_weighted = rf.options().predict_weighted_; 01091 oob_per_tree.reshape(Shp(1, rf.tree_count()), 0); 01092 breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0); 01093 //do the first time called. 01094 if(int(oobCount.size()) != rf.ext_param_.row_count_) 01095 { 01096 oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0); 01097 oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0); 01098 } 01099 } 01100 01101 template<class RF, class PR, class SM, class ST> 01102 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 01103 { 01104 // go through the samples 01105 int total_oob =0; 01106 int wrong_oob =0; 01107 for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll) 01108 { 01109 // if the lth sample is oob... 01110 if(!sm.is_used()[ll]) 01111 { 01112 // update number of trees in which current sample is oob 01113 ++oobCount[ll]; 01114 01115 // update number of oob samples in this tree. 01116 ++total_oob; 01117 // get the predicted votes ---> tmp_prob; 01118 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll)); 01119 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 01120 rf.tree(index).parameters_, 01121 pos); 01122 tmp_prob.init(0); 01123 for(int ii = 0; ii < class_count; ++ii) 01124 { 01125 tmp_prob[ii] = node.prob_begin()[ii]; 01126 } 01127 if(is_weighted) 01128 { 01129 for(int ii = 0; ii < class_count; ++ii) 01130 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1)); 01131 } 01132 rowVector(prob_oob, ll) += tmp_prob; 01133 int label = argMax(tmp_prob); 01134 01135 if(label != pr.response()(ll, 0)) 01136 { 01137 // update number of wrong oob samples in this tree. 01138 ++wrong_oob; 01139 // update number of trees in which current sample is wrong oob 01140 ++oobErrorCount[ll]; 01141 } 01142 } 01143 } 01144 int breimanstyle = 0; 01145 int totalOobCount = 0; 01146 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll) 01147 { 01148 if(oobCount[ll]) 01149 { 01150 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0)) 01151 ++breimanstyle; 01152 ++totalOobCount; 01153 if(oobroc_per_tree.shape(2) == 1) 01154 { 01155 oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++; 01156 } 01157 } 01158 } 01159 if(oobroc_per_tree.shape(2) == 1) 01160 oobroc_per_tree.bindOuter(index)/=totalOobCount; 01161 if(oobroc_per_tree.shape(2) > 1) 01162 { 01163 MultiArrayView<3, double> current_roc 01164 = oobroc_per_tree.bindOuter(index); 01165 for(int gg = 0; gg < current_roc.shape(2); ++gg) 01166 { 01167 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll) 01168 { 01169 if(oobCount[ll]) 01170 { 01171 int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))? 01172 1 : 0; 01173 current_roc(pr.response()(ll, 0), pred, gg)+= 1; 01174 } 01175 } 01176 current_roc.bindOuter(gg)/= totalOobCount; 01177 } 01178 } 01179 breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount); 01180 oob_per_tree[index] = double(wrong_oob)/double(total_oob); 01181 // go through the ib samples; 01182 } 01183 01184 /** Normalise variable importance after the number of trees is known. 01185 */ 01186 template<class RF, class PR> 01187 void visit_at_end(RF & rf, PR & pr) 01188 { 01189 // ullis original metric and breiman style stuff 01190 oob_per_tree2 = 0; 01191 int totalOobCount =0; 01192 int breimanstyle = 0; 01193 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll) 01194 { 01195 if(oobCount[ll]) 01196 { 01197 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0)) 01198 ++breimanstyle; 01199 oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll]; 01200 ++totalOobCount; 01201 } 01202 } 01203 oob_per_tree2 /= totalOobCount; 01204 oob_breiman = double(breimanstyle)/totalOobCount; 01205 // mean error of each tree 01206 MultiArrayView<2, double> mean(Shp(1,1), &oob_mean); 01207 MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std); 01208 rowStatistics(oob_per_tree, mean, stdDev); 01209 } 01210 }; 01211 01212 /** calculate variable importance while learning. 01213 */ 01214 class VariableImportanceVisitor : public VisitorBase 01215 { 01216 public: 01217 01218 /** This Array has the same entries as the R - random forest variable 01219 * importance. 01220 * Matrix is featureCount by (classCount +2) 01221 * variable_importance_(ii,jj) is the variable importance measure of 01222 * the ii-th variable according to: 01223 * jj = 0 - (classCount-1) 01224 * classwise permutation importance 01225 * jj = rowCount(variable_importance_) -2 01226 * permutation importance 01227 * jj = rowCount(variable_importance_) -1 01228 * gini decrease importance. 01229 * 01230 * permutation importance: 01231 * The difference between the fraction of OOB samples classified correctly 01232 * before and after permuting (randomizing) the ii-th column is calculated. 01233 * The ii-th column is permuted rep_cnt times. 01234 * 01235 * class wise permutation importance: 01236 * same as permutation importance. We only look at those OOB samples whose 01237 * response corresponds to class jj. 01238 * 01239 * gini decrease importance: 01240 * row ii corresponds to the sum of all gini decreases induced by variable ii 01241 * in each node of the random forest. 01242 */ 01243 MultiArray<2, double> variable_importance_; 01244 int repetition_count_; 01245 bool in_place_; 01246 01247 #ifdef HasHDF5 01248 void save(std::string filename, std::string prefix) 01249 { 01250 prefix = "variable_importance_" + prefix; 01251 writeHDF5(filename.c_str(), 01252 prefix.c_str(), 01253 variable_importance_); 01254 } 01255 #endif 01256 /** Constructor 01257 * \param rep_cnt (defautl: 10) how often should 01258 * the permutation take place. Set to 1 to make calculation faster (but 01259 * possibly more instable) 01260 */ 01261 VariableImportanceVisitor(int rep_cnt = 10) 01262 : repetition_count_(rep_cnt) 01263 01264 {} 01265 01266 /** calculates impurity decrease based variable importance after every 01267 * split. 01268 */ 01269 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 01270 void visit_after_split( Tree & tree, 01271 Split & split, 01272 Region & parent, 01273 Region & leftChild, 01274 Region & rightChild, 01275 Feature_t & features, 01276 Label_t & labels) 01277 { 01278 //resize to right size when called the first time 01279 01280 Int32 const class_count = tree.ext_param_.class_count_; 01281 Int32 const column_count = tree.ext_param_.column_count_; 01282 if(variable_importance_.size() == 0) 01283 { 01284 01285 variable_importance_ 01286 .reshape(MultiArrayShape<2>::type(column_count, 01287 class_count+2)); 01288 } 01289 01290 if(split.createNode().typeID() == i_ThresholdNode) 01291 { 01292 Node<i_ThresholdNode> node(split.createNode()); 01293 variable_importance_(node.column(),class_count+1) 01294 += split.region_gini_ - split.minGini(); 01295 } 01296 } 01297 01298 /**compute permutation based var imp. 01299 * (Only an Array of size oob_sample_count x 1 is created. 01300 * - apposed to oob_sample_count x feature_count in the other method. 01301 * 01302 * \sa FieldProxy 01303 */ 01304 template<class RF, class PR, class SM, class ST> 01305 void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & st, int index) 01306 { 01307 typedef MultiArrayShape<2>::type Shp_t; 01308 Int32 column_count = rf.ext_param_.column_count_; 01309 Int32 class_count = rf.ext_param_.class_count_; 01310 01311 /* This solution saves memory uptake but not multithreading 01312 * compatible 01313 */ 01314 // remove the const cast on the features (yep , I know what I am 01315 // doing here.) data is not destroyed. 01316 //typename PR::Feature_t & features 01317 // = const_cast<typename PR::Feature_t &>(pr.features()); 01318 01319 typename PR::FeatureWithMemory_t features = pr.features(); 01320 01321 //find the oob indices of current tree. 01322 ArrayVector<Int32> oob_indices; 01323 ArrayVector<Int32>::iterator 01324 iter; 01325 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii) 01326 if(!sm.is_used()[ii]) 01327 oob_indices.push_back(ii); 01328 01329 //create space to back up a column 01330 std::vector<double> backup_column; 01331 01332 // Random foo 01333 #ifdef CLASSIFIER_TEST 01334 RandomMT19937 random(1); 01335 #else 01336 RandomMT19937 random(RandomSeed); 01337 #endif 01338 UniformIntRandomFunctor<RandomMT19937> 01339 randint(random); 01340 01341 01342 //make some space for the results 01343 MultiArray<2, double> 01344 oob_right(Shp_t(1, class_count + 1)); 01345 MultiArray<2, double> 01346 perm_oob_right (Shp_t(1, class_count + 1)); 01347 01348 01349 // get the oob success rate with the original samples 01350 for(iter = oob_indices.begin(); 01351 iter != oob_indices.end(); 01352 ++iter) 01353 { 01354 if(rf.tree(index) 01355 .predictLabel(rowVector(features, *iter)) 01356 == pr.response()(*iter, 0)) 01357 { 01358 //per class 01359 ++oob_right[pr.response()(*iter,0)]; 01360 //total 01361 ++oob_right[class_count]; 01362 } 01363 } 01364 //get the oob rate after permuting the ii'th dimension. 01365 for(int ii = 0; ii < column_count; ++ii) 01366 { 01367 perm_oob_right.init(0.0); 01368 //make backup of orinal column 01369 backup_column.clear(); 01370 for(iter = oob_indices.begin(); 01371 iter != oob_indices.end(); 01372 ++iter) 01373 { 01374 backup_column.push_back(features(*iter,ii)); 01375 } 01376 01377 //get the oob rate after permuting the ii'th dimension. 01378 for(int rr = 0; rr < repetition_count_; ++rr) 01379 { 01380 //permute dimension. 01381 int n = oob_indices.size(); 01382 for(int jj = 1; jj < n; ++jj) 01383 std::swap(features(oob_indices[jj], ii), 01384 features(oob_indices[randint(jj+1)], ii)); 01385 01386 //get the oob sucess rate after permuting 01387 for(iter = oob_indices.begin(); 01388 iter != oob_indices.end(); 01389 ++iter) 01390 { 01391 if(rf.tree(index) 01392 .predictLabel(rowVector(features, *iter)) 01393 == pr.response()(*iter, 0)) 01394 { 01395 //per class 01396 ++perm_oob_right[pr.response()(*iter, 0)]; 01397 //total 01398 ++perm_oob_right[class_count]; 01399 } 01400 } 01401 } 01402 01403 01404 //normalise and add to the variable_importance array. 01405 perm_oob_right /= repetition_count_; 01406 perm_oob_right -=oob_right; 01407 perm_oob_right *= -1; 01408 perm_oob_right /= oob_indices.size(); 01409 variable_importance_ 01410 .subarray(Shp_t(ii,0), 01411 Shp_t(ii+1,class_count+1)) += perm_oob_right; 01412 //copy back permuted dimension 01413 for(int jj = 0; jj < int(oob_indices.size()); ++jj) 01414 features(oob_indices[jj], ii) = backup_column[jj]; 01415 } 01416 } 01417 01418 /** calculate permutation based impurity after every tree has been 01419 * learned default behaviour is that this happens out of place. 01420 * If you have very big data sets and want to avoid copying of data 01421 * set the in_place_ flag to true. 01422 */ 01423 template<class RF, class PR, class SM, class ST> 01424 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 01425 { 01426 after_tree_ip_impl(rf, pr, sm, st, index); 01427 } 01428 01429 /** Normalise variable importance after the number of trees is known. 01430 */ 01431 template<class RF, class PR> 01432 void visit_at_end(RF & rf, PR & pr) 01433 { 01434 variable_importance_ /= rf.trees_.size(); 01435 } 01436 }; 01437 01438 /** Verbose output 01439 */ 01440 class RandomForestProgressVisitor : public VisitorBase { 01441 public: 01442 RandomForestProgressVisitor() : VisitorBase() {} 01443 01444 template<class RF, class PR, class SM, class ST> 01445 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index){ 01446 if(index != rf.options().tree_count_-1) { 01447 std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]" 01448 << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush; 01449 } 01450 else { 01451 std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl; 01452 } 01453 } 01454 01455 template<class RF, class PR> 01456 void visit_at_end(RF const & rf, PR const & pr) { 01457 std::string a = TOCS; 01458 std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a << std::endl; 01459 } 01460 01461 template<class RF, class PR> 01462 void visit_at_beginning(RF const & rf, PR const & pr) { 01463 TIC; 01464 std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl; 01465 } 01466 01467 private: 01468 USETICTOC; 01469 }; 01470 01471 01472 /** Computes Correlation/Similarity Matrix of features while learning 01473 * random forest. 01474 */ 01475 class CorrelationVisitor : public VisitorBase 01476 { 01477 public: 01478 /** gini_missc(ii, jj) describes how well variable jj can describe a partition 01479 * created on variable ii(when variable ii was chosen) 01480 */ 01481 MultiArray<2, double> gini_missc; 01482 MultiArray<2, int> tmp_labels; 01483 /** additional noise features. 01484 */ 01485 MultiArray<2, double> noise; 01486 MultiArray<2, double> noise_l; 01487 /** how well can a noise column describe a partition created on variable ii. 01488 */ 01489 MultiArray<2, double> corr_noise; 01490 MultiArray<2, double> corr_l; 01491 01492 /** Similarity Matrix 01493 * 01494 * (numberOfFeatures + 1) by (number Of Features + 1) Matrix 01495 * gini_missc 01496 * - row normalized by the number of times the column was chosen 01497 * - mean of corr_noise subtracted 01498 * - and symmetrised. 01499 * 01500 */ 01501 MultiArray<2, double> similarity; 01502 /** Distance Matrix 1-similarity 01503 */ 01504 MultiArray<2, double> distance; 01505 ArrayVector<int> tmp_cc; 01506 01507 /** How often was variable ii chosen 01508 */ 01509 ArrayVector<int> numChoices; 01510 typedef BestGiniOfColumn<GiniCriterion> ColumnDecisionFunctor; 01511 BestGiniOfColumn<GiniCriterion> bgfunc; 01512 void save(std::string file, std::string prefix) 01513 { 01514 /* 01515 std::string tmp; 01516 #define VAR_WRITE(NAME) \ 01517 tmp = #NAME;\ 01518 tmp += "_";\ 01519 tmp += prefix;\ 01520 vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME); 01521 VAR_WRITE(gini_missc); 01522 VAR_WRITE(corr_noise); 01523 VAR_WRITE(distance); 01524 VAR_WRITE(similarity); 01525 vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data())); 01526 #undef VAR_WRITE 01527 */ 01528 } 01529 template<class RF, class PR> 01530 void visit_at_beginning(RF const & rf, PR & pr) 01531 { 01532 typedef MultiArrayShape<2>::type Shp; 01533 int n = rf.ext_param_.column_count_; 01534 gini_missc.reshape(Shp(n +1,n+ 1)); 01535 corr_noise.reshape(Shp(n + 1, 10)); 01536 corr_l.reshape(Shp(n +1, 10)); 01537 01538 noise.reshape(Shp(pr.features().shape(0), 10)); 01539 noise_l.reshape(Shp(pr.features().shape(0), 10)); 01540 RandomMT19937 random(RandomSeed); 01541 for(int ii = 0; ii < noise.size(); ++ii) 01542 { 01543 noise[ii] = random.uniform53(); 01544 noise_l[ii] = random.uniform53() > 0.5; 01545 } 01546 bgfunc = ColumnDecisionFunctor( rf.ext_param_); 01547 tmp_labels.reshape(pr.response().shape()); 01548 tmp_cc.resize(2); 01549 numChoices.resize(n+1); 01550 // look at allaxes 01551 } 01552 template<class RF, class PR> 01553 void visit_at_end(RF const & rf, PR const & pr) 01554 { 01555 typedef MultiArrayShape<2>::type Shp; 01556 similarity.reshape(gini_missc.shape()); 01557 similarity = gini_missc;; 01558 MultiArray<2, double> mean_noise(Shp(corr_noise.shape(0), 1)); 01559 rowStatistics(corr_noise, mean_noise); 01560 mean_noise/= MultiArrayView<2, int>(mean_noise.shape(), numChoices.data()); 01561 int rC = similarity.shape(0); 01562 for(int jj = 0; jj < rC-1; ++jj) 01563 { 01564 rowVector(similarity, jj) /= numChoices[jj]; 01565 rowVector(similarity, jj) -= mean_noise(jj, 0); 01566 } 01567 for(int jj = 0; jj < rC; ++jj) 01568 { 01569 similarity(rC -1, jj) /= numChoices[jj]; 01570 } 01571 rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0); 01572 similarity = abs(similarity); 01573 FindMinMax<double> minmax; 01574 inspectMultiArray(srcMultiArrayRange(similarity), minmax); 01575 01576 for(int jj = 0; jj < rC; ++jj) 01577 similarity(jj, jj) = minmax.max; 01578 01579 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)) 01580 += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose(); 01581 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2; 01582 columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose(); 01583 for(int jj = 0; jj < rC; ++jj) 01584 similarity(jj, jj) = 0; 01585 01586 FindMinMax<double> minmax2; 01587 inspectMultiArray(srcMultiArrayRange(similarity), minmax2); 01588 for(int jj = 0; jj < rC; ++jj) 01589 similarity(jj, jj) = minmax2.max; 01590 distance.reshape(gini_missc.shape(), minmax2.max); 01591 distance -= similarity; 01592 } 01593 01594 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 01595 void visit_after_split( Tree & tree, 01596 Split & split, 01597 Region & parent, 01598 Region & leftChild, 01599 Region & rightChild, 01600 Feature_t & features, 01601 Label_t & labels) 01602 { 01603 if(split.createNode().typeID() == i_ThresholdNode) 01604 { 01605 double wgini; 01606 tmp_cc.init(0); 01607 for(int ii = 0; ii < parent.size(); ++ii) 01608 { 01609 tmp_labels[parent[ii]] 01610 = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold()); 01611 ++tmp_cc[tmp_labels[parent[ii]]]; 01612 } 01613 double region_gini = bgfunc.loss_of_region(tmp_labels, 01614 parent.begin(), 01615 parent.end(), 01616 tmp_cc); 01617 01618 int n = split.bestSplitColumn(); 01619 ++numChoices[n]; 01620 ++(*(numChoices.end()-1)); 01621 //this functor does all the work 01622 for(int k = 0; k < features.shape(1); ++k) 01623 { 01624 bgfunc(columnVector(features, k), 01625 0, 01626 tmp_labels, 01627 parent.begin(), parent.end(), 01628 tmp_cc); 01629 wgini = (region_gini - bgfunc.min_gini_); 01630 gini_missc(n, k) 01631 += wgini; 01632 } 01633 for(int k = 0; k < 10; ++k) 01634 { 01635 bgfunc(columnVector(noise, k), 01636 0, 01637 tmp_labels, 01638 parent.begin(), parent.end(), 01639 tmp_cc); 01640 wgini = (region_gini - bgfunc.min_gini_); 01641 corr_noise(n, k) 01642 += wgini; 01643 } 01644 01645 for(int k = 0; k < 10; ++k) 01646 { 01647 bgfunc(columnVector(noise_l, k), 01648 0, 01649 tmp_labels, 01650 parent.begin(), parent.end(), 01651 tmp_cc); 01652 wgini = (region_gini - bgfunc.min_gini_); 01653 corr_l(n, k) 01654 += wgini; 01655 } 01656 bgfunc(labels,0, tmp_labels, parent.begin(), parent.end(),tmp_cc); 01657 wgini = (region_gini - bgfunc.min_gini_); 01658 gini_missc(n, columnCount(gini_missc)-1) 01659 += wgini; 01660 01661 region_gini = split.region_gini_; 01662 #if 1 01663 Node<i_ThresholdNode> node(split.createNode()); 01664 gini_missc(rowCount(gini_missc)-1, 01665 node.column()) 01666 +=split.region_gini_ - split.minGini(); 01667 #endif 01668 for(int k = 0; k < 10; ++k) 01669 { 01670 split.bgfunc(columnVector(noise, k), 01671 0, 01672 labels, 01673 parent.begin(), parent.end(), 01674 parent.classCounts()); 01675 corr_noise(rowCount(gini_missc)-1, 01676 k) 01677 += wgini; 01678 } 01679 #if 0 01680 for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k) 01681 { 01682 wgini = region_gini - split.min_gini_[k]; 01683 01684 gini_missc(rowCount(gini_missc)-1, 01685 split.splitColumns[k]) 01686 += wgini; 01687 } 01688 01689 for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k) 01690 { 01691 split.bgfunc(columnVector(features, split.splitColumns[k]), 01692 labels, 01693 parent.begin(), parent.end(), 01694 parent.classCounts()); 01695 wgini = region_gini - split.bgfunc.min_gini_; 01696 gini_missc(rowCount(gini_missc)-1, 01697 split.splitColumns[k]) += wgini; 01698 } 01699 #endif 01700 // remember to partition the data according to the best. 01701 gini_missc(rowCount(gini_missc)-1, 01702 columnCount(gini_missc)-1) 01703 += region_gini; 01704 SortSamplesByDimensions<Feature_t> 01705 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold()); 01706 std::partition(parent.begin(), parent.end(), sorter); 01707 } 01708 } 01709 }; 01710 01711 01712 } // namespace visitors 01713 } // namespace rf 01714 } // namespace vigra 01715 01716 //@} 01717 #endif // RF_VISITORS_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|