STK++ 0.9.13
STK_WeightedSvd.h
Go to the documentation of this file.
1/*--------------------------------------------------------------------*/
2/* Copyright (C) 2004-2016 Serge Iovleff, Université Lille 1, Inria
3
4 This program is free software; you can redistribute it and/or modify
5 it under the terms of the GNU Lesser General Public License as
6 published by the Free Software Foundation; either version 2 of the
7 License, or (at your option) any later version.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU Lesser General Public License for more details.
13
14 You should have received a copy of the GNU Lesser General Public
15 License along with this program; if not, write to the
16 Free Software Foundation, Inc.,
17 59 Temple Place,
18 Suite 330,
19 Boston, MA 02111-1307
20 USA
21
22 Contact : S..._Dot_I..._At_stkpp_Dot_org (see copyright for ...)
23 */
24
25/*
26 * Project: stkpp::Algebra
27 * created on: 10 août 2015
28 * Author: iovleff, S..._Dot_I..._At_stkpp_Dot_org (see copyright for ...)
29 **/
30
35#ifndef STK_WEIGHTEDSVD_H
36#define STK_WEIGHTEDSVD_H
37
38namespace STK
39{
40// forward declaration
41template<class Array, class WRows, class WCols> class WeightedSvd;
42
43namespace hidden
44{
48template<class Array_, class WRows, class WCols>
55} // namespace hidden
56
58template<class Array, class WRows, class WCols>
59class WeightedSvd: public ISvd< WeightedSvd<Array, WRows, WCols> >
60{
61 public:
63 using Base::U_;
64 using Base::D_;
65 using Base::V_;
74 WeightedSvd( Array const& a, WRows const& wrows, WCols const& wcols, int dim)
75 : Base(a, false, (dim>0) ? true:false, (dim>0) ? true:false)
76 , wrows_(wrows), wcols_(wcols), dim_(dim)
77
78 {
79 if (wrows.range() != U_.rows()) { wrows_.resize(U_.rows()).setValue(1./U_.sizeRows());}
80 else { wrows_ /= wrows.sum();}
81 if (wcols.range() != U_.cols()) { wcols_.resize(U_.cols()).setOne();}
82 dim_ = std::min(dim_, U_.sizeRows());
83 dim_ = std::min(dim_, U_.sizeCols());
84 }
86 virtual ~WeightedSvd() {}
88 virtual bool run()
89 {
91#ifdef STKUSELAPACK
92 lapack::Svd solver(U_, false, this->withU_, this->withV_);
93 // if there is no cv, fall back to STK++ svd
94 if (!solver.run())
95 {
96 Svd<CArrayXX> dec(U_, true, this->withU_, this->withV_);
97 if (!dec.run()) return false;
98 else
99 {
100 U_.move(dec.U_);
101 D_ = dec.D_;
102 V_ = dec.V_;
103 }
104 }
105 else
106 {
107 U_.move(solver.U_);
108 D_.move(solver.D_);
109 V_.move(solver.V_);
110 }
111#else
112 Svd<CArrayXX> solver(U_, true, this->withU_, this->withV_);
113 if (!solver.run()) return false;
114 else
115 {
116 U_.move(dec.U_);
117 D_ = dec.D_;
118 V_ = dec.V_;
119 }
120
121#endif
122 // weight back
123 return true;
124 }
125 private:
131 int dim_;
132};
133
134} // namespace STK
135
136#endif /* STK_WEIGHTEDSVD_H */
DiagonalizeOperator< Derived > const diagonalize() const
Derived & resize(Range const &I, Range const &J)
resize the Array.
virtual bool run()
run the computations.
Compute the Singular Value Decomposition of an array.
Definition STK_ISvd.h:61
virtual bool run()
implement the run method
Definition STK_ISvd.h:155
ArrayD D_
Diagonal array of the singular values.
Definition STK_ISvd.h:193
The MultidimRegression class allows to regress a multidimensional output variable among a multivariat...
This class perform a weighted svd decomposition.
WeightedSvd(Array const &a, WRows const &wrows, WCols const &wcols, int dim)
default constructor.
virtual bool run()
run the weighted svd
ISvd< WeightedSvd< Array, WRows, WCols > > Base
CPointX wcols_
columns weights
int dim_
number of eigenvectors (left and right)
virtual ~WeightedSvd()
destructor
CVectorX wrows_
rows weights
The namespace STK is the main domain space of the Statistical ToolKit project.
traits class for the algebra methods.