SHOGUN  3.2.1
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
ConditionalProbabilityTree.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Chiyuan Zhang
8  * Copyright (C) 2012 Chiyuan Zhang
9  */
10 
11 #include <vector>
12 #include <stack>
13 
16 
17 using namespace shogun;
18 using namespace std;
19 
21 {
22  if (data)
23  {
24  if (data->get_feature_class() != C_STREAMING_DENSE)
25  SG_ERROR("Expected StreamingDenseFeatures\n")
26  if (data->get_feature_type() != F_SHORTREAL)
27  SG_ERROR("Expected float32_t feature type\n")
28 
29  set_features(dynamic_cast<CStreamingDenseFeatures<float32_t>* >(data));
30  }
31 
32  vector<int32_t> predicts;
33 
34  m_feats->start_parser();
35  while (m_feats->get_next_example())
36  {
37  predicts.push_back(apply_multiclass_example(m_feats->get_vector()));
38  m_feats->release_example();
39  }
40  m_feats->end_parser();
41 
42  CMulticlassLabels *labels = new CMulticlassLabels(predicts.size());
43  for (size_t i=0; i < predicts.size(); ++i)
44  labels->set_int_label(i, predicts[i]);
45  return labels;
46 }
47 
49 {
50  compute_conditional_probabilities(ex);
51  SGVector<float64_t> probs(m_leaves.size());
52  for (map<int32_t,bnode_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it)
53  {
54  probs[it->first] = accumulate_conditional_probability(it->second);
55  }
56  return SGVector<float64_t>::arg_max(probs.vector, 1, probs.vlen);
57 }
58 
60 {
61  stack<bnode_t *> nodes;
62  nodes.push((bnode_t*) m_root);
63 
64  while (!nodes.empty())
65  {
66  bnode_t *node = nodes.top();
67  nodes.pop();
68  if (node->left())
69  {
70  nodes.push(node->left());
71  nodes.push(node->right());
72 
73  // don't calculate for leaf
74  node->data.p_right = predict_node(ex, node);
75  }
76  }
77 }
78 
80 {
81  float64_t prob = 1;
82  bnode_t *par = (bnode_t*) leaf->parent();
83  while (par != NULL)
84  {
85  if (leaf == par->left())
86  prob *= (1-par->data.p_right);
87  else
88  prob *= par->data.p_right;
89 
90  leaf = par;
91  par = (bnode_t*) leaf->parent();
92  }
93 
94  return prob;
95 }
96 
98 {
99  if (data)
100  {
101  if (data->get_feature_class() != C_STREAMING_DENSE)
102  SG_ERROR("Expected StreamingDenseFeatures\n")
103  if (data->get_feature_type() != F_SHORTREAL)
104  SG_ERROR("Expected float32_t features\n")
105  set_features(dynamic_cast<CStreamingDenseFeatures<float32_t> *>(data));
106  }
107  else
108  {
109  if (!m_feats)
110  SG_ERROR("No data features provided\n")
111  }
112 
113  m_machines->reset_array();
114  SG_UNREF(m_root);
115  m_root = NULL;
116 
117  m_leaves.clear();
118 
119  m_feats->start_parser();
120  for (int32_t ipass=0; ipass < m_num_passes; ++ipass)
121  {
122  while (m_feats->get_next_example())
123  {
124  train_example(m_feats->get_vector(), static_cast<int32_t>(m_feats->get_label()));
125  m_feats->release_example();
126  }
127 
128  if (ipass < m_num_passes-1)
129  m_feats->reset_stream();
130  }
131  m_feats->end_parser();
132 
133  for (int32_t i=0; i < m_machines->get_num_elements(); ++i)
134  {
135  COnlineLibLinear *lll = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(i));
136  lll->stop_train();
137  SG_UNREF(lll);
138  }
139 
140  return true;
141 }
142 
144 {
145  if (m_root)
147  else
148  printf("Empty Tree\n");
149 }
150 
152 {
153  if (m_root == NULL)
154  {
155  m_root = new bnode_t();
156  m_root->data.label = label;
157  m_leaves.insert(make_pair(label, (bnode_t*) m_root));
158  m_root->machine(create_machine(ex));
159  return;
160  }
161 
162  if (m_leaves.find(label) != m_leaves.end())
163  {
164  train_path(ex, m_leaves[label]);
165  }
166  else
167  {
168  bnode_t *node = (bnode_t*) m_root;
169  while (node->left() != NULL)
170  {
171  // not a leaf
172  bool is_left = which_subtree(node, ex);
173  float64_t node_label;
174  if (is_left)
175  node_label = 0;
176  else
177  node_label = 1;
178  train_node(ex, node_label, node);
179 
180  if (is_left)
181  node = node->left();
182  else
183  node = node->right();
184  }
185 
186  m_leaves.erase(node->data.label);
187 
188  bnode_t *left_node = new bnode_t();
189  left_node->data.label = node->data.label;
190  node->data.label = -1;
191  COnlineLibLinear *node_mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine()));
192  COnlineLibLinear *mch = new COnlineLibLinear(node_mch);
193  SG_UNREF(node_mch);
194  mch->start_train();
195  m_machines->push_back(mch);
196  left_node->machine(m_machines->get_num_elements()-1);
197  m_leaves.insert(make_pair(left_node->data.label, left_node));
198  node->left(left_node);
199 
200  bnode_t *right_node = new bnode_t();
201  right_node->data.label = label;
202  right_node->machine(create_machine(ex));
203  m_leaves.insert(make_pair(label, right_node));
204  node->right(right_node);
205  }
206 }
207 
209 {
210  float64_t node_label = 0;
211  train_node(ex, node_label, node);
212 
213  bnode_t *par = (bnode_t*) node->parent();
214  while (par != NULL)
215  {
216  if (par->left() == node)
217  node_label = 0;
218  else
219  node_label = 1;
220 
221  train_node(ex, node_label, par);
222  node = par;
223  par = (bnode_t*) node->parent();
224  }
225 }
226 
228 {
229  REQUIRE(node, "Node must not be NULL\n");
230  COnlineLibLinear *mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine()));
231  REQUIRE(mch, "Instance of %s could not be casted to COnlineLibLinear\n", node->get_name());
232  mch->train_one(ex, label);
233  SG_UNREF(mch);
234 }
235 
237 {
238  REQUIRE(node, "Node must not be NULL\n");
239  COnlineLibLinear *mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine()));
240  REQUIRE(mch, "Instance of %s could not be casted to COnlineLibLinear\n", node->get_name());
241  float64_t pred = mch->apply_one(ex.vector, ex.vlen);
242  SG_UNREF(mch);
243  // use sigmoid function to turn the decision value into valid probability
244  return 1.0/(1+CMath::exp(-pred));
245 }
246 
248 {
249  COnlineLibLinear *mch = new COnlineLibLinear();
250  mch->start_train();
251  mch->train_one(ex, 0);
252  m_machines->push_back(mch);
253  return m_machines->get_num_elements()-1;
254 }
float64_t accumulate_conditional_probability(bnode_t *leaf)
void parent(CTreeMachineNode *par)
The node of the tree structure forming a TreeMachine The node contains pointer to its parent and poin...
void machine(int32_t idx)
void train_example(SGVector< float32_t > ex, int32_t label)
#define SG_UNREF(x)
Definition: SGRefObject.h:35
static int32_t arg_max(T *vec, int32_t inc, int32_t len, T *maxv_ptr=NULL)
return arg_max(vec)
Definition: SGVector.cpp:1048
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
#define SG_ERROR(...)
Definition: SGIO.h:130
#define REQUIRE(x,...)
Definition: SGIO.h:207
virtual void train_one(SGVector< float32_t > ex, float64_t label)
float64_t predict_node(SGVector< float32_t > ex, bnode_t *node)
Multiclass Labels for multi-class classification.
static void print_data(const ConditionalProbabilityTreeNodeData &data)
void right(CBinaryTreeMachineNode *r)
int32_t create_machine(SGVector< float32_t > ex)
double float64_t
Definition: common.h:50
virtual float64_t apply_one(int32_t vec_idx)
get output for example "vec_idx"
virtual EFeatureClass get_feature_class() const =0
virtual const char * get_name() const
void train_path(SGVector< float32_t > ex, bnode_t *node)
void train_node(SGVector< float32_t > ex, float64_t label, bnode_t *node)
CBinaryTreeMachineNode< VwConditionalProbabilityTreeNodeData > bnode_t
Class implementing a purely online version of CLibLinear, using the L2R_L1LOSS_SVC_DUAL solver only...
bool set_int_label(int32_t idx, int32_t label)
The class Features is the base class of all feature objects.
Definition: Features.h:68
static float64_t exp(float64_t x)
Definition: Math.h:444
void left(CBinaryTreeMachineNode *l)
void compute_conditional_probabilities(SGVector< float32_t > ex)
virtual EFeatureType get_feature_type() const =0
index_t vlen
Definition: SGVector.h:707
virtual int32_t apply_multiclass_example(SGVector< float32_t > ex)

SHOGUN 机器学习工具包 - 项目文档