SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
multiclass
tree
RelaxedTree.h
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Chiyuan Zhang
8
* Copyright (C) 2012 Chiyuan Zhang
9
*/
10
11
#ifndef RELAXEDTREE_H__
12
#define RELAXEDTREE_H__
13
14
#include <utility>
15
#include <vector>
16
17
#include <
shogun/features/DenseFeatures.h
>
18
#include <
shogun/classifier/svm/LibSVM.h
>
19
#include <
shogun/multiclass/tree/TreeMachine.h
>
20
#include <
shogun/multiclass/tree/RelaxedTreeNodeData.h
>
21
22
namespace
shogun
23
{
24
25
class
CBaseMulticlassMachine;
26
34
class
CRelaxedTree
:
public
CTreeMachine
<RelaxedTreeNodeData>
35
{
36
public
:
38
CRelaxedTree
();
39
41
virtual
~CRelaxedTree
();
42
44
virtual
const
char
*
get_name
()
const
{
return
"RelaxedTree"
; }
45
47
virtual
CMulticlassLabels
*
apply_multiclass
(
CFeatures
* data=NULL);
48
52
void
set_features
(
CDenseFeatures<float64_t>
*feats)
53
{
54
SG_REF
(feats);
55
SG_UNREF
(
m_feats
);
56
m_feats
= feats;
57
}
58
62
virtual
void
set_kernel
(
CKernel
*kernel)
63
{
64
SG_REF
(kernel);
65
SG_UNREF
(
m_kernel
);
66
m_kernel
= kernel;
67
}
68
73
virtual
void
set_labels
(
CLabels
* lab)
74
{
75
CMulticlassLabels
*mlab =
dynamic_cast<
CMulticlassLabels
*
>
(lab);
76
REQUIRE
(lab,
"requires MulticlassLabes\n"
)
77
78
CMachine::set_labels
(mlab);
79
m_num_classes
= mlab->
get_num_classes
();
80
}
81
85
void
set_machine_for_confusion_matrix
(
CBaseMulticlassMachine
*machine)
86
{
87
SG_REF
(machine);
88
SG_UNREF
(
m_machine_for_confusion_matrix
);
89
m_machine_for_confusion_matrix
= machine;
90
}
91
95
void
set_svm_C
(
float64_t
C)
96
{
97
m_svm_C
= C;
98
}
102
float64_t
get_svm_C
()
const
103
{
104
return
m_svm_C
;
105
}
106
110
void
set_svm_epsilon
(
float64_t
epsilon
)
111
{
112
m_svm_epsilon
=
epsilon
;
113
}
117
float64_t
get_svm_epsilon
()
const
118
{
119
return
m_svm_epsilon
;
120
}
121
127
void
set_A
(
float64_t
A)
128
{
129
m_A
= A;
130
}
134
float64_t
get_A
()
const
135
{
136
return
m_A
;
137
}
138
143
void
set_B
(int32_t B)
144
{
145
m_B
= B;
146
}
150
int32_t
get_B
()
const
151
{
152
return
m_B
;
153
}
154
158
void
set_max_num_iter
(int32_t n_iter)
159
{
160
m_max_num_iter
= n_iter;
161
}
165
int32_t
get_max_num_iter
()
const
166
{
167
return
m_max_num_iter
;
168
}
169
179
virtual
bool
train
(
CFeatures
* data=NULL)
180
{
181
return
CMachine::train
(data);
182
}
183
185
typedef
std::pair<std::pair<int32_t, int32_t>,
float64_t
>
entry_t
;
186
protected
:
193
float64_t
apply_one
(int32_t idx);
194
201
virtual
bool
train_machine
(
CFeatures
* data);
202
204
bnode_t
*
train_node
(
const
SGMatrix<float64_t>
&conf_mat,
SGVector<int32_t>
classes);
206
std::vector<entry_t>
init_node
(
const
SGMatrix<float64_t>
&global_conf_mat,
SGVector<int32_t>
classes);
208
SGVector<int32_t>
train_node_with_initialization
(
const
CRelaxedTree::entry_t
&mu_entry,
SGVector<int32_t>
classes,
CSVM
*svm);
209
211
float64_t
compute_score
(
SGVector<int32_t>
mu,
CSVM
*svm);
213
SGVector<int32_t>
color_label_space
(
CSVM
*svm,
SGVector<int32_t>
classes);
215
SGVector<float64_t>
eval_binary_model_K
(
CSVM
*svm);
216
218
void
enforce_balance_constraints_upper
(
SGVector<int32_t>
&mu,
SGVector<float64_t>
&delta_neg,
SGVector<float64_t>
&delta_pos, int32_t B_prime,
SGVector<float64_t>
& xi_neg_class);
220
void
enforce_balance_constraints_lower
(
SGVector<int32_t>
&mu,
SGVector<float64_t>
&delta_neg,
SGVector<float64_t>
&delta_pos, int32_t B_prime,
SGVector<float64_t>
& xi_neg_class);
221
223
int32_t
m_max_num_iter
;
225
float64_t
m_A
;
227
int32_t
m_B
;
229
float64_t
m_svm_C
;
231
float64_t
m_svm_epsilon
;
233
CKernel
*
m_kernel
;
235
CDenseFeatures<float64_t>
*
m_feats
;
237
CBaseMulticlassMachine
*
m_machine_for_confusion_matrix
;
239
int32_t
m_num_classes
;
240
};
241
242
}
/* shogun */
243
244
#endif
/* end of include guard: RELAXEDTREE_H__ */
245
SHOGUN
机器学习工具包 - 项目文档