STK++ 0.9.13
STK_CvHandler.h
Go to the documentation of this file.
1/*--------------------------------------------------------------------*/
2/* Copyright (C) 2004-2016 Serge Iovleff, Université Lille 1, Inria
3
4 This program is free software; you can redistribute it and/or modify
5 it under the terms of the GNU Lesser General Public License as
6 published by the Free Software Foundation; either version 2 of the
7 License, or (at your option) any later version.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU Lesser General Public License for more details.
13
14 You should have received a copy of the GNU Lesser General Public
15 License along with this program; if not, write to the
16 Free Software Foundation, Inc.,
17 59 Temple Place,
18 Suite 330,
19 Boston, MA 02111-1307
20 USA
21
22 Contact : S..._Dot_I..._At_stkpp_Dot_org (see copyright for ...)
23*/
24
25/*
26 * Project: stkpp::DManager
27 * created on: 15 nov. 2016
28 * Author: iovleff, S..._Dot_I..._At_stkpp_Dot_org (see copyright for ...)
29 **/
30
37#ifndef STK_CVHANDLER_H
38#define STK_CVHANDLER_H
39
40#include <Sdk.h>
41
44
45namespace STK
46{
47
53{
54 public:
58 CvHandler( Range const& rangeData, int nbFolds);
60 inline virtual ~CvHandler() {}
61
63 inline int nbFolds() const { return nbFolds_;}
65 inline Range const& rangeData() const { return rangeData_;}
67 inline CVectorXi const& partitions() const { return partitions_;}
69 inline CVectorXi const& sizePartitions() const { return sizePartitions_;}
70
71 inline virtual bool run()
72 { partition(); hasRun_ = true; return true;}
73
74 inline void setData( Range const& rangeData, int nbFolds)
75 {
80 hasRun_ = false;
81 }
83 template<class Data>
84 bool getKFold( int k, Data const& x,Data& xFold, Data& xTest);
86 template<class xData, class yData>
87 bool getKFold( int k, xData const& x, xData& xFold, xData& xTest
88 , yData const& y, yData& yFold, yData& yTest);
89
90 protected:
92 inline void partition();
93
94 private:
103};
104
105/* Default constructor. nbFolds is set to the number of observation
106 * @param rangeData the range of the data to set
107 * @param nbFolds numbbe of Folds
108 **/
109inline CvHandler::CvHandler( Range const& rangeData, int nbFolds)
110 : IRunnerBase()
111 , rangeData_(rangeData), nbFolds_(nbFolds)
112 , partitions_(), sizePartitions_()
113{
114 // check nbFolds parameter
115 if (nbFolds_<1)
119}
120/* get the data set when setting out fold k and test data set */
121template<class Data>
122bool CvHandler::getKFold( int k, Data const& x,Data& xFold, Data& xTest)
123{
124 // check if partitions are determined
125 if (!hasRun_)
127 return false;
128 }
129 // check dimensions
130 if (x.rows() != rangeData_)
132 return false;
133 }
134 if (sizePartitions_.begin() > k)
136 return false;
137 }
138 if (sizePartitions_.end() <= k)
140 return false;
141 }
142 // prepare containers
143 Range xFoldRows = x.rows();
144 xFoldRows.decLast(sizePartitions_[k]);
145 xFold.resize(xFoldRows, x.cols());
146 xTest.resize(sizePartitions_[k], x.cols());
147 // copy data
148 int iFoldRow = xFold.beginRows(), iTestRow = xTest.beginRows();
149 for (int i = partitions_.begin(); i < partitions_.end(); ++i)
150 {
151 if (partitions_[i] == k)
152 {
153 xTest.row(iTestRow) = x.row(i);
154 ++iTestRow;
155 }
156 else
157 {
158 xFold.row(iFoldRow) = x.row(i);
159 ++iFoldRow;
160 }
161 }
162 return true;
163}
164/* get the data set when setting out fold k and test data set */
165template<class xData, class yData>
167 , yData const& y, yData& yFold, yData& yTest)
168{
169 // check if partitions are determined
170 if (!hasRun_)
172 return false;
173 }
174 // check dimensions
175 if (x.rows() != rangeData_)
177 return false;
178 }
179 if (y.rows() != rangeData_)
181 return false;
182 }
183 if (sizePartitions_.begin() > k)
185 return false;
186 }
187 if (sizePartitions_.end() <= k)
189 return false;
190 }
191 // prepare constainers
192 Range xFoldRows = x.rows();
193 xFoldRows.decLast(sizePartitions_[k]);
194 xFold.resize(xFoldRows, x.cols());
195 xTest.resize(sizePartitions_[k], x.cols());
196 yFold.resize(xFoldRows, y.cols());
197 yTest.resize(sizePartitions_[k], y.cols());
198 // copy data
199 int iFoldRow = xFold.beginRows(), iTestRow = xTest.beginRows();
200 for (int i = partitions_.begin(); i < partitions_.end(); ++i)
201 {
202 if (partitions_[i] == k)
203 {
204 xTest.row(iTestRow) = x.row(i);
205 yTest.row(iTestRow) = y.row(i);
206 ++iTestRow;
207 }
208 else
209 {
210 xFold.row(iFoldRow) = x.row(i);
211 yFold.row(iFoldRow) = y.row(i);
212 ++iFoldRow;
213 }
214 }
215 return true;
216}
217
218/* create a random partition in k folds*/
220{
223 //fill the container with the index of folds
224 for(int i = partitions_.begin() ; i< partitions_.end() ;i++)
225 {
228 }
229 //make a random rearrangement
230 int begin = partitions_.begin();
231 for (int i=partitions_.end()-2; i>begin; --i)
232 { std::swap(partitions_[i], partitions_[Law::UniformDiscrete::rand(begin, i+1)]);}
233}
234
235} // namespace STK
236
237#endif /* STK_CVHANDLER_H */
In this file we implement the final class CArrayVector.
In this file we implement the uniform (discrete) law.
#define STKERROR_NO_ARG(Where, Error)
Definition STK_Macros.h:49
#define STKRUNTIME_ERROR_1ARG(Where, Arg, Error)
Definition STK_Macros.h:129
#define STKERROR_1ARG(Where, Arg, Error)
Definition STK_Macros.h:61
This file include all the other header files of the project Sdk.
CvHanler is an utility function for building the submatrix/subvectors needed when using k-folds cross...
CVectorXi partitions_
repartition of the sample into k-folds
CVectorXi const & partitions() const
Range rangeData_
Range of the data set (number of rows)
void partition()
create a random partition in k folds
bool getKFold(int k, Data const &x, Data &xFold, Data &xTest)
get the data set when setting out fold k and test data set
virtual ~CvHandler()
destructor
void setData(Range const &rangeData, int nbFolds)
CVectorXi sizePartitions_
size of each fold
int nbFolds_
Number of folds.
CvHandler(Range const &rangeData, int nbFolds)
Default constructor.
Range const & rangeData() const
virtual bool run()
run the computations.
int nbFolds() const
CVectorXi const & sizePartitions() const
Derived & resize(Range const &I, Range const &J)
resize the Array.
void clear()
clear all allocated memory .
Abstract base class for all classes having a.
Definition STK_IRunner.h:65
String msg_error_
String with the last error message.
Definition STK_IRunner.h:96
bool hasRun_
true if run has been used, false otherwise
Definition STK_IRunner.h:98
virtual int rand() const
Generate a pseudo Uniform random variate.
The MultidimRegression class allows to regress a multidimensional output variable among a multivariat...
Index sub-vector region: Specialization when the size is unknown.
Definition STK_Range.h:265
int size() const
get the size of the TRange (the number of elements).
Definition STK_Range.h:303
The namespace STK is the main domain space of the Statistical ToolKit project.
TRange< UnknownSize > Range
Definition STK_Range.h:59