Sacado Package Browser (Single Doxygen Collection) Version of the Day
Loading...
Searching...
No Matches
dfad_sfad_example.cpp
Go to the documentation of this file.
1// @HEADER
2// ***********************************************************************
3//
4// Sacado Package
5// Copyright (2006) Sandia Corporation
6//
7// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8// the U.S. Government retains certain rights in this software.
9//
10// This library is free software; you can redistribute it and/or modify
11// it under the terms of the GNU Lesser General Public License as
12// published by the Free Software Foundation; either version 2.1 of the
13// License, or (at your option) any later version.
14//
15// This library is distributed in the hope that it will be useful, but
16// WITHOUT ANY WARRANTY; without even the implied warranty of
17// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18// Lesser General Public License for more details.
19//
20// You should have received a copy of the GNU Lesser General Public
21// License along with this library; if not, write to the Free Software
22// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
23// USA
24// Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
25// (etphipp@sandia.gov).
26//
27// ***********************************************************************
28// @HEADER
29
30// dfad_sfad_example
31//
32// usage:
33// dfad_sfad_example
34//
35// output:
36// prints the results of computing the second derivative times a vector
37// for a simple function with forward nested forward mode AD using the
38// Sacado::Fad::DFad and Sacado::Fad::SFad classes.
39
40#include <iostream>
41#include <iomanip>
42
43#include "Sacado.hpp"
44
45// The function to differentiate
46template <typename ScalarT>
47ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) {
48 ScalarT r = c*std::log(b+1.)/std::sin(a);
49 return r;
50}
51
52// The analytic first and second derivative of func with respect to a and b
53void analytic_deriv(double a, double b, double c,
54 double& drda, double& drdb,
55 double& d2rda2, double& d2rdb2, double& d2rdadb)
56{
57 drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2.))*std::cos(a);
58 drdb = c / ((b+1.)*std::sin(a));
59 d2rda2 = c*std::log(b+1.)/std::sin(a) + 2.*(c*std::log(b+1.)/std::pow(std::sin(a),3.))*std::pow(std::cos(a),2.);
60 d2rdb2 = -c / (std::pow(b+1.,2.)*std::sin(a));
61 d2rdadb = -c / ((b+1.)*std::pow(std::sin(a),2.))*std::cos(a);
62}
63
64// Function that computes func and its first derivative w.r.t a & b using
65// Sacado AD
66template <typename ScalarT>
67void func_and_deriv(const ScalarT& a, const ScalarT& b, const ScalarT& c,
68 ScalarT& r, ScalarT& drda, ScalarT& drdb) {
70 FadType a_fad(2, 0, a);
71 FadType b_fad(2, 1, b);
72 FadType c_fad = c;
73
74 FadType r_fad = func(a_fad, b_fad, c_fad);
75 r = r_fad.val();
76 drda = r_fad.dx(0);
77 drdb = r_fad.dx(1);
78}
79
80// Function that computes func, its first derivative w.r.t a & b, and its
81// second derivative in the direction of [v_a, v_b] with Sacado AD
82//
83// Define x = [a, b], v = [v_a, v_b], and y(t) = x + t*v. Then
84// df/dx*v = d/dt f(y(t)) |_{t=0}.
85//
86// In the code below, we differentiate with respect to t in this manner.
87// Addtionally we take a short-cut and don't introduce t directly and
88// compute a(t) = a + t*v_a, b(t) = b + t*v_b. Instead we
89// initialize a_fad and b_fad directly as if we had computed them in this way.
90template <typename ScalarT>
91void func_and_deriv2(const ScalarT& a, const ScalarT& b, const ScalarT& c,
92 const ScalarT& v_a, const ScalarT& v_b,
93 ScalarT& r, ScalarT& drda, ScalarT& drdb,
94 ScalarT& z_a, ScalarT& z_b) {
96
97 // The below is equivalent to:
98 // FadType t(1, 0.0); f_fad.fastAccessDx(0) = 1;
99 // FadType a_fad = a + t*v_a;
100 // FadType b_fad = b + t*v_b;
101 FadType a_fad(1, a); a_fad.fastAccessDx(0) = v_a;
102 FadType b_fad(1, b); b_fad.fastAccessDx(0) = v_b;
103 FadType c_fad = c;
104
105 FadType r_fad, drda_fad, drdb_fad;
106 func_and_deriv(a_fad, b_fad, c_fad, r_fad, drda_fad, drdb_fad);
107 r = r_fad.val(); // r
108 // note: also have r_fad.dx(0) = dr/da*v_a + dr/db*v_b
109 drda = drda_fad.val(); // dr/da
110 drdb = drdb_fad.val(); // dr/db
111 z_a = drda_fad.dx(0); // d^2r/da^2 * v_a + d^2r/dadb * v_b
112 z_b = drdb_fad.dx(0); // d^2r/dadb * v_a + d^2r/db^2 * v_b
113}
114
115int main(int argc, char **argv)
116{
117 double pi = std::atan(1.0)*4.0;
118
119 // Values of function arguments
120 double a = pi/4;
121 double b = 2.0;
122 double c = 3.0;
123
124 // Direction we wish to differentiate for second derivative
125 double v_a = 1.5;
126 double v_b = 3.6;
127
128 // Compute derivatives via AD
129 double r_ad, drda_ad, drdb_ad, z_a_ad, z_b_ad;
130 func_and_deriv2(a, b, c, v_a, v_b, r_ad, drda_ad, drdb_ad, z_a_ad, z_b_ad);
131
132 // Compute function
133 double r = func(a, b, c);
134
135 // Compute derivatives analytically
136 double drda, drdb, d2rda2, d2rdb2, d2rdadb;
137 analytic_deriv(a, b, c, drda, drdb, d2rda2, d2rdb2, d2rdadb);
138 double z_a = d2rda2*v_a + d2rdadb*v_b;
139 double z_b = d2rdadb*v_a + d2rdb2*v_b;
140
141 // Print the results
142 int p = 4;
143 int w = p+7;
144 std::cout.setf(std::ios::scientific);
145 std::cout.precision(p);
146 std::cout << " r = " << std::setw(w) << r << " (original) == "
147 << std::setw(w) << r_ad << " (AD) Error = " << std::setw(w)
148 << r - r_ad << std::endl
149 << "dr/da = " << std::setw(w) << drda << " (analytic) == "
150 << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w)
151 << drda - drda_ad << std::endl
152 << "dr/db = " << std::setw(w) << drdb << " (analytic) == "
153 << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w)
154 << drdb - drdb_ad << std::endl
155 << "z_a = " << std::setw(w) << z_a << " (analytic) == "
156 << std::setw(w) << z_a_ad << " (AD) Error = " << std::setw(w)
157 << z_a - z_a_ad << std::endl
158 << "z_b = " << std::setw(w) << z_b << " (analytic) == "
159 << std::setw(w) << z_b_ad << " (AD) Error = " << std::setw(w)
160 << z_b - z_b_ad << std::endl;
161
162 double tol = 1.0e-14;
163 if (std::fabs(r - r_ad) < tol &&
164 std::fabs(drda - drda_ad) < tol &&
165 std::fabs(drdb - drdb_ad) < tol &&
166 std::fabs(z_a - z_a_ad) < tol &&
167 std::fabs(z_b - z_b_ad) < tol) {
168 std::cout << "\nExample passed!" << std::endl;
169 return 0;
170 }
171 else {
172 std::cout <<"\nSomething is wrong, example failed!" << std::endl;
173 return 1;
174 }
175}
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
int main()
Definition: ad_example.cpp:191
Sacado::Fad::DFad< double > FadType
ScalarT func(const ScalarT &a, const ScalarT &b, const ScalarT &c)
void func_and_deriv2(const ScalarT &a, const ScalarT &b, const ScalarT &c, const ScalarT &v_a, const ScalarT &v_b, ScalarT &r, ScalarT &drda, ScalarT &drdb, ScalarT &z_a, ScalarT &z_b)
void func_and_deriv(const ScalarT &a, const ScalarT &b, const ScalarT &c, ScalarT &r, ScalarT &drda, ScalarT &drdb)
void analytic_deriv(double a, double b, double c, double &drda, double &drdb, double &d2rda2, double &d2rdb2, double &d2rdadb)
const char * p
const double tol