SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
StructuredOutputMachine.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Written (W) 2013 Thoralf Klein
9  * Written (W) 2012 Fernando José Iglesias García
10  * Copyright (C) 2012 Fernando José Iglesias García
11  */
12 
14 
15 using namespace shogun;
16 
18 : CMachine(), m_model(NULL), m_surrogate_loss(NULL)
19 {
20  register_parameters();
21 }
22 
24  CStructuredModel* model,
25  CStructuredLabels* labs)
26 : CMachine(), m_model(model), m_surrogate_loss(NULL)
27 {
28  SG_REF(m_model);
29  set_labels(labs);
30  register_parameters();
31 }
32 
34 {
38 }
39 
41 {
42  SG_REF(model);
44  m_model = model;
45 }
46 
48 {
49  SG_REF(m_model);
50  return m_model;
51 }
52 
53 void CStructuredOutputMachine::register_parameters()
54 {
55  SG_ADD((CSGObject**)&m_model, "m_model", "Structured model", MS_NOT_AVAILABLE);
56  SG_ADD((CSGObject**)&m_surrogate_loss, "m_surrogate_loss", "Surrogate loss", MS_NOT_AVAILABLE);
57  SG_ADD(&m_verbose, "verbose", "Verbosity flag", MS_NOT_AVAILABLE);
58  SG_ADD((CSGObject**)&m_helper, "helper", "Training helper", MS_NOT_AVAILABLE);
59 
60  m_verbose = false;
61  m_helper = NULL;
62 }
63 
65 {
67  REQUIRE(m_model != NULL, "please call set_model() before set_labels()\n");
69 }
70 
72 {
74 }
75 
77 {
78  return m_model->get_features();
79 }
80 
82 {
83  SG_REF(loss);
85  m_surrogate_loss = loss;
86 }
87 
89 {
91  return m_surrogate_loss;
92 }
93 
95 {
96  int32_t dim = m_model->get_dim();
97 
98  int32_t from=0, to=0;
99  CFeatures* features = get_features();
100  if (info)
101  {
102  from = info->m_from;
103  to = (info->m_N == 0) ? features->get_num_vectors() : from+info->m_N;
104  }
105  else
106  {
107  from = 0;
108  to = features->get_num_vectors();
109  }
110  SG_UNREF(features);
111 
112  float64_t R = 0.0;
113  for (int32_t i=0; i<dim; i++)
114  subgrad[i] = 0;
115 
116  for (int32_t i=from; i<to; i++)
117  {
118  CResultSet* result = m_model->argmax(SGVector<float64_t>(W,dim,false), i, true);
119  SGVector<float64_t> psi_pred = result->psi_pred;
120  SGVector<float64_t> psi_truth = result->psi_truth;
121  SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, 1.0, psi_pred.vector, dim);
122  SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, -1.0, psi_truth.vector, dim);
123  R += result->score;
124  SG_UNREF(result);
125  }
126 
127  return R;
128 }
129 
131 {
132  SG_ERROR("%s::risk_nslack_slack_rescale() has not been implemented!\n", get_name());
133  return 0.0;
134 }
135 
137 {
138  SG_ERROR("%s::risk_1slack_margin_rescale() has not been implemented!\n", get_name());
139  return 0.0;
140 }
141 
143 {
144  SG_ERROR("%s::risk_1slack_slack_rescale() has not been implemented!\n", get_name());
145  return 0.0;
146 }
147 
149 {
150  SG_ERROR("%s::risk_customized_formulation() has not been implemented!\n", get_name());
151  return 0.0;
152 }
153 
155  TMultipleCPinfo* info, EStructRiskType rtype)
156 {
157  float64_t ret = 0.0;
158  switch(rtype)
159  {
161  ret = risk_nslack_margin_rescale(subgrad, W, info);
162  break;
164  ret = risk_nslack_slack_rescale(subgrad, W, info);
165  break;
167  ret = risk_1slack_margin_rescale(subgrad, W, info);
168  break;
170  ret = risk_1slack_slack_rescale(subgrad, W, info);
171  break;
172  case CUSTOMIZED_RISK:
173  ret = risk_customized_formulation(subgrad, W, info);
174  break;
175  default:
176  SG_ERROR("%s::risk(): cannot recognize the risk type!\n", get_name());
177  ret = -1;
178  break;
179  }
180  return ret;
181 }
182 
184 {
185  if (m_helper == NULL)
186  {
187  SG_ERROR("%s::get_helper(): no helper has been created!"
188  "Please set verbose before training!\n", get_name());
189  }
190 
191  SG_REF(m_helper);
192  return m_helper;
193 }
194 
196 {
197  m_verbose = verbose;
198 }
199 
201 {
202  return m_verbose;
203 }

SHOGUN Machine Learning Toolbox - Documentation