20 using namespace shogun;
22 CPrimalMosekSOSVM::CPrimalMosekSOSVM()
29 CPrimalMosekSOSVM::CPrimalMosekSOSVM(
38 void CPrimalMosekSOSVM::init()
47 m_regularization = 1.0;
51 CPrimalMosekSOSVM::~CPrimalMosekSOSVM()
55 bool CPrimalMosekSOSVM::train_machine(
CFeatures* data)
57 SG_DEBUG(
"Entering CPrimalMosekSOSVM::train_machine.\n");
61 CFeatures* model_features = get_features();
63 m_model->init_training();
65 m_model->check_training_setup();
66 SG_DEBUG(
"The training setup is correct.\n");
69 int32_t M = m_model->get_dim();
71 int32_t num_aux = m_model->get_num_aux();
73 int32_t num_aux_con = m_model->get_num_aux_con();
77 SG_DEBUG(
"M=%d, N =%d, num_aux=%d, num_aux_con=%d.\n", M, N, num_aux, num_aux_con);
80 CMosek* mosek =
new CMosek(0, M+num_aux+N);
82 REQUIRE(mosek->get_rescode() == MSK_RES_OK,
"Mosek object could not be properly created in PrimalMosekSOSVM training.\n");
87 m_model->init_primal_opt(m_regularization, A, a, B, b, lb, ub, C);
90 "%s::train_machine(): lb.vlen can only be 0 or w.vlen!\n", get_name());
93 "%s::train_machine(): ub.vlen can only be 0 or w.vlen!\n", get_name());
101 SG_DEBUG(
"Regularization used in PrimalMosekSOSVM equal to %.2f.\n", m_regularization);
104 REQUIRE(mosek->init_sosvm(M, N, num_aux, num_aux_con, C, m_lb, m_ub, A, b) == MSK_RES_OK,
105 "Mosek error in PrimalMosekSOSVM initializing SO-SVM.\n")
119 for ( int32_t i = 0 ; i < N ; ++i )
126 int32_t num_con = num_aux_con;
127 int32_t old_num_con = num_con;
128 bool exception =
false;
138 SG_DEBUG(
"Iteration #%d: Cutting plane training with num_con=%d and old_num_con=%d.\n",
139 iteration, num_con, old_num_con);
141 old_num_con = num_con;
143 for ( int32_t i = 0 ; i < N ; ++i )
160 while ( cur_res != NULL )
168 if ( slack > max_slack + m_epsilon )
172 if ( ! insert_result(cur_list, result) )
178 add_constraint(mosek, result, num_con, i);
185 if ( ! insert_result(cur_list, result) )
191 add_constraint(mosek, result, num_con, i);
200 SG_DEBUG(
"Entering Mosek QP solver.\n");
202 mosek->optimize(sol);
203 for ( int32_t i = 0 ; i < M+num_aux+N ; ++i )
207 else if ( i < M+num_aux )
210 m_slacks[i-M-num_aux] = sol[i];
213 SG_DEBUG(
"QP solved. The primal objective value is %.4f.\n", mosek->get_primal_objective_value());
217 }
while ( old_num_con != num_con && ! exception );
219 po_value = mosek->get_primal_objective_value();
231 int32_t M = m_w.vlen;
247 SG_ERROR(
"model(%s) should have either of psi_computed or psi_computed_sparse"
248 "to be set true\n", m_model->get_name());
253 bool CPrimalMosekSOSVM::insert_result(
CList* result_list,
CResultSet* result)
const
259 SG_PRINT(
"ResultSet could not be inserted in the list..."
260 "aborting training of PrimalMosekSOSVM\n");
266 bool CPrimalMosekSOSVM::add_constraint(
272 int32_t M = m_model->get_dim();
277 for (
int i = 0 ; i < M ; ++i )
288 SG_ERROR(
"model(%s) should have either of psi_computed or psi_computed_sparse"
289 "to be set true\n", m_model->get_name());
292 return ( mosek->add_constraint_sosvm(dPsi, con_idx, train_idx,
293 m_model->get_num_aux(), -result->
delta) == MSK_RES_OK );
297 float64_t CPrimalMosekSOSVM::compute_primal_objective()
const
307 void CPrimalMosekSOSVM::set_regularization(
float64_t C)
309 m_regularization = C;
SGVector< float64_t > psi_truth
float64_t loss(float64_t prediction, float64_t label)
static float64_t dot(const bool *v1, const bool *v2, int32_t n)
compute dot product between v1 and v2 (blas optimized)
Base class of the labels used in Structured Output (SO) problems.
virtual bool init(CFeatures *features)=0
CSGObject * get_next_element()
static const float64_t INFTY
infinity
void add_to_dense(T alpha, T *vec, int32_t dim, bool abs_val=false)
virtual int32_t get_num_vectors() const =0
static const float64_t epsilon
CSGObject * get_first_element()
int32_t get_num_elements()
static T max(T a, T b)
return the maximum of two integers
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
SGVector< T > clone() const
Class CStructuredModel that represents the application specific model and contains most of the applic...
T dense_dot(T alpha, T *vec, int32_t dim, T b)
The class Features is the base class of all feature objects.
SGVector< float64_t > psi_pred
CSGObject * get_element(int32_t index) const
CHingeLoss implements the hinge loss function.
SGSparseVector< float64_t > psi_truth_sparse
void push_back(CSGObject *e)
void set_epsilon(float *begin, float max)
SGSparseVector< float64_t > psi_pred_sparse
Class List implements a doubly connected list for low-level-objects.
bool insert_element(CSGObject *data)