adevs
/home/rotten/adevs-2.6/include/adevs_corrected_euler.h
00001 #ifndef _adevs_corrected_euler_h_
00002 #define _adevs_corrected_euler_h_
00003 #include <cmath>
00004 #include "adevs_hybrid.h"
00005 
00006 namespace adevs
00007 {
00008 
00013 template <typename X> class corrected_euler:
00014         public ode_solver<X>
00015 {
00016         public:
00021                 corrected_euler(ode_system<X>* sys, double err_tol, double h_max);
00023                 ~corrected_euler();
00024                 double integrate(double* q, double h_lim);
00025                 void advance(double* q, double h);
00026         private:
00027                 double *dq, // derivative
00028                            *qq, // trial solution
00029                            *t,  // temporary variable for computing k2
00030                            *k[2]; // k1 and k2
00031                 const double err_tol; // Error tolerance
00032                 const double h_max; // Maximum time step
00033                 double h_cur; // Previous time step that satisfied error constraint
00034                 // Compute a step of size h, put it in qq, and return the error
00035                 double trial_step(double h);
00036 };
00037 
00038 template <typename X>
00039 corrected_euler<X>::corrected_euler(ode_system<X>* sys, double err_tol,
00040                 double h_max):
00041         ode_solver<X>(sys),err_tol(err_tol),h_max(h_max),h_cur(h_max)
00042 {
00043         for (int i = 0; i < 2; i++) k[i] = new double[sys->numVars()];
00044         dq = new double[sys->numVars()];
00045         qq = new double[sys->numVars()];
00046         t = new double[sys->numVars()];
00047 }
00048 
00049 template <typename X>
00050 corrected_euler<X>::~corrected_euler()
00051 {
00052         delete [] t; delete [] qq; delete [] dq;
00053         for (int i = 0; i < 2; i++) delete [] k[i];
00054 }
00055 
00056 template <typename X>
00057 void corrected_euler<X>::advance(double* q, double h)
00058 {
00059         double dt;
00060         while ((dt = integrate(q,h)) < h) h -= dt;
00061 }
00062 
00063 template <typename X>
00064 double corrected_euler<X>::integrate(double* q, double h_lim)
00065 {
00066         // Initial error estimate and step size
00067         double err = DBL_MAX, h = std::min(h_cur*1.1,std::min(h_max,h_lim));
00068         for (;;) {
00069                 // Copy q to the trial vector
00070                 for (int i = 0; i < this->sys->numVars(); i++) qq[i] = q[i];
00071                 // Make the trial step which will be stored in qq
00072                 err = trial_step(h);
00073                 // If the error is ok, then we have found the proper step size
00074                 if (err <= err_tol) { // Keep h if shrunk to control the error
00075                         if (h_lim >= h_cur) h_cur = h; 
00076                         break;
00077                 }
00078                 // Otherwise shrink the step size and try again
00079                 else {
00080                         double h_guess = 0.8*err_tol*h/fabs(err);
00081                         if (h < h_guess) h *= 0.8;
00082                         else h = h_guess;
00083                 }
00084         }
00085         // Put the trial solution in q and return the selected step size
00086         for (int i = 0; i < this->sys->numVars(); i++) q[i] = qq[i];
00087         return h;
00088 }
00089 
00090 template <typename X>
00091 double corrected_euler<X>::trial_step(double step)
00092 {
00093         int j;
00094         // Compute k1
00095         this->sys->der_func(qq,dq); 
00096         for (j = 0; j < this->sys->numVars(); j++) k[0][j] = step*dq[j];
00097         // Compute k2
00098         for (j = 0; j < this->sys->numVars(); j++) t[j] = qq[j] + 0.5*k[0][j];
00099         this->sys->der_func(t,dq);
00100         for (j = 0; j < this->sys->numVars(); j++) k[1][j] = step*dq[j];
00101         // Compute next state and approximate error
00102         double err = 0.0;
00103         for (j = 0; j < this->sys->numVars(); j++) {
00104                 qq[j] += k[1][j]; // Next state
00105                 err = std::max(err,fabs(k[0][j]-k[1][j])); // Maximum error
00106         }
00107         return err; // Return the error
00108 }
00109 
00110 } // end of namespace
00111 #endif