SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
latent
LatentSVM.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Viktor Gal
8
* Copyright (C) 2012 Viktor Gal
9
*/
10
11
#include <typeinfo>
12
13
#include <
shogun/classifier/svm/SVMOcas.h
>
14
#include <
shogun/latent/LatentSVM.h
>
15
16
using namespace
shogun;
17
18
CLatentSVM::CLatentSVM
()
19
:
CLinearLatentMachine
()
20
{
21
}
22
23
CLatentSVM::CLatentSVM
(
CLatentModel
* model,
float64_t
C)
24
:
CLinearLatentMachine
(model, C)
25
{
26
}
27
28
CLatentSVM::~CLatentSVM
()
29
{
30
}
31
32
CLatentLabels
*
CLatentSVM::apply_latent
()
33
{
34
if
(!
m_model
)
35
SG_ERROR
(
"LatentModel is not set!\n"
)
36
37
if
(
m_model
->
get_num_vectors
() < 1)
38
return
NULL;
39
40
index_t
num_examples =
m_model
->
get_num_vectors
();
41
CLatentLabels
* hs =
new
CLatentLabels
(num_examples);
42
CBinaryLabels
* ys =
new
CBinaryLabels
(num_examples);
43
hs->
set_labels
(ys);
44
m_model
->
set_labels
(hs);
45
46
for
(
index_t
i = 0; i < num_examples; ++i)
47
{
48
/* find h for the example */
49
CData
* h =
m_model
->
infer_latent_variable
(
w
, i);
50
hs->
add_latent_label
(h);
51
}
52
53
/* compute the y labels */
54
CDotFeatures
* x =
m_model
->
get_psi_feature_vectors
();
55
x->
dense_dot_range
(ys->
get_labels
().
vector
, 0, num_examples, NULL,
w
.
vector
,
w
.
vlen
, 0.0);
56
57
return
hs;
58
}
59
60
float64_t
CLatentSVM::do_inner_loop
(
float64_t
cooling_eps)
61
{
62
CLabels
* ys =
m_model
->
get_labels
()->
get_labels
();
63
CDotFeatures
* feats = (
m_model
->
get_caching
() ?
64
m_model
->
get_cached_psi_features
() :
65
m_model
->
get_psi_feature_vectors
());
66
CSVMOcas
svm(
m_C
, feats, ys);
67
svm.
set_epsilon
(cooling_eps);
68
svm.
train
();
69
SG_UNREF
(ys);
70
SG_UNREF
(feats);
71
72
/* copy the resulting w */
73
SGVector<float64_t>
cur_w = svm.
get_w
();
74
memcpy(
w
.
vector
, cur_w.
vector
, cur_w.
vlen
*
sizeof
(
float64_t
));
75
76
return
svm.
compute_primal_objective
();
77
}
78
SHOGUN
机器学习工具包 - 项目文档