论文标题

一种有效的方法来训练小型模型,以解决知识蒸馏的回归问题

An Efficient Method of Training Small Models for Regression Problems with Knowledge Distillation

论文作者

Takamoto, Makoto, Morishita, Yusuke, Imaoka, Hitoshi

论文摘要

压缩深神经网络(DNN)模型成为现实应用程序的非常重要且必要的技术,例如在移动设备上部署这些模型。知识蒸馏是模型压缩的最流行方法之一,并且已经对开发这种技术进行了许多研究。但是,这些研究主要集中在分类问题上,尽管在回归问题上有许多DNN的应用,但很少有关于回归问题的尝试。在本文中,我们提出了针对回归问题的知识蒸馏的新形式主义。首先,我们提出了一种新的损失功能,教师的拒绝损失损失,该损失拒绝使用教师模型预测来培训样本中的异常值。其次,我们考虑一个具有两个输出的多任务网络:一个估计训练标签,通常受到嘈杂标签的污染;以及其他估计教师模型的输出,预计将在记忆效果下修改噪声标签。通过考虑多任务网络,对学生模型的特征提取的培训变得更加有效,它使我们获得了比从头开始训练的培训更好的学生模型。我们使用一个简单的玩具模型进行了全面的评估:正弦函数和两个开放数据集:mpiigaze和Multi-Pie。我们的结果表明,无论数据集中的注释误差级别如何,准确性都一致。

Compressing deep neural network (DNN) models becomes a very important and necessary technique for real-world applications, such as deploying those models on mobile devices. Knowledge distillation is one of the most popular methods for model compression, and many studies have been made on developing this technique. However, those studies mainly focused on classification problems, and very few attempts have been made on regression problems, although there are many application of DNNs on regression problems. In this paper, we propose a new formalism of knowledge distillation for regression problems. First, we propose a new loss function, teacher outlier rejection loss, which rejects outliers in training samples using teacher model predictions. Second, we consider a multi-task network with two outputs: one estimates training labels which is in general contaminated by noisy labels; And the other estimates teacher model's output which is expected to modify the noise labels following the memorization effects. By considering the multi-task network, training of the feature extraction of student models becomes more effective, and it allows us to obtain a better student model than one trained from scratch. We performed comprehensive evaluation with one simple toy model: sinusoidal function, and two open datasets: MPIIGaze, and Multi-PIE. Our results show consistent improvement in accuracy regardless of the annotation error level in the datasets.

扫码加入交流群

加入微信交流群

微信交流群二维码

扫码加入学术交流群,获取更多资源