Deep Mutual Learning

发布 : 2021-07-18 分类 : 深度学习 浏览 :

《Deep Mutual Learning》论文笔记,该文提出了 DML 策略,在训练过程中模型之间互相学习。实验结果显示,有力的先验教师网络不是必要的,学生网络之间互学习蒸馏的性能比从静态有力的教师网络中蒸馏效果更好。

深度神经网络中的互信息可以参考这个视频

1 引言

深度神经网络在很多问题上表现出色的性能,但这些网络经常包含大量的参数,这个确定造成执行时缓慢或者需要大量内存,限制了应用和平台。这加剧了更快更小模型的研究。
基于知识蒸馏的得到的小模型经常有与大模型相同的表征容量。与大模型相比,这些小模型很难训练去找到正确的参数。知识蒸馏方法以强大的教师网络为基础,训练更小的学生网络去模仿教师网络。模仿教师网络分类的概率或者表征,传递的信息超过了传统的监督学习目标。学习模仿老师的优化问题比直接学习目标函数更容易,而且小得多的学生网络可以匹配甚至超过更大的老师网络。
本文探索了与知识蒸馏思想相关但又不同的互学习。互学习以未训练过的学习网络为基础,同时训练去解决任务。具体来说,每个学生使用两个 loss 来训练:传统监督学习 loss 和模仿 loss,将每个学生的类别后验与其他学生的分类概率对齐。观察发现:

  • (1)训练网络的效率随着网络的数量增加而增加。
  • (2)适用于多种网络架构,适用于由大小混合网络组成的异构队列。
  • (3)与独立训练相比,即使是在队列中相互训练的大型网络也能提高性能。
  • (4)我们注意到,虽然我们的重点是获得一个有效的网络,但整个队列也可以用作高效的集成模型。

2 深度互学习

2.1 Formulation

image.png
上图展示了 DML 方法的整体流程。

  • 传统监督损失(交叉熵)

image.png
传统的监督损失训练网络以预测训练实例的正确标签。为了提高 Θ1 在测试实例上的泛化性能,我们使用另一个对等网络 Θ2 以其后验概率 p2 的形式提供训练经验。为了衡量两个网络的预测 p1 和 p2 的匹配,我们采用 Kullback Leibler (KL) Divergence。

  • KL 距离

image.png
KL 散度不满足对称性,两个网络的损失分别如下:
image.png
通过这种方式,每个网络都学习正确预测训练实例的真实标签(监督损失 LC)以及匹配其对等方的概率估计(KL 模仿损失)。

2.2 Optimisation

优化策略是分别对网络输入相同的样本,依次更新模型参数,
image.png

2.3 Extension to Larger Student Cohorts

扩展到更过学生网络有两种方案。

  • 1.每个网络的 KL 项累加再取均值

image.png
取均值是为了让损失的主体是监督 loss。

  • 2.概率分布取均值再算 KL,相当于看做只有一个 teacher

image.png

作者通过实验发现,第一种方法效果更好,为了什么呢?因为 DML 策略的目标之一是产生更高的后验熵,而(9)方法得到的教师网络后验概率在真正类别上更高,与目标相悖。

3 实验

image.png
相同网络有提升,不同网络也有提升。
image.png
上表数据是两个 MobileNet 采用 DML 策略训练,两者的性能取均值。
image.png
与蒸馏策略对比。
image.png
上图可以发现,学生增高,性能也得到提高。

DML 为什么 work

image.png
主要原因是它考虑了其他类别的熵。具体看原文。

本文作者 : HeoLis
原文链接 : https://ishero.net/Deep%20Mutual%20Learning.html
版权声明 : 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明出处!

学习、记录、分享、获得

微信扫一扫, 向我投食

微信扫一扫, 向我投食