[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_earlystopping.hxx | ![]() |
00001 #ifndef RF_EARLY_STOPPING_P_HXX 00002 #define RF_EARLY_STOPPING_P_HXX 00003 #include <cmath> 00004 #include "rf_common.hxx" 00005 00006 namespace vigra 00007 { 00008 00009 #if 0 00010 namespace es_detail 00011 { 00012 template<class T> 00013 T power(T const & in, int n) 00014 { 00015 T result = NumericTraits<T>::one(); 00016 for(int ii = 0; ii < n ;++ii) 00017 result *= in; 00018 return result; 00019 } 00020 } 00021 #endif 00022 00023 /**Base class from which all EarlyStopping Functors derive. 00024 */ 00025 class StopBase 00026 { 00027 protected: 00028 ProblemSpec<> ext_param_; 00029 int tree_count_ ; 00030 bool is_weighted_; 00031 00032 public: 00033 template<class T> 00034 void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false) 00035 { 00036 ext_param_ = prob; 00037 is_weighted_ = is_weighted; 00038 tree_count_ = tree_count; 00039 } 00040 00041 /** called after the prediction of a tree was added to the total prediction 00042 * \param WeightIter Iterator to the weights delivered by current tree. 00043 * \param k after kth tree 00044 * \param prob Total probability array 00045 * \param totalCt sum of probability array. 00046 */ 00047 template<class WeightIter, class T, class C> 00048 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */) 00049 {return false;} 00050 }; 00051 00052 00053 /**Stop predicting after a set number of trees 00054 */ 00055 class StopAfterTree : public StopBase 00056 { 00057 public: 00058 double max_tree_p; 00059 int max_tree_; 00060 typedef StopBase SB; 00061 00062 ArrayVector<double> depths; 00063 00064 /** Constructor 00065 * \param max_tree number of trees to be used for prediction 00066 */ 00067 StopAfterTree(double max_tree) 00068 : 00069 max_tree_p(max_tree) 00070 {} 00071 00072 template<class T> 00073 void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false) 00074 { 00075 max_tree_ = ceil(max_tree_p * tree_count); 00076 SB::set_external_parameters(prob, tree_count, is_weighted); 00077 } 00078 00079 template<class WeightIter, class T, class C> 00080 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */) 00081 { 00082 if(k == SB::tree_count_ -1) 00083 { 00084 depths.push_back(double(k+1)/double(SB::tree_count_)); 00085 return false; 00086 } 00087 if(k < max_tree_) 00088 return false; 00089 depths.push_back(double(k+1)/double(SB::tree_count_)); 00090 return true; 00091 } 00092 }; 00093 00094 /** Stop predicting after a certain amount of votes exceed certain proportion. 00095 * case unweighted voting: stop if the leading class exceeds proportion * SB::tree_count_ 00096 * case weighted votion: stop if the leading class exceeds proportion * msample_ * SB::tree_count_ ; 00097 * (maximal number of votes possible in both cases) 00098 */ 00099 class StopAfterVoteCount : public StopBase 00100 { 00101 public: 00102 double proportion_; 00103 typedef StopBase SB; 00104 ArrayVector<double> depths; 00105 00106 /** Constructor 00107 * \param proportion specify proportion to be used. 00108 */ 00109 StopAfterVoteCount(double proportion) 00110 : 00111 proportion_(proportion) 00112 {} 00113 00114 template<class WeightIter, class T, class C> 00115 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & prob, double /* totalCt */) 00116 { 00117 if(k == SB::tree_count_ -1) 00118 { 00119 depths.push_back(double(k+1)/double(SB::tree_count_)); 00120 return false; 00121 } 00122 00123 00124 if(SB::is_weighted_) 00125 { 00126 if(prob[argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_) 00127 { 00128 depths.push_back(double(k+1)/double(SB::tree_count_)); 00129 return true; 00130 } 00131 } 00132 else 00133 { 00134 if(prob[argMax(prob)] > proportion_ * SB::tree_count_) 00135 { 00136 depths.push_back(double(k+1)/double(SB::tree_count_)); 00137 return true; 00138 } 00139 } 00140 return false; 00141 } 00142 00143 }; 00144 00145 00146 /** Stop predicting if the 2norm of the probabilities does not change*/ 00147 class StopIfConverging : public StopBase 00148 00149 { 00150 public: 00151 double thresh_; 00152 int num_; 00153 MultiArray<2, double> last_; 00154 MultiArray<2, double> cur_; 00155 ArrayVector<double> depths; 00156 typedef StopBase SB; 00157 00158 /** Constructor 00159 * \param thresh: If the two norm of the probabilites changes less then thresh then stop 00160 * \param num : look at atleast num trees before stopping 00161 */ 00162 StopIfConverging(double thresh, int num = 10) 00163 : 00164 thresh_(thresh), 00165 num_(num) 00166 {} 00167 00168 template<class T> 00169 void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false) 00170 { 00171 last_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0); 00172 cur_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0); 00173 SB::set_external_parameters(prob, tree_count, is_weighted); 00174 } 00175 template<class WeightIter, class T, class C> 00176 bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> const & prob, double totalCt) 00177 { 00178 if(k == SB::tree_count_ -1) 00179 { 00180 depths.push_back(double(k+1)/double(SB::tree_count_)); 00181 return false; 00182 } 00183 if(k <= num_) 00184 { 00185 last_ = prob; 00186 last_/= last_.norm(1); 00187 return false; 00188 } 00189 else 00190 { 00191 cur_ = prob; 00192 cur_ /= cur_.norm(1); 00193 last_ -= cur_; 00194 double nrm = last_.norm(); 00195 if(nrm < thresh_) 00196 { 00197 depths.push_back(double(k+1)/double(SB::tree_count_)); 00198 return true; 00199 } 00200 else 00201 { 00202 last_ = cur_; 00203 } 00204 } 00205 return false; 00206 } 00207 }; 00208 00209 00210 /** Stop predicting if the margin prob(leading class) - prob(second class) exceeds a proportion 00211 * case unweighted voting: stop if margin exceeds proportion * SB::tree_count_ 00212 * case weighted votion: stop if margin exceeds proportion * msample_ * SB::tree_count_ ; 00213 * (maximal number of votes possible in both cases) 00214 */ 00215 class StopIfMargin : public StopBase 00216 { 00217 public: 00218 double proportion_; 00219 typedef StopBase SB; 00220 ArrayVector<double> depths; 00221 00222 /** Constructor 00223 * \param proportion specify proportion to be used. 00224 */ 00225 StopIfMargin(double proportion) 00226 : 00227 proportion_(proportion) 00228 {} 00229 00230 template<class WeightIter, class T, class C> 00231 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> prob, double /* totalCt */) 00232 { 00233 if(k == SB::tree_count_ -1) 00234 { 00235 depths.push_back(double(k+1)/double(SB::tree_count_)); 00236 return false; 00237 } 00238 int index = argMax(prob); 00239 double a = prob[argMax(prob)]; 00240 prob[argMax(prob)] = 0; 00241 double b = prob[argMax(prob)]; 00242 prob[index] = a; 00243 double margin = a - b; 00244 if(SB::is_weighted_) 00245 { 00246 if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_) 00247 { 00248 depths.push_back(double(k+1)/double(SB::tree_count_)); 00249 return true; 00250 } 00251 } 00252 else 00253 { 00254 if(prob[argMax(prob)] > proportion_ * SB::tree_count_) 00255 { 00256 depths.push_back(double(k+1)/double(SB::tree_count_)); 00257 return true; 00258 } 00259 } 00260 return false; 00261 } 00262 }; 00263 00264 00265 /**Probabilistic Stopping criterion (binomial test) 00266 * 00267 * Can only be used in a two class setting 00268 * 00269 * Stop if the Parameters estimated for the underlying binomial distribution 00270 * can be estimated with certainty over 1-alpha. 00271 * (Thesis, Rahul Nair Page 80 onwards: called the "binomial" criterion 00272 */ 00273 class StopIfBinTest : public StopBase 00274 { 00275 public: 00276 double alpha_; 00277 MultiArrayView<2, double> n_choose_k; 00278 /** Constructor 00279 * \param proportion specify alpha value for binomial test. 00280 * \param nck_ Matrix with precomputed values for n choose k 00281 * nck_(n, k) is n choose k. 00282 */ 00283 StopIfBinTest(double alpha, MultiArrayView<2, double> nck_) 00284 : 00285 alpha_(alpha), 00286 n_choose_k(nck_) 00287 {} 00288 typedef StopBase SB; 00289 00290 /**ArrayVector that will contain the fraction of trees that was visited before terminating 00291 */ 00292 ArrayVector<double> depths; 00293 00294 double binomial(int N, int k, double p) 00295 { 00296 // return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k); 00297 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k); 00298 } 00299 00300 template<class WeightIter, class T, class C> 00301 bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> prob, double totalCt) 00302 { 00303 if(k == SB::tree_count_ -1) 00304 { 00305 depths.push_back(double(k+1)/double(SB::tree_count_)); 00306 return false; 00307 } 00308 if(k < 10) 00309 { 00310 return false; 00311 } 00312 int index = argMax(prob); 00313 int n_a = prob[index]; 00314 int n_b = prob[(index+1)%2]; 00315 int n_tilde = (SB::tree_count_ - n_a + n_b); 00316 double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde); 00317 vigra_precondition(p_a <= 1, "probability should be smaller than 1"); 00318 double cum_val = 0; 00319 int c = 0; 00320 // std::cerr << "prob: " << p_a << std::endl; 00321 if(n_a <= 0)n_a = 0; 00322 if(n_b <= 0)n_b = 0; 00323 for(int ii = 0; ii <= n_b + n_a;++ii) 00324 { 00325 // std::cerr << "nb +ba " << n_b + n_a << " " << ii <<std::endl; 00326 cum_val += binomial(n_b + n_a, ii, p_a); 00327 if(cum_val >= 1 -alpha_) 00328 { 00329 c = ii; 00330 break; 00331 } 00332 } 00333 // std::cerr << c << " " << n_a << " " << n_b << " " << p_a << alpha_ << std::endl; 00334 if(c < n_a) 00335 { 00336 depths.push_back(double(k+1)/double(SB::tree_count_)); 00337 return true; 00338 } 00339 00340 return false; 00341 } 00342 }; 00343 00344 /**Probabilistic Stopping criteria. (toChange) 00345 * 00346 * Can only be used in a two class setting 00347 * 00348 * Stop if the probability that the decision will change after seeing all trees falls under 00349 * a specified value alpha. 00350 * (Thesis, Rahul Nair Page 80 onwards: called the "toChange" criterion 00351 */ 00352 class StopIfProb : public StopBase 00353 { 00354 public: 00355 double alpha_; 00356 MultiArrayView<2, double> n_choose_k; 00357 00358 00359 /** Constructor 00360 * \param proportion specify alpha value 00361 * \param nck_ Matrix with precomputed values for n choose k 00362 * nck_(n, k) is n choose k. 00363 */ 00364 StopIfProb(double alpha, MultiArrayView<2, double> nck_) 00365 : 00366 alpha_(alpha), 00367 n_choose_k(nck_) 00368 {} 00369 typedef StopBase SB; 00370 /**ArrayVector that will contain the fraction of trees that was visited before terminating 00371 */ 00372 ArrayVector<double> depths; 00373 00374 double binomial(int N, int k, double p) 00375 { 00376 // return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k); 00377 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k); 00378 } 00379 00380 template<class WeightIter, class T, class C> 00381 bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> prob, double totalCt) 00382 { 00383 if(k == SB::tree_count_ -1) 00384 { 00385 depths.push_back(double(k+1)/double(SB::tree_count_)); 00386 return false; 00387 } 00388 if(k <= 10) 00389 { 00390 return false; 00391 } 00392 int index = argMax(prob); 00393 int n_a = prob[index]; 00394 int n_b = prob[(index+1)%2]; 00395 int n_needed = ceil(double(SB::tree_count_)/2.0)-n_a; 00396 int n_tilde = SB::tree_count_ - (n_a +n_b); 00397 if(n_tilde <= 0) n_tilde = 0; 00398 if(n_needed <= 0) n_needed = 0; 00399 double p = 0; 00400 for(int ii = n_needed; ii < n_tilde; ++ii) 00401 p += binomial(n_tilde, ii, 0.5); 00402 00403 if(p >= 1-alpha_) 00404 { 00405 depths.push_back(double(k+1)/double(SB::tree_count_)); 00406 return true; 00407 } 00408 00409 return false; 00410 } 00411 }; 00412 } //namespace vigra; 00413 #endif //RF_EARLY_STOPPING_P_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|