MLPACK  1.0.10
nmf_mult_div.hpp
Go to the documentation of this file.
1 
32 #ifndef __MLPACK_METHODS_LMF_UPDATE_RULES_NMF_MULT_DIV_HPP
33 #define __MLPACK_METHODS_LMF_UPDATE_RULES_NMF_MULT_DIV_HPP
34 
35 #include <mlpack/core.hpp>
36 
37 namespace mlpack {
38 namespace amf {
39 
41 {
42  public:
43  // Empty constructor required for the WUpdateRule template.
45 
46  template<typename MatType>
47  void Initialize(const MatType& dataset, const size_t rank)
48  {
49  (void)dataset;
50  (void)rank;
51  }
52 
66  template<typename MatType>
67  inline static void WUpdate(const MatType& V,
68  arma::mat& W,
69  const arma::mat& H)
70  {
71  // Simple implementation left in the header file.
72  arma::mat t1;
73  arma::rowvec t2;
74 
75  t1 = W * H;
76  for (size_t i = 0; i < W.n_rows; ++i)
77  {
78  for (size_t j = 0; j < W.n_cols; ++j)
79  {
80  // Writing this as a single expression does not work as of Armadillo
81  // 3.920. This should be fixed in a future release, and then the code
82  // below can be fixed.
83  //t2 = H.row(j) % V.row(i) / t1.row(i);
84  t2.set_size(H.n_cols);
85  for (size_t k = 0; k < t2.n_elem; ++k)
86  {
87  t2(k) = H(j, k) * V(i, k) / t1(i, k);
88  }
89 
90  W(i, j) = W(i, j) * sum(t2) / sum(H.row(j));
91  }
92  }
93  }
94 
108  template<typename MatType>
109  inline static void HUpdate(const MatType& V,
110  const arma::mat& W,
111  arma::mat& H)
112  {
113  // Simple implementation left in the header file.
114  arma::mat t1;
115  arma::colvec t2;
116 
117  t1 = W * H;
118  for (size_t i = 0; i < H.n_rows; i++)
119  {
120  for (size_t j = 0; j < H.n_cols; j++)
121  {
122  // Writing this as a single expression does not work as of Armadillo
123  // 3.920. This should be fixed in a future release, and then the code
124  // below can be fixed.
125  //t2 = W.col(i) % V.col(j) / t1.col(j);
126  t2.set_size(W.n_rows);
127  for (size_t k = 0; k < t2.n_elem; ++k)
128  {
129  t2(k) = W(k, i) * V(k, j) / t1(k, j);
130  }
131 
132  H(i,j) = H(i,j) * sum(t2) / sum(W.col(i));
133  }
134  }
135  }
136 };
137 
138 }; // namespace amf
139 }; // namespace mlpack
140 
141 #endif
static void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void Initialize(const MatType &dataset, const size_t rank)
static void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.