STK++ 0.9.13
STK_CG.h
Go to the documentation of this file.
1 /*--------------------------------------------------------------------*/
2/* Copyright (C) 2013-2015 Serge Iovleff, Quentin Grimonprez
3
4 This program is free software; you can redistribute it and/or modify
5 it under the terms of the GNU Lesser General Public License as
6 published by the Free Software Foundation; either version 2 of the
7 License, or (at your option) any later version.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU Lesser General Public License for more details.
13
14 You should have received a copy of the GNU Lesser General Public
15 License along with this program; if not, write to the
16 Free Software Foundation, Inc.,
17 59 Temple Place,
18 Suite 330,
19 Boston, MA 02111-1307
20 USA
21
22 Contact : S..._DOT_I..._AT_stkpp.org (see copyright for ...)
23*/
24
25/*
26 * Project: stkpp::Algebra
27 * created on: 24 mai 2013
28 * Author: Quentin Grimonprez, Serge Iovleff
29 **/
30
36#ifndef STK_CG_H
37#define STK_CG_H
38
39#include <Sdk.h>
40
41namespace STK
42{
43
44template<class ColVector>
46{
47 ColVector operator()() const
48 { return ColVector();}
49};
50
85template<class MultFunctor, class ColVector, class InitFunctor = DefaultFunctor<ColVector> >
86class CG
87{
88 public:
89 typedef typename ColVector::Type Type;
91 CG(): x_(), r_(), eps_(0.), iter_(0), nbStart_(0), p_mult_(0), p_init_(0), p_b_(0) {}
99 , ColVector const& b
100 , InitFunctor* const& p_init =0
102 : x_(), r_()
103 , eps_(eps), iter_(0), nbStart_(0)
104 , p_mult_(&mult)
105 , p_init_(p_init)
106 , p_b_(&b)
107 {}
108
112 CG( CG const& cg)
113 : x_(cg.x_)
114 , r_(cg.r_)
118 , p_b_(cg.p_b_)
119 {}
121 ~CG() {}
123 CG* clone() const { return new CG(*this);}
124
126 inline ColVector const& x() const { return x_;}
128 inline Real const& x(int const& i) const { return x_[i];}
130 inline int const& iter() const { return iter_;}
132 inline int const& nbStart() const { return nbStart_;}
134 inline ColVector const& r() const { return r_;}
136 inline void setEps(Type const& eps) {eps_ = eps;}
138 inline void setB(ColVector const& b) { p_b_=&b;}
140 inline void setMultFunctor(MultFunctor const& mult) { p_mult_= &mult; }
146 int run() { return cg();}
150 inline String const& error() const { return msg_error_;}
151
152 protected:
156 int cg()
157 {
158 iter_ = 0;
159 int nbStart = 0;
160 ColVector xOld, z, p_;
161
162 // initialization
163 if(!p_init_) {x_ = *p_b_;}
164 else { x_ = (*p_init_)();}
165 while(nbStart<3)
166 {
167 int step = 0; //number of step
168 Real bnorm2 = p_b_->norm2(), alpha, beta; //
169 if (bnorm2 == 0.) bnorm2 = 1.;
170 //compute the residuals
171 r_ = *p_b_ - (*p_mult_)(x_);
172 if (r_.norm2()/bnorm2 <eps_) { break;}
173 //initialization of the conjugate direction
174 p_= r_;
175 while(1)
176 {
177 Real rnorm2 = r_.norm2();
178 //compute z=A p
179 z.move((*p_mult_)(p_));
180 //compute alpha
181 alpha = rnorm2/p_.dot(z);
182 //update x_
183 xOld.exchange(x_);
184 x_ = xOld + alpha * p_;
185 iter_++;
186 //update residuals
187 r_ = r_ - (alpha * z);
188 //compute beta
189 beta = 1/rnorm2;
190 if ((rnorm2=r_.norm2())/bnorm2 <eps_) { nbStart = 2; break;}
191 beta *= rnorm2;
192 //update p_
193 p_ = (p_ * beta) + r_;
194 step++;
195 if( step > 50 ) { break;}
196 }
197 nbStart++;
198 }
199 // return an error
200 return iter_;
201 }
202
203 private:
205 ColVector x_;
207 ColVector r_;
211 int iter_;
219 ColVector const* p_b_;
220};
221
263template<class MultFunctor, class CondFunctor, class ColVector, class InitFunctor = DefaultFunctor<ColVector> >
264class PCG
265{
266 public:
267 typedef typename ColVector::Type Type;
274 PCG( MultFunctor const& mult, CondFunctor const& cond, ColVector const& b, InitFunctor* const& p_init =0, Type eps=Arithmetic<Type>::epsilon())
275 : x_(), r_()
280 , p_b_(&b)
281 {};
285 PCG( PCG const& pcg)
286 : x_(pcg.x_), r_(pcg.r_)
287 , eps_(pcg.eps_), iter_(0)
291 , p_b_(pcg.p_b_)
292 {};
294 ~PCG() {};
296 PCG* clone() const { return new PCG(*this);}
297
299 inline ColVector const& x() const { return x_;}
301 inline Real const& x(int const& i) const { return x_[i];}
303 inline ColVector const& r() const { return r_;}
305 inline void setEps(Type const& eps) {eps_ = eps;}
307 inline void setB(ColVector const& b) { p_b_=&b;}
309 inline void setInitFunctor(InitFunctor const& init) { p_init_= &init; }
311 inline void setMultFunctor(MultFunctor const& mult) { p_mult_= &mult; }
313 inline void setCondFunctor(CondFunctor const& cond) { p_cond_= &cond; }
315 inline int run() { return pcg();}
319 inline String const& error() const { return msg_error_;}
320
321 protected:
325 int pcg()
326 {
327 int nbStart = 0;
328 ColVector xOld, y, z, p;
329
330 Real bnorm2 = p_b_->norm2(), alpha, beta; //
331 iter_= 0;
332 // initialization
333 if(!p_init_) {x_ = *p_b_;}
334 else { x_ = (*p_init_)();}
335 // first loop -> allow to restart algorithm in case of divergence.
336 while(nbStart<2)
337 {
338 if (bnorm2 == 0.) bnorm2 = 1.;
339 //compute the residuals
340 r_ = *p_b_ - (*p_mult_)(x_);
341 if (r_.norm2()/bnorm2 <eps_) { break;}
342 //initialization of the conjugate direction
343 y = (*p_cond_)(r_);
344 p = y;
345 Real rty=r_.dot(y);
346 while(1)
347 {
348 Real rnorm2 = r_.norm2();
349 //compute z=A p
350 z.move((*p_mult_)(p));
351 //compute alpha
352 alpha = rty/p.dot(z);
353 //update x_
354 xOld.exchange(x_);
355 x_ = xOld + alpha * p;
356 iter_++;
357 //update residuals
358 r_ = r_ - (alpha * z);
359 //update y
360 y = (*p_cond_)(r_);
361 //compute beta
362 beta = 1/rty;
363 if ((rnorm2=r_.norm2())/bnorm2 <eps_) { nbStart = 2; break;}
364 rty = r_.dot(y);
365 beta *= rty;
366 //update p_
367 p = (p * beta) + y;
368 }
369 nbStart++;
370 }
371 return iter_;
372 }
373
374 private:
376 ColVector x_;
378 ColVector r_;
380 int iter_;
390 ColVector const* p_b_;
391};
392
393} // namespace STK
394
395#undef MAXITER
396
397#endif /* STK_CG_H_ */
This file include all the other header files of the project Sdk.
The conjugate gradient method is an algorithm for the numerical solution of particular systems of lin...
Definition STK_CG.h:87
int const & nbStart() const
Definition STK_CG.h:132
void setEps(Type const &eps)
Set the tolerance.
Definition STK_CG.h:136
Real const & x(int const &i) const
Definition STK_CG.h:128
int nbStart_
number of restart_
Definition STK_CG.h:213
int cg()
Definition STK_CG.h:156
CG * clone() const
clone pattern
Definition STK_CG.h:123
CG(CG const &cg)
Copy constructor.
Definition STK_CG.h:112
String const & error() const
get the last error message.
Definition STK_CG.h:150
CG(MultFunctor const &mult, ColVector const &b, InitFunctor *const &p_init=0, Type eps=Arithmetic< Type >::epsilon())
Constructor.
Definition STK_CG.h:98
String msg_error_
String with the last error message.
Definition STK_CG.h:154
ColVector const & x() const
Definition STK_CG.h:126
void setB(ColVector const &b)
Set the constant vector.
Definition STK_CG.h:138
int run()
Definition STK_CG.h:146
ColVector r_
residuals of the system
Definition STK_CG.h:207
ColVector::Type Type
Definition STK_CG.h:89
Type eps_
tolerance
Definition STK_CG.h:209
MultFunctor const * p_mult_
pointer on the functor performing Ax
Definition STK_CG.h:215
~CG()
destructor
Definition STK_CG.h:121
ColVector const * p_b_
constant pointer on the second member of the system
Definition STK_CG.h:219
void setInitFunctor(InitFunctor *const &p_init)
Set functor computing x at initialization.
Definition STK_CG.h:142
ColVector const & r() const
Definition STK_CG.h:134
InitFunctor const * p_init_
pointer on the functor initializing x
Definition STK_CG.h:217
int iter_
number of iterations
Definition STK_CG.h:211
ColVector x_
solution of the system
Definition STK_CG.h:205
int const & iter() const
Definition STK_CG.h:130
CG()
Default Constructor.
Definition STK_CG.h:91
void setMultFunctor(MultFunctor const &mult)
Set functor computing Ax.
Definition STK_CG.h:140
The MultidimRegression class allows to regress a multidimensional output variable among a multivariat...
In most cases, preconditioning is necessary to ensure fast convergence of the conjugate gradient meth...
Definition STK_CG.h:265
p_mult_ p_cond_ p_b_ b
Definition STK_CG.h:281
String msg_error_
String with the last error message.
Definition STK_CG.h:323
String const & error() const
get the last error message.
Definition STK_CG.h:319
Type eps_
tolerance
Definition STK_CG.h:382
int run()
run the conjugate gradient
Definition STK_CG.h:315
void setInitFunctor(InitFunctor const &init)
Set functor computing x at initialization.
Definition STK_CG.h:309
PCG(PCG const &pcg)
Copy constructor.
Definition STK_CG.h:285
ColVector::Type Type
Definition STK_CG.h:267
p_mult_ mult
Definition STK_CG.h:277
ColVector const & x() const
Definition STK_CG.h:299
CondFunctor const * p_cond_
pointer on the functor performing
Definition STK_CG.h:386
~PCG()
destructor
Definition STK_CG.h:294
void setB(ColVector const &b)
Set the constant vector.
Definition STK_CG.h:307
ColVector x_
solution of the system
Definition STK_CG.h:376
void setMultFunctor(MultFunctor const &mult)
Set functor computing Ax.
Definition STK_CG.h:311
void setEps(Type const &eps)
Set the tolerance.
Definition STK_CG.h:305
InitFunctor const * p_init_
pointer on the functor initializing x
Definition STK_CG.h:388
Real const & x(int const &i) const
Definition STK_CG.h:301
int iter_
number of iterations
Definition STK_CG.h:380
int pcg()
preconditioned Gradient implementation
Definition STK_CG.h:325
ColVector const * p_b_
constant pointer on the second member of the system
Definition STK_CG.h:390
ColVector const & r() const
Definition STK_CG.h:303
ColVector r_
residuals of the system
Definition STK_CG.h:378
PCG * clone() const
clone pattern
Definition STK_CG.h:296
MultFunctor const * p_mult_
pointer on the functor performing Ax
Definition STK_CG.h:384
p_mult_ p_cond_ cond
Definition STK_CG.h:278
void setCondFunctor(CondFunctor const &cond)
Set functor computing the value .
Definition STK_CG.h:313
Arrays::MultOp< Lhs, Rhs >::result_type mult(Lhs const &lhs, Rhs const &rhs)
convenience function for the multiplication of two matrices
std::basic_string< Char > String
STK fundamental type of a String.
double Real
STK fundamental type of Real values.
The namespace STK is the main domain space of the Statistical ToolKit project.
Arithmetic properties of STK fundamental types.
ColVector operator()() const
Definition STK_CG.h:47