35 #ifndef VIGRA_RF3_COMMON_HXX
36 #define VIGRA_RF3_COMMON_HXX
39 #include <type_traits>
43 #include "../multi_array.hxx"
44 #include "../mathutil.hxx"
57 struct LessEqualSplitTest
60 LessEqualSplitTest(
size_t dim = 0, T
const & val = 0)
66 template<
typename FEATURES>
67 size_t operator()(FEATURES
const & features)
const
69 return features(dim_) <= val_ ? 0 : 1;
81 typedef size_t input_type;
83 template <
typename ITER,
typename OUTITER>
84 void operator()(ITER begin, ITER end, OUTITER out)
86 std::fill(buffer_.begin(), buffer_.end(), 0);
89 for (ITER it = begin; it != end; ++it)
92 if (v >= buffer_.size())
94 buffer_.resize(v+1, 0);
98 max_v = std::max(max_v, v);
100 for (
size_t i = 0; i <= max_v; ++i)
102 *out = buffer_[i] /
static_cast<double>(n);
107 std::vector<size_t> buffer_;
112 template <
typename VALUETYPE>
113 struct ArgMaxVectorAcc
116 typedef VALUETYPE value_type;
117 typedef std::vector<value_type> input_type;
118 template <
typename ITER,
typename OUTITER>
119 void operator()(ITER begin, ITER end, OUTITER out)
121 std::fill(buffer_.begin(), buffer_.end(), 0);
123 for (ITER it = begin; it != end; ++it)
125 input_type
const & vec = *it;
126 if (vec.size() >= buffer_.size())
128 buffer_.resize(vec.size(), 0);
130 value_type
const n = std::accumulate(vec.begin(), vec.end(),
static_cast<value_type
>(0));
131 for (
size_t i = 0; i < vec.size(); ++i)
133 buffer_[i] += vec[i] /
static_cast<double>(n);
135 max_v = std::max(vec.size()-1, max_v);
137 for (
size_t i = 0; i <= max_v; ++i)
144 std::vector<double> buffer_;
216 template <
typename FUNCTOR>
221 typedef FUNCTOR Functor;
228 best_score_(std::numeric_limits<double>::max()),
230 n_total_(std::accumulate(priors.begin(), priors.end(), 0.0))
233 template <
typename FEATURES,
typename LABELS,
typename WEIGHTS,
typename ITER>
235 FEATURES
const & features,
236 LABELS
const & labels,
237 WEIGHTS
const & weights,
247 std::vector<double> counts(priors_.size(), 0.0);
251 for (; next != end; ++begin, ++next)
254 size_t const left_index = *begin;
255 size_t const right_index = *next;
256 size_t const label =
static_cast<size_t>(labels(left_index));
257 counts[label] += weights[left_index];
258 n_left += weights[left_index];
261 auto const left = features(left_index, dim);
262 auto const right = features(right_index, dim);
268 double const s = score(priors_, counts, n_total_, n_left);
269 bool const better_score = s < best_score_;
273 best_split_ = 0.5*(left+right);
286 std::vector<double>
const priors_;
287 double const n_total_;
299 double operator()(std::vector<double>
const & priors,
300 std::vector<double>
const & counts,
double n_total,
double n_left)
const
302 double const n_right = n_total - n_left;
303 double gini_left = 1.0;
304 double gini_right = 1.0;
305 for (
size_t i = 0; i < counts.size(); ++i)
307 double const p_left = counts[i] / n_left;
308 double const p_right = (priors[i] - counts[i]) / n_right;
309 gini_left -= (p_left*p_left);
310 gini_right -= (p_right*p_right);
312 return n_left*gini_left + n_right*gini_right;
316 template <
typename LABELS,
typename WEIGHTS,
typename ITER>
317 static double region_score(LABELS
const & labels, WEIGHTS
const & weights, ITER begin, ITER end)
320 std::vector<double> counts;
322 for (
auto it = begin; it != end; ++it)
325 auto const lbl = labels[d];
326 if (counts.size() <= lbl)
328 counts.resize(lbl+1, 0.0);
330 counts[lbl] += weights[d];
336 for (
auto x : counts)
351 double operator()(std::vector<double>
const & priors, std::vector<double>
const & counts,
double n_total,
double n_left)
const
353 double const n_right = n_total - n_left;
355 for (
size_t i = 0; i < counts.size(); ++i)
357 double c = counts[i];
368 template <
typename LABELS,
typename WEIGHTS,
typename ITER>
369 double region_score(LABELS
const & , WEIGHTS
const & , ITER , ITER )
const
371 vigra_fail(
"EntropyScore::region_score(): Not implemented yet.");
385 double operator()(std::vector<double>
const & priors, std::vector<double>
const & counts,
double ,
double )
const
387 double const eps = 1e-10;
389 std::vector<double> norm_counts(counts.size(), 0.0);
390 for (
size_t i = 0; i < counts.size(); ++i)
394 norm_counts[i] = counts[i] / priors[i];
403 double const mean = std::accumulate(norm_counts.begin(), norm_counts.end(), 0.0) / nnz;
407 for (
size_t i = 0; i < norm_counts.size(); ++i)
411 double const v = (mean-norm_counts[i]);
418 template <
typename LABELS,
typename WEIGHTS,
typename ITER>
419 double region_score(LABELS
const & , WEIGHTS
const & , ITER , ITER )
const
421 vigra_fail(
"KolmogorovSmirnovScore::region_score(): Region score not available for the Kolmogorov-Smirnov split.");
427 template <
typename ARR>
428 struct RFNodeDescription
431 RFNodeDescription(
size_t depth, ARR
const & priors)
443 template <
typename LABELS,
typename ITER>
444 bool is_pure(LABELS
const & , RFNodeDescription<ITER>
const & desc)
447 for (
auto n : desc.priors_)
466 template <
typename LABELS,
typename ITER>
467 bool operator()(LABELS
const & labels, RFNodeDescription<ITER>
const & desc)
const
469 return is_pure(labels, desc);
482 max_depth_(max_depth)
485 template <
typename LABELS,
typename ITER>
486 bool operator()(LABELS
const & labels, RFNodeDescription<ITER>
const & desc)
const
488 if (desc.depth_ >= max_depth_)
491 return is_pure(labels, desc);
508 template <
typename LABELS,
typename ARR>
509 bool operator()(LABELS
const & labels, RFNodeDescription<ARR>
const & desc)
const
511 typedef typename ARR::value_type value_type;
512 if (std::accumulate(desc.priors_.begin(), desc.priors_.end(),
static_cast<value_type
>(0)) <= min_n_)
515 return is_pure(labels, desc);
530 logtau_(std::
log(tau))
532 vigra_precondition(tau > 0 && tau < 1,
"NodeComplexityStop(): Tau must be in the open interval (0, 1).");
535 template <
typename LABELS,
typename ARR>
536 bool operator()(LABELS
const & , RFNodeDescription<ARR>
const & desc)
538 typedef typename ARR::value_type value_type;
541 size_t const total = std::accumulate(desc.priors_.begin(), desc.priors_.end(),
static_cast<value_type
>(0));
546 for (
auto v : desc.priors_)
551 lg +=
loggamma(static_cast<double>(v+1));
554 lg +=
loggamma(static_cast<double>(nnz+1));
555 lg -=
loggamma(static_cast<double>(total+1));
559 return lg >= logtau_;
565 enum RandomForestOptionTags
589 features_per_node_(0),
590 features_per_node_switch_(RF_SQRT),
591 bootstrap_sampling_(
true),
595 node_complexity_tau_(-1),
596 min_num_instances_(1),
597 use_stratification_(
false),
609 tree_count_ = p_tree_count;
622 features_per_node_switch_ = RF_CONST;
623 features_per_node_ = p_features_per_node;
639 vigra_precondition(p_features_per_node_switch == RF_SQRT ||
640 p_features_per_node_switch == RF_LOG ||
641 p_features_per_node_switch == RF_ALL,
642 "RandomForestOptions::features_per_node(): Input must be RF_SQRT, RF_LOG or RF_ALL.");
643 features_per_node_switch_ = p_features_per_node_switch;
654 bootstrap_sampling_ = b;
666 bootstrap_sampling_ =
false;
682 vigra_precondition(p_split == RF_GINI ||
683 p_split == RF_ENTROPY ||
685 "RandomForestOptions::split(): Input must be RF_GINI, RF_ENTROPY or RF_KSD.");
708 node_complexity_tau_ = tau;
719 min_num_instances_ = n;
733 use_stratification_ = b;
774 if (features_per_node_switch_ == RF_SQRT)
776 else if (features_per_node_switch_ == RF_LOG)
778 else if (features_per_node_switch_ == RF_CONST)
779 return features_per_node_;
780 else if (features_per_node_switch_ == RF_ALL)
782 vigra_fail(
"RandomForestOptions::get_features_per_node(): Unknown switch.");
787 int features_per_node_;
788 RandomForestOptionTags features_per_node_switch_;
789 bool bootstrap_sampling_;
790 size_t resample_count_;
791 RandomForestOptionTags split_;
793 double node_complexity_tau_;
794 size_t min_num_instances_;
795 bool use_stratification_;
797 std::vector<double> class_weights_;
803 template <
typename LabelType>
818 ProblemSpec & num_features(
size_t n)
824 ProblemSpec & num_instances(
size_t n)
830 ProblemSpec & num_classes(
size_t n)
836 ProblemSpec & distinct_classes(std::vector<LabelType> v)
838 distinct_classes_ = v;
839 num_classes_ = v.size();
843 ProblemSpec & actual_mtry(
size_t m)
849 ProblemSpec & actual_msample(
size_t m)
855 bool operator==(ProblemSpec
const & other)
const
857 #define COMPARE(field) if (field != other.field) return false;
858 COMPARE(num_features_);
859 COMPARE(num_instances_);
860 COMPARE(num_classes_);
861 COMPARE(distinct_classes_);
862 COMPARE(actual_mtry_);
863 COMPARE(actual_msample_);
868 size_t num_features_;
869 size_t num_instances_;
871 std::vector<LabelType> distinct_classes_;
873 size_t actual_msample_;
RandomForestOptions & min_num_instances(size_t n)
Do not split a node if it contains less than min_num_instances data points.
Definition: random_forest_common.hxx:717
RandomForestOptions & split(RandomForestOptionTags p_split)
The split criterion.
Definition: random_forest_common.hxx:680
Random forest 'maximum depth' stop criterion.
Definition: random_forest_common.hxx:476
size_t get_features_per_node(size_t total) const
Get the actual number of features per node.
Definition: random_forest_common.hxx:772
DepthStop(size_t max_depth)
Constructor: terminate tree construction at max_depth.
Definition: random_forest_common.hxx:480
RandomForestOptions & features_per_node(RandomForestOptionTags p_features_per_node_switch)
The number of features that are considered when computing the split.
Definition: random_forest_common.hxx:637
RandomForestOptions & max_depth(size_t d)
Do not split a node if its depth is greater or equal to max_depth.
Definition: random_forest_common.hxx:695
RandomForestOptions & bootstrap_sampling(bool b)
Use bootstrap sampling.
Definition: random_forest_common.hxx:652
problem specification class for the random forest.
Definition: rf_common.hxx:538
Definition: random_forest_common.hxx:217
RandomForestOptions & class_weights(std::vector< double > const &v)
Each datapoint is weighted by its class weight. By default, each class has weight 1...
Definition: random_forest_common.hxx:759
RandomForestOptions & use_stratification(bool b)
Use stratification when creating the bootstrap samples.
Definition: random_forest_common.hxx:731
RandomForestOptions & resample_count(size_t n)
If resample_count is greater than zero, the split in each node is computed using only resample_count ...
Definition: random_forest_common.hxx:663
Functor that computes the entropy score.
Definition: random_forest_common.hxx:348
RandomForestOptions & n_threads(int n)
The number of threads that are used in training.
Definition: random_forest_common.hxx:744
bool operator==(FFTWComplex< R > const &a, const FFTWComplex< R > &b)
equal
Definition: fftw3.hxx:825
Random forest 'node purity' stop criterion.
Definition: random_forest_common.hxx:463
NodeComplexityStop(double tau=0.001)
Constructor: stop when fewer than 1/tau label arrangements are possible.
Definition: random_forest_common.hxx:528
RandomForestOptions & node_complexity_tau(double tau)
Value of the node complexity termination criterion.
Definition: random_forest_common.hxx:706
linalg::TemporaryMatrix< T > log(MultiArrayView< 2, T, C > const &v)
double loggamma(double x)
The natural logarithm of the gamma function.
Definition: mathutil.hxx:1603
Functor that computes the gini score.
Definition: random_forest_common.hxx:296
Random forest 'node complexity' stop criterion.
Definition: random_forest_common.hxx:524
Options class for vigra::rf3::RandomForest version 3.
Definition: random_forest_common.hxx:582
RandomForestOptions & features_per_node(int p_features_per_node)
The number of features that are considered when computing the split.
Definition: random_forest_common.hxx:620
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up.
Definition: fixedpoint.hxx:675
RandomForestOptions & tree_count(int p_tree_count)
The number of trees.
Definition: random_forest_common.hxx:607
NumInstancesStop(size_t min_n)
Constructor: terminate tree construction when node contains less than min_n instances.
Definition: random_forest_common.hxx:503
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root.
Definition: fixedpoint.hxx:616
Functor that computes the Kolmogorov-Smirnov score.
Definition: random_forest_common.hxx:382
Random forest 'number of datapoints' stop criterion.
Definition: random_forest_common.hxx:499