Sacado Package Browser (Single Doxygen Collection) Version of the Day
Loading...
Searching...
No Matches
dfad_view_handle_example.cpp
Go to the documentation of this file.
1// $Id$
2// $Source$
3// @HEADER
4// ***********************************************************************
5//
6// Sacado Package
7// Copyright (2006) Sandia Corporation
8//
9// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
10// the U.S. Government retains certain rights in this software.
11//
12// This library is free software; you can redistribute it and/or modify
13// it under the terms of the GNU Lesser General Public License as
14// published by the Free Software Foundation; either version 2.1 of the
15// License, or (at your option) any later version.
16//
17// This library is distributed in the hope that it will be useful, but
18// WITHOUT ANY WARRANTY; without even the implied warranty of
19// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20// Lesser General Public License for more details.
21//
22// You should have received a copy of the GNU Lesser General Public
23// License along with this library; if not, write to the Free Software
24// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
25// USA
26// Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
27// (etphipp@sandia.gov).
28//
29// ***********************************************************************
30// @HEADER
31
32// dfad_example
33//
34// usage:
35// dfad_view_handle_example
36//
37// output:
38// prints the results of differentiating a simple function with forward
39// mode AD using the Sacado::Fad::DFad class (uses dynamic memory
40// allocation for number of derivative components) and ViewFad as a
41// handle into externally stored derivative data
42
43#include <iostream>
44#include <iomanip>
45
46#include "Sacado.hpp"
47
48// The function to differentiate
49template <typename ScalarRes, typename Scalar1, typename Scalar2>
50ScalarRes func(const Scalar1& a, const Scalar1& b, const Scalar2& c) {
51 ScalarRes r = c*std::log(b+1.)/std::sin(a);
52
53 return r;
54}
55
56// The analytic derivative of func(a,b,c) with respect to a and b
57void func_deriv(double a, double b, double c, double& drda, double& drdb)
58{
59 drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2.))*std::cos(a);
60 drdb = c / ((b+1.)*std::sin(a));
61}
62
63int main(int argc, char **argv)
64{
65 Kokkos::initialize();
66 int ret = 0;
67 {
68
69 double pi = std::atan(1.0)*4.0;
70
71 // Values of function arguments
72 double a = pi/4;
73 double b = 2.0;
74 double c = 3.0;
75
76 // View to store derivative data
77 const int num_deriv = 2;
78 Kokkos::View<double**,Kokkos::LayoutLeft,Kokkos::HostSpace> v( "v", 2, num_deriv );
79
80 // Initialize derivative data
81 Kokkos::deep_copy( v, 0.0 );
82 v(0,0) = 1.0; // First (0) indep. var
83 v(1,1) = 1.0; // Second (1) indep. var
84
85 // The Fad type
87
88 // View handle type -- first 0 is static length (e.g., SFad), second 0
89 // is static stride, which you can make 1 if you know the View will be
90 // LayoutRight (e.g., not GPU). When values are 0, they are treated
91 // dynamically
93
94 // Fad objects
95 ViewFadType afad( &v(0,0), &a, num_deriv, v.stride_1() );
96 ViewFadType bfad( &v(1,0), &b, num_deriv, v.stride_1() );
97 FadType cfad(c);
98 FadType rfad;
99
100 // Compute function
101 double r = func<double>(a, b, c);
102
103 // Compute derivative analytically
104 double drda, drdb;
105 func_deriv(a, b, c, drda, drdb);
106
107 // Compute function and derivative with AD
108 rfad = func<FadType>(afad, bfad, cfad);
109
110 // Extract value and derivatives
111 double r_ad = rfad.val(); // r
112 double drda_ad = rfad.dx(0); // dr/da
113 double drdb_ad = rfad.dx(1); // dr/db
114
115 // Print the results
116 int p = 4;
117 int w = p+7;
118 std::cout.setf(std::ios::scientific);
119 std::cout.precision(p);
120 std::cout << " r = " << r << " (original) == " << std::setw(w) << r_ad
121 << " (AD) Error = " << std::setw(w) << r - r_ad << std::endl
122 << "dr/da = " << std::setw(w) << drda << " (analytic) == "
123 << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w)
124 << drda - drda_ad << std::endl
125 << "dr/db = " << std::setw(w) << drdb << " (analytic) == "
126 << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w)
127 << drdb - drdb_ad << std::endl;
128
129 double tol = 1.0e-14;
130 if (std::fabs(r - r_ad) < tol &&
131 std::fabs(drda - drda_ad) < tol &&
132 std::fabs(drdb - drdb_ad) < tol) {
133 std::cout << "\nExample passed!" << std::endl;
134 ret = 0;
135 }
136 else {
137 std::cout <<"\nSomething is wrong, example failed!" << std::endl;
138 ret = 1;
139 }
140
141 }
142 Kokkos::finalize();
143 return ret;
144}
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
int main()
Sacado::Fad::DFad< double > FadType
Fad specializations for Teuchos::BLAS wrappers.
void func_deriv(double a, double b, double c, double &drda, double &drdb)
ScalarRes func(const Scalar1 &a, const Scalar1 &b, const Scalar2 &c)
const char * p
const double tol