Sacado Package Browser (Single Doxygen Collection) Version of the Day
Loading...
Searching...
No Matches
dfad_sfc_example.cpp
Go to the documentation of this file.
1// @HEADER
2// ***********************************************************************
3//
4// Sacado Package
5// Copyright (2006) Sandia Corporation
6//
7// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8// the U.S. Government retains certain rights in this software.
9//
10// This library is free software; you can redistribute it and/or modify
11// it under the terms of the GNU Lesser General Public License as
12// published by the Free Software Foundation; either version 2.1 of the
13// License, or (at your option) any later version.
14//
15// This library is distributed in the hope that it will be useful, but
16// WITHOUT ANY WARRANTY; without even the implied warranty of
17// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18// Lesser General Public License for more details.
19//
20// You should have received a copy of the GNU Lesser General Public
21// License along with this library; if not, write to the Free Software
22// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
23// USA
24// Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
25// (etphipp@sandia.gov).
26//
27// ***********************************************************************
28// @HEADER
29
30// dfad_sfc_example
31//
32// usage:
33// dfad_sfc_example
34//
35// output:
36// Uses the scalar flop counter to count the flops for a derivative
37// of a simple function using DFad
38
39#include <iostream>
40#include <iomanip>
41
42#include "Sacado_No_Kokkos.hpp"
43
44// The function to differentiate
45template <typename ScalarT>
46ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) {
47 ScalarT r = c*std::log(b+1.)/std::sin(a);
48
49 return r;
50}
51
52// The analytic derivative of func(a,b,c) with respect to a and b
53template <typename ScalarT>
54void func_deriv(const ScalarT& a, const ScalarT& b, const ScalarT& c,
55 ScalarT& drda, ScalarT& drdb)
56{
57 drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2.))*std::cos(a);
58 drdb = c / ((b+1.)*std::sin(a));
59}
60
63
64int main(int argc, char **argv)
65{
66 double pi = std::atan(1.0)*4.0;
67
68 // Values of function arguments
69 double a = pi/4;
70 double b = 2.0;
71 double c = 3.0;
72
73 // Number of independent variables
74 int num_deriv = 2;
75
76 // Compute function
77 SFC as(a);
78 SFC bs(b);
79 SFC cs(c);
81 SFC rs = func(as, bs, cs);
83
84 std::cout << "Flop counts for function evaluation:";
85 SFC::printCounters(std::cout);
86
87 // Compute derivative analytically
88 SFC drdas, drdbs;
90 func_deriv(as, bs, cs, drdas, drdbs);
92
93 std::cout << "\nFlop counts for analytic derivative evaluation:";
94 SFC::printCounters(std::cout);
95
96 // Compute function and derivative with AD
97 FAD_SFC afad(num_deriv, 0, a);
98 FAD_SFC bfad(num_deriv, 1, b);
99 FAD_SFC cfad(c);
101 FAD_SFC rfad = func(afad, bfad, cfad);
103
104 std::cout << "\nFlop counts for AD function and derivative evaluation:";
105 SFC::printCounters(std::cout);
106
107 // Extract value and derivatives
108 double r = rs.val(); // r
109 double drda = drdas.val(); // dr/da
110 double drdb = drdbs.val(); // dr/db
111
112 double r_ad = rfad.val().val(); // r
113 double drda_ad = rfad.dx(0).val(); // dr/da
114 double drdb_ad = rfad.dx(1).val(); // dr/db
115
116 // Print the results
117 int p = 4;
118 int w = p+7;
119 std::cout.setf(std::ios::scientific);
120 std::cout.precision(p);
121 std::cout << "\nValues/derivatives of computation" << std::endl
122 << " r = " << r << " (original) == " << std::setw(w) << r_ad
123 << " (AD) Error = " << std::setw(w) << r - r_ad << std::endl
124 << "dr/da = " << std::setw(w) << drda << " (analytic) == "
125 << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w)
126 << drda - drda_ad << std::endl
127 << "dr/db = " << std::setw(w) << drdb << " (analytic) == "
128 << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w)
129 << drdb - drdb_ad << std::endl;
130
131 double tol = 1.0e-14;
133 // The Solaris and Irix CC compilers get higher counts for operator=
134 // than does g++, which avoids an extra copy when returning a function value.
135 // The test on fc.totalFlopCount allows for this variation.
136 if (std::fabs(r - r_ad) < tol &&
137 std::fabs(drda - drda_ad) < tol &&
138 std::fabs(drdb - drdb_ad) < tol &&
139 (fc.totalFlopCount == 40)) {
140 std::cout << "\nExample passed!" << std::endl;
141 return 0;
142 }
143 else {
144 std::cout <<"\nSomething is wrong, example failed!" << std::endl;
145 return 1;
146 }
147}
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
int main()
Fad specializations for Teuchos::BLAS wrappers.
Class storing flop counts and summary flop counts.
static std::ostream & printCounters(std::ostream &out)
Print the current static flop counts to out.
static FlopCounts getCounters()
Get the flop counts after a block of computations.
static void resetCounters()
Reset static flop counters before starting a block of computations.
const T & val() const
Return the current value.
static void finalizeCounters()
Finalize total flop count after block of computations.
Sacado::FlopCounterPack::ScalarFlopCounter< double > SFC
void func_deriv(const ScalarT &a, const ScalarT &b, const ScalarT &c, ScalarT &drda, ScalarT &drdb)
ScalarT func(const ScalarT &a, const ScalarT &b, const ScalarT &c)
Sacado::Fad::DFad< SFC > FAD_SFC
const char * p
const double tol