SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
evaluation
CrossValidationMulticlassStorage.h
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Heiko Strathmann, Sergey Lisitsyn
8
*
9
*/
10
11
#ifndef CROSSVALIDATIONMULTICLASSSTORAGE_H_
12
#define CROSSVALIDATIONMULTICLASSSTORAGE_H_
13
14
#include <
shogun/evaluation/CrossValidationOutput.h
>
15
#include <
shogun/evaluation/BinaryClassEvaluation.h
>
16
#include <
shogun/labels/MulticlassLabels.h
>
17
#include <
shogun/lib/SGMatrix.h
>
18
#include <
shogun/lib/DynamicObjectArray.h
>
19
20
namespace
shogun
21
{
22
23
class
CMachine;
24
class
CLabels;
25
class
CEvaluation;
26
31
class
CCrossValidationMulticlassStorage
:
public
CCrossValidationOutput
32
{
33
public
:
34
40
CCrossValidationMulticlassStorage
(
bool
compute_ROC=
true
,
bool
compute_PRC=
false
,
bool
compute_conf_matrices=
false
);
41
43
virtual
~CCrossValidationMulticlassStorage
();
44
52
SGMatrix<float64_t>
get_fold_ROC
(int32_t run, int32_t fold, int32_t c)
53
{
54
ASSERT
(0<=run)
55
ASSERT
(run<
m_num_runs
)
56
ASSERT
(0<=fold)
57
ASSERT
(fold<
m_num_folds
)
58
ASSERT
(0<=c)
59
ASSERT
(c<
m_num_classes
)
60
REQUIRE
(
m_compute_ROC
,
"ROC computation was not enabled\n"
)
61
return
m_fold_ROC_graphs
[run*
m_num_folds
*
m_num_classes
+fold*
m_num_classes
+c];
62
}
63
71
SGMatrix<float64_t>
get_fold_PRC
(int32_t run, int32_t fold, int32_t c)
72
{
73
ASSERT
(0<=run)
74
ASSERT
(run<
m_num_runs
)
75
ASSERT
(0<=fold)
76
ASSERT
(fold<
m_num_folds
)
77
ASSERT
(0<=c)
78
ASSERT
(c<
m_num_classes
)
79
REQUIRE
(
m_compute_PRC
,
"PRC computation was not enabled\n"
)
80
return
m_fold_PRC_graphs
[run*
m_num_folds
*
m_num_classes
+fold*
m_num_classes
+c];
81
}
82
87
void
append_binary_evaluation
(
CBinaryClassEvaluation
* evaluation)
88
{
89
m_binary_evaluations
->
push_back
(evaluation);
90
}
91
96
CBinaryClassEvaluation
*
get_binary_evaluation
(int32_t idx)
97
{
98
return
(
CBinaryClassEvaluation
*)
m_binary_evaluations
->
get_element_safe
(idx);
99
}
100
108
float64_t
get_fold_evaluation_result
(int32_t run, int32_t fold, int32_t c, int32_t e)
109
{
110
ASSERT
(0<=run)
111
ASSERT
(run<
m_num_runs
)
112
ASSERT
(0<=fold)
113
ASSERT
(fold<
m_num_folds
)
114
ASSERT
(0<=c)
115
ASSERT
(c<
m_num_classes
)
116
ASSERT
(0<=e)
117
int32_t n_evals =
m_binary_evaluations
->
get_num_elements
();
118
ASSERT
(e<n_evals)
119
return
m_evaluations_results
[run*
m_num_folds
*
m_num_classes
*n_evals+fold*
m_num_classes
*n_evals+c*n_evals+e];
120
}
121
126
float64_t
get_fold_accuracy
(int32_t run, int32_t fold)
127
{
128
ASSERT
(0<=run)
129
ASSERT
(run<
m_num_runs
)
130
ASSERT
(0<=fold)
131
ASSERT
(fold<
m_num_folds
)
132
return
m_accuracies
[run*
m_num_folds
+fold];
133
}
134
139
SGMatrix<int32_t>
get_fold_conf_matrix
(int32_t run, int32_t fold)
140
{
141
ASSERT
(0<=run)
142
ASSERT
(run<
m_num_runs
)
143
ASSERT
(0<=fold)
144
ASSERT
(fold<
m_num_folds
)
145
REQUIRE
(
m_compute_conf_matrices
,
"Confusion matrices computation was not enabled\n"
)
146
return
m_conf_matrices
[run*
m_num_folds
+fold];
147
}
148
150
virtual
void
post_init
();
151
153
virtual
void
post_update_results
();
154
158
virtual
void
init_expose_labels
(
CLabels
* labels);
159
165
virtual
void
update_test_result
(
CLabels
* results,
166
const
char
* prefix=
""
);
167
173
virtual
void
update_test_true_result
(
CLabels
* results,
174
const
char
* prefix=
""
);
175
177
virtual
const
char
*
get_name
()
const
{
return
"CrossValidationMulticlassStorage"
; }
178
179
protected
:
180
182
bool
m_initialized
;
183
185
CDynamicObjectArray
*
m_binary_evaluations
;
186
188
SGVector<float64_t>
m_evaluations_results
;
189
191
SGVector<float64_t>
m_accuracies
;
192
194
bool
m_compute_ROC
;
195
197
SGMatrix<float64_t>
*
m_fold_ROC_graphs
;
198
200
bool
m_compute_PRC
;
201
203
SGMatrix<float64_t>
*
m_fold_PRC_graphs
;
204
206
bool
m_compute_conf_matrices
;
207
209
SGMatrix<int32_t>
*
m_conf_matrices
;
210
212
CMulticlassLabels
*
m_pred_labels
;
213
215
CMulticlassLabels
*
m_true_labels
;
216
218
int32_t
m_num_classes
;
219
220
};
221
222
}
223
224
#endif
/* CROSSVALIDATIONMULTICLASSSTORAGE_H_ */
SHOGUN
机器学习工具包 - 项目文档