MLPACK  1.0.10
simple_tolerance_termination.hpp
Go to the documentation of this file.
1 
20 #ifndef _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
21 #define _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
22 
23 #include <mlpack/core.hpp>
24 
25 namespace mlpack {
26 namespace amf {
27 
28 template <class MatType>
30 {
31  public:
33  const size_t maxIterations = 10000,
34  const size_t reverseStepTolerance = 3)
38 
39  void Initialize(const MatType& V)
40  {
41  residueOld = DBL_MAX;
42  iteration = 1;
43  residue = DBL_MIN;
44  reverseStepCount = 0;
45 
46  this->V = &V;
47 
48  c_index = 0;
49  c_indexOld = 0;
50 
51  reverseStepCount = 0;
52  }
53 
54  bool IsConverged(arma::mat& W, arma::mat& H)
55  {
56  // Calculate norm of WH after each iteration.
57  arma::mat WH;
58 
59  WH = W * H;
60 
62  size_t n = V->n_rows;
63  size_t m = V->n_cols;
64  double sum = 0;
65  size_t count = 0;
66  for(size_t i = 0;i < n;i++)
67  {
68  for(size_t j = 0;j < m;j++)
69  {
70  double temp = 0;
71  if((temp = (*V)(i,j)) != 0)
72  {
73  temp = (temp - WH(i, j));
74  temp = temp * temp;
75  sum += temp;
76  count++;
77  }
78  }
79  }
80  residue = sum / count;
81  residue = sqrt(residue);
82 
83  iteration++;
84 
85  if((residueOld - residue) / residueOld < tolerance && iteration > 4)
86  {
87  if(reverseStepCount == 0 && isCopy == false)
88  {
89  isCopy = true;
90  this->W = W;
91  this->H = H;
92  c_index = residue;
94  }
96  }
97  else
98  {
99  reverseStepCount = 0;
100  if(residue <= c_indexOld && isCopy == true)
101  {
102  isCopy = false;
103  }
104  }
105 
107  {
108  if(isCopy)
109  {
110  W = this->W;
111  H = this->H;
112  residue = c_index;
113  }
114  return true;
115  }
116  else return false;
117  }
118 
119  const double& Index() { return residue; }
120  const size_t& Iteration() { return iteration; }
121  const size_t& MaxIterations() { return maxIterations; }
122 
123  private:
124  double tolerance;
126 
127  const MatType* V;
128 
129  size_t iteration;
130  double residueOld;
131  double residue;
132  double normOld;
133 
136 
137  bool isCopy;
138  arma::mat W;
139  arma::mat H;
140  double c_indexOld;
141  double c_index;
142 }; // class SimpleToleranceTermination
143 
144 }; // namespace amf
145 }; // namespace mlpack
146 
147 #endif // _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
148 
SimpleToleranceTermination(const double tolerance=1e-5, const size_t maxIterations=10000, const size_t reverseStepTolerance=3)