ROL
ROL_KLDivergence.hpp
Go to the documentation of this file.
1// @HEADER
2// ************************************************************************
3//
4// Rapid Optimization Library (ROL) Package
5// Copyright (2014) Sandia Corporation
6//
7// Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
8// license for use of this work by or on behalf of the U.S. Government.
9//
10// Redistribution and use in source and binary forms, with or without
11// modification, are permitted provided that the following conditions are
12// met:
13//
14// 1. Redistributions of source code must retain the above copyright
15// notice, this list of conditions and the following disclaimer.
16//
17// 2. Redistributions in binary form must reproduce the above copyright
18// notice, this list of conditions and the following disclaimer in the
19// documentation and/or other materials provided with the distribution.
20//
21// 3. Neither the name of the Corporation nor the names of the
22// contributors may be used to endorse or promote products derived from
23// this software without specific prior written permission.
24//
25// THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
26// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
28// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
29// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36//
37// Questions? Contact lead developers:
38// Drew Kouri (dpkouri@sandia.gov) and
39// Denis Ridzal (dridzal@sandia.gov)
40//
41// ************************************************************************
42// @HEADER
43
44#ifndef ROL_KLDIVERGENCE_HPP
45#define ROL_KLDIVERGENCE_HPP
46
48
78namespace ROL {
79
80template<class Real>
81class KLDivergence : public RandVarFunctional<Real> {
82private:
83 Real eps_;
84
85 Real gval_;
86 Real gvval_;
87 Real hval_;
88 ROL::Ptr<Vector<Real> > scaledGradient_;
89 ROL::Ptr<Vector<Real> > scaledHessVec_;
90
92
93 using RandVarFunctional<Real>::val_;
94 using RandVarFunctional<Real>::gv_;
95 using RandVarFunctional<Real>::g_;
96 using RandVarFunctional<Real>::hv_;
98
99 using RandVarFunctional<Real>::point_;
100 using RandVarFunctional<Real>::weight_;
101
106
107 void checkInputs(void) const {
108 Real zero(0);
109 ROL_TEST_FOR_EXCEPTION((eps_ <= zero), std::invalid_argument,
110 ">>> ERROR (ROL::KLDivergence): Threshold must be positive!");
111 }
112
113public:
118 KLDivergence(const Real eps = 1.e-2)
119 : RandVarFunctional<Real>(), eps_(eps), firstResetKLD_(true) {
120 checkInputs();
121 }
122
131 KLDivergence(ROL::ParameterList &parlist)
132 : RandVarFunctional<Real>(), firstResetKLD_(true) {
133 ROL::ParameterList &list
134 = parlist.sublist("SOL").sublist("Risk Measure").sublist("KL Divergence");
135 eps_ = list.get<Real>("Threshold");
136 checkInputs();
137 }
138
139 void initialize(const Vector<Real> &x) {
141 if ( firstResetKLD_ ) {
142 scaledGradient_ = x.dual().clone();
143 scaledHessVec_ = x.dual().clone();
144 firstResetKLD_ = false;
145 }
146 const Real zero(0);
147 gval_ = zero; gvval_ = zero; hval_ = zero;
148 scaledGradient_->zero(); scaledHessVec_->zero();
149 }
150
152 const Vector<Real> &x,
153 const std::vector<Real> &xstat,
154 Real &tol) {
155 Real val = computeValue(obj,x,tol);
156 Real ev = exponential(val,xstat[0]*eps_);
157 val_ += weight_ * ev;
158 }
159
160 Real getValue(const Vector<Real> &x,
161 const std::vector<Real> &xstat,
162 SampleGenerator<Real> &sampler) {
163 if ( xstat[0] == static_cast<Real>(0) ) {
164 return ROL_INF<Real>();
165 }
166 Real ev(0);
167 sampler.sumAll(&val_,&ev,1);
168 return (static_cast<Real>(1) + std::log(ev)/eps_)/xstat[0];
169 }
170
172 const Vector<Real> &x,
173 const std::vector<Real> &xstat,
174 Real &tol) {
175 Real val = computeValue(obj,x,tol);
176 Real ev = exponential(val,xstat[0]*eps_);
177 val_ += weight_ * ev;
178 gval_ += weight_ * ev * val;
179 computeGradient(*dualVector_,obj,x,tol);
180 g_->axpy(weight_*ev,*dualVector_);
181 }
182
184 std::vector<Real> &gstat,
185 const Vector<Real> &x,
186 const std::vector<Real> &xstat,
187 SampleGenerator<Real> &sampler) {
188 std::vector<Real> local(2), global(2);
189 local[0] = val_;
190 local[1] = gval_;
191 sampler.sumAll(&local[0],&global[0],2);
192 Real ev = global[0], egval = global[1];
193
194 sampler.sumAll(*g_,g);
195 g.scale(static_cast<Real>(1)/ev);
196
197 if ( xstat[0] == static_cast<Real>(0) ) {
198 gstat[0] = ROL_INF<Real>();
199 }
200 else {
201 gstat[0] = -((static_cast<Real>(1) + std::log(ev)/eps_)/xstat[0]
202 - egval/ev)/xstat[0];
203 }
204 }
205
207 const Vector<Real> &v,
208 const std::vector<Real> &vstat,
209 const Vector<Real> &x,
210 const std::vector<Real> &xstat,
211 Real &tol) {
212 Real val = computeValue(obj,x,tol);
213 Real ev = exponential(val,xstat[0]*eps_);
214 Real gv = computeGradVec(*dualVector_,obj,v,x,tol);
215 val_ += weight_ * ev;
216 gv_ += weight_ * ev * gv;
217 gval_ += weight_ * ev * val;
218 gvval_ += weight_ * ev * val * gv;
219 hval_ += weight_ * ev * val * val;
220 g_->axpy(weight_*ev,*dualVector_);
221 scaledGradient_->axpy(weight_*ev*gv,*dualVector_);
222 scaledHessVec_->axpy(weight_*ev*val,*dualVector_);
223 computeHessVec(*dualVector_,obj,v,x,tol);
224 hv_->axpy(weight_*ev,*dualVector_);
225 }
226
228 std::vector<Real> &hvstat,
229 const Vector<Real> &v,
230 const std::vector<Real> &vstat,
231 const Vector<Real> &x,
232 const std::vector<Real> &xstat,
233 SampleGenerator<Real> &sampler) {
234 std::vector<Real> local(5), global(5);
235 local[0] = val_;
236 local[1] = gv_;
237 local[2] = gval_;
238 local[3] = gvval_;
239 local[4] = hval_;
240 sampler.sumAll(&local[0],&global[0],5);
241 Real ev = global[0], egv = global[1], egval = global[2];
242 Real egvval = global[3], ehval = global[4];
243 Real c0 = static_cast<Real>(1)/ev, c1 = c0*egval, c2 = c0*egv, c3 = eps_*c0;
244
245 sampler.sumAll(*hv_,hv);
246 dualVector_->zero();
248 hv.axpy(xstat[0]*eps_,*dualVector_);
249 hv.scale(c0);
250
251 dualVector_->zero();
252 sampler.sumAll(*g_,*dualVector_);
253 hv.axpy(-c3*(vstat[0]*c1 + xstat[0]*c2),*dualVector_);
254
255 dualVector_->zero();
257 hv.axpy(vstat[0]*c3,*dualVector_);
258
259 if ( xstat[0] == static_cast<Real>(0) ) {
260 hvstat[0] = ROL_INF<Real>();
261 }
262 else {
263 Real xstat2 = static_cast<Real>(2)/(xstat[0]*xstat[0]);
264 Real h11 = xstat2*((static_cast<Real>(1) + std::log(ev)/eps_)/xstat[0] - c1)
265 + (c3*ehval - eps_*c1*c1)/xstat[0];
266 hvstat[0] = vstat[0] * h11 + (c3*egvval - eps_*c1*c2);
267 }
268 }
269
270private:
271 Real exponential(const Real arg1, const Real arg2) const {
272 if ( arg1 < arg2 ) {
273 return power(exponential(arg1),arg2);
274 }
275 else {
276 return power(exponential(arg2),arg1);
277 }
278 }
279
280 Real exponential(const Real arg) const {
281 if ( arg >= std::log(ROL_INF<Real>()) ) {
282 return ROL_INF<Real>();
283 }
284 else {
285 return std::exp(arg);
286 }
287 }
288
289 Real power(const Real arg, const Real pow) const {
290 if ( arg >= std::pow(ROL_INF<Real>(),static_cast<Real>(1)/pow) ) {
291 return ROL_INF<Real>();
292 }
293 else {
294 return std::pow(arg,pow);
295 }
296 }
297};
298
299}
300
301#endif
Objective_SerialSimOpt(const Ptr< Obj > &obj, const V &ui) z0_ zero()
Provides an interface for the Kullback-Leibler distributionally robust expectation.
void checkInputs(void) const
void updateValue(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal storage for value computation.
void getHessVec(Vector< Real > &hv, std::vector< Real > &hvstat, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure Hessian-times-a-vector.
KLDivergence(ROL::ParameterList &parlist)
Constructor.
void updateHessVec(Objective< Real > &obj, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal risk measure storage for Hessian-time-a-vector computation.
void getGradient(Vector< Real > &g, std::vector< Real > &gstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure (sub)gradient.
Real getValue(const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure value.
KLDivergence(const Real eps=1.e-2)
Constructor.
Real power(const Real arg, const Real pow) const
void initialize(const Vector< Real > &x)
Initialize temporary variables.
Real exponential(const Real arg) const
Real exponential(const Real arg1, const Real arg2) const
void updateGradient(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal risk measure storage for gradient computation.
ROL::Ptr< Vector< Real > > scaledGradient_
ROL::Ptr< Vector< Real > > scaledHessVec_
Provides the interface to evaluate objective functions.
Provides the interface to implement any functional that maps a random variable to a (extended) real n...
Real computeValue(Objective< Real > &obj, const Vector< Real > &x, Real &tol)
Ptr< Vector< Real > > g_
virtual void initialize(const Vector< Real > &x)
Initialize temporary variables.
void computeHessVec(Vector< Real > &hv, Objective< Real > &obj, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
Ptr< Vector< Real > > hv_
void computeGradient(Vector< Real > &g, Objective< Real > &obj, const Vector< Real > &x, Real &tol)
Ptr< Vector< Real > > dualVector_
Real computeGradVec(Vector< Real > &g, Objective< Real > &obj, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
void sumAll(Real *input, Real *output, int dim) const
Defines the linear algebra or vector space interface.
Definition: ROL_Vector.hpp:84
virtual void scale(const Real alpha)=0
Compute where .
virtual const Vector & dual() const
Return dual representation of , for example, the result of applying a Riesz map, or change of basis,...
Definition: ROL_Vector.hpp:226
virtual ROL::Ptr< Vector > clone() const =0
Clone to make a new (uninitialized) vector.
virtual void axpy(const Real alpha, const Vector &x)
Compute where .
Definition: ROL_Vector.hpp:153