SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
machine
gp
EPInferenceMethod.h
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2013 Roman Votyakov
8
*
9
* Based on ideas from GAUSSIAN PROCESS REGRESSION AND CLASSIFICATION Toolbox
10
* Copyright (C) 2005-2013 by Carl Edward Rasmussen & Hannes Nickisch under the
11
* FreeBSD License
12
* http://www.gaussianprocess.org/gpml/code/matlab/doc/
13
*/
14
15
#ifndef _EPINFERENCEMETHOD_H_
16
#define _EPINFERENCEMETHOD_H_
17
18
#include <
shogun/lib/config.h
>
19
20
#ifdef HAVE_EIGEN3
21
22
#include <
shogun/machine/gp/InferenceMethod.h
>
23
24
namespace
shogun
25
{
26
34
class
CEPInferenceMethod
:
public
CInferenceMethod
35
{
36
public
:
38
CEPInferenceMethod
();
39
48
CEPInferenceMethod
(
CKernel
* kernel,
CFeatures
* features,
CMeanFunction
* mean,
49
CLabels
* labels,
CLikelihoodModel
* model);
50
51
virtual
~CEPInferenceMethod
();
52
57
virtual
EInferenceType
get_inference_type
()
const
{
return
INF_EP
; }
58
63
virtual
const
char
*
get_name
()
const
{
return
"EPInferenceMethod"
; }
64
76
virtual
float64_t
get_negative_log_marginal_likelihood
();
77
100
virtual
SGVector<float64_t>
get_alpha
();
101
116
virtual
SGMatrix<float64_t>
get_cholesky
();
117
129
virtual
SGVector<float64_t>
get_diagonal_vector
();
130
151
virtual
SGVector<float64_t>
get_posterior_mean
();
152
172
virtual
SGMatrix<float64_t>
get_posterior_covariance
();
173
178
virtual
float64_t
get_tolerance
()
const
{
return
m_tol; }
179
184
virtual
void
set_tolerance
(
const
float64_t
tol) { m_tol=tol; }
185
190
virtual
uint32_t
get_min_sweep
()
const
{
return
m_min_sweep; }
191
196
virtual
void
set_min_sweep
(
const
uint32_t min_sweep) { m_min_sweep=min_sweep; }
197
202
virtual
uint32_t
get_max_sweep
()
const
{
return
m_max_sweep; }
203
208
virtual
void
set_max_sweep
(
const
uint32_t max_sweep) { m_max_sweep=max_sweep; }
209
214
virtual
bool
supports_binary
()
const
215
{
216
check_members
();
217
return
m_model
->
supports_binary
();
218
}
219
221
virtual
void
update
();
222
223
protected
:
225
virtual
void
update_alpha
();
226
228
virtual
void
update_chol
();
229
231
virtual
void
update_approx_cov
();
232
234
virtual
void
update_approx_mean
();
235
237
virtual
void
update_negative_ml
();
238
242
virtual
void
update_deriv
();
243
251
virtual
SGVector<float64_t>
get_derivative_wrt_inference_method
(
252
const
TParameter
* param);
253
261
virtual
SGVector<float64_t>
get_derivative_wrt_likelihood_model
(
262
const
TParameter
* param);
263
271
virtual
SGVector<float64_t>
get_derivative_wrt_kernel
(
272
const
TParameter
* param);
273
281
virtual
SGVector<float64_t>
get_derivative_wrt_mean
(
282
const
TParameter
* param);
283
284
private
:
285
void
init();
286
287
private
:
289
SGVector<float64_t>
m_mu;
290
292
SGMatrix<float64_t>
m_Sigma;
293
295
float64_t
m_nlZ;
296
300
SGVector<float64_t>
m_tnu;
301
305
SGVector<float64_t>
m_ttau;
306
308
SGVector<float64_t>
m_sttau;
309
311
float64_t
m_tol;
312
314
uint32_t m_min_sweep;
315
317
uint32_t m_max_sweep;
318
319
SGMatrix<float64_t>
m_F;
320
};
321
}
322
#endif
/* HAVE_EIGEN3 */
323
#endif
/* _EPINFERENCEMETHOD_H_ */
SHOGUN
机器学习工具包 - 项目文档