SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
mathematics
linalg
linsolver
ConjugateGradientSolver.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2013 Soumyajit De
8
*/
9
10
#include <
shogun/lib/common.h
>
11
12
#ifdef HAVE_EIGEN3
13
14
#include <
shogun/lib/SGVector.h
>
15
#include <
shogun/lib/Time.h
>
16
#include <
shogun/mathematics/eigen3.h
>
17
#include <
shogun/mathematics/linalg/linop/LinearOperator.h
>
18
#include <
shogun/mathematics/linalg/linsolver/ConjugateGradientSolver.h
>
19
#include <
shogun/mathematics/linalg/linsolver/IterativeSolverIterator.h
>
20
21
using namespace
Eigen;
22
23
namespace
shogun
24
{
25
26
CConjugateGradientSolver::CConjugateGradientSolver()
27
:
CIterativeLinearSolver
<
float64_t
>()
28
{
29
SG_GCDEBUG
(
"%s created (%p)\n"
, this->
get_name
(),
this
);
30
}
31
32
CConjugateGradientSolver::CConjugateGradientSolver
(
bool
store_residuals)
33
:
CIterativeLinearSolver
<
float64_t
>(store_residuals)
34
{
35
SG_GCDEBUG
(
"%s created (%p)\n"
, this->
get_name
(),
this
);
36
}
37
38
CConjugateGradientSolver::~CConjugateGradientSolver
()
39
{
40
SG_GCDEBUG
(
"%s destroyed (%p)\n"
, this->
get_name
(),
this
);
41
}
42
43
SGVector<float64_t>
CConjugateGradientSolver::solve
(
44
CLinearOperator<float64_t>
* A,
SGVector<float64_t>
b)
45
{
46
SG_DEBUG
(
"CConjugateGradientSolve::solve(): Entering..\n"
);
47
48
// sanity check
49
REQUIRE
(A,
"Operator is NULL!\n"
);
50
REQUIRE
(A->
get_dimension
()==b.
vlen
,
"Dimension mismatch!\n"
);
51
52
// the final solution vector, initial guess is 0
53
SGVector<float64_t>
result(b.
vlen
);
54
result.set_const(0.0);
55
56
// the rest of the part hinges on eigen3 for computing norms
57
Map<VectorXd> x(result.vector, result.vlen);
58
Map<VectorXd> b_map(b.
vector
, b.
vlen
);
59
60
// direction vector
61
SGVector<float64_t>
p_(result.vlen);
62
Map<VectorXd> p(p_.
vector
, p_.
vlen
);
63
64
// residual r_i=b-Ax_i, here x_0=[0], so r_0=b
65
VectorXd r=b_map;
66
67
// initial direction is same as residual
68
p=r;
69
70
// the iterator for this iterative solver
71
IterativeSolverIterator<float64_t>
it(b_map,
m_max_iteration_limit
,
72
m_relative_tolerence
,
m_absolute_tolerence
);
73
74
// CG iteration begins
75
float64_t
r_norm2=r.dot(r);
76
77
// start the timer
78
CTime
time;
79
time.
start
();
80
81
// set the residuals to zero
82
if
(
m_store_residuals
)
83
m_residuals
.
set_const
(0.0);
84
85
for
(it.
begin
(r); !it.
end
(r); ++it)
86
{
87
SG_DEBUG
(
"CG iteration %d, residual norm %f\n"
,
88
it.
get_iter_info
().iteration_count,
89
it.
get_iter_info
().residual_norm);
90
91
if
(
m_store_residuals
)
92
{
93
m_residuals
[it.
get_iter_info
().iteration_count]
94
=it.
get_iter_info
().residual_norm;
95
}
96
97
// apply linear operator to the direction vector
98
SGVector<float64_t>
Ap_=A->
apply
(p_);
99
Map<VectorXd> Ap(Ap_.
vector
, Ap_.
vlen
);
100
101
// compute p^{T}Ap, if zero, failure
102
float64_t
p_dot_Ap=p.dot(Ap);
103
if
(p_dot_Ap==0.0)
104
break
;
105
106
// compute the alpha parameter of CG
107
float64_t
alpha=r_norm2/p_dot_Ap;
108
109
// update the solution vector and residual
110
// x_{i}=x_{i-1}+\alpha_{i}p
111
x+=alpha*p;
112
113
// r_{i}=r_{i-1}-\alpha_{i}p
114
r-=alpha*Ap;
115
116
// compute new ||r||_{2}, if zero, converged
117
float64_t
r_norm2_i=r.dot(r);
118
if
(r_norm2_i==0.0)
119
break
;
120
121
// compute the beta parameter of CG
122
float64_t
beta=r_norm2_i/r_norm2;
123
124
// update direction, and ||r||_{2}
125
r_norm2=r_norm2_i;
126
p=r+beta*p;
127
}
128
129
float64_t
elapsed=time.
cur_time_diff
();
130
131
if
(!it.
succeeded
(r))
132
SG_WARNING
(
"Did not converge!\n"
);
133
134
SG_INFO
(
"Iteration took %ld times, residual norm=%.20lf, time elapsed=%lf\n"
,
135
it.
get_iter_info
().iteration_count, it.
get_iter_info
().residual_norm, elapsed);
136
137
SG_DEBUG
(
"CConjugateGradientSolve::solve(): Leaving..\n"
);
138
return
result;
139
}
140
141
}
142
#endif // HAVE_EIGEN3
SHOGUN
机器学习工具包 - 项目文档