17 using namespace shogun;
25 SG_ERROR(
"Expected StreamingDenseFeatures\n")
27 SG_ERROR(
"Expected float32_t feature type\n")
32 vector<int32_t> predicts;
34 m_feats->start_parser();
35 while (m_feats->get_next_example())
37 predicts.push_back(apply_multiclass_example(m_feats->get_vector()));
38 m_feats->release_example();
40 m_feats->end_parser();
43 for (
size_t i=0; i < predicts.size(); ++i)
50 compute_conditional_probabilities(ex);
52 for (map<int32_t,bnode_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
54 probs[it->first] = accumulate_conditional_probability(it->second);
61 stack<bnode_t *> nodes;
64 while (!nodes.empty())
70 nodes.push(node->
left());
71 nodes.push(node->
right());
74 node->
data.p_right = predict_node(ex, node);
85 if (leaf == par->
left())
86 prob *= (1-par->
data.p_right);
88 prob *= par->
data.p_right;
102 SG_ERROR(
"Expected StreamingDenseFeatures\n")
104 SG_ERROR(
"Expected float32_t features\n")
110 SG_ERROR(
"No data features provided\n")
113 m_machines->reset_array();
119 m_feats->start_parser();
120 for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
122 while (m_feats->get_next_example())
124 train_example(m_feats->get_vector(),
static_cast<int32_t
>(m_feats->get_label()));
125 m_feats->release_example();
128 if (ipass < m_num_passes-1)
129 m_feats->reset_stream();
131 m_feats->end_parser();
133 for (int32_t i=0; i < m_machines->get_num_elements(); ++i)
148 printf(
"Empty Tree\n");
156 m_root->data.label = label;
157 m_leaves.insert(make_pair(label, (
bnode_t*) m_root));
158 m_root->machine(create_machine(ex));
162 if (m_leaves.find(label) != m_leaves.end())
164 train_path(ex, m_leaves[label]);
169 while (node->
left() != NULL)
172 bool is_left = which_subtree(node, ex);
178 train_node(ex, node_label, node);
183 node = node->
right();
186 m_leaves.erase(node->
data.label);
189 left_node->
data.label = node->
data.label;
190 node->
data.label = -1;
195 m_machines->push_back(mch);
196 left_node->
machine(m_machines->get_num_elements()-1);
197 m_leaves.insert(make_pair(left_node->
data.label, left_node));
198 node->
left(left_node);
201 right_node->
data.label = label;
202 right_node->
machine(create_machine(ex));
203 m_leaves.insert(make_pair(label, right_node));
204 node->
right(right_node);
211 train_node(ex, node_label, node);
216 if (par->
left() == node)
221 train_node(ex, node_label, par);
229 REQUIRE(node,
"Node must not be NULL\n");
231 REQUIRE(mch,
"Instance of %s could not be casted to COnlineLibLinear\n", node->
get_name());
238 REQUIRE(node,
"Node must not be NULL\n");
240 REQUIRE(mch,
"Instance of %s could not be casted to COnlineLibLinear\n", node->
get_name());
252 m_machines->push_back(mch);
253 return m_machines->get_num_elements()-1;