Stokhos Package Browser (Single Doxygen Collection) Version of the Day
Loading...
Searching...
No Matches
block_lu.h
Go to the documentation of this file.
1/*
2 * Copyright 2008-2009 NVIDIA Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#pragma once
18
19#include <cusp/array1d.h>
20#include <cusp/array2d.h>
21#include <cusp/linear_operator.h>
22
23#include <cmath>
24
25namespace cusp
26{
27namespace detail
28{
29
30template <typename IndexType, typename ValueType, typename MemorySpace, typename OrientationA>
31int block_lu_factor(cusp::array2d<ValueType,MemorySpace,OrientationA>& A,
32 cusp::array1d<IndexType,MemorySpace>& pivot)
33{
34 const int n = A.num_rows;
35
36 // For each row and column, k = 0, ..., n-1,
37 for (int k = 0; k < n; k++)
38 {
39 // find the pivot row
40 pivot[k] = k;
41 ValueType max = std::fabs(A(k,k));
42
43 for (int j = k + 1; j < n; j++)
44 {
45 if (max < std::fabs(A(j,k)))
46 {
47 max = std::fabs(A(j,k));
48 pivot[k] = j;
49 }
50 }
51
52 // and if the pivot row differs from the current row, then
53 // interchange the two rows.
54 if (pivot[k] != k)
55 for (int j = 0; j < n; j++)
56 std::swap(A(k,j), A(pivot[k],j));
57
58 // and if the matrix is singular, return error
59 if (A(k,k) == 0.0)
60 return -1;
61
62 // otherwise find the lower triangular matrix elements for column k.
63 for (int i = k + 1; i < n; i++)
64 A(i,k) /= A(k,k);
65
66 // update remaining matrix
67 for (int i = k + 1; i < n; i++)
68 for (int j = k + 1; j < n; j++)
69 A(i,j) -= A(i,k) * A(k,j);
70 }
71 return 0;
72}
73
74
75//LU solve for multiple right hand sides
76template <typename IndexType, typename ValueType, typename MemorySpace,
77 typename OrientationA, typename OrientationB>
78int block_lu_solve(const cusp::array2d<ValueType,MemorySpace,OrientationA>& A,
79 const cusp::array1d<IndexType,MemorySpace>& pivot,
80 const cusp::array2d<ValueType,MemorySpace,OrientationB>& b,
81 cusp::array2d<ValueType,MemorySpace,OrientationB>& x,
82 cusp::array2d_format)
83{
84 const int n = A.num_rows;
85 const int numRHS = b.num_cols;
86 // copy rhs to x
87 x = b;
88 // Solve the linear equation Lx = b for x, where L is a lower triangular matrix
89
90 for (int k = 0; k < n; k++)
91 {
92 if (pivot[k] != k){//swap row k of x with row pivot[k]
93 for (int j = 0; j < numRHS; j++)
94 std::swap(x(k,j),x(pivot[k],j));
95 }
96
97 for (int i = 0; i < k; i++){
98 for (int j = 0; j< numRHS; j++)
99 x(k,j) -= A(k,i) * x(i,j);
100 }
101 }
102
103 // Solve the linear equation Ux = y, where y is the solution
104 // obtained above of Lx = b and U is an upper triangular matrix.
105 for (int k = n - 1; k >= 0; k--)
106 {
107 for (int j = 0; j< numRHS; j++){
108 for (int i = k + 1; i < n; i++){
109 x(k, j) -= A(k,i) * x(i, j);
110 }
111 if (A(k,k) == 0)
112 return -1;
113 x(k,j) /= A(k,k);
114 }
115
116 }
117 return 0;
118}
119
120
121
122
123template <typename ValueType, typename MemorySpace>
124class block_lu_solver : public cusp::linear_operator<ValueType,MemorySpace>
125{
126 cusp::array2d<ValueType,cusp::host_memory> lu;
127 cusp::array1d<int,cusp::host_memory> pivot;
128
129public:
130 block_lu_solver() : linear_operator<ValueType,MemorySpace>() {}
131
132 template <typename MatrixType>
133 block_lu_solver(const MatrixType& A) :
134 linear_operator<ValueType,MemorySpace>(A.num_rows, A.num_cols, A.num_entries)
135 {
136 CUSP_PROFILE_SCOPED();
137
138 lu = A;
139 pivot.resize(A.num_rows);
141 }
142
143 template <typename VectorType1, typename VectorType2>
144 void operator()(const VectorType1& b, VectorType2& x) const
145 {
146 CUSP_PROFILE_SCOPED();
147 block_lu_solve(lu, pivot, b, x, typename VectorType2::format());
148 }
149};
150
151} // end namespace detail
152} // end namespace cusp
block_lu_solver(const MatrixType &A)
Definition block_lu.h:133
void operator()(const VectorType1 &b, VectorType2 &x) const
Definition block_lu.h:144
cusp::array1d< int, cusp::host_memory > pivot
Definition block_lu.h:127
cusp::array2d< ValueType, cusp::host_memory > lu
Definition block_lu.h:126
int block_lu_factor(cusp::array2d< ValueType, MemorySpace, OrientationA > &A, cusp::array1d< IndexType, MemorySpace > &pivot)
Definition block_lu.h:31
int block_lu_solve(const cusp::array2d< ValueType, MemorySpace, OrientationA > &A, const cusp::array1d< IndexType, MemorySpace > &pivot, const cusp::array2d< ValueType, MemorySpace, OrientationB > &b, cusp::array2d< ValueType, MemorySpace, OrientationB > &x, cusp::array2d_format)
Definition block_lu.h:78