SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
multiclass
MulticlassOneVsRestStrategy.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Chiyuan Zhang
8
* Copyright (C) 2012 Chiyuan Zhang
9
*/
10
11
#include <
shogun/multiclass/MulticlassOneVsRestStrategy.h
>
12
#include <
shogun/labels/BinaryLabels.h
>
13
#include <
shogun/labels/MulticlassLabels.h
>
14
#include <
shogun/mathematics/Math.h
>
15
16
using namespace
shogun;
17
18
CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy
()
19
:
CMulticlassStrategy
()
20
{
21
}
22
23
CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy
(
EProbHeuristicType
prob_heuris)
24
:
CMulticlassStrategy
(prob_heuris)
25
{
26
}
27
28
SGVector<int32_t>
CMulticlassOneVsRestStrategy::train_prepare_next
()
29
{
30
for
(int32_t i=0; i <
m_orig_labels
->
get_num_labels
(); ++i)
31
{
32
if
(((
CMulticlassLabels
*)
m_orig_labels
)->get_int_label(i)==
m_train_iter
)
33
((
CBinaryLabels
*)
m_train_labels
)->set_label(i, +1.0);
34
else
35
((
CBinaryLabels
*)
m_train_labels
)->set_label(i, -1.0);
36
}
37
38
// increase m_train_iter *after* setting labels
39
CMulticlassStrategy::train_prepare_next
();
40
41
return
SGVector<int32_t>
();
42
}
43
44
int32_t
CMulticlassOneVsRestStrategy::decide_label
(
SGVector<float64_t>
outputs)
45
{
46
if
(
m_rejection_strategy
&&
m_rejection_strategy
->
reject
(outputs))
47
return
CDenseLabels::REJECTION_LABEL
;
48
49
return
SGVector<float64_t>::arg_max
(outputs.
vector
, 1, outputs.
vlen
);
50
}
51
52
SGVector<index_t>
CMulticlassOneVsRestStrategy::decide_label_multiple_output
(
SGVector<float64_t>
outputs, int32_t n_outputs)
53
{
54
float64_t
* outputs_ = SG_MALLOC(
float64_t
, outputs.
vlen
);
55
int32_t* indices_ = SG_MALLOC(int32_t, outputs.
vlen
);
56
for
(int32_t i=0; i<outputs.
vlen
; i++)
57
{
58
outputs_[i] = outputs[i];
59
indices_[i] = i;
60
}
61
CMath::qsort_backward_index
(outputs_,indices_,outputs.
vlen
);
62
SGVector<index_t>
result(n_outputs);
63
for
(int32_t i=0; i<n_outputs; i++)
64
result[i] = indices_[i];
65
SG_FREE(outputs_);
66
SG_FREE(indices_);
67
return
result;
68
}
69
70
void
CMulticlassOneVsRestStrategy::rescale_outputs
(
SGVector<float64_t>
outputs)
71
{
72
switch
(
get_prob_heuris_type
())
73
{
74
case
OVA_NORM
:
75
rescale_heuris_norm
(outputs);
76
break
;
77
case
OVA_SOFTMAX
:
78
SG_ERROR
(
"%s::rescale_outputs(): Need to specify sigmoid parameters!\n"
,
get_name
());
79
break
;
80
case
PROB_HEURIS_NONE
:
81
break
;
82
default
:
83
SG_ERROR
(
"%s::rescale_outputs(): Unknown OVA probability heuristic type!\n"
,
get_name
());
84
break
;
85
}
86
}
87
88
void
CMulticlassOneVsRestStrategy::rescale_outputs
(
SGVector<float64_t>
outputs,
89
const
SGVector<float64_t>
As,
const
SGVector<float64_t>
Bs)
90
{
91
if
(
get_prob_heuris_type
()==
OVA_SOFTMAX
)
92
rescale_heuris_softmax
(outputs,As,Bs);
93
else
94
rescale_outputs
(outputs);
95
}
96
97
void
CMulticlassOneVsRestStrategy::rescale_heuris_norm
(
SGVector<float64_t>
outputs)
98
{
99
if
(
m_num_classes
!= outputs.
vlen
)
100
{
101
SG_ERROR
(
"%s::rescale_heuris_norm(): size(outputs) = %d != m_num_classes = %d\n"
,
102
get_name
(), outputs.
vlen
,
m_num_classes
);
103
}
104
105
float64_t
norm
=
SGVector<float64_t>::sum
(outputs);
106
norm += 1E-10;
107
for
(int32_t i=0; i<outputs.
vlen
; i++)
108
outputs[i] /= norm;
109
}
110
111
void
CMulticlassOneVsRestStrategy::rescale_heuris_softmax
(
SGVector<float64_t>
outputs,
112
const
SGVector<float64_t>
As,
const
SGVector<float64_t>
Bs)
113
{
114
if
(
m_num_classes
!= outputs.
vlen
)
115
{
116
SG_ERROR
(
"%s::rescale_heuris_softmax(): size(outputs) = %d != m_num_classes = %d\n"
,
117
get_name
(), outputs.
vlen
,
m_num_classes
);
118
}
119
120
for
(int32_t i=0; i<outputs.
vlen
; i++)
121
outputs[i] =
CMath::exp
(-As[i]*outputs[i]-Bs[i]);
122
123
float64_t
norm
=
SGVector<float64_t>::sum
(outputs);
124
norm += 1E-10;
125
for
(int32_t i=0; i<outputs.
vlen
; i++)
126
outputs[i] /= norm;
127
}
SHOGUN
机器学习工具包 - 项目文档