37 #ifndef VIGRA_RF3_IMPEX_HDF5_HXX
38 #define VIGRA_RF3_IMPEX_HDF5_HXX
46 #include "random_forest_3/random_forest.hxx"
47 #include "random_forest_3/random_forest_common.hxx"
48 #include "random_forest_3/random_forest_visitors.hxx"
49 #include "hdf5impex.hxx"
57 static const char *
const rf_hdf5_ext_param =
"_ext_param";
58 static const char *
const rf_hdf5_options =
"_options";
59 static const char *
const rf_hdf5_topology =
"topology";
60 static const char *
const rf_hdf5_parameters =
"parameters";
61 static const char *
const rf_hdf5_tree =
"Tree_";
62 static const char *
const rf_hdf5_version_group =
".";
63 static const char *
const rf_hdf5_version_tag =
"vigra_random_forest_version";
64 static const double rf_hdf5_version = 0.1;
70 rf_AllColumns = 0x00000000,
71 rf_ToBePrunedTag = 0x80000000,
72 rf_LeafNodeTag = 0x40000000,
74 rf_i_ThresholdNode = 0,
75 rf_i_HyperplaneNode = 1,
76 rf_i_HypersphereNode = 2,
77 rf_e_ConstProbNode = 0 | rf_LeafNodeTag,
78 rf_e_LogRegProbNode = 1 | rf_LeafNodeTag
81 static const unsigned int rf_tag_mask = 0xf0000000;
82 static const unsigned int rf_type_mask = 0x00000003;
83 static const unsigned int rf_zero_mask = 0xffffffff & ~rf_tag_mask & ~rf_type_mask;
87 inline std::string get_cwd(HDF5File & h5context)
89 return h5context.get_absolute_path(h5context.pwd());
93 template <
typename FEATURES,
typename LABELS>
94 typename DefaultRF<FEATURES, LABELS>::type
95 random_forest_import_HDF5(HDF5File & h5ctx, std::string
const & pathname =
"")
97 typedef typename DefaultRF<FEATURES, LABELS>::type RF;
98 typedef typename RF::Graph Graph;
99 typedef typename RF::Node Node;
100 typedef typename RF::SplitTests SplitTest;
101 typedef typename LABELS::value_type LabelType;
102 typedef typename RF::AccInputType AccInputType;
103 typedef typename AccInputType::value_type AccValueType;
107 if (pathname.size()) {
108 cwd = detail::get_cwd(h5ctx);
112 if (h5ctx.existsAttribute(rf_hdf5_version_group, rf_hdf5_version_tag)) {
114 h5ctx.readAttribute(rf_hdf5_version_group, rf_hdf5_version_tag, version);
115 vigra_precondition(version <= rf_hdf5_version,
"random_forest_import_HDF5(): unexpected file format version.");
120 size_t num_instances;
125 MultiArray<1, LabelType> distinct_labels_marray;
126 MultiArray<1, double> class_weights_marray;
128 h5ctx.cd(rf_hdf5_ext_param);
129 h5ctx.read(
"actual_msample_", msample);
130 h5ctx.read(
"actual_mtry_", actual_mtry);
131 h5ctx.read(
"class_count_", num_classes);
132 h5ctx.readAndResize(
"class_weights_", class_weights_marray);
133 h5ctx.read(
"column_count_", num_features);
134 h5ctx.read(
"is_weighted_", is_weighted_int);
135 h5ctx.readAndResize(
"labels", distinct_labels_marray);
136 h5ctx.read(
"row_count_", num_instances);
139 bool is_weighted = is_weighted_int == 1 ?
true :
false;
142 size_t min_num_instances;
145 int bootstrap_sampling_int;
147 h5ctx.cd(rf_hdf5_options);
148 h5ctx.read(
"min_split_node_size_", min_num_instances);
149 h5ctx.read(
"mtry_", mtry);
150 h5ctx.read(
"mtry_switch_", mtry_switch_int);
151 h5ctx.read(
"sample_with_replacement_", bootstrap_sampling_int);
152 h5ctx.read(
"tree_count_", tree_count);
155 RandomForestOptionTags mtry_switch = (RandomForestOptionTags)mtry_switch_int;
156 bool bootstrap_sampling = bootstrap_sampling_int == 1 ?
true :
false;
158 std::vector<LabelType>
const distinct_labels(distinct_labels_marray.begin(), distinct_labels_marray.end());
159 std::vector<double>
const class_weights(class_weights_marray.begin(), class_weights_marray.end());
161 auto const pspec = ProblemSpec<LabelType>()
162 .num_features(num_features)
163 .num_instances(num_instances)
164 .num_classes(num_classes)
165 .distinct_classes(distinct_labels)
166 .actual_mtry(actual_mtry)
167 .actual_msample(msample);
169 auto options = RandomForestOptions()
170 .min_num_instances(min_num_instances)
171 .bootstrap_sampling(bootstrap_sampling)
172 .tree_count(tree_count);
173 options.features_per_node_switch_ = mtry_switch;
174 options.features_per_node_ = mtry;
176 options.class_weights(class_weights);
179 typename RF::template NodeMap<SplitTest>::type split_tests;
180 typename RF::template NodeMap<AccInputType>::type leaf_responses;
182 auto const groups = h5ctx.ls();
183 for (
auto const & groupname : groups) {
184 if (groupname.substr(0, std::char_traits<char>::length(rf_hdf5_tree)).compare(rf_hdf5_tree) != 0) {
188 MultiArray<1, unsigned int> topology;
189 MultiArray<1, double> parameters;
191 h5ctx.readAndResize(rf_hdf5_topology, topology);
192 h5ctx.readAndResize(rf_hdf5_parameters, parameters);
195 vigra_precondition(topology[0] == num_features,
"random_forest_import_HDF5(): number of features mismatch.");
196 vigra_precondition(topology[1] == num_classes,
"random_forest_import_HDF5(): number of classes mismatch.");
198 Node
const n = gr.addNode();
200 std::queue<std::pair<unsigned int, Node> > q;
203 auto const el = q.front();
205 unsigned int const index = el.first;
206 Node
const parent = el.second;
208 vigra_precondition((topology[index] & rf_zero_mask) == 0,
"random_forest_import_HDF5(): unexpected node type: type & zero_mask > 0");
210 if (topology[index] & rf_LeafNodeTag) {
211 unsigned int const probs_start = topology[index+1] + 1;
213 vigra_precondition((topology[index] & rf_tag_mask) == rf_LeafNodeTag,
"random_forest_import_HDF5(): unexpected node type: additional tags in leaf node");
215 std::vector<AccValueType> node_response;
217 for (
unsigned int i = 0; i < num_classes; ++i) {
218 node_response.push_back(parameters[probs_start + i]);
221 leaf_responses.insert(parent, node_response);
224 vigra_precondition(topology[index] == rf_i_ThresholdNode,
"random_forest_import_HDF5(): unexpected node type.");
226 Node
const left = gr.addNode();
227 Node
const right = gr.addNode();
229 gr.addArc(parent, left);
230 gr.addArc(parent, right);
232 split_tests.insert(parent, SplitTest(topology[index+4], parameters[topology[index+1]+1]));
234 q.push(std::make_pair(topology[index+2], left));
235 q.push(std::make_pair(topology[index+3], right));
246 RF rf(gr, split_tests, leaf_responses, pspec);
247 rf.options_ = options;
253 class PaddedNumberString
257 PaddedNumberString(
int n)
260 width_ = ss_.str().size();
263 std::string operator()(
int k)
const
266 ss_ << std::setw(width_) << std::setfill(
'0') << k;
272 mutable std::ostringstream ss_;
277 template <
typename RF>
278 void random_forest_export_HDF5(
280 HDF5File & h5context,
281 std::string
const & pathname =
""
283 typedef typename RF::LabelType LabelType;
284 typedef typename RF::Node Node;
287 if (pathname.size()) {
288 cwd = detail::get_cwd(h5context);
289 h5context.cd_mk(pathname);
293 h5context.writeAttribute(rf_hdf5_version_group, rf_hdf5_version_tag,
297 auto const & p = rf.problem_spec_;
298 auto const & opts = rf.options_;
299 MultiArray<1, LabelType> distinct_classes(Shape1(p.distinct_classes_.size()), p.distinct_classes_.data());
300 MultiArray<1, double> class_weights(Shape1(p.num_classes_), 1.0);
302 if (opts.class_weights_.size() > 0)
305 for (
size_t i = 0; i < opts.class_weights_.size(); ++i)
306 class_weights(i) = opts.class_weights_[i];
310 h5context.cd_mk(rf_hdf5_ext_param);
311 h5context.write(
"column_count_", p.num_features_);
312 h5context.write(
"row_count_", p.num_instances_);
313 h5context.write(
"class_count_", p.num_classes_);
314 h5context.write(
"actual_mtry_", p.actual_mtry_);
315 h5context.write(
"actual_msample_", p.actual_msample_);
316 h5context.write(
"labels", distinct_classes);
317 h5context.write(
"is_weighted_", is_weighted);
318 h5context.write(
"class_weights_", class_weights);
319 h5context.write(
"precision_", 0.0);
320 h5context.write(
"problem_type_", 1.0);
321 h5context.write(
"response_size_", 1.0);
322 h5context.write(
"used_", 1.0);
326 h5context.cd_mk(rf_hdf5_options);
327 h5context.write(
"min_split_node_size_", opts.min_num_instances_);
328 h5context.write(
"mtry_", opts.features_per_node_);
329 h5context.write(
"mtry_func_", 0.0);
330 h5context.write(
"mtry_switch_", opts.features_per_node_switch_);
331 h5context.write(
"predict_weighted_", 0.0);
332 h5context.write(
"prepare_online_learning_", 0.0);
333 h5context.write(
"sample_with_replacement_", opts.bootstrap_sampling_ ? 1 : 0);
334 h5context.write(
"stratification_method_", 3.0);
335 h5context.write(
"training_set_calc_switch_", 1.0);
336 h5context.write(
"training_set_func_", 0.0);
337 h5context.write(
"training_set_proportion_", 1.0);
338 h5context.write(
"training_set_size_", 0.0);
339 h5context.write(
"tree_count_", opts.tree_count_);
343 detail::PaddedNumberString tree_number(rf.num_trees());
344 for (
size_t i = 0; i < rf.num_trees(); ++i)
347 std::vector<UInt32> topology;
348 std::vector<double> parameters;
349 topology.push_back(p.num_features_);
350 topology.push_back(p.num_classes_);
352 auto const & probs = rf.node_responses_;
353 auto const & splits = rf.split_tests_;
354 auto const & gr = rf.graph_;
355 auto const root = gr.getRoot(i);
361 std::stack<std::pair<Node, std::ptrdiff_t> > stack;
362 stack.emplace(root, -1);
363 while (!stack.empty())
365 auto const n = stack.top().first;
366 auto const i = stack.top().second;
371 topology[i] = topology.size();
373 if (gr.numChildren(n) == 0)
378 topology.push_back(rf_LeafNodeTag);
379 topology.push_back(parameters.size());
380 auto const & prob = probs.at(n);
381 auto const weight = std::accumulate(prob.begin(), prob.end(), 0.0);
382 parameters.push_back(weight);
383 parameters.insert(parameters.end(), prob.begin(), prob.end());
390 topology.push_back(rf_i_ThresholdNode);
391 topology.push_back(parameters.size());
392 topology.push_back(-1);
393 topology.push_back(-1);
394 topology.push_back(splits.at(n).dim_);
395 parameters.push_back(1.0);
396 parameters.push_back(splits.at(n).val_);
399 stack.emplace(gr.getChild(n, 0), topology.size()-3);
400 stack.emplace(gr.getChild(n, 1), topology.size()-2);
405 MultiArray<1, UInt32> topo(Shape1(topology.size()), topology.data());
406 MultiArray<1, double> para(Shape1(parameters.size()), parameters.data());
408 auto const name = rf_hdf5_tree + tree_number(i);
409 h5context.cd_mk(name);
410 h5context.write(rf_hdf5_topology, topo);
411 h5context.write(rf_hdf5_parameters, para);
424 #endif // VIGRA_NEW_RANDOM_FOREST_IMPEX_HDF5_HXX