SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
multiclass
tree
VwConditionalProbabilityTree.h
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Chiyuan Zhang
8
* Copyright (C) 2012 Chiyuan Zhang
9
*/
10
11
#ifndef CONDITIONALPROBABILITYTREE_H__
12
#define CONDITIONALPROBABILITYTREE_H__
13
14
#include <map>
15
16
#include <
shogun/multiclass/tree/TreeMachine.h
>
17
#include <
shogun/classifier/vw/VowpalWabbit.h
>
18
19
namespace
shogun
20
{
21
23
struct
VwConditionalProbabilityTreeNodeData
24
{
26
int32_t
label
;
28
float64_t
p_right
;
29
31
VwConditionalProbabilityTreeNodeData
():
label
(-1),
p_right
(0) {}
32
};
33
35
typedef
CBinaryTreeMachineNode<VwConditionalProbabilityTreeNodeData>
bnode_t
;
36
38
class
CVwConditionalProbabilityTree
:
public
CTreeMachine
<VwConditionalProbabilityTreeNodeData>
39
{
40
public
:
41
43
CVwConditionalProbabilityTree
(int32_t num_passes=1)
44
:
m_num_passes
(num_passes),
m_feats
(NULL)
45
{
46
}
47
49
virtual
~CVwConditionalProbabilityTree
() {}
50
52
virtual
const
char
*
get_name
()
const
{
return
"VwConditionalProbabilityTree"
; }
53
55
void
set_num_passes
(int32_t num_passes)
56
{
57
m_num_passes
= num_passes;
58
}
59
61
int32_t
get_num_passes
()
const
62
{
63
return
m_num_passes
;
64
}
65
69
void
set_features
(
CStreamingVwFeatures
*feats)
70
{
71
SG_REF
(feats);
72
SG_UNREF
(
m_feats
);
73
m_feats
= feats;
74
}
75
77
virtual
CMulticlassLabels
*
apply_multiclass
(
CFeatures
* data=NULL);
78
80
virtual
int32_t
apply_multiclass_example
(
VwExample
* ex);
81
protected
:
83
virtual
bool
train_require_labels
()
const
{
return
false
; }
84
91
virtual
bool
train_machine
(
CFeatures
* data);
92
96
void
train_example
(
VwExample
*ex);
97
102
void
train_path
(
VwExample
*ex,
bnode_t
*
node
);
103
109
float64_t
train_node
(
VwExample
*ex,
bnode_t
*
node
);
110
114
int32_t
create_machine
(
VwExample
*ex);
115
121
virtual
bool
which_subtree
(
bnode_t
*
node
,
VwExample
*ex)=0;
122
124
void
compute_conditional_probabilities
(
VwExample
*ex);
125
129
float64_t
accumulate_conditional_probability
(
bnode_t
*leaf);
130
131
int32_t
m_num_passes
;
132
std::map<int32_t, bnode_t*>
m_leaves
;
133
CStreamingVwFeatures
*
m_feats
;
134
};
135
136
}
/* shogun */
137
138
#endif
/* end of include guard: CONDITIONALPROBABILITYTREE_H__ */
139
SHOGUN
机器学习工具包 - 项目文档