STK++ 0.9.13
STK_Gamma_ak_bj.h
Go to the documentation of this file.
1/*--------------------------------------------------------------------*/
2/* Copyright (C) 2004-2016 Serge Iovleff, Université Lille 1, Inria
3
4 This program is free software; you can redistribute it and/or modify
5 it under the terms of the GNU Lesser General Public License as
6 published by the Free Software Foundation; either version 2 of the
7 License, or (at your option) any later version.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU Lesser General Public License for more details.
13
14 You should have received a copy of the GNU Lesser General Public
15 License along with this program; if not, write to the
16 Free Software Foundation, Inc.,
17 59 Temple Place,
18 Suite 330,
19 Boston, MA 02111-1307
20 USA
21
22 Contact : S..._DOT_I..._AT_stkpp.org (see copyright for ...)
23*/
24
25/*
26 * Project: stkpp::Clustering
27 * created on: 5 sept. 2013
28 * Author: iovleff, serge.iovleff@stkpp.org
29 **/
30
35#ifndef STK_GAMMA_AK_BJ_H
36#define STK_GAMMA_AK_BJ_H
37
39#include "../GammaModels/STK_GammaBase.h"
40
41#define MAXITER 400
42#define TOL 1e-8
43
44namespace STK
45{
46template<class Array>class Gamma_ak_bj;
47
48namespace hidden
49{
52template<class Array_>
59
60} // namespace Clust
61
62
72template<class Array>
73class Gamma_ak_bj: public GammaBase< Gamma_ak_bj<Array> >
74{
75 public:
77 using Base::param_;
78 using Base::p_data;
79 using Base::meanjk;
80 using Base::variancejk;
81
96 void randomInit( CArrayXX const* const& p_tik, CPointX const* const& p_tk) ;
98 bool run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) ;
100 inline int computeNbFreeParameters() const
101 { return this->nbCluster()+ p_data()->sizeCols();}
102};
103
104/* Initialize randomly the parameters of the Gaussian mixture. The centers
105 * will be selected randomly among the data set and the standard-deviation
106 * will be set to 1.
107 */
108template<class Array>
109void Gamma_ak_bj<Array>::randomInit( CArrayXX const* const& p_tik, CPointX const* const& p_tk)
110{
111 // compute moments
112 this->moments(p_tik);
113 // simulates ak
114 for (int k= p_tik->beginCols(); k < p_tik->endCols(); ++k)
115 {
116 Real value= 0.;
117 for (int j=p_data()->beginCols(); j < p_data()->endCols(); ++j)
118 {
119 Real mean = meanjk(j,k), variance = variancejk(j,k);
120 value += mean*mean/variance;
121 }
122 param_.shape_[k]= Law::Exponential::rand(value/(p_data()->sizeCols()));
123 }
124 // simulate bj
125 for (int j=p_data()->beginCols(); j < p_data()->endCols(); ++j)
126 {
127 Real value= 0.;
128 for (int k= p_tik->beginCols(); k < p_tik->endCols(); ++k)
129 {
130 Real mean = meanjk(j,k), variance = variancejk(j,k);
131 value += p_tk->elt(k) * variance/mean;
132 }
133 param_.scale_[j] = Law::Exponential::rand(value/(this->nbSample()));
134 }
135#ifdef STK_MIXTURE_VERY_VERBOSE
136 stk_cout << _T(" Gamma_ak_bj<Array>::randomInit done\n");
137#endif
138}
139
140/* Compute the weighted mean and the common variance. */
141template<class Array>
142bool Gamma_ak_bj<Array>::run( CArrayXX const* const& p_tik, CPointX const* const& p_tk)
143{
144 if (!this->moments(p_tik)) { return false;}
145 // start estimations of the ajk and bj
146 Real qvalue = this->qValue(p_tik, p_tk);
147 int iter;
148 for(iter=0; iter<MAXITER; ++iter)
149 {
150 // compute ak
151 for (int k= p_tik->beginCols(); k < p_tik->endCols(); ++k)
152 {
153 // moment estimate and oldest value
154 Real x0 = (param_.mean_[k].square()/param_.variance_[k]).mean();
155 Real x1 = param_.shape_[k];
156 if ((x0 <=0.) || !Arithmetic<Real>::isFinite(x0)) return false;
157
158 // compute shape
159 hidden::invPsi f((param_.meanLog_[k] - param_.scale_.log()).mean());
160 Real a = Algo::findZero(f, x0, x1, TOL);
161
163 {
164 param_.shape_[k]= x0; // use moment estimate
165#ifdef STK_MIXTURE_DEBUG
166 stk_cout << _T("ML estimation failed in Gamma_ak_bj::run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) \n");
167 stk_cout << "x0 =" << x0 << _T("\n";);
168 stk_cout << "f(x0) =" << f(x0) << _T("\n";);
169 stk_cout << "x1 =" << x1 << _T("\n";);
170 stk_cout << "f(x1) =" << f(x1) << _T("\n";);
171#endif
172 }
173 else { param_.shape_[k]= a;}
174 }
175 // update all the b^j
176 for (int j=p_data()->beginCols(); j<p_data()->endCols(); ++j)
177 {
178 Real num = 0., den = 0.;
179 for (int k= p_tik->beginCols(); k < p_tik->endCols(); ++k)
180 {
181 num += param_.mean_[k][j] * p_tk->elt(k);
182 den += param_.shape_[k] * p_tk->elt(k);
183 }
184 // compute b_j
185 Real b = num/den;
186 // divergence
187 if (!Arithmetic<Real>::isFinite(b)) { return false;}
188 param_.scale_[j] = b;
189 }
190 // check convergence
191 Real value = this->qValue(p_tik, p_tk);
192#ifdef STK_MIXTURE_DEBUG
193 if (value < qvalue)
194 {
195 stk_cout << _T("In Gamma_ak_bj::run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) : run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) diverge\n");
196 stk_cout << _T("New value =") << value << _T(", qvalue =") << qvalue << _T("\n");
197 }
198#endif
199 if ((value - qvalue) < TOL) break;
200 qvalue = value;
201 }
202#ifdef STK_MIXTURE_DEBUG
203 if (iter == MAXITER)
204 {
205 stk_cout << _T("In Gamma_ak_bj::run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) : run( CArrayXX const* const& p_tik, CPointX const* const& p_tk) did not converge\n");
206 stk_cout << _T("qvalue =") << qvalue << _T("\n");
207 }
208#endif
209 return true;
210}
211
212
213} // namespace STK
214
215#undef MAXITER
216#undef TOL
217
218#endif /* STK_Gamma_AK_BJ_H */
#define TOL
In this file we implement the exponential law.
#define MAXITER
#define stk_cout
Standard stk output stream.
#define _T(x)
Let x unmodified.
Base class for the gamma models.
Parameters param_
parameters of the derived mixture model.
Real meanjk(int j, int k)
get the weighted mean of the jth variable of the kth cluster.
Real variancejk(int j, int k)
get the weighted variance of the jth variable of the kth cluster.
Gamma_ak_bj is a mixture model of the following form.
Gamma_ak_bj(Gamma_ak_bj const &model)
copy constructor
~Gamma_ak_bj()
destructor
void randomInit(CArrayXX const *const &p_tik, CPointX const *const &p_tk)
Initialize randomly the parameters of the Gaussian mixture.
Gamma_ak_bj(int nbCluster)
default constructor
GammaBase< Gamma_ak_bj< Array > > Base
int computeNbFreeParameters() const
bool run(CArrayXX const *const &p_tik, CPointX const *const &p_tk)
Compute the weighted mean and the common variance.
virtual Real rand() const
Generate a pseudo Exponential random variate.
The MultidimRegression class allows to regress a multidimensional output variable among a multivariat...
Functor computing the difference between the psi function and a fixed value.
Real findZero(IFunction< Function > const &f, Real const &x0, Real const &x1, Real tol)
find the zero of a function.
double Real
STK fundamental type of Real values.
hidden::SliceVisitorSelector< Derived, hidden::MeanVisitor, Arrays::by_col_ >::type_result mean(Derived const &A)
If A is a row-vector or a column-vector then the function will return the usual mean value of the vec...
The namespace STK is the main domain space of the Statistical ToolKit project.
Arithmetic properties of STK fundamental types.
static bool isFinite(Type const &x)
ModelParameters< Clust::Gamma_ak_bj_ > Parameters
Type of the structure storing the parameters of a Gamma_ak_bj model.
Main class for the mixtures traits policy.