mlpack  2.0.1
information_gain.hpp
Go to the documentation of this file.
1 
15 #ifndef __MLPACK_METHODS_HOEFFDING_TREES_INFORMATION_GAIN_HPP
16 #define __MLPACK_METHODS_HOEFFDING_TREES_INFORMATION_GAIN_HPP
17 
18 namespace mlpack {
19 namespace tree {
20 
22 {
23  public:
33  static double Evaluate(const arma::Mat<size_t>& counts)
34  {
35  // Calculate the number of elements in the unsplit node and also in each
36  // proposed child.
37  size_t numElem = 0;
38  arma::vec splitCounts(counts.n_elem);
39  for (size_t i = 0; i < counts.n_cols; ++i)
40  {
41  splitCounts[i] = arma::accu(counts.col(i));
42  numElem += splitCounts[i];
43  }
44 
45  // Corner case: if there are no elements, the gain is zero.
46  if (numElem == 0)
47  return 0.0;
48 
49  arma::Col<size_t> classCounts = arma::sum(counts, 1);
50 
51  // Calculate the gain of the unsplit node.
52  double gain = 0.0;
53  for (size_t i = 0; i < classCounts.n_elem; ++i)
54  {
55  const double f = ((double) classCounts[i] / (double) numElem);
56  if (f > 0.0)
57  gain -= f * std::log2(f);
58  }
59 
60  // Now calculate the impurity of the split nodes and subtract them from the
61  // overall gain.
62  for (size_t i = 0; i < counts.n_cols; ++i)
63  {
64  if (splitCounts[i] > 0)
65  {
66  double splitGain = 0.0;
67  for (size_t j = 0; j < counts.n_rows; ++j)
68  {
69  const double f = ((double) counts(j, i) / (double) splitCounts[i]);
70  if (f > 0.0)
71  splitGain += f * std::log2(f);
72  }
73 
74  gain += ((double) splitCounts[i] / (double) numElem) * splitGain;
75  }
76  }
77 
78  return gain;
79  }
80 
86  static double Range(const size_t numClasses)
87  {
88  // The best possible case gives an information gain of 0. The worst
89  // possible case is even distribution, which gives n * (1/n * log2(1/n)) =
90  // log2(1/n) = -log2(n). So, the range is log2(n).
91  return std::log2(numClasses);
92  }
93 };
94 
95 } // namespace tree
96 } // namespace mlpack
97 
98 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
static double Range(const size_t numClasses)
Return the range of the information gain for the given number of classes.
static double Evaluate(const arma::Mat< size_t > &counts)
Given the sufficient statistics of a proposed split, calculate the information gain if that split was...