SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
multiclass
tree
ConditionalProbabilityTree.h
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Chiyuan Zhang
8
* Copyright (C) 2012 Chiyuan Zhang
9
*/
10
11
#ifndef CONDITIONALPROBABILITYTREE_H__
12
#define CONDITIONALPROBABILITYTREE_H__
13
14
#include <map>
15
16
#include <
shogun/features/streaming/StreamingDenseFeatures.h
>
17
#include <
shogun/multiclass/tree/TreeMachine.h
>
18
#include <
shogun/multiclass/tree/ConditionalProbabilityTreeNodeData.h
>
19
20
namespace
shogun
21
{
22
31
class
CConditionalProbabilityTree
:
public
CTreeMachine
<ConditionalProbabilityTreeNodeData>
32
{
33
public
:
35
CConditionalProbabilityTree
(int32_t num_passes=1)
36
:
m_num_passes
(num_passes),
m_feats
(NULL)
37
{
38
}
39
41
virtual
~CConditionalProbabilityTree
() {
SG_UNREF
(
m_feats
); }
42
44
virtual
const
char
*
get_name
()
const
{
return
"ConditionalProbabilityTree"
; }
45
47
void
set_num_passes
(int32_t num_passes)
48
{
49
m_num_passes
= num_passes;
50
}
51
53
int32_t
get_num_passes
()
const
54
{
55
return
m_num_passes
;
56
}
57
61
void
set_features
(
CStreamingDenseFeatures<float32_t>
*feats)
62
{
63
SG_REF
(feats);
64
SG_UNREF
(
m_feats
);
65
m_feats
= feats;
66
}
67
69
virtual
CMulticlassLabels
*
apply_multiclass
(
CFeatures
* data=NULL);
70
74
virtual
int32_t
apply_multiclass_example
(
SGVector<float32_t>
ex);
75
77
void
print_tree
();
78
protected
:
80
virtual
bool
train_require_labels
()
const
{
return
false
; }
81
88
virtual
bool
train_machine
(
CFeatures
* data);
89
94
void
train_example
(
SGVector<float32_t>
ex, int32_t label);
95
100
void
train_path
(
SGVector<float32_t>
ex,
bnode_t
*
node
);
101
107
void
train_node
(
SGVector<float32_t>
ex,
float64_t
label,
bnode_t
*
node
);
108
113
float64_t
predict_node
(
SGVector<float32_t>
ex,
bnode_t
*
node
);
114
118
int32_t
create_machine
(
SGVector<float32_t>
ex);
119
125
virtual
bool
which_subtree
(
bnode_t
*
node
,
SGVector<float32_t>
ex)=0;
126
128
void
compute_conditional_probabilities
(
SGVector<float32_t>
ex);
129
133
float64_t
accumulate_conditional_probability
(
bnode_t
*leaf);
134
135
int32_t
m_num_passes
;
136
std::map<int32_t, bnode_t*>
m_leaves
;
137
CStreamingDenseFeatures<float32_t>
*
m_feats
;
138
};
139
140
}
/* shogun */
141
142
#endif
/* end of include guard: CONDITIONALPROBABILITYTREE_H__ */
143
SHOGUN
机器学习工具包 - 项目文档