业务场景算法落地 - 利用预训练&伪标注&蒸馏实现一个通用的分类模型baseline
业务背景:
场景化的智能对话助手下,某场景已经通过模板配置的方式冷启动并且在线上运行了一段时间(即通过线上日志可以收集到一些query),但是由于模板配置过泛导致该场景下误召回的话术过多,因此考虑增加一个分类模型来区分query是否属于该业务场景。
方案设计:
方案设计如上图,主要包含5部分:
1. 基于领域数据微调原生bert(领域数据来源于业务场景)
2. 伪标注数据获取(也可利用主动学习的方案,标注少量的数据,这样效果会更好)
3. 利用领域适应的bert和伪标注的数据,训练教师模型
(教师模型训练过程中,可以采用噪声标签识别方案,过滤出噪声标签,提高教师模型效果)
4. 未标注数据获取(也可采用数据增强的方式,补充训练数据)
5. 利用教师模型蒸馏学生模型(text_cnn)
具体步骤:
1. 基于原生bert(google发布的bert-base)使用领域下的数据,使用MLM任务进行继续预训练
2. 伪标注数据的打标,基于一个假设:该场景下用户使用频次高的话术,认为正例数据;负例数据来源于其他场景话术;
(也可采用主动学习的方案,标注少量的数据来训练,从实验来看,伪标注的数据始终不如人工标注的数据表现好)
3. 教师模型训练:使用微调过的bert和伪标注数据,来训练教师模型。经过对比实验,在训练数据较少(5w以下),且正负样本不均衡(正例占比10%-20%之间)时,使用12层的bert效果明显好于6层的bert。
伪标注数据训练教师模型,效果比标注数据差的原因,主要是由于伪标注数据中存在噪声标签,因此,若能够对噪声标签进行清洗,理论上会有效提升模型效果;一个可用的噪声标签识别的方法:clean_lab 开源的代码,但是试验发现,该方法效果提升有限(有时甚至会变差),但是这个方法仍有可探索和优化的空间。
4. 未标注数据,将该场景下所有query当做未标注数据,来训练学生模型;
5. 利用教师模型和未标注数据来蒸馏text_cnn得到学生模型,具体蒸馏方法如下:
由于bert和text_cnn为异构模型,所以只能使用最后一层(sigmoid之前的一层)来蒸馏学生模型。
教师bert预测未标注数据,得到最后一层的logits,在学生text_cnn中,使用教师logits和cnn的预测的logits计算loss来训练。
(之所以使用最后一层的logits,而不是直接使用bert预测得到的y_label,则是由于y_label包含的信息不如logits多)
业务效果
经过此方案,可以快速上线一个分类模型,相对于线上版本,尾部垂类满意度由平均满意度50%提升至80%,可以解决60%-80%的场景错误问题。