Tensorflow: AUC的错误/问题与修正

                       Tensorflow: AUC的错误/问题与修正

AUC是评价模型的常用指标,Tensorflow作为著名的机器学习框架,自然有对这一指标的计算API,其官网API文档为AUC

问题

但是,这一API不是很好用,在此举一个很简单的例子:

import tensorflow as tf

x_1 = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 0.8, 0.9, 1]
y_1 = [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
x_2 = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
y_2 = [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]

x_placeholder = tf.placeholder(tf.float64, [10])
y_placeholder = tf.placeholder(tf.bool, [10])
auc = tf.metrics.auc(labels=y_placeholder, predictions=x_placeholder)
initializer = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

with tf.Session() as sess:
    sess.run(initializer)
    for i in range(3):
        auc_value, update_op = sess.run(auc, feed_dict={x_placeholder: x_1, y_placeholder: y_1})
        print('auc_1: ' + str(auc_value) + ", update_op: " + str(update_op))
        auc_value, update_op = sess.run(auc, feed_dict={x_placeholder: x_2, y_placeholder: y_2})
        print('auc_2: ' + str(auc_value) + ", update_op: " + str(update_op))

# output
# auc_1: 0.0, update_op: 1.999999e-07
# auc_2: 1.999999e-07, update_op: 0.48999995
# auc_1: 0.48999995, update_op: 0.32444444
# auc_2: 0.32444444, update_op: 0.48999995
# auc_1: 0.489999, update_op: 0.3904
# auc_2: 0.3904, update_op: 0.48999998

总的来说,上面这段代码反映出的令人困惑的问题有三个:

  1. 使用AUC时,一定要定义local variable initializer,不然就会报错,但是众所周知,AUC这个指标的计算过程不应该和“变量”产生关联
  2. 每次计算输出的update_op和下次输出的auc值总是相等
  3. 同样的输入,计算出的AUC不同

这三个问题我是在2017年夏天发现的,截止2018年8月,TF 1.10rc版本并未解决这一问题,官方文档中也未说明这一现象。

因此,我打算解决这一问题。

解决

经过认真阅读相关的源代码,可以发现,出现这些问题的原因是:Tensorflow在计算AUC这一问题上,采取了一种比较令人困惑的计算真阳性,假阳性,假阴性,真阴性值的计算策略。

为说明这一问题,我简单的贴一段Tensorflow中计算不同阈值下真阳性样本数量的源代码。

# assume num_threshold = 200
# label_is_pos is a tensor with shape (200,) dtype is tf.bool.
# pred_is_pos is a tensor with shape (200,) dtype is tf.bool
true_p = metric_variable([num_thresholds], dtypes.float32, name='true_positives')
is_true_positive = math_ops.to_float(math_ops.logical_and(label_is_pos, pred_is_pos))
update_ops['tp'] = state_ops.assign_add(true_p, math_ops.reduce_sum(is_true_positive, 1))
values['tp'] = true_p
return values, update_ops

这段代码说明了产生上述三个问题的原因:

  1. 计算tp会涉及到一个名为"true_positives"的变量,这是一个tensorflow的局部变量(local variable),这个变量记录了之前调用这一函数时的真阳性数量之和。因此使用auc时,一定要加入local variable initializer,不然就会报错。在第一次调用时,这一变量会被初始化为0。
  2. 可以确定,assign_add函数是造成update_op和下次输出的auc值总是相等的原因,但是官方的assign_add函数文档写的闪烁其词,源代码依赖太复杂,我还不能阐释这其中的具体细节是什么。
  3. update_op计算的其实并不是本次输入数据的auc,而是metrics.auc这一函数自从被第一次调用以来,所有单次计算出的auc的累计平均值。因此,每次计算出的AUC总是在变动。

理论上来说,只有第一次使用auc函数时,输出的update_op值是真正的auc,只要多调用几次,后面输出的AUC全部都是平均值,而非当前输入的数据的AUC。按照Tensorflow的官方说法,输出平均值似乎并不是一种错误,而是有意设计成这样的。但是这一设计就使得我们很难追踪测试集的AUC的实时变化,而且说句实在话我也不明白这到底有什么意义。

如果一定要用原生的tf框架,那么每次计算auc前,都要reset一次local variable。然后取auc函数返回的update_ops而非values作为auc的值

当然,每次计算auc都reset一次local variable非常不优雅,而且有潜在的风险。因此,我参考Tensorflow的实现,做了少量的修改,写了一个可以计算计算当前输入的AUC的函数。函数模块在Github上已经贴出,地址

我修改过的auc函数,返回值只有一个,就是输入参数(标签和预测)所计算出来的AUC值。我删除了和局部变量相关的代码,因此也就不用初始化局部变量了。

参考:https://blog.csdn.net/qq_37747262/article/details/82223155

已标记关键词 清除标记
相关推荐
简介 笔者当初为了学习JAVA,收集了很多经典源码,源码难易程度分为初级、中级、高级等,详情看源码列表,需要的可以直接下载! 这些源码反映了那时那景笔者对未来的盲目,对代码的热情、执着,对IT的憧憬、向往!此时此景,笔者只专注Android、Iphone等移动平台开发,看着这些源码心中有万分感慨,写此文章纪念那时那景! Java 源码包 Applet钢琴模拟程序java源码 2个目标文件,提供基本的音乐编辑功能。编辑音乐软件的朋友,这款实例会对你有所帮助。 Calendar万年历 1个目标文件 EJB 模拟银行ATM流程及操作源代码 6个目标文件,EJB来模拟银行ATM机的流程及操作:获取系统属性,初始化JNDI,取得Home对象的引用,创建EJB对象,并将当前的计数器初始化,调用每一个EJB对象的count()方法,保证Bean正常被激活和钝化,EJB对象是用完毕,从内存中清除,从账户中取出amt,如果amt>账户余额抛出异常,一个实体Bean可以表示不同的数据实例,我们应该通过主键来判断删除哪个数据实例…… ejbCreate函数用于初始化一个EJB实例 5个目标文件,演示Address EJB的实现 ,创建一个EJB测试客户端,得到名字上下文,查询jndi名,通过强制转型得到Home接口,getInitialContext()函数返回一个经过初始化的上下文,用client的getHome()函数调用Home接口函数得到远程接口的引用,用远程接口的引用访问EJB。 EJB中JNDI的使用源码例子 1个目标文件,JNDI的使用例子,有源代码,可以下载参考,JNDI的使用,初始化Context,它是连接JNDI树的起始点,查找你要的对象,打印找到的对象,关闭Context…… ftp文件传输 2个目标文件,FTP的目标是:(1)提高文件的共享性(计算机程序和/或数据),(2)鼓励间接地(通过程序)使用远程计算机,(3)保护用户因主机之间的文件存储系统导致的变化,(4)为了可靠和高效地传输,虽然用户可以在终端上直接地使用它,但是它的主要作用是供程序使用的。本规范尝试满足大型主机、微型主机、个人工作站、和TACs 的不同需求。例如,容易实现协议的设计。 Java EJB中有、无状态SessionBean的两个例子 两个例子,无状态SessionBean可会话Bean必须实现SessionBean,获取系统属性,初始化JNDI,取得Home对象的引用,创建EJB对象,计算利息等;在有状态SessionBean中,用累加器,以对话状态存储起来,创建EJB对象,并将当前的计数器初始化,调用每一个EJB对象的count()方法,保证Bean正常被激活和钝化,EJB对象是用完毕,从内存中清除…… Java Socket 聊天通信演示代码 2个目标文件,一个服务器,一个客户端。 Java Telnet客户端实例源码 一个目标文件,演示Socket的使用。 Java 组播组中发送和接受数据实例 3个目标文件。 Java读写文本文件的示例代码 1个目标文件。 java俄罗斯方块 一个目标文件。 Java非对称加密源码实例 1个目标文件 摘要:Java源码,算法相关,非对称加密   Java非对称加密源程序代码实例,本例中使用RSA加密技术,定义加密算法可用 DES,DESede,Blowfish等。   设定字符串为“张三,你好,我是李四”   产生张三的密钥对(keyPairZhang)   张三生成公钥(publicKeyZhang)并发送给李四,这里发送的是公钥的数组字节   通过网络或磁盘等方式,把公钥编码传送给李四,李四接收到张三编码后的公钥,将其解码,李四用张三的公钥加密信息,并发送给李四,张三用自己的私钥解密从李四处收到的信息…… Java利用DES私钥对称加密代码实例 同上 java聊天室 2个目标文件,简单。 java模拟掷骰子2个 1个目标文件,输出演示。 java凭图游戏 一个目标文件,简单。 java求一个整数的因子 如题。 Java生成密钥的实例 1个目标文件 摘要:Java源码,算法相关,密钥   Java生成密钥、保存密钥的实例源码,通过本源码可以了解到Java如何产生单钥加密的密钥(myKey)、产生双钥的密钥对(keyPair)、如何保存公钥的字节数组、保存私钥到文件privateKey.dat、如何用Java对象序列化保存私钥,通常应对私钥加密后再保存、如何从文件中得到公钥编码的字节数组、如何从字节数组解码公钥。 Java数据压缩与传输实例 1个目标文件 摘要:Java源码,文件操作,数据压缩,文件传输   Java数据压缩与传输实例,可以学习一下实例化套按字、得到文件输入流、压缩输入流、文件输出流、实例化缓冲
©️2020 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客 返回首页