多模态分类网络不平衡优化的分析&在多模态生成模型的尝试
一、多模态分类网络产生训练不平衡的原因
1.多模态分类网络模型结构
2.参数的梯度分别为
将(5)式带入上面即可获得梯度。我们假设a模态的特征提取的效果更好,根据softmax的特性Wa·φa的张量形式更接近one-hot,假设此时v模态的特征提取不如a模态,Wv·φv的张量类似均匀分布(每个类别的置信度都差不多,无法判断类别)。导致模型在全局收敛的后,v模态仍无法很好提取特征,从而出现模态不平衡优化的现象。
二、解决方案
1.梯度指导
Balanced Multimodal Learning via On-the-fly Gradient Modulation
https://arxiv.org/abs/2203.15332
根据梯度的不平衡,计算指导因子k,通过k去指导梯度下降,降低模态间的影响。
2.单模态教师
Improving Multi-Modal Learningwith Uni-Modal Teachers
https://arxiv.org/abs/2106.11059
通过预训练得到优化相对较好的单模态模型(可以提取一些特征),用单模态网络,在训练的过程中指导多模态网络。
(3)式第一项为普通的分类损失,第二项如下:
三、在多模态生成模型上的尝试
1.在MVAE上基于POE方法,使用梯度分析不平衡的出现
根据最终计算的梯度无法从中看出各个模特的不平衡优化......
2.在MVAE上基于POE方法,用过实验观察不平衡的出现
mnist 单模态
generator_acc: 96.33%
lantents Accuracy: 88.69%
svhn单模态
generator_acc: 62.93%
lantents classified Accuracy: 23.23%
mutimodel_poe
muti_net_generator_mnist_acc: 94.22%
muti_net_generator_svhn_acc: 66.31%
lantents Accuracy: 89.81%
其中
generator_acc 指测试集重构生成的准确度
lantents classified Accuracy 指测试集隐空间分类准确度
四、猜测
基于poe的融合方法是否真的存在不平衡的优化问题,如果存在,应该用什么样的评价指标观测
评论已关闭