SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
labels
MulticlassLabels.cpp
浏览该文件的文档.
1
#include <
shogun/labels/DenseLabels.h
>
2
#include <
shogun/labels/BinaryLabels.h
>
3
#include <
shogun/labels/MulticlassLabels.h
>
4
#include <
shogun/base/ParameterMap.h
>
5
6
using namespace
shogun;
7
8
CMulticlassLabels::CMulticlassLabels
() :
CDenseLabels
()
9
{
10
init();
11
}
12
13
CMulticlassLabels::CMulticlassLabels
(int32_t num_labels) :
CDenseLabels
(num_labels)
14
{
15
init();
16
}
17
18
CMulticlassLabels::CMulticlassLabels
(
const
SGVector<float64_t>
src) :
CDenseLabels
()
19
{
20
init();
21
set_labels
(src);
22
}
23
24
CMulticlassLabels::CMulticlassLabels
(
CFile
* loader) :
CDenseLabels
(loader)
25
{
26
init();
27
}
28
29
CMulticlassLabels::~CMulticlassLabels
()
30
{
31
}
32
33
void
CMulticlassLabels::init()
34
{
35
/* for this to work, migration has to be fixed */
36
// SG_ADD(&m_multiclass_confidences, "multiclass_confidences", "Vectors of "
37
// "multiclass confidences", MS_NOT_AVAILABLE);
38
39
// m_parameter_map->finalize_map();
40
41
m_multiclass_confidences
=
SGMatrix<float64_t>
();
42
}
43
44
void
CMulticlassLabels::set_multiclass_confidences
(int32_t i,
45
SGVector<float64_t>
confidences)
46
{
47
REQUIRE
(confidences.
size
()==
m_multiclass_confidences
.
num_rows
,
48
"%s::set_multiclass_confidences(): Length of confidences should "
49
"match size of the matrix"
,
get_name
());
50
51
for
(
index_t
j=0; j<confidences.
size
(); j++)
52
m_multiclass_confidences
(j,i) = confidences[j];
53
}
54
55
SGVector<float64_t>
CMulticlassLabels::get_multiclass_confidences
(int32_t i)
56
{
57
SGVector<float64_t>
confs(
m_multiclass_confidences
.
num_rows
);
58
for
(
index_t
j=0; j<confs.
size
(); j++)
59
confs[j] =
m_multiclass_confidences
(j,i);
60
61
return
confs;
62
}
63
64
void
CMulticlassLabels::allocate_confidences_for
(int32_t n_classes)
65
{
66
int32_t n_labels =
m_labels
.
size
();
67
REQUIRE
(n_labels!=0,
"%s::allocate_confidences_for(): There should be "
68
"labels to store confidences"
,
get_name
());
69
70
m_multiclass_confidences
=
SGMatrix<float64_t>
(n_classes,n_labels);
71
}
72
73
void
CMulticlassLabels::ensure_valid
(
const
char
* context)
74
{
75
CDenseLabels::ensure_valid
(context);
76
77
int32_t subset_size=
get_num_labels
();
78
for
(int32_t i=0; i<subset_size; i++)
79
{
80
int32_t real_i =
m_subset_stack
->
subset_idx_conversion
(i);
81
int32_t label = int32_t(
m_labels
[real_i]);
82
83
if
(label<0 ||
float64_t
(label)!=
m_labels
[real_i])
84
{
85
SG_ERROR
(
"%s%sMulticlass Labels must be in range 0...<nr_classes-1> and integers!\n"
,
86
context?context:
""
, context?
": "
:
""
);
87
}
88
}
89
}
90
91
ELabelType
CMulticlassLabels::get_label_type
()
const
92
{
93
return
LT_MULTICLASS
;
94
}
95
96
CBinaryLabels
*
CMulticlassLabels::get_binary_for_class
(int32_t i)
97
{
98
SGVector<float64_t>
binary_labels(
get_num_labels
());
99
100
bool
use_confidences =
false
;
101
if
((
m_multiclass_confidences
.
num_rows
!= 0) && (
m_multiclass_confidences
.
num_cols
!= 0))
102
{
103
use_confidences =
true
;
104
}
105
if
(use_confidences)
106
{
107
for
(int32_t k=0; k<binary_labels.
vlen
; k++)
108
{
109
int32_t label =
get_int_label
(k);
110
float64_t
confidence =
m_multiclass_confidences
(label,k);
111
binary_labels[k] = label == i ? confidence : -confidence;
112
}
113
}
114
else
115
{
116
for
(int32_t k=0; k<binary_labels.
vlen
; k++)
117
{
118
int32_t label =
get_int_label
(k);
119
binary_labels[k] = label == i ? +1.0 : -1.0;
120
}
121
}
122
return
new
CBinaryLabels
(binary_labels);
123
}
124
125
SGVector<float64_t>
CMulticlassLabels::get_unique_labels
()
126
{
127
/* extract all labels (copy because of possible subset) */
128
SGVector<float64_t>
unique_labels=
get_labels_copy
();
129
unique_labels.
vlen
=
SGVector<float64_t>::unique
(unique_labels.
vector
, unique_labels.
vlen
);
130
131
SGVector<float64_t>
result(unique_labels.
vlen
);
132
memcpy(result.
vector
, unique_labels.
vector
,
133
sizeof
(
float64_t
)*unique_labels.
vlen
);
134
135
return
result;
136
}
137
138
139
int32_t
CMulticlassLabels::get_num_classes
()
140
{
141
SGVector<float64_t>
unique=
get_unique_labels
();
142
return
unique.
vlen
;
143
}
SHOGUN
机器学习工具包 - 项目文档