SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
multiclass
MCLDA.h
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2013 Kevin Hughes
8
* Copyright (C) 2013 Kevin Hughes
9
*
10
* Thanks to Fernando José Iglesias García (shogun)
11
* and Matthieu Perrot (scikit-learn)
12
*/
13
14
#ifndef _MCLDA_H__
15
#define _MCLDA_H__
16
17
#include <
shogun/lib/config.h
>
18
19
#ifdef HAVE_EIGEN3
20
21
#include <
shogun/features/DotFeatures.h
>
22
#include <
shogun/features/DenseFeatures.h
>
23
#include <
shogun/machine/NativeMulticlassMachine.h
>
24
#include <
shogun/lib/SGNDArray.h
>
25
26
namespace
shogun
27
{
28
29
//#define DEBUG_MCLDA
30
39
class
CMCLDA
:
public
CNativeMulticlassMachine
40
{
41
public
:
42
MACHINE_PROBLEM_TYPE
(
PT_MULTICLASS
)
43
44
49
CMCLDA
(
float64_t
tolerance = 1e-4,
bool
store_cov =
false
);
50
58
CMCLDA
(
CDenseFeatures<float64_t>
* traindat,
CLabels
* trainlab,
float64_t
tolerance = 1e-4,
bool
store_cov =
false
);
59
60
virtual
~CMCLDA
();
61
67
virtual
CMulticlassLabels
*
apply_multiclass
(
CFeatures
* data=NULL);
68
73
inline
void
set_tolerance
(
float64_t
tolerance) { m_tolerance = tolerance; }
74
79
inline
bool
get_tolerance
() {
return
m_tolerance; }
80
85
virtual
EMachineType
get_classifier_type
() {
return
CT_LDA
; }
// for now add to machine typers properly later
86
91
virtual
void
set_features
(
CDotFeatures
* feat)
92
{
93
if
(feat->
get_feature_class
() !=
C_DENSE
||
94
feat->
get_feature_type
() !=
F_DREAL
)
95
SG_ERROR
(
"MCLDA requires SIMPLE REAL valued features\n"
)
96
97
SG_REF
(feat);
98
SG_UNREF
(m_features);
99
m_features = feat;
100
}
101
106
virtual
CDotFeatures
*
get_features
() {
SG_REF
(m_features);
return
m_features; }
107
112
virtual
const
char
*
get_name
()
const
{
return
"MCLDA"
; }
113
120
inline
SGVector< float64_t >
get_mean
(int32_t c)
const
121
{
122
return
SGVector< float64_t >
(m_means.
get_column_vector
(c), m_dim,
false
);
123
}
124
129
inline
SGMatrix< float64_t >
get_cov
()
const
130
{
131
return
m_cov;
132
}
133
134
protected
:
141
virtual
bool
train_machine
(
CFeatures
* data = NULL);
142
143
private
:
144
void
init();
145
146
void
cleanup();
147
148
private
:
150
CDotFeatures
* m_features;
151
153
float64_t
m_tolerance;
154
156
bool
m_store_cov;
157
159
int32_t m_num_classes;
160
162
int32_t m_dim;
163
167
SGMatrix< float64_t >
m_cov;
168
170
SGMatrix< float64_t >
m_means;
171
173
SGVector< float64_t >
m_xbar;
174
176
int32_t m_rank;
177
179
SGMatrix< float64_t >
m_scalings;
180
182
SGMatrix< float64_t >
m_coef;
183
185
SGVector< float64_t >
m_intercept;
186
187
};
/* class MCLDA */
188
}
/* namespace shogun */
189
190
#endif
/* HAVE_EIGEN3 */
191
#endif
/* _MCLDA_H__ */
SHOGUN
机器学习工具包 - 项目文档