MLPACK  1.0.10
svd_complete_incremental_learning.hpp
Go to the documentation of this file.
1 #ifndef SVD_COMPLETE_INCREMENTAL_LEARNING_HPP_INCLUDED
2 #define SVD_COMPLETE_INCREMENTAL_LEARNING_HPP_INCLUDED
3 
4 #include <mlpack/core.hpp>
5 
6 namespace mlpack
7 {
8 namespace amf
9 {
10 
11 template <class MatType>
13 {
14  public:
16  double kw = 0,
17  double kh = 0)
18  : u(u), kw(kw), kh(kh)
19  {}
20 
21  void Initialize(const MatType& dataset, const size_t rank)
22  {
23  (void)rank;
24  n = dataset.n_rows;
25  m = dataset.n_cols;
26 
27  currentUserIndex = 0;
28  currentItemIndex = 0;
29  }
30 
55  inline void WUpdate(const MatType& V,
56  arma::mat& W,
57  const arma::mat& H)
58  {
59  arma::mat deltaW(1, W.n_cols);
60  deltaW.zeros();
61  while(true)
62  {
63  double val;
64  if((val = V(currentItemIndex, currentUserIndex)) != 0)
65  {
66  deltaW += (val - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex)))
67  * arma::trans(H.col(currentUserIndex));
68  if(kw != 0) deltaW -= kw * W.row(currentItemIndex);
69  break;
70  }
72  if(currentUserIndex == n)
73  {
74  currentUserIndex = 0;
76  }
77  }
78 
79  W.row(currentItemIndex) += u*deltaW;
80  }
81 
91  inline void HUpdate(const MatType& V,
92  const arma::mat& W,
93  arma::mat& H)
94  {
95  arma::mat deltaH(H.n_rows, 1);
96  deltaH.zeros();
97 
98  while(true)
99  {
100  double val;
101  if((val = V(currentItemIndex, currentUserIndex)) != 0)
102  deltaH += (val - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex)))
103  * arma::trans(W.row(currentItemIndex));
104  if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
105 
107  if(currentUserIndex == n)
108  {
109  currentUserIndex = 0;
111  }
112  }
113 
114  H.col(currentUserIndex++) += u * deltaH;
115  }
116 
117  private:
118  double u;
119  double kw;
120  double kh;
121 
122  size_t n;
123  size_t m;
124 
127 };
128 
129 template<>
131 {
132  public:
134  double kw = 0,
135  double kh = 0)
136  : u(u), kw(kw), kh(kh), it(NULL)
137  {}
138 
140  {
141  delete it;
142  }
143 
144  void Initialize(const arma::sp_mat& dataset, const size_t rank)
145  {
146  (void)rank;
147  n = dataset.n_rows;
148  m = dataset.n_cols;
149 
150  it = new arma::sp_mat::const_iterator(dataset.begin());
151  isStart = true;
152  }
153 
163  inline void WUpdate(const arma::sp_mat& V,
164  arma::mat& W,
165  const arma::mat& H)
166  {
167  if(!isStart) (*it)++;
168  else isStart = false;
169 
170  if(*it == V.end())
171  {
172  delete it;
173  it = new arma::sp_mat::const_iterator(V.begin());
174  }
175 
176  size_t currentUserIndex = it->col();
177  size_t currentItemIndex = it->row();
178 
179  arma::mat deltaW(1, W.n_cols);
180  deltaW.zeros();
181 
182  deltaW += (**it - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex)))
183  * arma::trans(H.col(currentUserIndex));
184  if(kw != 0) deltaW -= kw * W.row(currentItemIndex);
185 
186  W.row(currentItemIndex) += u*deltaW;
187  }
188 
198  inline void HUpdate(const arma::sp_mat& V,
199  const arma::mat& W,
200  arma::mat& H)
201  {
202  (void)V;
203 
204  arma::mat deltaH(H.n_rows, 1);
205  deltaH.zeros();
206 
207  size_t currentUserIndex = it->col();
208  size_t currentItemIndex = it->row();
209 
210  deltaH += (**it - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex)))
211  * arma::trans(W.row(currentItemIndex));
212  if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
213 
214  H.col(currentUserIndex++) += u * deltaH;
215  }
216 
217  private:
218  double u;
219  double kw;
220  double kh;
221 
222  size_t n;
223  size_t m;
224 
225  arma::sp_mat dummy;
226  arma::sp_mat::const_iterator* it;
227 
228  bool isStart;
229 };
230 
231 }
232 }
233 
234 
235 #endif // SVD_COMPLETE_INCREMENTAL_LEARNING_HPP_INCLUDED
236 
void HUpdate(const arma::sp_mat &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: load.hpp:31
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
void WUpdate(const arma::sp_mat &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
SVDCompleteIncrementalLearning(double u=0.0001, double kw=0, double kh=0)
void Initialize(const MatType &dataset, const size_t rank)
void Initialize(const arma::sp_mat &dataset, const size_t rank)