[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_common.hxx | ![]() |
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 00037 #ifndef VIGRA_RF_COMMON_HXX 00038 #define VIGRA_RF_COMMON_HXX 00039 00040 namespace vigra 00041 { 00042 00043 00044 struct ClassificationTag 00045 {}; 00046 00047 struct RegressionTag 00048 {}; 00049 00050 namespace detail 00051 { 00052 class RF_DEFAULT; 00053 } 00054 inline detail::RF_DEFAULT& rf_default(); 00055 namespace detail 00056 { 00057 00058 /**\brief singleton default tag class - 00059 * 00060 * use the rf_default() factory function to use the tag. 00061 * \sa RandomForest<>::learn(); 00062 */ 00063 class RF_DEFAULT 00064 { 00065 private: 00066 RF_DEFAULT() 00067 {} 00068 public: 00069 friend RF_DEFAULT& ::vigra::rf_default(); 00070 00071 /** ok workaround for automatic choice of the decisiontree 00072 * stackentry. 00073 */ 00074 }; 00075 00076 /**\brief chooses between default type and type supplied 00077 * 00078 * This is an internal class and you shouldn't really care about it. 00079 * Just pass on used in RandomForest.learn() 00080 * Usage: 00081 *\code 00082 * // example: use container type supplied by user or ArrayVector if 00083 * // rf_default() was specified as argument; 00084 * template<class Container_t> 00085 * void do_some_foo(Container_t in) 00086 * { 00087 * typedef ArrayVector<int> Default_Container_t; 00088 * Default_Container_t default_value; 00089 * Value_Chooser<Container_t, Default_Container_t> 00090 * choose(in, default_value); 00091 * 00092 * // if the user didn't care and the in was of type 00093 * // RF_DEFAULT then default_value is used. 00094 * do_some_more_foo(choose.value()); 00095 * } 00096 * Value_Chooser choose_val<Type, Default_Type> 00097 *\endcode 00098 */ 00099 template<class T, class C> 00100 class Value_Chooser 00101 { 00102 public: 00103 typedef T type; 00104 static T & choose(T & t, C &) 00105 { 00106 return t; 00107 } 00108 }; 00109 00110 template<class C> 00111 class Value_Chooser<detail::RF_DEFAULT, C> 00112 { 00113 public: 00114 typedef C type; 00115 00116 static C & choose(detail::RF_DEFAULT &, C & c) 00117 { 00118 return c; 00119 } 00120 }; 00121 00122 00123 00124 00125 } //namespace detail 00126 00127 00128 /**\brief factory function to return a RF_DEFAULT tag 00129 * \sa RandomForest<>::learn() 00130 */ 00131 detail::RF_DEFAULT& rf_default() 00132 { 00133 static detail::RF_DEFAULT result; 00134 return result; 00135 } 00136 00137 /** tags used with the RandomForestOptions class 00138 * \sa RF_Traits::Option_t 00139 */ 00140 enum RF_OptionTag { RF_EQUAL, 00141 RF_PROPORTIONAL, 00142 RF_EXTERNAL, 00143 RF_NONE, 00144 RF_FUNCTION, 00145 RF_LOG, 00146 RF_SQRT, 00147 RF_CONST, 00148 RF_ALL}; 00149 00150 00151 /** \addtogroup MachineLearning 00152 **/ 00153 //@{ 00154 00155 /**\brief Options object for the random forest 00156 * 00157 * usage: 00158 * RandomForestOptions a = RandomForestOptions() 00159 * .param1(value1) 00160 * .param2(value2) 00161 * ... 00162 * 00163 * This class only contains options/parameters that are not problem 00164 * dependent. The ProblemSpec class contains methods to set class weights 00165 * if necessary. 00166 * 00167 * Note that the return value of all methods is *this which makes 00168 * concatenating of options as above possible. 00169 */ 00170 class RandomForestOptions 00171 { 00172 public: 00173 /**\name sampling options*/ 00174 /*\{*/ 00175 // look at the member access functions for documentation 00176 double training_set_proportion_; 00177 int training_set_size_; 00178 int (*training_set_func_)(int); 00179 RF_OptionTag 00180 training_set_calc_switch_; 00181 00182 bool sample_with_replacement_; 00183 RF_OptionTag 00184 stratification_method_; 00185 00186 00187 /**\name general random forest options 00188 * 00189 * these usually will be used by most split functors and 00190 * stopping predicates 00191 */ 00192 /*\{*/ 00193 RF_OptionTag mtry_switch_; 00194 int mtry_; 00195 int (*mtry_func_)(int) ; 00196 00197 bool predict_weighted_; 00198 int tree_count_; 00199 int min_split_node_size_; 00200 bool prepare_online_learning_; 00201 /*\}*/ 00202 00203 int serialized_size() const 00204 { 00205 return 12; 00206 } 00207 00208 00209 bool operator==(RandomForestOptions & rhs) const 00210 { 00211 bool result = true; 00212 #define COMPARE(field) result = result && (this->field == rhs.field); 00213 COMPARE(training_set_proportion_); 00214 COMPARE(training_set_size_); 00215 COMPARE(training_set_calc_switch_); 00216 COMPARE(sample_with_replacement_); 00217 COMPARE(stratification_method_); 00218 COMPARE(mtry_switch_); 00219 COMPARE(mtry_); 00220 COMPARE(tree_count_); 00221 COMPARE(min_split_node_size_); 00222 COMPARE(predict_weighted_); 00223 #undef COMPARE 00224 00225 return result; 00226 } 00227 bool operator!=(RandomForestOptions & rhs_) const 00228 { 00229 return !(*this == rhs_); 00230 } 00231 template<class Iter> 00232 void unserialize(Iter const & begin, Iter const & end) 00233 { 00234 Iter iter = begin; 00235 vigra_precondition(static_cast<int>(end - begin) == serialized_size(), 00236 "RandomForestOptions::unserialize():" 00237 "wrong number of parameters"); 00238 #define PULL(item_, type_) item_ = type_(*iter); ++iter; 00239 PULL(training_set_proportion_, double); 00240 PULL(training_set_size_, int); 00241 ++iter; //PULL(training_set_func_, double); 00242 PULL(training_set_calc_switch_, (RF_OptionTag)int); 00243 PULL(sample_with_replacement_, 0 != ); 00244 PULL(stratification_method_, (RF_OptionTag)int); 00245 PULL(mtry_switch_, (RF_OptionTag)int); 00246 PULL(mtry_, int); 00247 ++iter; //PULL(mtry_func_, double); 00248 PULL(tree_count_, int); 00249 PULL(min_split_node_size_, int); 00250 PULL(predict_weighted_, 0 !=); 00251 #undef PULL 00252 } 00253 template<class Iter> 00254 void serialize(Iter const & begin, Iter const & end) const 00255 { 00256 Iter iter = begin; 00257 vigra_precondition(static_cast<int>(end - begin) == serialized_size(), 00258 "RandomForestOptions::serialize():" 00259 "wrong number of parameters"); 00260 #define PUSH(item_) *iter = double(item_); ++iter; 00261 PUSH(training_set_proportion_); 00262 PUSH(training_set_size_); 00263 if(training_set_func_ != 0) 00264 { 00265 PUSH(1); 00266 } 00267 else 00268 { 00269 PUSH(0); 00270 } 00271 PUSH(training_set_calc_switch_); 00272 PUSH(sample_with_replacement_); 00273 PUSH(stratification_method_); 00274 PUSH(mtry_switch_); 00275 PUSH(mtry_); 00276 if(mtry_func_ != 0) 00277 { 00278 PUSH(1); 00279 } 00280 else 00281 { 00282 PUSH(0); 00283 } 00284 PUSH(tree_count_); 00285 PUSH(min_split_node_size_); 00286 PUSH(predict_weighted_); 00287 #undef PUSH 00288 } 00289 00290 void make_from_map(std::map<std::string, ArrayVector<double> > & in) 00291 { 00292 typedef MultiArrayShape<2>::type Shp; 00293 #define PULL(item_, type_) item_ = type_(in[#item_][0]); 00294 #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0); 00295 PULL(training_set_proportion_,double); 00296 PULL(training_set_size_, int); 00297 PULL(mtry_, int); 00298 PULL(tree_count_, int); 00299 PULL(min_split_node_size_, int); 00300 PULLBOOL(sample_with_replacement_, bool); 00301 PULLBOOL(prepare_online_learning_, bool); 00302 PULLBOOL(predict_weighted_, bool); 00303 00304 PULL(training_set_calc_switch_, (RF_OptionTag)int); 00305 PULL(stratification_method_, (RF_OptionTag)int); 00306 PULL(mtry_switch_, (RF_OptionTag)int); 00307 00308 /*don't pull*/ 00309 //PULL(mtry_func_!=0, int); 00310 //PULL(training_set_func,int); 00311 #undef PULL 00312 #undef PULLBOOL 00313 } 00314 void make_map(std::map<std::string, ArrayVector<double> > & in) const 00315 { 00316 typedef MultiArrayShape<2>::type Shp; 00317 #define PUSH(item_, type_) in[#item_] = ArrayVector<double>(1, double(item_)); 00318 #define PUSHFUNC(item_, type_) in[#item_] = ArrayVector<double>(1, double(item_!=0)); 00319 PUSH(training_set_proportion_,double); 00320 PUSH(training_set_size_, int); 00321 PUSH(mtry_, int); 00322 PUSH(tree_count_, int); 00323 PUSH(min_split_node_size_, int); 00324 PUSH(sample_with_replacement_, bool); 00325 PUSH(prepare_online_learning_, bool); 00326 PUSH(predict_weighted_, bool); 00327 00328 PUSH(training_set_calc_switch_, RF_OptionTag); 00329 PUSH(stratification_method_, RF_OptionTag); 00330 PUSH(mtry_switch_, RF_OptionTag); 00331 00332 PUSHFUNC(mtry_func_, int); 00333 PUSHFUNC(training_set_func_,int); 00334 #undef PUSH 00335 #undef PUSHFUNC 00336 } 00337 00338 00339 /**\brief create a RandomForestOptions object with default initialisation. 00340 * 00341 * look at the other member functions for more information on default 00342 * values 00343 */ 00344 RandomForestOptions() 00345 : 00346 training_set_proportion_(1.0), 00347 training_set_size_(0), 00348 training_set_func_(0), 00349 training_set_calc_switch_(RF_PROPORTIONAL), 00350 sample_with_replacement_(true), 00351 stratification_method_(RF_NONE), 00352 mtry_switch_(RF_SQRT), 00353 mtry_(0), 00354 mtry_func_(0), 00355 predict_weighted_(false), 00356 tree_count_(256), 00357 min_split_node_size_(1), 00358 prepare_online_learning_(false) 00359 {} 00360 00361 /**\brief specify stratification strategy 00362 * 00363 * default: RF_NONE 00364 * possible values: RF_EQUAL, RF_PROPORTIONAL, 00365 * RF_EXTERNAL, RF_NONE 00366 * RF_EQUAL: get equal amount of samples per class. 00367 * RF_PROPORTIONAL: sample proportional to fraction of class samples 00368 * in population 00369 * RF_EXTERNAL: strata_weights_ field of the ProblemSpec_t object 00370 * has been set externally. (defunct) 00371 */ 00372 RandomForestOptions & use_stratification(RF_OptionTag in) 00373 { 00374 vigra_precondition(in == RF_EQUAL || 00375 in == RF_PROPORTIONAL || 00376 in == RF_EXTERNAL || 00377 in == RF_NONE, 00378 "RandomForestOptions::use_stratification()" 00379 "input must be RF_EQUAL, RF_PROPORTIONAL," 00380 "RF_EXTERNAL or RF_NONE"); 00381 stratification_method_ = in; 00382 return *this; 00383 } 00384 00385 RandomForestOptions & prepare_online_learning(bool in) 00386 { 00387 prepare_online_learning_=in; 00388 return *this; 00389 } 00390 00391 /**\brief sample from training population with or without replacement? 00392 * 00393 * <br> Default: true 00394 */ 00395 RandomForestOptions & sample_with_replacement(bool in) 00396 { 00397 sample_with_replacement_ = in; 00398 return *this; 00399 } 00400 00401 /**\brief specify the fraction of the total number of samples 00402 * used per tree for learning. 00403 * 00404 * This value should be in [0.0 1.0] if sampling without 00405 * replacement has been specified. 00406 * 00407 * <br> default : 1.0 00408 */ 00409 RandomForestOptions & samples_per_tree(double in) 00410 { 00411 training_set_proportion_ = in; 00412 training_set_calc_switch_ = RF_PROPORTIONAL; 00413 return *this; 00414 } 00415 00416 /**\brief directly specify the number of samples per tree 00417 */ 00418 RandomForestOptions & samples_per_tree(int in) 00419 { 00420 training_set_size_ = in; 00421 training_set_calc_switch_ = RF_CONST; 00422 return *this; 00423 } 00424 00425 /**\brief use external function to calculate the number of samples each 00426 * tree should be learnt with. 00427 * 00428 * \param in function pointer that takes the number of rows in the 00429 * learning data and outputs the number samples per tree. 00430 */ 00431 RandomForestOptions & samples_per_tree(int (*in)(int)) 00432 { 00433 training_set_func_ = in; 00434 training_set_calc_switch_ = RF_FUNCTION; 00435 return *this; 00436 } 00437 00438 /**\brief weight each tree with number of samples in that node 00439 */ 00440 RandomForestOptions & predict_weighted() 00441 { 00442 predict_weighted_ = true; 00443 return *this; 00444 } 00445 00446 /**\brief use built in mapping to calculate mtry 00447 * 00448 * Use one of the built in mappings to calculate mtry from the number 00449 * of columns in the input feature data. 00450 * \param in possible values: RF_LOG, RF_SQRT or RF_ALL 00451 * <br> default: RF_SQRT. 00452 */ 00453 RandomForestOptions & features_per_node(RF_OptionTag in) 00454 { 00455 vigra_precondition(in == RF_LOG || 00456 in == RF_SQRT|| 00457 in == RF_ALL, 00458 "RandomForestOptions()::features_per_node():" 00459 "input must be of type RF_LOG or RF_SQRT"); 00460 mtry_switch_ = in; 00461 return *this; 00462 } 00463 00464 /**\brief Set mtry to a constant value 00465 * 00466 * mtry is the number of columns/variates/variables randomly choosen 00467 * to select the best split from. 00468 * 00469 */ 00470 RandomForestOptions & features_per_node(int in) 00471 { 00472 mtry_ = in; 00473 mtry_switch_ = RF_CONST; 00474 return *this; 00475 } 00476 00477 /**\brief use a external function to calculate mtry 00478 * 00479 * \param in function pointer that takes int (number of columns 00480 * of the and outputs int (mtry) 00481 */ 00482 RandomForestOptions & features_per_node(int(*in)(int)) 00483 { 00484 mtry_func_ = in; 00485 mtry_switch_ = RF_FUNCTION; 00486 return *this; 00487 } 00488 00489 /** How many trees to create? 00490 * 00491 * <br> Default: 255. 00492 */ 00493 RandomForestOptions & tree_count(int in) 00494 { 00495 tree_count_ = in; 00496 return *this; 00497 } 00498 00499 /**\brief Number of examples required for a node to be split. 00500 * 00501 * When the number of examples in a node is below this number, 00502 * the node is not split even if class separation is not yet perfect. 00503 * Instead, the node returns the proportion of each class 00504 * (among the remaining examples) during the prediction phase. 00505 * <br> Default: 1 (complete growing) 00506 */ 00507 RandomForestOptions & min_split_node_size(int in) 00508 { 00509 min_split_node_size_ = in; 00510 return *this; 00511 } 00512 }; 00513 00514 00515 /** \brief problem types 00516 */ 00517 enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER}; 00518 00519 00520 /** \brief problem specification class for the random forest. 00521 * 00522 * This class contains all the problem specific parameters the random 00523 * forest needs for learning. Specification of an instance of this class 00524 * is optional as all necessary fields will be computed prior to learning 00525 * if not specified. 00526 * 00527 * if needed usage is similar to that of RandomForestOptions 00528 */ 00529 00530 template<class LabelType = double> 00531 class ProblemSpec 00532 { 00533 00534 00535 public: 00536 00537 /** \brief problem class 00538 */ 00539 00540 typedef LabelType Label_t; 00541 ArrayVector<Label_t> classes; 00542 00543 int column_count_; // number of features 00544 int class_count_; // number of classes 00545 int row_count_; // number of samples 00546 00547 int actual_mtry_; // mtry used in training 00548 int actual_msample_; // number if in-bag samples per tree 00549 00550 Problem_t problem_type_; // classification or regression 00551 00552 int used_; // this ProblemSpec is valid 00553 ArrayVector<double> class_weights_; // if classes have different importance 00554 int is_weighted_; // class_weights_ are used 00555 double precision_; // termination criterion for regression loss 00556 00557 00558 template<class T> 00559 void to_classlabel(int index, T & out) const 00560 { 00561 out = T(classes[index]); 00562 } 00563 template<class T> 00564 int to_classIndex(T index) const 00565 { 00566 return std::find(classes.begin(), classes.end(), index) - classes.begin(); 00567 } 00568 00569 #define EQUALS(field) field(rhs.field) 00570 ProblemSpec(ProblemSpec const & rhs) 00571 : 00572 EQUALS(column_count_), 00573 EQUALS(class_count_), 00574 EQUALS(row_count_), 00575 EQUALS(actual_mtry_), 00576 EQUALS(actual_msample_), 00577 EQUALS(problem_type_), 00578 EQUALS(used_), 00579 EQUALS(class_weights_), 00580 EQUALS(is_weighted_), 00581 EQUALS(precision_) 00582 { 00583 std::back_insert_iterator<ArrayVector<Label_t> > 00584 iter(classes); 00585 std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 00586 } 00587 #undef EQUALS 00588 #define EQUALS(field) field(rhs.field) 00589 template<class T> 00590 ProblemSpec(ProblemSpec<T> const & rhs) 00591 : 00592 EQUALS(column_count_), 00593 EQUALS(class_count_), 00594 EQUALS(row_count_), 00595 EQUALS(actual_mtry_), 00596 EQUALS(actual_msample_), 00597 EQUALS(problem_type_), 00598 EQUALS(used_), 00599 EQUALS(class_weights_), 00600 EQUALS(is_weighted_), 00601 EQUALS(precision_) 00602 { 00603 std::back_insert_iterator<ArrayVector<Label_t> > 00604 iter(classes); 00605 std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 00606 } 00607 #undef EQUALS 00608 00609 // for some reason the function below does not match 00610 // the default copy constructor 00611 #define EQUALS(field) (this->field = rhs.field); 00612 ProblemSpec & operator=(ProblemSpec const & rhs) 00613 { 00614 EQUALS(column_count_); 00615 EQUALS(class_count_); 00616 EQUALS(row_count_); 00617 EQUALS(actual_mtry_); 00618 EQUALS(actual_msample_); 00619 EQUALS(problem_type_); 00620 EQUALS(used_); 00621 EQUALS(is_weighted_); 00622 EQUALS(precision_); 00623 class_weights_.clear(); 00624 std::back_insert_iterator<ArrayVector<double> > 00625 iter2(class_weights_); 00626 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2); 00627 classes.clear(); 00628 std::back_insert_iterator<ArrayVector<Label_t> > 00629 iter(classes); 00630 std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 00631 return *this; 00632 } 00633 00634 template<class T> 00635 ProblemSpec<Label_t> & operator=(ProblemSpec<T> const & rhs) 00636 { 00637 EQUALS(column_count_); 00638 EQUALS(class_count_); 00639 EQUALS(row_count_); 00640 EQUALS(actual_mtry_); 00641 EQUALS(actual_msample_); 00642 EQUALS(problem_type_); 00643 EQUALS(used_); 00644 EQUALS(is_weighted_); 00645 EQUALS(precision_); 00646 class_weights_.clear(); 00647 std::back_insert_iterator<ArrayVector<double> > 00648 iter2(class_weights_); 00649 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2); 00650 classes.clear(); 00651 std::back_insert_iterator<ArrayVector<Label_t> > 00652 iter(classes); 00653 std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 00654 return *this; 00655 } 00656 #undef EQUALS 00657 00658 template<class T> 00659 bool operator==(ProblemSpec<T> const & rhs) 00660 { 00661 bool result = true; 00662 #define COMPARE(field) result = result && (this->field == rhs.field); 00663 COMPARE(column_count_); 00664 COMPARE(class_count_); 00665 COMPARE(row_count_); 00666 COMPARE(actual_mtry_); 00667 COMPARE(actual_msample_); 00668 COMPARE(problem_type_); 00669 COMPARE(is_weighted_); 00670 COMPARE(precision_); 00671 COMPARE(used_); 00672 COMPARE(class_weights_); 00673 COMPARE(classes); 00674 #undef COMPARE 00675 return result; 00676 } 00677 00678 bool operator!=(ProblemSpec & rhs) 00679 { 00680 return !(*this == rhs); 00681 } 00682 00683 00684 size_t serialized_size() const 00685 { 00686 return 9 + class_count_ *int(is_weighted_+1); 00687 } 00688 00689 00690 template<class Iter> 00691 void unserialize(Iter const & begin, Iter const & end) 00692 { 00693 Iter iter = begin; 00694 vigra_precondition(end - begin >= 9, 00695 "ProblemSpec::unserialize():" 00696 "wrong number of parameters"); 00697 #define PULL(item_, type_) item_ = type_(*iter); ++iter; 00698 PULL(column_count_,int); 00699 PULL(class_count_, int); 00700 00701 vigra_precondition(end - begin >= 9 + class_count_, 00702 "ProblemSpec::unserialize(): 1"); 00703 PULL(row_count_, int); 00704 PULL(actual_mtry_,int); 00705 PULL(actual_msample_, int); 00706 PULL(problem_type_, Problem_t); 00707 PULL(is_weighted_, int); 00708 PULL(used_, int); 00709 PULL(precision_, double); 00710 if(is_weighted_) 00711 { 00712 vigra_precondition(end - begin == 9 + 2*class_count_, 00713 "ProblemSpec::unserialize(): 2"); 00714 class_weights_.insert(class_weights_.end(), 00715 iter, 00716 iter + class_count_); 00717 iter += class_count_; 00718 } 00719 classes.insert(classes.end(), iter, end); 00720 #undef PULL 00721 } 00722 00723 00724 template<class Iter> 00725 void serialize(Iter const & begin, Iter const & end) const 00726 { 00727 Iter iter = begin; 00728 vigra_precondition(end - begin == serialized_size(), 00729 "RandomForestOptions::serialize():" 00730 "wrong number of parameters"); 00731 #define PUSH(item_) *iter = double(item_); ++iter; 00732 PUSH(column_count_); 00733 PUSH(class_count_) 00734 PUSH(row_count_); 00735 PUSH(actual_mtry_); 00736 PUSH(actual_msample_); 00737 PUSH(problem_type_); 00738 PUSH(is_weighted_); 00739 PUSH(used_); 00740 PUSH(precision_); 00741 if(is_weighted_) 00742 { 00743 std::copy(class_weights_.begin(), 00744 class_weights_.end(), 00745 iter); 00746 iter += class_count_; 00747 } 00748 std::copy(classes.begin(), 00749 classes.end(), 00750 iter); 00751 #undef PUSH 00752 } 00753 00754 void make_from_map(std::map<std::string, ArrayVector<double> > & in) 00755 { 00756 typedef MultiArrayShape<2>::type Shp; 00757 #define PULL(item_, type_) item_ = type_(in[#item_][0]); 00758 PULL(column_count_,int); 00759 PULL(class_count_, int); 00760 PULL(row_count_, int); 00761 PULL(actual_mtry_,int); 00762 PULL(actual_msample_, int); 00763 PULL(problem_type_, (Problem_t)int); 00764 PULL(is_weighted_, int); 00765 PULL(used_, int); 00766 PULL(precision_, double); 00767 class_weights_ = in["class_weights_"]; 00768 #undef PUSH 00769 } 00770 void make_map(std::map<std::string, ArrayVector<double> > & in) const 00771 { 00772 typedef MultiArrayShape<2>::type Shp; 00773 #define PUSH(item_) in[#item_] = ArrayVector<double>(1, double(item_)); 00774 PUSH(column_count_); 00775 PUSH(class_count_) 00776 PUSH(row_count_); 00777 PUSH(actual_mtry_); 00778 PUSH(actual_msample_); 00779 PUSH(problem_type_); 00780 PUSH(is_weighted_); 00781 PUSH(used_); 00782 PUSH(precision_); 00783 in["class_weights_"] = class_weights_; 00784 #undef PUSH 00785 } 00786 00787 /**\brief set default values (-> values not set) 00788 */ 00789 ProblemSpec() 00790 : column_count_(0), 00791 class_count_(0), 00792 row_count_(0), 00793 actual_mtry_(0), 00794 actual_msample_(0), 00795 problem_type_(CHECKLATER), 00796 used_(false), 00797 is_weighted_(false), 00798 precision_(0.0) 00799 {} 00800 00801 00802 ProblemSpec & column_count(int in) 00803 { 00804 column_count_ = in; 00805 return *this; 00806 } 00807 00808 /**\brief supply with class labels - 00809 * 00810 * the preprocessor will not calculate the labels needed in this case. 00811 */ 00812 template<class C_Iter> 00813 ProblemSpec & classes_(C_Iter begin, C_Iter end) 00814 { 00815 int size = end-begin; 00816 for(int k=0; k<size; ++k, ++begin) 00817 classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin)); 00818 class_count_ = size; 00819 return *this; 00820 } 00821 00822 /** \brief supply with class weights - 00823 * 00824 * this is the only case where you would really have to 00825 * create a ProblemSpec object. 00826 */ 00827 template<class W_Iter> 00828 ProblemSpec & class_weights(W_Iter begin, W_Iter end) 00829 { 00830 class_weights_.insert(class_weights_.end(), begin, end); 00831 is_weighted_ = true; 00832 return *this; 00833 } 00834 00835 00836 00837 void clear() 00838 { 00839 used_ = false; 00840 classes.clear(); 00841 class_weights_.clear(); 00842 column_count_ = 0 ; 00843 class_count_ = 0; 00844 actual_mtry_ = 0; 00845 actual_msample_ = 0; 00846 problem_type_ = CHECKLATER; 00847 is_weighted_ = false; 00848 precision_ = 0.0; 00849 00850 } 00851 00852 bool used() const 00853 { 00854 return used_ != 0; 00855 } 00856 }; 00857 00858 00859 //@} 00860 00861 00862 00863 /**\brief Standard early stopping criterion 00864 * 00865 * Stop if region.size() < min_split_node_size_; 00866 */ 00867 class EarlyStoppStd 00868 { 00869 public: 00870 int min_split_node_size_; 00871 00872 template<class Opt> 00873 EarlyStoppStd(Opt opt) 00874 : min_split_node_size_(opt.min_split_node_size_) 00875 {} 00876 00877 template<class T> 00878 void set_external_parameters(ProblemSpec<T>const &, int /* tree_count */ = 0, bool /* is_weighted_ */ = false) 00879 {} 00880 00881 template<class Region> 00882 bool operator()(Region& region) 00883 { 00884 return region.size() < min_split_node_size_; 00885 } 00886 00887 template<class WeightIter, class T, class C> 00888 bool after_prediction(WeightIter, int /* k */, MultiArrayView<2, T, C> /* prob */, double /* totalCt */) 00889 { 00890 return false; 00891 } 00892 }; 00893 00894 00895 } // namespace vigra 00896 00897 #endif //VIGRA_RF_COMMON_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|