STK++ 0.9.13
STK_Gamma_ajk_bj.h
Go to the documentation of this file.
1/*--------------------------------------------------------------------*/
2/* Copyright (C) 2004-2016 Serge Iovleff, Université Lille 1, Inria
3
4 This program is free software; you can redistribute it and/or modify
5 it under the terms of the GNU Lesser General Public License as
6 published by the Free Software Foundation; either version 2 of the
7 License, or (at your option) any later version.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU Lesser General Public License for more details.
13
14 You should have received a copy of the GNU Lesser General Public
15 License along with this program; if not, write to the
16 Free Software Foundation, Inc.,
17 59 Temple Place,
18 Suite 330,
19 Boston, MA 02111-1307
20 USA
21
22 Contact : S..._DOT_I..._AT_stkpp.org (see copyright for ...)
23*/
24
25/*
26 * Project: stkpp::Clustering
27 * created on: 5 sept. 2013
28 * Author: iovleff, serge.iovleff@stkpp.org
29 **/
30
35#ifndef STK_GAMMA_AJK_BJ_H
36#define STK_GAMMA_AJK_BJ_H
37
39#include "../GammaModels/STK_GammaBase.h"
40
41#define MAXITER 400
42#define TOL 1e-8
43
44namespace STK
45{
46template<class Array>class Gamma_ajk_bj;
47
48namespace hidden
49{
52template<class Array_>
59
60} // namespace Clust
61
71template<class Array>
72class Gamma_ajk_bj: public GammaBase<Gamma_ajk_bj<Array> >
73{
74 public:
76 using Base::param_;
77
78 using Base::p_data;
79 using Base::meanjk;
80 using Base::variancejk;
81
93 void randomInit( CArrayXX const* const& p_tik, CPointX const* const& p_tk) ;
95 bool run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) ;
97 inline int computeNbFreeParameters() const
98 { return this->nbCluster()*p_data()->sizeCols()+ p_data()->sizeCols();}
99};
100
101/* Initialize randomly the parameters of the gamma mixture. The centers
102 * will be selected randomly among the data set and the standard-deviation
103 * will be set to 1.
104 */
105template<class Array>
106void Gamma_ajk_bj<Array>::randomInit( CArrayXX const* const& p_tik, CPointX const* const& p_tk)
107{
108 // compute moments
109 this->moments(p_tik);
110 for (int j=p_data()->beginCols(); j < p_data()->endCols(); ++j)
111 {
112 Real value =0.;
113 for (int k= p_tik->beginCols(); k < p_tik->endCols(); ++k)
114 {
115 Real mean = meanjk(j,k), variance = variancejk(j,k);
116 param_.shape_[k][j] = Law::Exponential::rand((mean*mean/variance));
117 value += p_tk->elt(k) * variance/mean;
118 }
119 param_.scale_[j] = Law::Exponential::rand(value/(this->nbSample()));
120 }
121#ifdef STK_MIXTURE_VERY_VERBOSE
122 stk_cout << _T(" Gamma_ajk_bj<Array>::randomInit done\n");
123#endif
124}
125
126/* Compute the weighted mean and the common variance. */
127template<class Array>
128bool Gamma_ajk_bj<Array>::run( CArrayXX const* const& p_tik, CPointX const* const& p_tk)
129{
130 if (!this->moments(p_tik)) { return false;}
131 // start estimations of the ajk and bj
132 Real qvalue = this->qValue(p_tik, p_tk);
133 int iter;
134 for(iter=0; iter<MAXITER; ++iter)
135 {
136 for (int j=p_data()->beginCols(); j<p_data()->endCols(); ++j)
137 {
138 Real num=0., den = 0.;
139 // compute ajk
140 for (int k= p_tik->beginCols(); k < p_tik->endCols(); ++k)
141 {
142 // moment estimate and oldest value
143 Real x0 = this->meanjk(j,k)*this->meanjk(j,k)/this->variancejk(j,k);
144 Real x1 = param_.shape_[k][j];
145 if ((x0 <=0.) || !Arithmetic<Real>::isFinite(x0)) return false;
146 // compute shape
147 hidden::invPsi f(param_.meanLog_[k][j] - std::log(param_.scale_[j]));
148 Real a = Algo::findZero(f, x0, x1, TOL);
149
151 {
152 param_.shape_[k][j] = x0; // use moment estimate
153#ifdef STK_MIXTURE_DEBUG
154 stk_cout << _T("ML estimation failed in Gamma_ajk_bj::run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) \n");
155 stk_cout << "x0 =" << x0 << _T("\n";);
156 stk_cout << "f(x0) =" << f(x0) << _T("\n";);
157 stk_cout << "x1 =" << x1 << _T("\n";);
158 stk_cout << "f(x1) =" << f(x1) << _T("\n";);
159#endif
160 }
161 else { param_.shape_[k][j] = a;}
162 num += param_.mean_[k][j] * p_tk->elt(k);
163 den += param_.shape_[k][j] * p_tk->elt(k);
164 }
165 // compute b_j
166 Real b = num/den;
167 // divergence
168 if (!Arithmetic<Real>::isFinite(b)) { return false;}
169 param_.scale_[j] = b;
170 }
171 // check convergence
172 Real value = this->qValue(p_tik, p_tk);
173#ifdef STK_MIXTURE_DEBUG
174 if (value < qvalue)
175 {
176 stk_cout << _T("In Gamma_ajk_bj::run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) : run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) diverge\n");
177 stk_cout << _T("New value =") << value << _T(", qvalue =") << qvalue << _T("\n");
178 }
179#endif
180 if ((value - qvalue) < TOL) break;
181 qvalue = value;
182 }
183#ifdef STK_MIXTURE_DEBUG
184 if (iter == MAXITER)
185 {
186 stk_cout << _T("In Gamma_ajk_bj::run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) : run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) did not converge\n");
187 stk_cout << _T("qvalue =") << qvalue << _T("\n");
188 }
189#endif
190 return true;
191}
192
193} // namespace STK
194
195#undef MAXITER
196#undef TOL
197
198#endif /* STK_Gamma_AJK_BJ_H */
#define TOL
In this file we implement the exponential law.
#define MAXITER
#define stk_cout
Standard stk output stream.
#define _T(x)
Let x unmodified.
Base class for the gamma models.
Parameters param_
parameters of the derived mixture model.
Real meanjk(int j, int k)
get the weighted mean of the jth variable of the kth cluster.
Real variancejk(int j, int k)
get the weighted variance of the jth variable of the kth cluster.
Gamma_ajk_bj is a mixture model of the following form.
bool run(CArrayXX const *const &p_tik, CPointX const *const &p_tk)
Compute the weighted mean and the common variance.
GammaBase< Gamma_ajk_bj< Array > > Base
int computeNbFreeParameters() const
~Gamma_ajk_bj()
destructor
void randomInit(CArrayXX const *const &p_tik, CPointX const *const &p_tk)
Initialize randomly the parameters of the Gamma mixture.
Gamma_ajk_bj(int nbCluster)
default constructor
Gamma_ajk_bj(Gamma_ajk_bj const &model)
copy constructor
virtual Real rand() const
Generate a pseudo Exponential random variate.
The MultidimRegression class allows to regress a multidimensional output variable among a multivariat...
Functor computing the difference between the psi function and a fixed value.
Real findZero(IFunction< Function > const &f, Real const &x0, Real const &x1, Real tol)
find the zero of a function.
double Real
STK fundamental type of Real values.
hidden::SliceVisitorSelector< Derived, hidden::MeanVisitor, Arrays::by_col_ >::type_result mean(Derived const &A)
If A is a row-vector or a column-vector then the function will return the usual mean value of the vec...
The namespace STK is the main domain space of the Statistical ToolKit project.
Arithmetic properties of STK fundamental types.
ModelParameters< Clust::Gamma_ajk_bj_ > Parameters
Type of the structure storing the parameters of a Gamma_ajk_bj model.
Main class for the mixtures traits policy.