Sacado Package Browser (Single Doxygen Collection) Version of the Day
Loading...
Searching...
No Matches
vector_blas_example.cpp
Go to the documentation of this file.
1// $Id$
2// $Source$
3// @HEADER
4// ***********************************************************************
5//
6// Sacado Package
7// Copyright (2006) Sandia Corporation
8//
9// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
10// the U.S. Government retains certain rights in this software.
11//
12// This library is free software; you can redistribute it and/or modify
13// it under the terms of the GNU Lesser General Public License as
14// published by the Free Software Foundation; either version 2.1 of the
15// License, or (at your option) any later version.
16//
17// This library is distributed in the hope that it will be useful, but
18// WITHOUT ANY WARRANTY; without even the implied warranty of
19// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20// Lesser General Public License for more details.
21//
22// You should have received a copy of the GNU Lesser General Public
23// License along with this library; if not, write to the Free Software
24// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
25// USA
26// Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
27// (etphipp@sandia.gov).
28//
29// ***********************************************************************
30// @HEADER
31
32// vector_blas_example
33//
34// usage:
35// vector_blas_example
36//
37// output:
38// prints the results of differentiating a BLAS routine with forward
39// mode AD using the Sacado::Fad::DVFad class (uses dynamic memory
40// allocation for number of derivative components stored in a contiguous
41// array).
42
43#include <iostream>
44#include <iomanip>
45
46#include "Sacado_No_Kokkos.hpp"
47#include "Teuchos_BLAS.hpp"
48#include "Sacado_Fad_BLAS.hpp"
49
51
52int main(int argc, char **argv)
53{
54 const unsigned int n = 5;
56 for (unsigned int i=0; i<n; i++) {
57 for (unsigned int j=0; j<n; j++)
58 A[i+j*n] = FadType(Teuchos::ScalarTraits<double>::random());
59 B[i] = FadType(n, Teuchos::ScalarTraits<double>::random());
60 for (unsigned int j=0; j<n; j++)
61 B[i].fastAccessDx(j) = Teuchos::ScalarTraits<double>::random();
62 C[i] = 0.0;
63 }
64
65 double *a = A.vals();
66 double *b = B.vals();
67 double *bdx = B.dx();
68 std::vector<double> c(n), cdx(n*n);
69
70 Teuchos::BLAS<int,double> blas;
71 blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &a[0], n, &b[0], 1, 0.0, &c[0], 1);
72 blas.GEMM(Teuchos::NO_TRANS, Teuchos::NO_TRANS, n, n, n, 1.0, &a[0], n, &bdx[0], n, 0.0, &cdx[0], n);
73
74 // Teuchos::BLAS<int,FadType> blas_fad;
75 // blas_fad.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &A[0], n, &B[0], 1, 0.0, &C[0], 1);
76
77 Teuchos::BLAS<int,FadType> sacado_fad_blas(false);
78 sacado_fad_blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &A[0], n, &B[0], 1, 0.0, &C[0], 1);
79
80 // Print the results
81 int p = 4;
82 int w = p+7;
83 std::cout.setf(std::ios::scientific);
84 std::cout.precision(p);
85
86 std::cout << "BLAS GEMV calculation:" << std::endl;
87 std::cout << "a = " << std::endl;
88 for (unsigned int i=0; i<n; i++) {
89 for (unsigned int j=0; j<n; j++)
90 std::cout << " " << std::setw(w) << a[i+j*n];
91 std::cout << std::endl;
92 }
93 std::cout << "b = " << std::endl;
94 for (unsigned int i=0; i<n; i++) {
95 std::cout << " " << std::setw(w) << b[i];
96 }
97 std::cout << std::endl;
98 std::cout << "bdot = " << std::endl;
99 for (unsigned int i=0; i<n; i++) {
100 for (unsigned int j=0; j<n; j++)
101 std::cout << " " << std::setw(w) << bdx[i+j*n];
102 std::cout << std::endl;
103 }
104 std::cout << "c = " << std::endl;
105 for (unsigned int i=0; i<n; i++) {
106 std::cout << " " << std::setw(w) << c[i];
107 }
108 std::cout << std::endl;
109 std::cout << "cdot = " << std::endl;
110 for (unsigned int i=0; i<n; i++) {
111 for (unsigned int j=0; j<n; j++)
112 std::cout << " " << std::setw(w) << cdx[i+j*n];
113 std::cout << std::endl;
114 }
115 std::cout << std::endl << std::endl;
116
117 std::cout << "FAD BLAS GEMV calculation:" << std::endl;
118 std::cout << "A.val() (should = a) = " << std::endl;
119 for (unsigned int i=0; i<n; i++) {
120 for (unsigned int j=0; j<n; j++)
121 std::cout << " " << std::setw(w) << A[i+j*n].val();
122 std::cout << std::endl;
123 }
124 std::cout << "B.val() (should = b) = " << std::endl;
125 for (unsigned int i=0; i<n; i++) {
126 std::cout << " " << std::setw(w) << B[i].val();
127 }
128 std::cout << std::endl;
129 std::cout << "B.dx() (should = bdot) = " << std::endl;
130 double *Bdx = B.dx();
131 for (unsigned int i=0; i<n; i++) {
132 for (unsigned int j=0; j<n; j++)
133 std::cout << " " << std::setw(w) << Bdx[i+j*n];
134 std::cout << std::endl;
135 }
136 std::cout << "C.val() (should = c) = " << std::endl;
137 for (unsigned int i=0; i<n; i++) {
138 std::cout << " " << std::setw(w) << C[i].val();
139 }
140 std::cout << std::endl;
141 std::cout << "C.dx() (should = cdot) = " << std::endl;
142 double *Cdx = C.dx();
143 for (unsigned int i=0; i<n; i++) {
144 for (unsigned int j=0; j<n; j++)
145 std::cout << " " << std::setw(w) << Cdx[i+j*n];
146 std::cout << std::endl;
147 }
148
149 double tol = 1.0e-14;
150 bool failed = false;
151 for (unsigned int i=0; i<n; i++) {
152 if (std::fabs(C[i].val() - c[i]) > tol)
153 failed = true;
154 for (unsigned int j=0; j<n; j++) {
155 if (std::fabs(C[i].dx(j) - cdx[i+j*n]) > tol)
156 failed = true;
157 }
158 }
159 if (!failed) {
160 std::cout << "\nExample passed!" << std::endl;
161 return 0;
162 }
163 else {
164 std::cout <<"\nSomething is wrong, example failed!" << std::endl;
165 return 1;
166 }
167}
expr expr dx(i)
expr expr expr fastAccessDx(i)) FAD_UNARYOP_MACRO(exp
expr val()
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
#define A
#define C(x)
int main()
Fad specializations for Teuchos::BLAS wrappers.
const char * p
const double tol
Sacado::Fad::DVFad< double > FadType