Sacado Package Browser (Single Doxygen Collection) Version of the Day
Loading...
Searching...
No Matches
fad_lj_grad.cpp
Go to the documentation of this file.
1// $Id$
2// $Source$
3// @HEADER
4// ***********************************************************************
5//
6// Sacado Package
7// Copyright (2006) Sandia Corporation
8//
9// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
10// the U.S. Government retains certain rights in this software.
11//
12// This library is free software; you can redistribute it and/or modify
13// it under the terms of the GNU Lesser General Public License as
14// published by the Free Software Foundation; either version 2.1 of the
15// License, or (at your option) any later version.
16//
17// This library is distributed in the hope that it will be useful, but
18// WITHOUT ANY WARRANTY; without even the implied warranty of
19// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20// Lesser General Public License for more details.
21//
22// You should have received a copy of the GNU Lesser General Public
23// License along with this library; if not, write to the Free Software
24// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
25// USA
26// Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
27// (etphipp@sandia.gov).
28//
29// ***********************************************************************
30// @HEADER
31
32#include "Sacado_Random.hpp"
33#include "Sacado_No_Kokkos.hpp"
34
35#include "Fad/fad.h"
36#include "TinyFadET/tfad.h"
37
38#include "Teuchos_Time.hpp"
39#include "Teuchos_CommandLineProcessor.hpp"
40
41// A simple performance test that computes the derivative of a simple
42// expression using many variants of Fad.
43
44void FAD::error(const char *msg) {
45 std::cout << msg << std::endl;
46}
47
48namespace {
49 double xi[3], xj[3], pa[4], f[3], delr[3];
50}
51
52template <typename T>
53inline T
54vec3_distsq(const T xi[], const double xj[]) {
55 T delr0 = xi[0]-xj[0];
56 T delr1 = xi[1]-xj[1];
57 T delr2 = xi[2]-xj[2];
58 return delr0*delr0 + delr1*delr1 + delr2*delr2;
59}
60
61template <typename T>
62inline T
63vec3_distsq(const T xi[], const double xj[], T delr[]) {
64 delr[0] = xi[0]-xj[0];
65 delr[1] = xi[1]-xj[1];
66 delr[2] = xi[2]-xj[2];
67 return delr[0]*delr[0] + delr[1]*delr[1] + delr[2]*delr[2];
68}
69
70template <typename T>
71inline void
72lj(const T xi[], const double xj[], T& energy) {
73 T delr2 = vec3_distsq(xi,xj);
74 T delr_2 = 1.0/delr2;
75 T delr_6 = delr_2*delr_2*delr_2;
76 energy = (pa[1]*delr_6 - pa[2])*delr_6 - pa[3];
77}
78
79inline void
80lj_and_grad(const double xi[], const double xj[], double& energy,
81 double f[]) {
82 double delr2 = vec3_distsq(xi,xj,delr);
83 double delr_2 = 1.0/delr2;
84 double delr_6 = delr_2*delr_2*delr_2;
85 energy = (pa[1]*delr_6 - pa[2])*delr_6 - pa[3];
86 double tmp = (-12.0*pa[1]*delr_6 - 6.0*pa[2])*delr_6*delr_2;
87 f[0] = delr[0]*tmp;
88 f[1] = delr[1]*tmp;
89 f[2] = delr[2]*tmp;
90}
91
92template <typename FadType>
93double
94do_time(int nloop)
95{
96 Teuchos::Time timer("lj", false);
97 FadType xi_fad[3], energy;
98
99 for (int i=0; i<3; i++) {
100 xi_fad[i] = FadType(3, i, xi[i]);
101 }
102
103 timer.start(true);
104 for (int j=0; j<nloop; j++) {
105
106 lj(xi_fad, xj, energy);
107
108 for (int i=0; i<3; i++)
109 f[i] += -energy.fastAccessDx(i);
110 }
111 timer.stop();
112
113 return timer.totalElapsedTime() / nloop;
114}
115
116double
118{
119 Teuchos::Time timer("lj", false);
120 double energy, ff[3];
121
122 timer.start(true);
123 for (int j=0; j<nloop; j++) {
124
125 lj_and_grad(xi, xj, energy, ff);
126
127 for (int i=0; i<3; i++)
128 f[i] += -ff[i];
129
130 }
131 timer.stop();
132
133 return timer.totalElapsedTime() / nloop;
134}
135
136int main(int argc, char* argv[]) {
137 int ierr = 0;
138
139 try {
140 double t, ta;
141 int p = 2;
142 int w = p+7;
143
144 // Set up command line options
145 Teuchos::CommandLineProcessor clp;
146 clp.setDocString("This program tests the speed of various forward mode AD implementations for a single multiplication operation");
147 int nloop = 1000000;
148 clp.setOption("nloop", &nloop, "Number of loops");
149
150 // Parse options
151 Teuchos::CommandLineProcessor::EParseCommandLineReturn
152 parseReturn= clp.parse(argc, argv);
153 if(parseReturn != Teuchos::CommandLineProcessor::PARSE_SUCCESSFUL)
154 return 1;
155
156 std::cout.setf(std::ios::scientific);
157 std::cout.precision(p);
158 std::cout << "Times (sec) nloop = " << nloop << ": " << std::endl;
159
160 Sacado::Random<double> urand(0.0, 1.0);
161 for (int i=0; i<3; i++) {
162 xi[i] = urand.number();
163 xj[i] = urand.number();
164 pa[i] = urand.number();
165 }
166 pa[3] = urand.number();
167
168 ta = do_time_analytic(nloop);
169 std::cout << "Analytic: " << std::setw(w) << ta << std::endl;
170
171 t = do_time< FAD::TFad<3,double> >(nloop);
172 std::cout << "TFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
173
174 t = do_time< FAD::Fad<double> >(nloop);
175 std::cout << "Fad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
176
177 t = do_time< Sacado::Fad::SFad<double,3> >(nloop);
178 std::cout << "SFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
179
180 t = do_time< Sacado::Fad::SLFad<double,3> >(nloop);
181 std::cout << "SLFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
182
183 t = do_time< Sacado::Fad::DFad<double> >(nloop);
184 std::cout << "DFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
185
186 t = do_time< Sacado::ELRFad::SFad<double,3> >(nloop);
187 std::cout << "ELRSFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
188
189 t = do_time< Sacado::ELRFad::SLFad<double,3> >(nloop);
190 std::cout << "ELRSLFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
191
192 t = do_time< Sacado::ELRFad::DFad<double> >(nloop);
193 std::cout << "ELRDFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
194
195 t = do_time< Sacado::CacheFad::DFad<double> >(nloop);
196 std::cout << "CacheFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
197
198 t = do_time< Sacado::Fad::DVFad<double> >(nloop);
199 std::cout << "DVFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
200
201 }
202 catch (std::exception& e) {
203 std::cout << e.what() << std::endl;
204 ierr = 1;
205 }
206 catch (const char *s) {
207 std::cout << s << std::endl;
208 ierr = 1;
209 }
210 catch (...) {
211 std::cout << "Caught unknown exception!" << std::endl;
212 ierr = 1;
213 }
214
215 return ierr;
216}
int main()
Sacado::Fad::DFad< double > FadType
A random number generator that generates random numbers uniformly distributed in the interval (a,...
ScalarT number()
Get random number.
void lj(const T xi[], const double xj[], T &energy)
T vec3_distsq(const T xi[], const double xj[])
void lj_and_grad(const double xi[], const double xj[], double &energy, double f[])
double do_time(int nloop)
double do_time_analytic(int nloop)
const char * p