MLPACK
1.0.10
Main Page
Related Pages
Namespaces
Classes
Files
File List
File Members
src
mlpack
methods
amf
termination_policies
simple_tolerance_termination.hpp
Go to the documentation of this file.
1
20
#ifndef _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
21
#define _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
22
23
#include <
mlpack/core.hpp
>
24
25
namespace
mlpack {
26
namespace
amf {
27
28
template
<
class
MatType>
29
class
SimpleToleranceTermination
30
{
31
public
:
32
SimpleToleranceTermination
(
const
double
tolerance
= 1e-5,
33
const
size_t
maxIterations
= 10000,
34
const
size_t
reverseStepTolerance
= 3)
35
:
tolerance
(
tolerance
),
36
maxIterations
(
maxIterations
),
37
reverseStepTolerance
(
reverseStepTolerance
) {}
38
39
void
Initialize
(
const
MatType&
V
)
40
{
41
residueOld
= DBL_MAX;
42
iteration
= 1;
43
residue
= DBL_MIN;
44
reverseStepCount
= 0;
45
46
this->V = &
V
;
47
48
c_index
= 0;
49
c_indexOld
= 0;
50
51
reverseStepCount
= 0;
52
}
53
54
bool
IsConverged
(arma::mat&
W
, arma::mat&
H
)
55
{
56
// Calculate norm of WH after each iteration.
57
arma::mat WH;
58
59
WH = W *
H
;
60
61
residueOld
=
residue
;
62
size_t
n =
V
->n_rows;
63
size_t
m =
V
->n_cols;
64
double
sum = 0;
65
size_t
count = 0;
66
for
(
size_t
i = 0;i < n;i++)
67
{
68
for
(
size_t
j = 0;j < m;j++)
69
{
70
double
temp = 0;
71
if
((temp = (*
V
)(i,j)) != 0)
72
{
73
temp = (temp - WH(i, j));
74
temp = temp * temp;
75
sum += temp;
76
count++;
77
}
78
}
79
}
80
residue
= sum / count;
81
residue
= sqrt(
residue
);
82
83
iteration
++;
84
85
if
((
residueOld
-
residue
) / residueOld < tolerance && iteration > 4)
86
{
87
if
(
reverseStepCount
== 0 &&
isCopy
==
false
)
88
{
89
isCopy
=
true
;
90
this->W =
W
;
91
this->H =
H
;
92
c_index
=
residue
;
93
c_indexOld
=
residueOld
;
94
}
95
reverseStepCount
++;
96
}
97
else
98
{
99
reverseStepCount
= 0;
100
if
(
residue
<=
c_indexOld
&&
isCopy
==
true
)
101
{
102
isCopy
=
false
;
103
}
104
}
105
106
if
(
reverseStepCount
==
reverseStepTolerance
||
iteration
>
maxIterations
)
107
{
108
if
(
isCopy
)
109
{
110
W = this->
W
;
111
H = this->
H
;
112
residue
=
c_index
;
113
}
114
return
true
;
115
}
116
else
return
false
;
117
}
118
119
const
double
&
Index
() {
return
residue
; }
120
const
size_t
&
Iteration
() {
return
iteration
; }
121
const
size_t
&
MaxIterations
() {
return
maxIterations
; }
122
123
private
:
124
double
tolerance
;
125
size_t
maxIterations
;
126
127
const
MatType*
V
;
128
129
size_t
iteration
;
130
double
residueOld
;
131
double
residue
;
132
double
normOld
;
133
134
size_t
reverseStepTolerance
;
135
size_t
reverseStepCount
;
136
137
bool
isCopy
;
138
arma::mat
W
;
139
arma::mat
H
;
140
double
c_indexOld
;
141
double
c_index
;
142
};
// class SimpleToleranceTermination
143
144
};
// namespace amf
145
};
// namespace mlpack
146
147
#endif // _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
148
mlpack::amf::SimpleToleranceTermination::Index
const double & Index()
Definition:
simple_tolerance_termination.hpp:119
mlpack::amf::SimpleToleranceTermination::reverseStepCount
size_t reverseStepCount
Definition:
simple_tolerance_termination.hpp:135
mlpack::amf::SimpleToleranceTermination::MaxIterations
const size_t & MaxIterations()
Definition:
simple_tolerance_termination.hpp:121
mlpack::amf::SimpleToleranceTermination::iteration
size_t iteration
Definition:
simple_tolerance_termination.hpp:129
mlpack::amf::SimpleToleranceTermination::tolerance
double tolerance
Definition:
simple_tolerance_termination.hpp:124
mlpack::amf::SimpleToleranceTermination::SimpleToleranceTermination
SimpleToleranceTermination(const double tolerance=1e-5, const size_t maxIterations=10000, const size_t reverseStepTolerance=3)
Definition:
simple_tolerance_termination.hpp:32
mlpack::amf::SimpleToleranceTermination::V
const MatType * V
Definition:
simple_tolerance_termination.hpp:127
mlpack::amf::SimpleToleranceTermination::H
arma::mat H
Definition:
simple_tolerance_termination.hpp:139
mlpack::amf::SimpleToleranceTermination::normOld
double normOld
Definition:
simple_tolerance_termination.hpp:132
mlpack::amf::SimpleToleranceTermination::Iteration
const size_t & Iteration()
Definition:
simple_tolerance_termination.hpp:120
mlpack::amf::SimpleToleranceTermination::residue
double residue
Definition:
simple_tolerance_termination.hpp:131
mlpack::amf::SimpleToleranceTermination::isCopy
bool isCopy
Definition:
simple_tolerance_termination.hpp:137
mlpack::amf::SimpleToleranceTermination::W
arma::mat W
Definition:
simple_tolerance_termination.hpp:138
mlpack::amf::SimpleToleranceTermination::reverseStepTolerance
size_t reverseStepTolerance
Definition:
simple_tolerance_termination.hpp:134
core.hpp
mlpack::amf::SimpleToleranceTermination::maxIterations
size_t maxIterations
Definition:
simple_tolerance_termination.hpp:125
mlpack::amf::SimpleToleranceTermination::c_index
double c_index
Definition:
simple_tolerance_termination.hpp:141
mlpack::amf::SimpleToleranceTermination
Definition:
simple_tolerance_termination.hpp:29
mlpack::amf::SimpleToleranceTermination::Initialize
void Initialize(const MatType &V)
Definition:
simple_tolerance_termination.hpp:39
mlpack::amf::SimpleToleranceTermination::c_indexOld
double c_indexOld
Definition:
simple_tolerance_termination.hpp:140
mlpack::amf::SimpleToleranceTermination::IsConverged
bool IsConverged(arma::mat &W, arma::mat &H)
Definition:
simple_tolerance_termination.hpp:54
mlpack::amf::SimpleToleranceTermination::residueOld
double residueOld
Definition:
simple_tolerance_termination.hpp:130
Generated by
1.8.6