ergo
mat_utils.h
Go to the documentation of this file.
00001 /* Ergo, version 3.2, a program for linear scaling electronic structure
00002  * calculations.
00003  * Copyright (C) 2012 Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek.
00004  * 
00005  * This program is free software: you can redistribute it and/or modify
00006  * it under the terms of the GNU General Public License as published by
00007  * the Free Software Foundation, either version 3 of the License, or
00008  * (at your option) any later version.
00009  * 
00010  * This program is distributed in the hope that it will be useful,
00011  * but WITHOUT ANY WARRANTY; without even the implied warranty of
00012  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00013  * GNU General Public License for more details.
00014  * 
00015  * You should have received a copy of the GNU General Public License
00016  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
00017  * 
00018  * Primary academic reference:
00019  * Kohn−Sham Density Functional Theory Electronic Structure Calculations 
00020  * with Linearly Scaling Computational Time and Memory Usage,
00021  * Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek,
00022  * J. Chem. Theory Comput. 7, 340 (2011),
00023  * <http://dx.doi.org/10.1021/ct100611z>
00024  * 
00025  * For further information about Ergo, see <http://www.ergoscf.org>.
00026  */
00027 #ifndef MAT_UTILS_HEADER
00028 #define MAT_UTILS_HEADER
00029 #include "Interval.h"
00030 #include "matrix_proxy.h"
00031 namespace mat {
00032 
00033   template<typename Tmatrix, typename Treal>
00034     struct DiffMatrix {
00035       typedef typename Tmatrix::VectorType VectorType;
00036       void getCols(SizesAndBlocks & colsCopy) const {
00037         A.getCols(colsCopy);
00038       }
00039       int get_nrows() const { 
00040         assert( A.get_nrows() == B.get_nrows() );
00041         return A.get_nrows(); 
00042       }
00043       Treal frob() const {
00044         return Tmatrix::frob_diff(A, B);
00045       }
00046       void quickEuclBounds(Treal & euclLowerBound, 
00047                            Treal & euclUpperBound) const {
00048         Treal frobTmp = frob();
00049         euclLowerBound = frobTmp  / template_blas_sqrt( (Treal)get_nrows() );
00050         euclUpperBound = frobTmp;
00051       }
00052 
00053       Tmatrix const & A;
00054       Tmatrix const & B;
00055       DiffMatrix(Tmatrix const & A_, Tmatrix const & B_)
00056       : A(A_), B(B_) {}
00057       template<typename Tvector>
00058       void matVecProd(Tvector & y, Tvector const & x) const {
00059         Tvector tmp(y);
00060         tmp = (Treal)-1.0 * B * x;   // -B * x
00061         y   = (Treal)1.0 * A * x;    // A * x
00062         y  += (Treal)1.0 * tmp;        // A * x - B * x  => (A - B) * x
00063       }
00064     };
00065 
00066 
00067   // ATAMatrix AT*A 
00068   template<typename Tmatrix, typename Treal>
00069     struct ATAMatrix {
00070       typedef typename Tmatrix::VectorType VectorType;
00071       Tmatrix const & A;
00072       explicit ATAMatrix(Tmatrix const & A_)
00073       : A(A_) {}
00074       void getCols(SizesAndBlocks & colsCopy) const {
00075         A.getRows(colsCopy);
00076       }
00077       void quickEuclBounds(Treal & euclLowerBound, 
00078                            Treal & euclUpperBound) const {
00079         Treal frobA = A.frob();
00080         euclLowerBound = 0;
00081         euclUpperBound = frobA * frobA;
00082       }
00083       
00084       // y = AT*A*x
00085       template<typename Tvector>
00086       void matVecProd(Tvector & y, Tvector const & x) const {
00087         y = x;
00088         y = A * y;
00089         y = transpose(A) * y;
00090       }
00091       // Number of rows of A^T * A is the number of columns of A 
00092       int get_nrows() const { return A.get_ncols(); }       
00093     };
00094 
00095 
00096   template<typename Tmatrix, typename Tmatrix2, typename Treal>
00097     struct TripleMatrix {
00098       typedef typename Tmatrix::VectorType VectorType;
00099       void getCols(SizesAndBlocks & colsCopy) const {
00100         A.getCols(colsCopy);
00101       }
00102       int get_nrows() const { 
00103         assert( A.get_nrows() == Z.get_nrows() );
00104         return A.get_nrows(); 
00105       }
00106       void quickEuclBounds(Treal & euclLowerBound, 
00107                            Treal & euclUpperBound) const {
00108         Treal frobA = A.frob();
00109         Treal frobZ = Z.frob();
00110         euclLowerBound = 0;
00111         euclUpperBound = frobA * frobZ * frobZ;
00112       }
00113       
00114       Tmatrix  const & A;
00115       Tmatrix2 const & Z;
00116       TripleMatrix(Tmatrix const & A_, Tmatrix2 const & Z_)
00117       : A(A_), Z(Z_) {}
00118       void matVecProd(VectorType & y, VectorType const & x) const {
00119         VectorType tmp(x);
00120         tmp = Z * tmp;            // Z * x
00121         y = (Treal)1.0 * A * tmp; // A * Z * x
00122         y = transpose(Z) * y;     // Z^T * A * Z * x
00123       }
00124     };
00125 
00126 
00127   template<typename Tmatrix, typename Tmatrix2, typename Treal>
00128     struct CongrTransErrorMatrix {
00129       typedef typename Tmatrix::VectorType VectorType;
00130       void getCols(SizesAndBlocks & colsCopy) const {
00131         E.getRows(colsCopy);
00132       }
00133       int get_nrows() const { 
00134         return E.get_ncols(); 
00135       }
00136       void quickEuclBounds(Treal & euclLowerBound, 
00137                            Treal & euclUpperBound) const {
00138         Treal frobA = A.frob();
00139         Treal frobZ = Zt.frob();
00140         Treal frobE = E.frob();
00141         euclLowerBound = 0;
00142         euclUpperBound = frobA * frobE * frobE + 2 * frobA * frobE * frobZ;
00143       }
00144       
00145       Tmatrix  const & A;
00146       Tmatrix2 const & Zt;
00147       Tmatrix2 const & E;
00148       
00149       CongrTransErrorMatrix(Tmatrix const & A_, 
00150                             Tmatrix2 const & Z_,
00151                             Tmatrix2 const & E_)
00152       : A(A_), Zt(Z_), E(E_) {}
00153       void matVecProd(VectorType & y, VectorType const & x) const {
00154         
00155         VectorType tmp(x);
00156         tmp = E * tmp;               // E * x
00157         y = (Treal)-1.0 * A * tmp;   // -A * E * x
00158         y = transpose(E) * y;        // -E^T * A * E * x
00159         
00160         VectorType tmp1;
00161         tmp = x;
00162         tmp = Zt * tmp;              // Zt * x
00163         tmp1 = (Treal)1.0 * A * tmp; // A * Zt * x
00164         tmp1 = transpose(E) * tmp1;  // E^T * A * Zt * x
00165         y += (Treal)1.0 * tmp1;
00166 
00167         tmp = x;
00168         tmp = E * tmp;               // E * x
00169         tmp1 = (Treal)1.0 * A * tmp; // A * E * x
00170         tmp1 = transpose(Zt) * tmp1; // Zt^T * A * E * x
00171         y += (Treal)1.0 * tmp1; 
00172       }
00173     };
00174 
00175 
00176 
00177 }  /* end namespace mat */
00178 #endif