SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
classifier
svm
SVMLin.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2006-2009 Soeren Sonnenburg
8
* Copyright (C) 2006-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9
*/
10
11
#include <
shogun/classifier/svm/SVMLin.h
>
12
#include <
shogun/labels/Labels.h
>
13
#include <
shogun/mathematics/Math.h
>
14
#include <shogun/lib/external/ssl.h>
15
#include <
shogun/machine/LinearMachine.h
>
16
#include <
shogun/features/DotFeatures.h
>
17
#include <
shogun/labels/Labels.h
>
18
#include <
shogun/labels/BinaryLabels.h
>
19
20
using namespace
shogun;
21
22
CSVMLin::CSVMLin
()
23
:
CLinearMachine
(), C1(1), C2(1),
epsilon
(1e-5), use_bias(true)
24
{
25
}
26
27
CSVMLin::CSVMLin
(
28
float64_t
C,
CDotFeatures
* traindat,
CLabels
* trainlab)
29
:
CLinearMachine
(), C1(C), C2(C),
epsilon
(1e-5), use_bias(true)
30
{
31
set_features
(traindat);
32
set_labels
(trainlab);
33
}
34
35
36
CSVMLin::~CSVMLin
()
37
{
38
}
39
40
bool
CSVMLin::train_machine
(
CFeatures
* data)
41
{
42
ASSERT
(
m_labels
)
43
44
if
(data)
45
{
46
if
(!data->
has_property
(
FP_DOT
))
47
SG_ERROR
(
"Specified features are not of type CDotFeatures\n"
)
48
set_features
((
CDotFeatures
*) data);
49
}
50
51
ASSERT
(
features
)
52
53
SGVector<float64_t>
train_labels=((
CBinaryLabels
*)
m_labels
)->get_labels();
54
int32_t num_feat=
features
->
get_dim_feature_space
();
55
int32_t num_vec=
features
->
get_num_vectors
();
56
57
ASSERT
(num_vec==train_labels.
vlen
)
58
59
struct
options Options;
60
struct
data Data;
61
struct
vector_double Weights;
62
struct
vector_double Outputs;
63
64
Data.l=num_vec;
65
Data.m=num_vec;
66
Data.u=0;
67
Data.n=num_feat+1;
68
Data.nz=num_feat+1;
69
Data.Y=train_labels.
vector
;
70
Data.features=
features
;
71
Data.C = SG_MALLOC(
float64_t
, Data.l);
72
73
Options.algo = SVM;
74
Options.lambda=1/(2*
get_C1
());
75
Options.lambda_u=1/(2*
get_C1
());
76
Options.S=10000;
77
Options.R=0.5;
78
Options.epsilon =
get_epsilon
();
79
Options.cgitermax=10000;
80
Options.mfnitermax=50;
81
Options.Cp =
get_C2
()/
get_C1
();
82
Options.Cn = 1;
83
84
if
(
use_bias
)
85
Options.bias=1.0;
86
else
87
Options.bias=0.0;
88
89
for
(int32_t i=0;i<num_vec;i++)
90
{
91
if
(train_labels.
vector
[i]>0)
92
Data.C[i]=Options.Cp;
93
else
94
Data.C[i]=Options.Cn;
95
}
96
ssl_train(&Data, &Options, &Weights, &Outputs);
97
ASSERT
(Weights.vec && Weights.d==num_feat+1)
98
99
float64_t
sgn=train_labels.
vector
[0];
100
for
(int32_t i=0; i<num_feat+1; i++)
101
Weights.vec[i]*=sgn;
102
103
set_w
(
SGVector<float64_t>
(Weights.vec, num_feat));
104
set_bias
(Weights.vec[num_feat]);
105
106
SG_FREE(Data.C);
107
SG_FREE(Outputs.vec);
108
return
true
;
109
}
SHOGUN
机器学习工具包 - 项目文档