41 using namespace internal;
45 SG_SDEBUG(
"Data manager instance initialized with %d data sources!\n", num_distributions);
46 fetchers.resize(num_distributions);
47 std::fill(fetchers.begin(), fetchers.end(),
nullptr);
49 train_test_mode=default_train_test_mode;
50 train_mode=default_train_mode;
51 train_test_ratio=default_train_test_ratio;
52 cross_validation_mode=default_cross_validation_mode;
63 typedef const std::unique_ptr<DataFetcher> fetcher_type;
64 if (std::any_of(fetchers.begin(), fetchers.end(), [](fetcher_type& f) {
return f->m_num_samples==0; }))
65 SG_SERROR(
"number of samples from all the distributions are not set!")
67 std::for_each(fetchers.begin(), fetchers.end(), [&n](fetcher_type& f) { n+=f->m_num_samples; });
76 typedef const std::unique_ptr<DataFetcher> fetcher_type;
77 if (std::any_of(fetchers.begin(), fetchers.end(), [](fetcher_type& f) {
return f->m_num_samples==0; }))
78 SG_SERROR(
"number of samples from all the distributions are not set!")
82 for (
size_t i=0; i<fetchers.size(); ++i)
83 divisor=
CMath::gcd(divisor, fetchers[i]->m_num_samples);
86 SG_SDEBUG(
"min blocksize is %d!", min_blocksize);
97 "Total number of samples is 0! Please set the number of samples!\n");
98 REQUIRE(blocksize>0 && blocksize<=n,
99 "The blocksize has to be within [0, %d], given = %d!\n",
102 "Total number of samples (%d) has to be divisble by the blocksize (%d)!\n",
105 for (
size_t i=0; i<fetchers.size(); ++i)
107 index_t m=fetchers[i]->m_num_samples;
109 "Blocksize (%d) cannot be even distributed with a ratio of %f!\n",
111 fetchers[i]->fetch_blockwise().with_blocksize(blocksize*m/n);
112 SG_SDEBUG(
"block[%d].size = ", i, blocksize*m/n);
120 REQUIRE(num_blocks_per_burst>0,
121 "Number of blocks per burst (%d) has to be greater than 0!\n",
122 num_blocks_per_burst);
125 typedef std::unique_ptr<DataFetcher> fetcher_type;
126 std::for_each(fetchers.begin(), fetchers.end(), [&blocksize](fetcher_type& f)
128 blocksize+=f->m_block_details.m_blocksize;
131 "Blocksizes are not set!\n");
134 if (num_blocks_per_burst>max_num_blocks_per_burst)
136 SG_SINFO(
"There can only be %d blocks per burst given the blocksize (%d)!\n", max_num_blocks_per_burst, blocksize);
137 SG_SINFO(
"Setting num blocks per burst to be %d instead!\n", max_num_blocks_per_burst);
138 num_blocks_per_burst=max_num_blocks_per_burst;
141 for (
size_t i=0; i<fetchers.size(); ++i)
142 fetchers[i]->fetch_blockwise().with_num_blocks_per_burst(num_blocks_per_burst);
149 REQUIRE(i<(int64_t)fetchers.size(),
150 "Value of i (%d) should be between 0 and %d, inclusive!",
151 i, fetchers.size()-1);
153 return InitPerFeature(fetchers[i]);
159 REQUIRE(i<(int64_t)fetchers.size(),
160 "Value of i (%d) should be between 0 and %d, inclusive!",
161 i, fetchers.size()-1);
163 if (fetchers[i]!=
nullptr)
164 return fetchers[i]->m_samples;
172 REQUIRE(i<(int64_t)fetchers.size(),
173 "Value of i (%d) should be between 0 and %d, inclusive!",
174 i, fetchers.size()-1);
176 return fetchers[i]->m_num_samples;
182 REQUIRE(i<(int64_t)fetchers.size(),
183 "Value of i (%d) should be between 0 and %d, inclusive!",
184 i, fetchers.size()-1);
186 if (fetchers[i]!=
nullptr)
187 return fetchers[i]->get_num_samples();
195 REQUIRE(i<(int64_t)fetchers.size(),
196 "Value of i (%d) should be between 0 and %d, inclusive!",
197 i, fetchers.size()-1);
199 if (fetchers[i]!=
nullptr)
200 return fetchers[i]->m_block_details.m_blocksize;
205 void DataManager::set_blockwise(
bool blockwise)
208 for (
size_t i=0; i<fetchers.size(); ++i)
209 fetchers[i]->set_blockwise(blockwise);
213 const bool DataManager::is_blockwise()
const 217 for (
size_t i=0; i<fetchers.size(); ++i)
218 blockwise&=!fetchers[i]->m_block_details.m_full_data;
223 void DataManager::set_train_test_mode(
bool on)
226 if (!train_test_mode)
228 train_mode=default_train_mode;
229 train_test_ratio=default_train_test_ratio;
230 cross_validation_mode=default_cross_validation_mode;
232 REQUIRE(fetchers.size()>0,
"Features are not set!");
233 typedef std::unique_ptr<DataFetcher> fetcher_type;
234 std::for_each(fetchers.begin(), fetchers.end(), [
this, on](fetcher_type& f)
236 f->set_train_test_mode(on);
239 f->set_train_mode(train_mode);
240 f->set_train_test_ratio(train_test_ratio);
245 bool DataManager::is_train_test_mode()
const 247 return train_test_mode;
250 void DataManager::set_train_mode(
bool on)
256 SG_SERROR(
"Train mode cannot be used without turning on Train/Test mode first!" 257 "Please call set_train_test_mode(True) before using this method!\n");
261 bool DataManager::is_train_mode()
const 266 void DataManager::set_cross_validation_mode(
bool on)
269 cross_validation_mode=on;
272 SG_SERROR(
"Cross-validation mode cannot be used without turning on Train/Test mode first!" 273 "Please call set_train_test_mode(True) before using this method!\n");
277 bool DataManager::is_cross_validation_mode()
const 279 return cross_validation_mode;
282 void DataManager::set_train_test_ratio(
float64_t ratio)
285 train_test_ratio=ratio;
288 SG_SERROR(
"Train-test ratio cannot be set without turning on Train/Test mode first!" 289 "Please call set_train_test_mode(True) before using this method!\n");
293 float64_t DataManager::get_train_test_ratio()
const 295 return train_test_ratio;
298 index_t DataManager::get_num_folds()
const 300 return ceil(get_train_test_ratio())+1;
303 void DataManager::shuffle_features()
306 REQUIRE(fetchers.size()>0,
"Features are not set!");
307 typedef std::unique_ptr<DataFetcher> fetcher_type;
308 std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->shuffle_features(); });
312 void DataManager::unshuffle_features()
315 REQUIRE(fetchers.size()>0,
"Features are not set!");
316 typedef std::unique_ptr<DataFetcher> fetcher_type;
317 std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->unshuffle_features(); });
321 void DataManager::init_active_subset()
326 "Train-test subset cannot be used without turning on Train/Test mode first!" 327 "Please call set_train_test_mode(True) before using this method!\n");
328 REQUIRE(fetchers.size()>0,
"Features are not set!");
330 typedef std::unique_ptr<DataFetcher> fetcher_type;
331 std::for_each(fetchers.begin(), fetchers.end(), [
this](fetcher_type& f)
333 f->set_train_mode(train_mode);
334 f->set_train_test_ratio(train_test_ratio);
335 f->init_active_subset();
340 void DataManager::use_fold(
index_t idx)
345 "Fold subset cannot be used without turning on Train/Test mode first!" 346 "Please call set_train_test_mode(True) before using this method!\n");
347 REQUIRE(fetchers.size()>0,
"Features are not set!");
348 REQUIRE(idx>=0,
"Fold index has to be in [0, %d]!", get_num_folds()-1);
349 REQUIRE(idx<get_num_folds(),
"Fold index has to be in [0, %d]!", get_num_folds()-1);
351 typedef std::unique_ptr<DataFetcher> fetcher_type;
352 std::for_each(fetchers.begin(), fetchers.end(), [
this, idx](fetcher_type& f)
354 f->set_train_mode(train_mode);
355 f->set_train_test_ratio(train_test_ratio);
361 void DataManager::start()
364 REQUIRE(fetchers.size()>0,
"Features are not set!");
366 if (train_test_mode && !cross_validation_mode)
367 init_active_subset();
369 typedef std::unique_ptr<DataFetcher> fetcher_type;
370 std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->start(); });
382 for (
size_t i=0; i<fetchers.size(); ++i)
384 auto feats=fetchers[i]->next();
387 auto blocksize=fetchers[i]->m_block_details.m_blocksize;
388 auto num_blocks_curr_burst=feats->get_num_vectors()/blocksize;
391 if (next_samples.m_num_blocks==0)
392 next_samples.m_num_blocks=num_blocks_curr_burst;
394 ASSERT(next_samples.m_num_blocks==num_blocks_curr_burst);
404 void DataManager::end()
407 REQUIRE(fetchers.size()>0,
"Features are not set!");
408 typedef std::unique_ptr<DataFetcher> fetcher_type;
409 std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->end(); });
413 void DataManager::reset()
416 REQUIRE(fetchers.size()>0,
"Features are not set!");
417 typedef std::unique_ptr<DataFetcher> fetcher_type;
418 std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->reset(); });
static int32_t gcd(int32_t a, int32_t b)
static std::vector< Block > create_blocks(CFeatures *feats, index_t num_blocks, index_t size)
index_t get_num_samples() const
DataManager(index_t num_distributions)
index_t & num_samples_at(index_t i)
void set_blocksize(index_t blocksize)
InitPerFeature samples_at(index_t i)
void set_num_blocks_per_burst(index_t num_blocks_per_burst)
all of classes and functions are contained in the shogun namespace
The class Features is the base class of all feature objects.
index_t get_min_blocksize() const
class NextSamples is the return type for next() call in DataManager. If there are no more samples (fr...
const index_t blocksize_at(index_t i) const