SHOGUN
3.2.1
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
machine
gp
InferenceMethod.h
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2013 Roman Votyakov
8
* Written (W) 2013-2014 Heiko Strathmann
9
* Copyright (C) 2012 Jacob Walker
10
* Copyright (C) 2013 Roman Votyakov
11
*/
12
13
#ifndef CINFERENCEMETHOD_H_
14
#define CINFERENCEMETHOD_H_
15
16
#include <
shogun/lib/config.h
>
17
18
#ifdef HAVE_EIGEN3
19
20
#include <
shogun/base/SGObject.h
>
21
#include <
shogun/kernel/Kernel.h
>
22
#include <
shogun/features/Features.h
>
23
#include <
shogun/labels/Labels.h
>
24
#include <
shogun/machine/gp/LikelihoodModel.h
>
25
#include <
shogun/machine/gp/MeanFunction.h
>
26
#include <
shogun/evaluation/DifferentiableFunction.h
>
27
28
namespace
shogun
29
{
30
32
enum
EInferenceType
33
{
34
INF_NONE
=0,
35
INF_EXACT
=10,
36
INF_FITC
=20,
37
INF_LAPLACIAN
=30,
38
INF_EP
=40
39
};
40
50
class
CInferenceMethod
:
public
CDifferentiableFunction
51
{
52
public
:
54
CInferenceMethod
();
55
64
CInferenceMethod
(
CKernel
* kernel,
CFeatures
* features,
65
CMeanFunction
* mean,
CLabels
* labels,
CLikelihoodModel
* model);
66
67
virtual
~CInferenceMethod
();
68
73
virtual
EInferenceType
get_inference_type
()
const
{
return
INF_NONE
; }
74
86
virtual
float64_t
get_negative_log_marginal_likelihood
()=0;
87
123
float64_t
get_marginal_likelihood_estimate
(int32_t num_importance_samples=1,
124
float64_t
ridge_size=1e-15);
125
139
virtual
CMap<TParameter*, SGVector<float64_t>
>*
140
get_negative_log_marginal_likelihood_derivatives
(
CMap
<
TParameter
*,
141
CSGObject
*>* parameters);
142
153
virtual
SGVector<float64_t>
get_alpha
()=0;
154
166
virtual
SGMatrix<float64_t>
get_cholesky
()=0;
167
179
virtual
SGVector<float64_t>
get_diagonal_vector
()=0;
180
196
virtual
SGVector<float64_t>
get_posterior_mean
()=0;
197
213
virtual
SGMatrix<float64_t>
get_posterior_covariance
()=0;
214
222
virtual
CMap<TParameter*, SGVector<float64_t>
>*
get_gradient
(
223
CMap<TParameter*, CSGObject*>
* parameters)
224
{
225
return
get_negative_log_marginal_likelihood_derivatives
(parameters);
226
}
227
232
virtual
SGVector<float64_t>
get_value
()
233
{
234
SGVector<float64_t>
result(1);
235
result[0]=
get_negative_log_marginal_likelihood
();
236
return
result;
237
}
238
243
virtual
CFeatures
*
get_features
() {
SG_REF
(
m_features
);
return
m_features
; }
244
249
virtual
void
set_features
(
CFeatures
* feat)
250
{
251
SG_REF
(feat);
252
SG_UNREF
(
m_features
);
253
m_features
=feat;
254
}
255
260
virtual
CKernel
*
get_kernel
() {
SG_REF
(
m_kernel
);
return
m_kernel
; }
261
266
virtual
void
set_kernel
(
CKernel
* kern)
267
{
268
SG_REF
(kern);
269
SG_UNREF
(
m_kernel
);
270
m_kernel
=kern;
271
}
272
277
virtual
CMeanFunction
*
get_mean
() {
SG_REF
(
m_mean
);
return
m_mean
; }
278
283
virtual
void
set_mean
(
CMeanFunction
* m)
284
{
285
SG_REF
(m);
286
SG_UNREF
(
m_mean
);
287
m_mean
=m;
288
}
289
294
virtual
CLabels
*
get_labels
() {
SG_REF
(
m_labels
);
return
m_labels
; }
295
300
virtual
void
set_labels
(
CLabels
* lab)
301
{
302
SG_REF
(lab);
303
SG_UNREF
(
m_labels
);
304
m_labels
=lab;
305
}
306
311
CLikelihoodModel
*
get_model
() {
SG_REF
(
m_model
);
return
m_model
; }
312
317
virtual
void
set_model
(
CLikelihoodModel
* mod)
318
{
319
SG_REF
(mod);
320
SG_UNREF
(
m_model
);
321
m_model
=mod;
322
}
323
328
virtual
float64_t
get_scale
()
const
{
return
m_scale
; }
329
334
virtual
void
set_scale
(
float64_t
scale) {
m_scale
=scale; }
335
341
virtual
bool
supports_regression
()
const
{
return
false
; }
342
348
virtual
bool
supports_binary
()
const
{
return
false
; }
349
355
virtual
bool
supports_multiclass
()
const
{
return
false
; }
356
358
virtual
void
update
();
359
360
protected
:
362
virtual
void
check_members
()
const
;
363
365
virtual
void
update_alpha
()=0;
366
368
virtual
void
update_chol
()=0;
369
373
virtual
void
update_deriv
()=0;
374
376
virtual
void
update_train_kernel
();
377
385
virtual
SGVector<float64_t>
get_derivative_wrt_inference_method
(
386
const
TParameter
* param)=0;
387
395
virtual
SGVector<float64_t>
get_derivative_wrt_likelihood_model
(
396
const
TParameter
* param)=0;
397
405
virtual
SGVector<float64_t>
get_derivative_wrt_kernel
(
406
const
TParameter
* param)=0;
407
415
virtual
SGVector<float64_t>
get_derivative_wrt_mean
(
416
const
TParameter
* param)=0;
417
421
static
void
*
get_derivative_helper
(
void
* p);
422
423
private
:
424
void
init();
425
426
protected
:
428
CKernel
*
m_kernel
;
429
431
CMeanFunction
*
m_mean
;
432
434
CLikelihoodModel
*
m_model
;
435
437
CFeatures
*
m_features
;
438
440
CLabels
*
m_labels
;
441
443
SGVector<float64_t>
m_alpha
;
444
446
SGMatrix<float64_t>
m_L
;
447
449
float64_t
m_scale
;
450
452
SGMatrix<float64_t>
m_ktrtr
;
453
};
454
}
455
#endif
/* HAVE_EIGEN3 */
456
#endif
/* CINFERENCEMETHOD_H_ */
SHOGUN
Machine Learning Toolbox - Documentation