Rythmos - Transient Integration for Differential Equations Version of the Day
Loading...
Searching...
No Matches
Rythmos_AdjointModelEvaluator.hpp
1//@HEADER
2// ***********************************************************************
3//
4// Rythmos Package
5// Copyright (2006) Sandia Corporation
6//
7// Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
8// license for use of this work by or on behalf of the U.S. Government.
9//
10// This library is free software; you can redistribute it and/or modify
11// it under the terms of the GNU Lesser General Public License as
12// published by the Free Software Foundation; either version 2.1 of the
13// License, or (at your option) any later version.
14//
15// This library is distributed in the hope that it will be useful, but
16// WITHOUT ANY WARRANTY; without even the implied warranty of
17// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18// Lesser General Public License for more details.
19//
20// You should have received a copy of the GNU Lesser General Public
21// License along with this library; if not, write to the Free Software
22// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
23// USA
24// Questions? Contact Todd S. Coffey (tscoffe@sandia.gov)
25//
26// ***********************************************************************
27//@HEADER
28
29#ifndef RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
30#define RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
31
32
33#include "Rythmos_IntegratorBase.hpp"
34#include "Thyra_ModelEvaluator.hpp" // Interface
35#include "Thyra_StateFuncModelEvaluatorBase.hpp" // Implementation
36#include "Thyra_ModelEvaluatorDelegatorBase.hpp"
37#include "Thyra_DefaultScaledAdjointLinearOp.hpp"
38#include "Thyra_DefaultAdjointLinearOpWithSolve.hpp"
39#include "Thyra_VectorStdOps.hpp"
40#include "Thyra_MultiVectorStdOps.hpp"
41#include "Teuchos_implicit_cast.hpp"
42#include "Teuchos_Assert.hpp"
43
44
45namespace Rythmos {
46
47
172template<class Scalar>
174 : virtual public Thyra::StateFuncModelEvaluatorBase<Scalar>
175{
176public:
177
180
183
185 void setFwdStateModel(
186 const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel,
187 const Thyra::ModelEvaluatorBase::InArgs<Scalar> &basePoint );
188
192 void setFwdTimeRange( const TimeRange<Scalar> &fwdTimeRange );
193
207 const RCP<const InterpolationBufferBase<Scalar> > &fwdStateSolutionBuffer );
208
210
213
215 RCP<const Thyra::VectorSpaceBase<Scalar> > get_x_space() const;
217 RCP<const Thyra::VectorSpaceBase<Scalar> > get_f_space() const;
219 Thyra::ModelEvaluatorBase::InArgs<Scalar> getNominalValues() const;
221 RCP<Thyra::LinearOpWithSolveBase<Scalar> > create_W() const;
223 RCP<Thyra::LinearOpBase<Scalar> > create_W_op() const;
225 Thyra::ModelEvaluatorBase::InArgs<Scalar> createInArgs() const;
226
228
229private:
230
233
235 Thyra::ModelEvaluatorBase::OutArgs<Scalar> createOutArgsImpl() const;
237 void evalModelImpl(
238 const Thyra::ModelEvaluatorBase::InArgs<Scalar> &inArgs_bar,
239 const Thyra::ModelEvaluatorBase::OutArgs<Scalar> &outArgs_bar
240 ) const;
241
243
244private:
245
246 // /////////////////////////
247 // Private data members
248
249 RCP<const Thyra::ModelEvaluator<Scalar> > fwdStateModel_;
250 Thyra::ModelEvaluatorBase::InArgs<Scalar> basePoint_;
251 TimeRange<Scalar> fwdTimeRange_;
252 RCP<const InterpolationBufferBase<Scalar> > fwdStateSolutionBuffer_;
253
254 mutable bool isInitialized_;
255 mutable Thyra::ModelEvaluatorBase::InArgs<Scalar> prototypeInArgs_bar_;
256 mutable Thyra::ModelEvaluatorBase::OutArgs<Scalar> prototypeOutArgs_bar_;
257 mutable Thyra::ModelEvaluatorBase::InArgs<Scalar> adjointNominalValues_;
258 mutable RCP<Thyra::LinearOpBase<Scalar> > my_W_bar_adj_op_;
259 mutable RCP<Thyra::LinearOpBase<Scalar> > my_d_f_d_x_dot_op_;
260
261 // /////////////////////////
262 // Private member functions
263
264 // Just-in-time initialization function
265 void initialize() const;
266
267};
268
269
274template<class Scalar>
275RCP<AdjointModelEvaluator<Scalar> >
277 const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel,
278 const TimeRange<Scalar> &fwdTimeRange
279 )
280{
281 RCP<AdjointModelEvaluator<Scalar> >
282 adjointModel = Teuchos::rcp(new AdjointModelEvaluator<Scalar>);
283 adjointModel->setFwdStateModel(fwdStateModel, fwdStateModel->getNominalValues());
284 adjointModel->setFwdTimeRange(fwdTimeRange);
285 return adjointModel;
286}
287
288
289// /////////////////////////////////
290// Implementations
291
292
293// Constructors/Intializers/Accessors
294
295
296template<class Scalar>
300
301
302template<class Scalar>
304 const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel,
305 const Thyra::ModelEvaluatorBase::InArgs<Scalar> &basePoint
306 )
307{
308 TEUCHOS_TEST_FOR_EXCEPT(is_null(fwdStateModel));
309 fwdStateModel_ = fwdStateModel;
310 basePoint_ = basePoint;
311 isInitialized_ = false;
312}
313
314
315template<class Scalar>
317 const TimeRange<Scalar> &fwdTimeRange )
318{
319 fwdTimeRange_ = fwdTimeRange;
320}
321
322
323template<class Scalar>
325 const RCP<const InterpolationBufferBase<Scalar> > &fwdStateSolutionBuffer )
326{
327 TEUCHOS_TEST_FOR_EXCEPT(is_null(fwdStateSolutionBuffer));
328 fwdStateSolutionBuffer_ = fwdStateSolutionBuffer;
329}
330
331
332// Public functions overridden from ModelEvaulator
333
334
335template<class Scalar>
336RCP<const Thyra::VectorSpaceBase<Scalar> >
338{
339 initialize();
340 return fwdStateModel_->get_f_space();
341}
342
343
344template<class Scalar>
345RCP<const Thyra::VectorSpaceBase<Scalar> >
347{
348 initialize();
349 return fwdStateModel_->get_x_space();
350}
351
352
353template<class Scalar>
354Thyra::ModelEvaluatorBase::InArgs<Scalar>
356{
357 initialize();
358 return adjointNominalValues_;
359}
360
361
362template<class Scalar>
363RCP<Thyra::LinearOpWithSolveBase<Scalar> >
365{
366 initialize();
367 return Thyra::nonconstAdjointLows<Scalar>(fwdStateModel_->create_W());
368}
369
370
371template<class Scalar>
372RCP<Thyra::LinearOpBase<Scalar> >
374{
375 initialize();
376 return Thyra::nonconstAdjoint<Scalar>(fwdStateModel_->create_W_op());
377}
378
379
380template<class Scalar>
381Thyra::ModelEvaluatorBase::InArgs<Scalar>
383{
384 initialize();
385 return prototypeInArgs_bar_;
386}
387
388
389// Private functions overridden from ModelEvaulatorDefaultBase
390
391
392template<class Scalar>
393Thyra::ModelEvaluatorBase::OutArgs<Scalar>
395{
396 initialize();
397 return prototypeOutArgs_bar_;
398}
399
400
401template<class Scalar>
402void AdjointModelEvaluator<Scalar>::evalModelImpl(
403 const Thyra::ModelEvaluatorBase::InArgs<Scalar> &inArgs_bar,
404 const Thyra::ModelEvaluatorBase::OutArgs<Scalar> &outArgs_bar
405 ) const
406{
407
408 using Teuchos::rcp_dynamic_cast;
409 using Teuchos::describe;
410 typedef Teuchos::ScalarTraits<Scalar> ST;
411 typedef Thyra::ModelEvaluatorBase MEB;
412 typedef Thyra::DefaultScaledAdjointLinearOp<Scalar> DSALO;
413 typedef Thyra::DefaultAdjointLinearOpWithSolve<Scalar> DALOWS;
414 typedef Teuchos::VerboseObjectTempState<Thyra::ModelEvaluatorBase> VOTSME;
415
416 //
417 // A) Header stuff
418 //
419
420 THYRA_MODEL_EVALUATOR_DECORATOR_EVAL_MODEL_GEN_BEGIN(
421 "AdjointModelEvaluator", inArgs_bar, outArgs_bar, Teuchos::null );
422
423 initialize();
424
425 VOTSME fwdStateModel_outputTempState(fwdStateModel_,out,verbLevel);
426
427 //const bool trace = includesVerbLevel(verbLevel, Teuchos::VERB_LOW);
428 const bool dumpAll = includesVerbLevel(localVerbLevel, Teuchos::VERB_EXTREME);
429
430 //
431 // B) Unpack the input and output arguments to see what we have to compute
432 //
433
434 // B.1) InArgs
435
436 const Scalar t_bar = inArgs_bar.get_t();
437 const RCP<const Thyra::VectorBase<Scalar> >
438 lambda_rev_dot = inArgs_bar.get_x_dot().assert_not_null(), // x_bar_dot
439 lambda = inArgs_bar.get_x().assert_not_null(); // x_bar
440 const Scalar alpha_bar = inArgs_bar.get_alpha();
441 const Scalar beta_bar = inArgs_bar.get_beta();
442
443 if (dumpAll) {
444 *out << "\nlambda_rev_dot = " << describe(*lambda_rev_dot, Teuchos::VERB_EXTREME);
445 *out << "\nlambda = " << describe(*lambda, Teuchos::VERB_EXTREME);
446 *out << "\nalpha_bar = " << alpha_bar << "\n";
447 *out << "\nbeta_bar = " << beta_bar << "\n";
448 }
449
450 // B.2) OutArgs
451
452 const RCP<Thyra::VectorBase<Scalar> > f_bar = outArgs_bar.get_f();
453
454 RCP<DALOWS> W_bar;
455 if (outArgs_bar.supports(MEB::OUT_ARG_W))
456 W_bar = rcp_dynamic_cast<DALOWS>(outArgs_bar.get_W(), true);
457
458 RCP<DSALO> W_bar_op;
459 if (outArgs_bar.supports(MEB::OUT_ARG_W_op))
460 W_bar_op = rcp_dynamic_cast<DSALO>(outArgs_bar.get_W_op(), true);
461
462 if (dumpAll) {
463 if (!is_null(W_bar)) {
464 *out << "\nW_bar = " << describe(*W_bar, Teuchos::VERB_EXTREME);
465 }
466 if (!is_null(W_bar_op)) {
467 *out << "\nW_bar_op = " << describe(*W_bar_op, Teuchos::VERB_EXTREME);
468 }
469 }
470
471 //
472 // C) Evaluate the needed quantities from the underlying forward Model
473 //
474
475 MEB::InArgs<Scalar> fwdInArgs = fwdStateModel_->createInArgs();
476
477 // C.1) Set the required input arguments
478
479 fwdInArgs = basePoint_;
480
481 if (!is_null(fwdStateSolutionBuffer_)) {
482 const Scalar t = fwdTimeRange_.length() - t_bar;
483 RCP<const Thyra::VectorBase<Scalar> > x, x_dot;
484 get_x_and_x_dot<Scalar>( *fwdStateSolutionBuffer_, t,
485 outArg(x), outArg(x_dot) );
486 fwdInArgs.set_x(x);
487 fwdInArgs.set_x_dot(x);
488 }
489 else {
490 // If we don't have an IB object to get the state from, we will assume
491 // that the problem is linear and, therefore, we can pass in any old value
492 // of x, x_dot, and t and get the W_bar_adj object that we need. For this
493 // purpose, we will assume the model's base point will do.
494
495 // 2008/05/14: rabartl: ToDo: Implement real variable dependancy
496 // communication support to make sure that this is okay! If the model is
497 // really nonlinear we need to check for this and throw if the user did
498 // not set up a fwdStateSolutionBuffer object!
499 }
500
501
502 // C.2) Evaluate W_bar_adj if needed
503
504 RCP<Thyra::LinearOpWithSolveBase<Scalar> > W_bar_adj;
505 RCP<Thyra::LinearOpBase<Scalar> > W_bar_adj_op;
506 {
507
508 MEB::OutArgs<Scalar> fwdOutArgs = fwdStateModel_->createOutArgs();
509
510 // Get or create W_bar_adj or W_bar_adj_op if needed
511 if (!is_null(W_bar)) {
512 // If we have W_bar, the W_bar_adj was already created in
513 // this->create_W()
514 W_bar_adj = W_bar->getNonconstOp();
515 W_bar_adj_op = W_bar_adj;
516 }
517 else if (!is_null(W_bar_op)) {
518 // If we have W_bar_op, the W_bar_adj_op was already created in
519 // this->create_W_op()
520 W_bar_adj_op = W_bar_op->getNonconstOp();
521 }
522 else if (!is_null(f_bar)) {
523 TEUCHOS_TEST_FOR_EXCEPT_MSG(true, "ToDo: Unit test this code!");
524 // If the user did not pass in W_bar or W_bar_op, then we need to create
525 // our own local LOB form W_bar_adj_op of W_bar_adj in order to evaluate
526 // the residual f_bar
527 if (is_null(my_W_bar_adj_op_)) {
528 my_W_bar_adj_op_ = fwdStateModel_->create_W_op();
529 }
530 W_bar_adj_op = my_W_bar_adj_op_;
531 }
532
533 // Set W_bar_adj or W_bar_adj_op on the OutArgs object
534 if (!is_null(W_bar_adj)) {
535 fwdOutArgs.set_W(W_bar_adj);
536 }
537 else if (!is_null(W_bar_adj_op)) {
538 fwdOutArgs.set_W_op(W_bar_adj_op);
539 }
540
541 // Set alpha and beta on OutArgs object
542 if (!is_null(W_bar_adj) || !is_null(W_bar_adj_op)) {
543 fwdInArgs.set_alpha(alpha_bar);
544 fwdInArgs.set_beta(beta_bar);
545 }
546
547 // Evaluate the model
548 if (!is_null(W_bar_adj) || !is_null(W_bar_adj_op)) {
549 fwdStateModel_->evalModel( fwdInArgs, fwdOutArgs );
550 }
551
552 // Print the objects if requested
553 if (!is_null(W_bar_adj) && dumpAll)
554 *out << "\nW_bar_adj = " << describe(*W_bar_adj, Teuchos::VERB_EXTREME);
555 if (!is_null(W_bar_adj_op) && dumpAll)
556 *out << "\nW_bar_adj_op = " << describe(*W_bar_adj_op, Teuchos::VERB_EXTREME);
557
558 }
559
560 // C.3) Evaluate d(f)/d(x_dot) if needed
561
562 RCP<Thyra::LinearOpBase<Scalar> > d_f_d_x_dot_op;
563 if (!is_null(f_bar)) {
564 if (is_null(my_d_f_d_x_dot_op_)) {
565 my_d_f_d_x_dot_op_ = fwdStateModel_->create_W_op();
566 }
567 d_f_d_x_dot_op = my_d_f_d_x_dot_op_;
568 MEB::OutArgs<Scalar> fwdOutArgs = fwdStateModel_->createOutArgs();
569 fwdOutArgs.set_W_op(d_f_d_x_dot_op);
570 fwdInArgs.set_alpha(ST::one());
571 fwdInArgs.set_beta(ST::zero());
572 fwdStateModel_->evalModel( fwdInArgs, fwdOutArgs );
573 if (dumpAll) {
574 *out << "\nd_f_d_x_dot_op = " << describe(*d_f_d_x_dot_op, Teuchos::VERB_EXTREME);
575 }
576 }
577
578 //
579 // D) Evaluate the adjoint equation residual:
580 //
581 // f_bar = d(f)/d(x_dot)^T * lambda_hat + 1/beta_bar * W_bar_adj^T * lambda
582 // - d(g)/d(x)^T
583 //
584
585 if (!is_null(f_bar)) {
586
587 // D.1) lambda_hat = lambda_rev_dot - alpha_bar/beta_bar * lambda
588 const RCP<Thyra::VectorBase<Scalar> >
589 lambda_hat = createMember(lambda_rev_dot->space());
590 Thyra::V_VpStV<Scalar>( outArg(*lambda_hat),
591 *lambda_rev_dot, -alpha_bar/beta_bar, *lambda );
592 if (dumpAll)
593 *out << "\nlambda_hat = " << describe(*lambda_hat, Teuchos::VERB_EXTREME);
594
595 // D.2) f_bar = d(f)/d(x_dot)^T * lambda_hat
596 Thyra::apply<Scalar>( *d_f_d_x_dot_op, Thyra::CONJTRANS, *lambda_hat,
597 outArg(*f_bar) );
598
599 // D.3) f_bar += 1/beta_bar * W_bar_adj^T * lambda
600 Thyra::apply<Scalar>( *W_bar_adj_op, Thyra::CONJTRANS, *lambda,
601 outArg(*f_bar), 1.0/beta_bar, ST::one() );
602
603 // D.4) f_bar += - d(g)/d(x)^T
604 // 2008/05/15: rabart: ToDo: Implement once we add support for
605 // distributed response functions
606
607 if (dumpAll)
608 *out << "\nf_bar = " << describe(*f_bar, Teuchos::VERB_EXTREME);
609
610 }
611
612 if (dumpAll) {
613 if (!is_null(W_bar)) {
614 *out << "\nW_bar = " << describe(*W_bar, Teuchos::VERB_EXTREME);
615 }
616 if (!is_null(W_bar_op)) {
617 *out << "\nW_bar_op = " << describe(*W_bar_op, Teuchos::VERB_EXTREME);
618 }
619 }
620
621
622 //
623 // E) Do any remaining post processing
624 //
625
626 THYRA_MODEL_EVALUATOR_DECORATOR_EVAL_MODEL_END();
627
628}
629
630
631// private
632
633
634template<class Scalar>
635void AdjointModelEvaluator<Scalar>::initialize() const
636{
637
638 typedef Thyra::ModelEvaluatorBase MEB;
639
640 if (isInitialized_)
641 return;
642
643 //
644 // A) Validate the that forward Model is of the correct form!
645 //
646
647 MEB::InArgs<Scalar> fwdStateModelInArgs = fwdStateModel_->createInArgs();
648 MEB::OutArgs<Scalar> fwdStateModelOutArgs = fwdStateModel_->createOutArgs();
649
650#ifdef HAVE_RYTHMOS_DEBUG
651 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_x_dot) );
652 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_x) );
653 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_t) );
654 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_alpha) );
655 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_beta) );
656 TEUCHOS_ASSERT( fwdStateModelOutArgs.supports(MEB::OUT_ARG_f) );
657 TEUCHOS_ASSERT( fwdStateModelOutArgs.supports(MEB::OUT_ARG_W) );
658#endif
659
660 //
661 // B) Set up the prototypical InArgs and OutArgs
662 //
663
664 {
665 MEB::InArgsSetup<Scalar> inArgs_bar;
666 inArgs_bar.setModelEvalDescription(this->description());
667 inArgs_bar.setSupports( MEB::IN_ARG_x_dot );
668 inArgs_bar.setSupports( MEB::IN_ARG_x );
669 inArgs_bar.setSupports( MEB::IN_ARG_t );
670 inArgs_bar.setSupports( MEB::IN_ARG_alpha );
671 inArgs_bar.setSupports( MEB::IN_ARG_beta );
672 prototypeInArgs_bar_ = inArgs_bar;
673 }
674
675 {
676 MEB::OutArgsSetup<Scalar> outArgs_bar;
677 outArgs_bar.setModelEvalDescription(this->description());
678 outArgs_bar.setSupports(MEB::OUT_ARG_f);
679 if (fwdStateModelOutArgs.supports(MEB::OUT_ARG_W) ) {
680 outArgs_bar.setSupports(MEB::OUT_ARG_W);
681 outArgs_bar.set_W_properties(fwdStateModelOutArgs.get_W_properties());
682 }
683 if (fwdStateModelOutArgs.supports(MEB::OUT_ARG_W_op) ) {
684 outArgs_bar.setSupports(MEB::OUT_ARG_W_op);
685 outArgs_bar.set_W_properties(fwdStateModelOutArgs.get_W_properties());
686 }
687 prototypeOutArgs_bar_ = outArgs_bar;
688 }
689
690 //
691 // D) Set up the nominal values for the adjoint
692 //
693
694 // Copy structure
695 adjointNominalValues_ = prototypeInArgs_bar_;
696 // Just set a zero initial condition for the adjoint
697 const RCP<Thyra::VectorBase<Scalar> > zero_lambda_vec =
698 createMember(fwdStateModel_->get_f_space());
699 V_S( zero_lambda_vec.ptr(), ScalarTraits<Scalar>::zero() );
700 adjointNominalValues_.set_x_dot(zero_lambda_vec);
701 adjointNominalValues_.set_x(zero_lambda_vec);
702
703 //
704 // E) Wipe out other cached objects
705 //
706
707 my_W_bar_adj_op_ = Teuchos::null;
708 my_d_f_d_x_dot_op_ = Teuchos::null;
709
710 //
711 // F) We are initialized!
712 //
713
714 isInitialized_ = true;
715
716}
717
718
719} // namespace Rythmos
720
721
722#endif // RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
Standard concrete adjoint ModelEvaluator for time-constant mass matrix models.
RCP< const Thyra::VectorSpaceBase< Scalar > > get_x_space() const
void setFwdTimeRange(const TimeRange< Scalar > &fwdTimeRange)
Set the forward time range that this adjoint model will be defined over.
Thyra::ModelEvaluatorBase::InArgs< Scalar > createInArgs() const
void setFwdStateSolutionBuffer(const RCP< const InterpolationBufferBase< Scalar > > &fwdStateSolutionBuffer)
Set the interpolation buffer that will return values of the state solution x and x_dot at various poi...
RCP< Thyra::LinearOpBase< Scalar > > create_W_op() const
RCP< Thyra::LinearOpWithSolveBase< Scalar > > create_W() const
RCP< AdjointModelEvaluator< Scalar > > adjointModelEvaluator(const RCP< const Thyra::ModelEvaluator< Scalar > > &fwdStateModel, const TimeRange< Scalar > &fwdTimeRange)
Nonmember constructor.
RCP< const Thyra::VectorSpaceBase< Scalar > > get_f_space() const
Thyra::ModelEvaluatorBase::InArgs< Scalar > getNominalValues() const
void setFwdStateModel(const RCP< const Thyra::ModelEvaluator< Scalar > > &fwdStateModel, const Thyra::ModelEvaluatorBase::InArgs< Scalar > &basePoint)
Set the underlying forward model and base point.