SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
mathematics
linalg
linsolver
ConjugateOrthogonalCGSolver.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2013 Soumyajit De
8
*/
9
10
#include <
shogun/lib/common.h
>
11
12
#ifdef HAVE_EIGEN3
13
14
#include <
shogun/lib/SGVector.h
>
15
#include <
shogun/lib/Time.h
>
16
#include <
shogun/mathematics/eigen3.h
>
17
#include <
shogun/mathematics/Math.h
>
18
#include <
shogun/mathematics/linalg/linop/LinearOperator.h
>
19
#include <
shogun/mathematics/linalg/linsolver/ConjugateOrthogonalCGSolver.h
>
20
#include <
shogun/mathematics/linalg/linsolver/IterativeSolverIterator.h
>
21
using namespace
Eigen;
22
23
namespace
shogun
24
{
25
26
CConjugateOrthogonalCGSolver::CConjugateOrthogonalCGSolver()
27
:
CIterativeLinearSolver
<
complex128_t
,
float64_t
>()
28
{
29
SG_GCDEBUG
(
"%s created (%p)\n"
, this->
get_name
(),
this
);
30
}
31
32
CConjugateOrthogonalCGSolver::CConjugateOrthogonalCGSolver
(
bool
store_residuals)
33
:
CIterativeLinearSolver
<
complex128_t
,
float64_t
>(store_residuals)
34
{
35
SG_GCDEBUG
(
"%s created (%p)\n"
, this->
get_name
(),
this
);
36
}
37
38
CConjugateOrthogonalCGSolver::~CConjugateOrthogonalCGSolver
()
39
{
40
SG_GCDEBUG
(
"%s destroyed (%p)\n"
, this->
get_name
(),
this
);
41
}
42
43
SGVector<complex128_t>
CConjugateOrthogonalCGSolver::solve
(
44
CLinearOperator<complex128_t>
* A,
SGVector<float64_t>
b)
45
{
46
SG_DEBUG
(
"CConjugateOrthogonalCGSolver::solve(): Entering..\n"
);
47
48
// sanity check
49
REQUIRE
(A,
"Operator is NULL!\n"
);
50
REQUIRE
(A->
get_dimension
()==b.
vlen
,
"Dimension mismatch!\n, %d vs %d"
,
51
A->
get_dimension
(), b.
vlen
);
52
53
// the final solution vector, initial guess is 0
54
SGVector<complex128_t>
result(b.
vlen
);
55
result.set_const(0.0);
56
57
// the rest of the part hinges on eigen3 for computing norms
58
Map<VectorXcd> x(result.vector, result.vlen);
59
Map<VectorXd> b_map(b.
vector
, b.
vlen
);
60
61
// direction vector
62
SGVector<complex128_t>
p_(result.vlen);
63
Map<VectorXcd> p(p_.
vector
, p_.
vlen
);
64
65
// residual r_i=b-Ax_i, here x_0=[0], so r_0=b
66
VectorXcd r=b_map.cast<
complex128_t
>();
67
68
// initial direction is same as residual
69
p=r;
70
71
// the iterator for this iterative solver
72
IterativeSolverIterator<complex128_t>
it(r,
m_max_iteration_limit
,
73
m_relative_tolerence
,
m_absolute_tolerence
);
74
75
// start the timer
76
CTime
time;
77
time.
start
();
78
79
// set the residuals to zero
80
if
(
m_store_residuals
)
81
m_residuals
.
set_const
(0.0);
82
83
// CG iteration begins
84
complex128_t
r_norm2=r.transpose()*r;
85
86
for
(it.
begin
(r); !it.
end
(r); ++it)
87
{
88
SG_DEBUG
(
"CG iteration %d, residual norm %f\n"
,
89
it.
get_iter_info
().iteration_count,
90
it.
get_iter_info
().residual_norm);
91
92
if
(
m_store_residuals
)
93
{
94
m_residuals
[it.
get_iter_info
().iteration_count]
95
=it.
get_iter_info
().residual_norm;
96
}
97
98
// apply linear operator to the direction vector
99
SGVector<complex128_t>
Ap_=A->
apply
(p_);
100
Map<VectorXcd> Ap(Ap_.
vector
, Ap_.
vlen
);
101
102
// compute p^{T}Ap, if zero, failure
103
complex128_t
p_T_times_Ap=p.transpose()*Ap;
104
if
(p_T_times_Ap==0.0)
105
break
;
106
107
// compute the alpha parameter of CG
108
complex128_t
alpha=r_norm2/p_T_times_Ap;
109
110
// update the solution vector and residual
111
// x_{i}=x_{i-1}+\alpha_{i}p
112
x+=alpha*p;
113
114
// r_{i}=r_{i-1}-\alpha_{i}p
115
r-=alpha*Ap;
116
117
// compute new ||r||_{2}, if zero, converged
118
complex128_t
r_norm2_i=r.transpose()*r;
119
if
(r_norm2_i==0.0)
120
break
;
121
122
// compute the beta parameter of CG
123
complex128_t
beta=r_norm2_i/r_norm2;
124
125
// update direction, and ||r||_{2}
126
r_norm2=r_norm2_i;
127
p=r+beta*p;
128
}
129
130
float64_t
elapsed=time.
cur_time_diff
();
131
132
if
(!it.
succeeded
(r))
133
SG_WARNING
(
"Did not converge!\n"
);
134
135
SG_INFO
(
"Iteration took %ld times, residual norm=%.20lf, time elapsed=%lf\n"
,
136
it.
get_iter_info
().iteration_count, it.
get_iter_info
().residual_norm, elapsed);
137
138
SG_DEBUG
(
"CConjugateOrthogonalCGSolver::solve(): Leaving..\n"
);
139
return
result;
140
}
141
142
}
143
#endif // HAVE_EIGEN3
SHOGUN
机器学习工具包 - 项目文档