SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
evaluation
PRCEvaluation.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2011 Sergey Lisitsyn
8
* Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
9
*/
10
11
#include <
shogun/evaluation/PRCEvaluation.h
>
12
#include <
shogun/labels/RegressionLabels.h
>
13
#include <
shogun/labels/BinaryLabels.h
>
14
#include <
shogun/mathematics/Math.h
>
15
16
using namespace
shogun;
17
18
CPRCEvaluation::~CPRCEvaluation
()
19
{
20
}
21
22
float64_t
CPRCEvaluation::evaluate
(
CLabels
* predicted,
CLabels
* ground_truth)
23
{
24
ASSERT
(predicted && ground_truth)
25
ASSERT
(predicted->
get_num_labels
()==ground_truth->
get_num_labels
())
26
ASSERT
(predicted->
get_label_type
()==
LT_BINARY
)
27
ASSERT
(ground_truth->
get_label_type
()==
LT_BINARY
)
28
ground_truth->
ensure_valid
();
29
30
// number of true positive examples
31
float64_t
tp = 0.0;
32
int32_t i;
33
34
// total number of positive labels in predicted
35
int32_t pos_count=0;
36
37
// initialize number of labels and labels
38
SGVector<float64_t>
orig_labels = predicted->
get_values
();
39
int32_t length = orig_labels.
vlen
;
40
float64_t
* labels =
SGVector<float64_t>::clone_vector
(orig_labels.
vector
, length);
41
42
// get indexes for sort
43
int32_t* idxs = SG_MALLOC(int32_t, length);
44
for
(i=0; i<length; i++)
45
idxs[i] = i;
46
47
// sort indexes by labels ascending
48
CMath::qsort_backward_index
(labels,idxs,length);
49
50
// clean and initialize graph and auPRC
51
SG_FREE(labels);
52
m_PRC_graph
=
SGMatrix<float64_t>
(2,length);
53
m_thresholds
=
SGVector<float64_t>
(length);
54
m_auPRC
= 0.0;
55
56
// get total numbers of positive and negative labels
57
for
(i=0; i<length; i++)
58
{
59
if
(ground_truth->
get_value
(i) > 0)
60
pos_count++;
61
}
62
63
// assure number of positive examples is >0
64
ASSERT
(pos_count>0)
65
66
// create PRC curve
67
for
(i=0; i<length; i++)
68
{
69
// update number of true positive examples
70
if
(ground_truth->
get_value
(idxs[i]) > 0)
71
tp += 1.0;
72
73
// precision (x)
74
m_PRC_graph
[2*i] = tp/
float64_t
(i+1);
75
// recall (y)
76
m_PRC_graph
[2*i+1] = tp/
float64_t
(pos_count);
77
78
m_thresholds[i]= predicted->
get_value
(idxs[i]);
79
}
80
81
// calc auRPC using area under curve
82
m_auPRC =
CMath::area_under_curve
(
m_PRC_graph
.
matrix
,length,
true
);
83
84
// set computed indicator
85
m_computed
=
true
;
86
87
SG_FREE(idxs);
88
return
m_auPRC
;
89
}
90
91
SGMatrix<float64_t>
CPRCEvaluation::get_PRC
()
92
{
93
if
(!
m_computed
)
94
SG_ERROR
(
"Uninitialized, please call evaluate first"
)
95
96
return
m_PRC_graph
;
97
}
98
99
SGVector<float64_t>
CPRCEvaluation::get_thresholds
()
100
{
101
if
(!
m_computed
)
102
SG_ERROR
(
"Uninitialized, please call evaluate first"
)
103
104
return
m_thresholds
;
105
}
106
107
float64_t
CPRCEvaluation::get_auPRC
()
108
{
109
if
(!
m_computed
)
110
SG_ERROR
(
"Uninitialized, please call evaluate first"
)
111
112
return
m_auPRC
;
113
}
SHOGUN
机器学习工具包 - 项目文档