SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
transfer
multitask
MultitaskLinearMachine.h
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Copyright (C) 2012 Sergey Lisitsyn
8
*/
9
10
#ifndef MULTITASKMACHINE_H_
11
#define MULTITASKMACHINE_H_
12
13
#include <
shogun/lib/config.h
>
14
#include <
shogun/machine/LinearMachine.h
>
15
#include <
shogun/transfer/multitask/TaskRelation.h
>
16
#include <
shogun/transfer/multitask/TaskGroup.h
>
17
#include <
shogun/transfer/multitask/TaskTree.h
>
18
#include <
shogun/transfer/multitask/Task.h
>
19
20
#include <vector>
21
#include <set>
22
23
using namespace
std;
24
25
namespace
shogun
26
{
30
class
CMultitaskLinearMachine
:
public
CLinearMachine
31
{
32
33
public
:
35
CMultitaskLinearMachine
();
36
43
CMultitaskLinearMachine
(
44
CDotFeatures
* training_data,
45
CLabels
* training_labels,
CTaskRelation
* task_relation);
46
48
virtual
~
CMultitaskLinearMachine
();
49
51
virtual
const
char
* get_name()
const
52
{
53
return
"MultitaskLinearMachine"
;
54
}
55
59
int32_t get_current_task()
const
;
60
64
void
set_current_task(int32_t task);
65
70
virtual
SGVector<float64_t>
get_w()
const
;
71
76
virtual
void
set_w(
const
SGVector<float64_t>
src_w);
77
82
virtual
void
set_bias(
float64_t
b);
83
88
virtual
float64_t
get_bias();
89
93
CTaskRelation
* get_task_relation()
const
;
94
98
void
set_task_relation(
CTaskRelation
* task_relation);
99
101
virtual
bool
supports_locking
()
const
{
return
true
; }
102
104
virtual
void
post_lock(
CLabels
* labels,
CFeatures
* features_);
105
107
virtual
bool
train_locked(
SGVector<index_t>
indices);
108
110
virtual
CBinaryLabels
* apply_locked_binary(
SGVector<index_t>
indices);
111
113
virtual
float64_t
apply_one(int32_t i);
114
115
protected
:
116
118
virtual
SGVector<float64_t>
apply_get_outputs(
CFeatures
* data=NULL);
119
121
virtual
bool
train_machine(
CFeatures
* data=NULL);
122
124
virtual
bool
train_locked_implementation(
SGVector<index_t>
* tasks);
125
127
SGVector<index_t>
* get_subset_tasks_indices();
128
129
private
:
130
132
void
register_parameters();
133
134
protected
:
135
137
int32_t
m_current_task
;
138
140
CTaskRelation
*
m_task_relation
;
141
143
SGMatrix<float64_t>
m_tasks_w
;
144
146
SGVector<float64_t>
m_tasks_c
;
147
149
vector< set<index_t> >
m_tasks_indices
;
150
151
};
152
}
153
#endif
SHOGUN
机器学习工具包 - 项目文档