STK++ 0.9.13
STK_Gamma_ajk_b.h
Go to the documentation of this file.
1/*--------------------------------------------------------------------*/
2/* Copyright (C) 2004-2016 Serge Iovleff, Université Lille 1, Inria
3
4 This program is free software; you can redistribute it and/or modify
5 it under the terms of the GNU Lesser General Public License as
6 published by the Free Software Foundation; either version 2 of the
7 License, or (at your option) any later version.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU Lesser General Public License for more details.
13
14 You should have received a copy of the GNU Lesser General Public
15 License along with this program; if not, write to the
16 Free Software Foundation, Inc.,
17 59 Temple Place,
18 Suite 330,
19 Boston, MA 02111-1307
20 USA
21
22 Contact : S..._Dot_I..._At_stkpp_Dot_org (see copyright for ...)
23*/
24
25/*
26 * Project: stkpp::Clustering
27 * created on: 29 août 2014
28 * Author: iovleff, S..._Dot_I..._At_stkpp_Dot_org (see copyright for ...)
29 **/
30
36#ifndef STK_GAMMA_AJK_B_H
37#define STK_GAMMA_AJK_B_H
38
40#include "../GammaModels/STK_GammaBase.h"
41
42#define MAXITER 400
43#define TOL 1e-8
44
45
46namespace STK
47{
48template<class Array>class Gamma_ajk_b;
49
50namespace hidden
51{
54template<class Array_>
61
62} // namespace hidden
63
73template<class Array>
74class Gamma_ajk_b: public GammaBase<Gamma_ajk_b<Array> >
75{
76 public:
78 using Base::param_;
79
80 using Base::p_data;
81 using Base::meanjk;
82 using Base::variancejk;
83
95 void randomInit( CArrayXX const* const& p_tik, CPointX const* const& p_tk) ;
97 bool run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) ;
99 inline int computeNbFreeParameters() const
100 { return this->nbCluster()*p_data()->sizeCols() + 1;}
101};
102
103/* Initialize randomly the parameters of the gamma mixture. The centers
104 * will be selected randomly among the data set and the standard-deviation
105 * will be set to 1.
106 */
107template<class Array>
108void Gamma_ajk_b<Array>::randomInit( CArrayXX const* const& p_tik, CPointX const* const& p_tk)
109{
110 // compute moments
111 this->moments(p_tik);
112 Real value =0.;
113 for (int j=p_data()->beginCols(); j < p_data()->endCols(); ++j)
114 {
115 // random scale for each cluster
116 for (int k= p_tik->beginCols(); k < p_tik->endCols(); ++k)
117 {
118 Real mean = meanjk(j,k), variance = variancejk(j,k);
119 param_.shape_[k][j] = Law::Exponential::rand((mean*mean/variance));
120 value += p_tk->elt(k) * variance/mean;
121 }
122 }
123 param_.scale_ = Law::Exponential::rand(value/(p_data()->sizeCols()*this->nbSample()));
124#ifdef STK_MIXTURE_VERY_VERBOSE
125 stk_cout << _T(" Gamma_ajk_b<Array>::randomInit done\n");
126#endif
127}
128
129/* Compute the weighted mean and the common variance. */
130template<class Array>
131bool Gamma_ajk_b<Array>::run( CArrayXX const* const& p_tik, CPointX const* const& p_tk)
132{
133 bool flag = true;
134 if (!this->moments(p_tik)) { flag = false;}
135 // start estimations of the ajk and bj
136 Real qvalue = this->qValue(p_tik, p_tk);
137 int iter;
138 for(iter=0; iter<MAXITER; ++iter)
139 {
140 for (int j=p_data()->beginCols(); j<p_data()->endCols(); ++j)
141 {
142 // compute ajk
143 for (int k= p_tik->beginCols(); k < p_tik->endCols(); ++k)
144 {
145 // moment estimate and oldest value
146 Real x0 = meanjk(j,k)*meanjk(j,k)/variancejk(j,k);
147 Real x1 = param_.shape_[k][j];
148 if ((x0 <=0.) || !Arithmetic<Real>::isFinite(x0)) return false;
149 // compute shape
150 hidden::invPsi f(param_.meanLog_[k][j] - std::log(param_.scale_));
151
152
153 Real a = Algo::findZero(f, x0, x1, TOL);
154
156 {
157 param_.shape_[k][j] = x0; // use moment estimate
158#ifdef STK_MIXTURE_DEBUG
159 stk_cout << _T("ML estimation failed in Gamma_ajk_bj::run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) \n");
160 stk_cout << "x0 =" << x0 << _T("\n";);
161 stk_cout << "f(x0) =" << f(x0) << _T("\n";);
162 stk_cout << "x1 =" << x1 << _T("\n";);
163 stk_cout << "f(x1) =" << f(x1) << _T("\n";);
164#endif
165 }
166 else { param_.shape_[k][j] = a;}
167 }
168 }
169 Real num=0., den = 0.;
170 for (int k= p_tik->beginCols(); k < p_tik->endCols(); ++k)
171 {
172 num += param_.mean_[k].sum() * p_tk->elt(k);
173 den += param_.shape_[k].sum() * p_tk->elt(k);
174 }
175 // compute b
176 Real b = num/den;
177 // divergence
178 if (!Arithmetic<Real>::isFinite(b)) { return false;}
179 param_.scale_ = b;
180 // check convergence
181 Real value = this->qValue(p_tik, p_tk);
182#ifdef STK_MIXTURE_DEBUG
183 if (value < qvalue)
184 {
185 stk_cout << _T("In Gamma_ajk_b::run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) : run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) diverge\n");
186 stk_cout << _T("New value =") << value << _T(", qvalue =") << qvalue << _T("\n");
187 }
188#endif
189 if ((value - qvalue) < TOL) break;
190 qvalue = value;
191 }
192#ifdef STK_MIXTURE_DEBUG
193 if (iter == MAXITER)
194 {
195 stk_cout << _T("In Gamma_ajk_b::run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) : run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) did not converge\n");
196 stk_cout << _T("qvalue =") << qvalue << _T("\n");
197 }
198#endif
199 return flag;
200}
201
202} // namespace STK
203
204#undef MAXITER
205#undef TOL
206
207#endif /* STK_Gamma_AJK_B_H */
#define TOL
In this file we implement the exponential law.
#define MAXITER
#define stk_cout
Standard stk output stream.
#define _T(x)
Let x unmodified.
Base class for the gamma models.
Parameters param_
parameters of the derived mixture model.
Real meanjk(int j, int k)
get the weighted mean of the jth variable of the kth cluster.
Real variancejk(int j, int k)
get the weighted variance of the jth variable of the kth cluster.
Gamma_ajk_b is a mixture model of the following form.
Gamma_ajk_b(int nbCluster)
default constructor
GammaBase< Gamma_ajk_b< Array > > Base
Gamma_ajk_b(Gamma_ajk_b const &model)
copy constructor
void randomInit(CArrayXX const *const &p_tik, CPointX const *const &p_tk)
Initialize randomly the parameters of the Gamma mixture.
int computeNbFreeParameters() const
~Gamma_ajk_b()
destructor
bool run(CArrayXX const *const &p_tik, CPointX const *const &p_tk)
Compute the weighted mean and the common variance.
virtual Real rand() const
Generate a pseudo Exponential random variate.
The MultidimRegression class allows to regress a multidimensional output variable among a multivariat...
Functor computing the difference between the psi function and a fixed value.
Real findZero(IFunction< Function > const &f, Real const &x0, Real const &x1, Real tol)
find the zero of a function.
double Real
STK fundamental type of Real values.
hidden::SliceVisitorSelector< Derived, hidden::MeanVisitor, Arrays::by_col_ >::type_result mean(Derived const &A)
If A is a row-vector or a column-vector then the function will return the usual mean value of the vec...
The namespace STK is the main domain space of the Statistical ToolKit project.
Arithmetic properties of STK fundamental types.
ModelParameters< Clust::Gamma_ajk_b_ > Parameters
Type of the structure storing the parameters of a Gamma_ajk_b model.
Main class for the mixtures traits policy.