87 TEMPUS_FUNC_TIME_MONITOR_DIFF(
"Tempus::IntegratorAdjointSensitivity::advanceTime()", TEMPUS_AS_AT);
90 using Teuchos::rcp_dynamic_cast;
92 using Thyra::VectorSpaceBase;
99 using Thyra::createMember;
100 using Thyra::createMembers;
102 typedef Thyra::ModelEvaluatorBase MEB;
103 typedef Thyra::DefaultMultiVectorProductVector<Scalar> DMVPV;
104 typedef Thyra::DefaultProductVector<Scalar> DPV;
107 RCP<const SolutionHistory<Scalar> > state_solution_history =
108 state_integrator_->getSolutionHistory();
109 RCP<const SolutionState<Scalar> > initial_state =
110 (*state_solution_history)[0];
113 bool state_status =
true;
115 TEMPUS_FUNC_TIME_MONITOR_DIFF(
"Tempus::IntegratorAdjointSensitivity::advanceTime::state", TEMPUS_AS_AT_FWD);
117 state_status = state_integrator_->advanceTime(timeFinal);
124 adjoint_aux_model_->setFinalTime(state_integrator_->getTime());
127 adjoint_aux_model_->setForwardSolutionHistory(state_solution_history);
130 RCP<const VectorSpaceBase<Scalar> > g_space = model_->get_g_space(g_index_);
131 RCP<const VectorSpaceBase<Scalar> > x_space = model_->get_x_space();
132 const int num_g = g_space->dim();
133 RCP<MultiVectorBase<Scalar> > dgdx = createMembers(x_space, num_g);
134 MEB::InArgs<Scalar> inargs = model_->getNominalValues();
135 RCP<const SolutionState<Scalar> > state =
136 state_solution_history->getCurrentState();
137 inargs.set_t(state->getTime());
138 inargs.set_x(state->getX());
139 inargs.set_x_dot(state->getXDot());
140 MEB::OutArgs<Scalar> outargs = model_->createOutArgs();
141 MEB::OutArgs<Scalar> adj_outargs = adjoint_model_->createOutArgs();
142 outargs.set_DgDx(g_index_,
143 MEB::Derivative<Scalar>(dgdx, MEB::DERIV_MV_GRADIENT_FORM));
144 model_->evalModel(inargs, outargs);
145 outargs.set_DgDx(g_index_, MEB::Derivative<Scalar>());
151 RCP<DPV> adjoint_init =
152 rcp_dynamic_cast<DPV>(Thyra::createMember(adjoint_aux_model_->get_x_space()));
153 RCP<MultiVectorBase<Scalar> > adjoint_init_mv =
154 rcp_dynamic_cast<DMVPV>(adjoint_init->getNonconstVectorBlock(0))->getNonconstMultiVector();
155 assign(adjoint_init->getNonconstVectorBlock(1).ptr(),
156 Teuchos::ScalarTraits<Scalar>::zero());
157 if (mass_matrix_is_identity_)
158 assign(adjoint_init_mv.ptr(), *dgdx);
160 inargs.set_alpha(1.0);
161 inargs.set_beta(0.0);
162 RCP<LinearOpWithSolveBase<Scalar> > W;
163 if (adj_outargs.supports(MEB::OUT_ARG_W)) {
165 W = adjoint_model_->create_W();
166 adj_outargs.set_W(W);
167 adjoint_model_->evalModel(inargs, adj_outargs);
168 adj_outargs.set_W(Teuchos::null);
172 RCP<const LinearOpWithSolveFactoryBase<Scalar> > lowsfb =
173 adjoint_model_->get_W_factory();
174 TEUCHOS_TEST_FOR_EXCEPTION(
175 lowsfb == Teuchos::null, std::logic_error,
176 "Adjoint ME must support W out-arg or provide a W_factory for non-identity mass matrix");
179 RCP<LinearOpBase<Scalar> > W_op = adjoint_model_->create_W_op();
180 adj_outargs.set_W_op(W_op);
181 RCP<PreconditionerFactoryBase<Scalar> > prec_factory =
182 lowsfb->getPreconditionerFactory();
183 RCP<PreconditionerBase<Scalar> > W_prec;
184 if (prec_factory != Teuchos::null)
185 W_prec = prec_factory->createPrec();
186 else if (adj_outargs.supports(MEB::OUT_ARG_W_prec)) {
187 W_prec = adjoint_model_->create_W_prec();
188 adj_outargs.set_W_prec(W_prec);
190 adjoint_model_->evalModel(inargs, adj_outargs);
191 adj_outargs.set_W_op(Teuchos::null);
192 if (adj_outargs.supports(MEB::OUT_ARG_W_prec))
193 adj_outargs.set_W_prec(Teuchos::null);
196 W = lowsfb->createOp();
197 if (W_prec != Teuchos::null) {
198 if (prec_factory != Teuchos::null)
199 prec_factory->initializePrec(
200 Thyra::defaultLinearOpSource<Scalar>(W_op), W_prec.get());
201 Thyra::initializePreconditionedOp<Scalar>(
202 *lowsfb, W_op, W_prec, W.ptr());
205 Thyra::initializeOp<Scalar>(*lowsfb, W_op, W.ptr());
207 TEUCHOS_TEST_FOR_EXCEPTION(
208 W == Teuchos::null, std::logic_error,
209 "A null W has been encountered in Tempus::IntegratorAdjointSensitivity::advanceTime!\n");
212 assign(adjoint_init_mv.ptr(), Teuchos::ScalarTraits<Scalar>::zero());
213 W->solve(Thyra::NOTRANS, *dgdx, adjoint_init_mv.ptr());
217 bool sens_status =
true;
219 TEMPUS_FUNC_TIME_MONITOR_DIFF(
"Tempus::IntegratorAdjointSensitivity::advanceTime::adjoint", TEMPUS_AS_AT_ADJ);
221 const Scalar tinit = adjoint_integrator_->getTimeStepControl()->getInitTime();
222 adjoint_integrator_->initializeSolutionHistory(tinit, adjoint_init);
223 sens_status = adjoint_integrator_->advanceTime(timeFinal);
225 RCP<const SolutionHistory<Scalar> > adjoint_solution_history =
226 adjoint_integrator_->getSolutionHistory();
229 RCP<const VectorSpaceBase<Scalar> > p_space = model_->get_p_space(p_index_);
230 dgdp_ = createMembers(p_space, num_g);
231 if (g_depends_on_p_) {
232 MEB::DerivativeSupport dgdp_support =
233 outargs.supports(MEB::OUT_ARG_DgDp, g_index_, p_index_);
234 if (dgdp_support.supports(MEB::DERIV_MV_GRADIENT_FORM)) {
235 outargs.set_DgDp(g_index_, p_index_,
236 MEB::Derivative<Scalar>(dgdp_,
237 MEB::DERIV_MV_GRADIENT_FORM));
238 model_->evalModel(inargs, outargs);
240 else if (dgdp_support.supports(MEB::DERIV_MV_JACOBIAN_FORM)) {
241 const int num_p = p_space->dim();
242 RCP<MultiVectorBase<Scalar> > dgdp_trans =
243 createMembers(g_space, num_p);
244 outargs.set_DgDp(g_index_, p_index_,
245 MEB::Derivative<Scalar>(dgdp_trans,
246 MEB::DERIV_MV_JACOBIAN_FORM));
247 model_->evalModel(inargs, outargs);
248 Thyra::DetachedMultiVectorView<Scalar> dgdp_view(*dgdp_);
249 Thyra::DetachedMultiVectorView<Scalar> dgdp_trans_view(*dgdp_trans);
250 for (
int i=0; i<num_p; ++i)
251 for (
int j=0; j<num_g; ++j)
252 dgdp_view(i,j) = dgdp_trans_view(j,i);
255 TEUCHOS_TEST_FOR_EXCEPTION(
true, std::logic_error,
256 "Invalid dg/dp support");
257 outargs.set_DgDp(g_index_, p_index_, MEB::Derivative<Scalar>());
260 assign(dgdp_.ptr(), Scalar(0.0));
264 if (ic_depends_on_p_ && dxdp_init_ != Teuchos::null) {
265 RCP<const SolutionState<Scalar> > adjoint_state =
266 adjoint_solution_history->getCurrentState();
267 RCP<const VectorBase<Scalar> > adjoint_x =
268 rcp_dynamic_cast<const DPV>(adjoint_state->getX())->getVectorBlock(0);
269 RCP<const MultiVectorBase<Scalar> > adjoint_mv =
270 rcp_dynamic_cast<const DMVPV>(adjoint_x)->getMultiVector();
271 if (mass_matrix_is_identity_)
272 dxdp_init_->apply(Thyra::CONJTRANS, *adjoint_mv, dgdp_.ptr(), Scalar(1.0),
275 inargs.set_t(initial_state->getTime());
276 inargs.set_x(initial_state->getX());
277 inargs.set_x_dot(initial_state->getXDot());
278 inargs.set_alpha(1.0);
279 inargs.set_beta(0.0);
280 RCP<LinearOpBase<Scalar> > W_op = adjoint_model_->create_W_op();
281 adj_outargs.set_W_op(W_op);
282 adjoint_model_->evalModel(inargs, adj_outargs);
283 adj_outargs.set_W_op(Teuchos::null);
284 RCP<MultiVectorBase<Scalar> > tmp = createMembers(x_space, num_g);
285 W_op->apply(Thyra::NOTRANS, *adjoint_mv, tmp.ptr(), Scalar(1.0),
287 dxdp_init_->apply(Thyra::CONJTRANS, *tmp, dgdp_.ptr(), Scalar(1.0),
295 if (f_depends_on_p_) {
296 RCP<const SolutionState<Scalar> > adjoint_state =
297 adjoint_solution_history->getCurrentState();
298 RCP<const VectorBase<Scalar> > z =
299 rcp_dynamic_cast<const DPV>(adjoint_state->getX())->getVectorBlock(1);
300 RCP<const MultiVectorBase<Scalar> > z_mv =
301 rcp_dynamic_cast<const DMVPV>(z)->getMultiVector();
302 Thyra::V_VmV(dgdp_.ptr(), *dgdp_, *z_mv);
305 buildSolutionHistory(state_solution_history, adjoint_solution_history);
307 return state_status && sens_status;
605 const Teuchos::RCP<
const SolutionHistory<Scalar> >& state_solution_history,
606 const Teuchos::RCP<
const SolutionHistory<Scalar> >& adjoint_solution_history)
610 using Teuchos::rcp_dynamic_cast;
611 using Teuchos::ParameterList;
614 using Thyra::VectorSpaceBase;
615 using Thyra::createMembers;
616 using Thyra::multiVectorProductVector;
618 typedef Thyra::DefaultProductVectorSpace<Scalar> DPVS;
619 typedef Thyra::DefaultProductVector<Scalar> DPV;
621 RCP<const VectorSpaceBase<Scalar> > x_space = model_->get_x_space();
622 RCP<const VectorSpaceBase<Scalar> > adjoint_space =
623 rcp_dynamic_cast<const DPVS>(adjoint_aux_model_->get_x_space())->getBlock(0);
624 Teuchos::Array< RCP<const VectorSpaceBase<Scalar> > > spaces(2);
626 spaces[1] = adjoint_space;
627 RCP<const DPVS > prod_space = Thyra::productVectorSpace(spaces());
629 int num_states = state_solution_history->getNumStates();
630 const Scalar t_init = state_integrator_->getTimeStepControl()->getInitTime();
631 const Scalar t_final = state_integrator_->getTime();
632 for (
int i=0; i<num_states; ++i) {
633 RCP<const SolutionState<Scalar> > forward_state =
634 (*state_solution_history)[i];
635 RCP<const SolutionState<Scalar> > adjoint_state =
636 adjoint_solution_history->findState(t_final+t_init-forward_state->getTime());
639 RCP<DPV> x = Thyra::defaultProductVector(prod_space);
640 RCP<const VectorBase<Scalar> > adjoint_x =
641 rcp_dynamic_cast<const DPV>(adjoint_state->getX())->getVectorBlock(0);
642 assign(x->getNonconstVectorBlock(0).ptr(), *(forward_state->getX()));
643 assign(x->getNonconstVectorBlock(1).ptr(), *(adjoint_x));
644 RCP<VectorBase<Scalar> > x_b = x;
647 RCP<DPV> x_dot = Thyra::defaultProductVector(prod_space);
648 RCP<const VectorBase<Scalar> > adjoint_x_dot =
649 rcp_dynamic_cast<const DPV>(adjoint_state->getXDot())->getVectorBlock(0);
650 assign(x_dot->getNonconstVectorBlock(0).ptr(), *(forward_state->getXDot()));
651 assign(x_dot->getNonconstVectorBlock(1).ptr(), *(adjoint_x_dot));
652 RCP<VectorBase<Scalar> > x_dot_b = x_dot;
656 if (forward_state->getXDotDot() != Teuchos::null) {
657 x_dot_dot = Thyra::defaultProductVector(prod_space);
658 RCP<const VectorBase<Scalar> > adjoint_x_dot_dot =
659 rcp_dynamic_cast<const DPV>(
660 adjoint_state->getXDotDot())->getVectorBlock(0);
661 assign(x_dot_dot->getNonconstVectorBlock(0).ptr(),
662 *(forward_state->getXDotDot()));
663 assign(x_dot_dot->getNonconstVectorBlock(1).ptr(),
664 *(adjoint_x_dot_dot));
666 RCP<VectorBase<Scalar> > x_dot_dot_b = x_dot_dot;
668 RCP<SolutionState<Scalar> > prod_state = forward_state->clone();
669 prod_state->setX(x_b);
670 prod_state->setXDot(x_dot_b);
671 prod_state->setXDotDot(x_dot_dot_b);
672 prod_state->setPhysicsState(Teuchos::null);
673 solutionHistory_->addState(prod_state);
681 Teuchos::RCP<Teuchos::ParameterList> inputPL,
686 Teuchos::RCP<Teuchos::ParameterList> spl = Teuchos::parameterList();
687 if (inputPL != Teuchos::null)
688 *spl = inputPL->sublist(
"Sensitivities");
690 int p_index = spl->get<
int>(
"Sensitivity Parameter Index", 0);
691 int g_index = spl->get<
int>(
"Response Function Index", 0);
692 bool g_depends_on_p = spl->get<
bool>(
"Response Depends on Parameters",
true);
693 bool f_depends_on_p = spl->get<
bool>(
"Residual Depends on Parameters",
true);
694 bool ic_depends_on_p = spl->get<
bool>(
"IC Depends on Parameters",
true);
695 bool mass_matrix_is_identity = spl->get<
bool>(
"Mass Matrix Is Identity",
false);
697 auto state_integrator = createIntegratorBasic<Scalar>(inputPL, model);
700 if (spl->isParameter(
"Response Depends on Parameters"))
701 spl->remove(
"Response Depends on Parameters");
702 if (spl->isParameter(
"Residual Depends on Parameters"))
703 spl->remove(
"Residual Depends on Parameters");
704 if (spl->isParameter(
"IC Depends on Parameters"))
705 spl->remove(
"IC Depends on Parameters");
707 const Scalar tinit = state_integrator->getTimeStepControl()->getInitTime();
708 const Scalar tfinal = state_integrator->getTimeStepControl()->getFinalTime();
712 Teuchos::RCP<Thyra::ModelEvaluator<Scalar>> adjt_model = adjoint_model;
713 if (adjoint_model == Teuchos::null)
716 auto adjoint_aux_model = Teuchos::rcp(
new AdjointAuxSensitivityModelEvaluator<Scalar>(model, adjt_model, tinit, tfinal, spl));
720 auto integrator_name = inputPL->get<std::string>(
"Integrator Name");
721 auto integratorPL = Teuchos::sublist(inputPL, integrator_name,
true);
722 auto shPL = Teuchos::sublist(integratorPL,
"Solution History",
true);
723 auto combined_solution_History = createSolutionHistoryPL<Scalar>(shPL);
725 auto adjoint_integrator = createIntegratorBasic<Scalar>(inputPL, adjoint_aux_model);
727 Teuchos::RCP<IntegratorAdjointSensitivity<Scalar>> integrator = Teuchos::rcp(
new IntegratorAdjointSensitivity<Scalar>(
728 model, state_integrator, adjt_model, adjoint_aux_model, adjoint_integrator, combined_solution_History, p_index, g_index, g_depends_on_p,
729 f_depends_on_p, ic_depends_on_p, mass_matrix_is_identity));