【论文阅读】Layer-wise Conditioning Analysis in Exploring the Learning Dynamics of DNNs

Posted by UUQ on 2024-10-09
Estimated Reading Time 7 Minutes
Words 1.9k In Total
Viewed Times

这篇文章提出通过层间分析法(layer-wise analysis) 实现在神经网络训练中“debug”,并解释了Batch Norm后的网络到底哪好?

概述

一、Hessian矩阵

从这里开始可能会多次混用Hessian Matrix和FIM,这两个之间有数学上的关系:数学补充

具体来讲,在训练DNN时,Hessian矩阵 或 FIM(Fisher Information Matrix)非常重要。但是计算Hessian矩阵实际上是需要很大的内存和时间,所以试图使用效率高的方法去近似,比如K-FAC算法。不过K-FAC算法有两个假设:1) weight-gradients in different layers are assumed to be uncorrelated; 2) the input and output-gradient in each layer are approximated as independent.

于是就可以成功地把full FIM表示为一个分块对角阵:$\pmb{F} = diag(F_1,…,F_K)$,其中每一个$F_k$称为sub-FIM(具体到每一层的FIM),并可通过以下公式计算/估算:

image-20241018233650629

上面这个式子其实有点像文中比较开始部分中从一维线性回归的Hessian矩阵计算,到多维。其形式都是一个Covariance Matrix 然后$\otimes$ 一个东西(这个$\otimes$是Kronecker product,链接中的博客有简单介绍)

二、layer-wise条件分析的基础——从full到sub:取其精华

1. FIM的重要信息

为了让sub-FIM也能表征出full FIM能给我们的信息,首先我们要知道full FIM中哪些信息比较重要,然后再看看sub-FIM能不能也拿到这些:

其中,黑塞矩阵的两个具体特征更重要,他们分别是1. 矩阵最大特征值$\lambda_{max}$ 2. 矩阵的条件数(conditional number) $\kappa (\pmb{H})$

因为:

Property 1 :$\lambda_{max}(F_k)$indicates the magnitude of the weight-gradient in each layer, which shows the steepness of the landscape w.r.t.different layers.
Property 2 : $\kappa (F_k)$indicates how easy it is to optimize the corresponding layer.

这里的大K表示网络层数,小k表示第几层。

2. 实验

**目标:**证明sub-FIMs能反映出full FIM的信息和趋势

**任务:**MNIST

**变量:**BN 和 无BN(plained)

观察变量:某个layer(图中用了第三层和第六层)的sub-FIM和full FIM的两变量($\lambda_{max}, \kappa_{percentage}$)趋势是否有明显一致性。

注:这里的$\kappa$用的是$\frac{\lambda_{max}}{\lambda_{p}}$的定义,其中$\lambda_p$代表第p大的特征值,比如$\kappa_{100%}$其实就是条件数的最初定义。

结果(Figure 1):

image-20241021210641220

还是比较明显的,两方面结论:

  1. plained出现梯度消失(最大特征值1e-5),而BN组梯度就很好;BN提升了conditioning。
  2. 上文提出的layer-wise conditioning分析有展现网络学习动态的潜力。

3. 高效地直接计算 两个值

既然已经验证通过sub-FIMs来展现网络学习动态的能力(主要通过这两个属性),那么考虑直接高效算出这两个值,而不是先近似sub-FIM再计算。

定义:$\Sigma_{x}=\mathbb{E}{p(x)}(xx^T), \Sigma{\nabla\pmb{h}} = \mathbb{E}{q(y|x)}(\frac{\partial l}{\partial \pmb{h}}^T\frac{\partial l}{\partial \pmb{h}})$ 分别是输入的协方差矩阵、反向传播的梯度协方差矩阵,该层的sub-FIM $F=\Sigma_x \otimes \Sigma{\nabla\pmb{h}}$,则有
$$
\lambda_{max}(F)=\lambda_{max}(\Sigma_x)\cdot \lambda_{max}(\Sigma_{\nabla \pmb{h}})\
\kappa(F)=\kappa(\Sigma_x)\cdot \kappa(\Sigma_{\nabla \pmb{h}})
$$

三、BN好吗?到底怎么好?

1. BN的层间尺度不变性

在原文第四节中,使用上文的layer-wise 条件分析法,分析了BatchNorm怎么样。

首先提出了两个Theorem,描述了Normalization方法使网络具有一种“缩放不变性”。(这里有待再理解)

实验

然后就是基于实验进行分析了,同样还是MNIST,20层MLP,每层256 neurons。

变量: 1. initialization方法:random / HE-init 2. Plain / BN

image-20241022095906625

结果表明,BN能很好地维持住input/output-gradient在各层的取值,而plain就会使得这两个值指数级的下降,HE-initialization缓解了这一点。

不过我有一个问题就是,如果只是从上文的层间分析来看,我们想要得到$\lambda_{max}(F)=\lambda_{max}(\Sigma_{x})\cdot \lambda_{max}(\Sigma_{\nabla \pmb{h}})$ ,这里虽然拆开了两方向的最大特征值,发现了明显变化,但是乘起来不就抵消了吗?如果单纯只用之前的方法去算$\lambda_{max}(F), \kappa(F)$岂不是看不出什么差别???

2. Weight Domination

定义

某一层发生Weight Domination现象为:权重矩阵$W_k$、梯度$\frac{\partial L}{\partial W_k}$,有$\lambda_{max}(W_k)>> \lambda_{max}(\frac{\partial L}{\partial W_k})$,这时候的$\lambda$就只能是奇异值了。

其实这种weight domination本质上是某一层的梯度太小了,参数可能几乎不更新了,这可能是因为W的不断增加或input的不断减少(难道说这里呼应了上一节对$\lambda_{max}(F)$的拆分?)

然而,即使BN稳定了layer input,但可能也同样会因为weight的magnitude增加而表现变差。

后果&实验

Weight Domination使得网络学习能力变差,这里又做了个实验,通过禁止其权重更新的方式模拟weight domination(这对吗…)

image-20241022104632084

观察到0的越多,表现越差(但是我觉得这是肯定的啊,你都不让更新权重了,有的时候可能显然是可以很好地更新的,这种模拟是否真的hold)

3. BN改善了DNN各层的条件(condition)

使用layer-wise分析方法验证“BN中的whitening input可以提升优化的条件” 这个proposition,结果发现不仅提升了每层输入的condition,还提升了梯度的condition。

另外还有 dying neurons 问题,就是ReLU以后未激活(输入小于等于0),full neurons则是对每个example都activate;猜测正是这个dying neurons导致plain组有很多小/零的eigenvalue,而BN组因为center了,所以没有什么dying/full(这里存疑啊,为什么考虑full,以及真就没dying了吗,怎么保证对所有一定都激活?猜测?)

四、 深度残差网络

本文认为,训练深度残差网络时随着层数增加,表现变差不仅仅是因为过拟合,同样也有优化困难的原因。

引入实验

ResNet实验,复现Resnet论文的实验。观察到loss先增加、再在random guess中几乎不变(weight domination),最后足够多的迭代后减小。

层间条件分析

分析最后一个线性层(交叉熵之前),观察输入的$\lambda_{max}(\Sigma_x)$,参数矩的二范式$||W||2$,和梯度协方差的$\lambda{max}$。

认为:输入x的$\lambda$大,使得梯度大,后面是W大,进而W使得loss增大(这到底是为什么。。。

1. 解决方法

在最后一个线性层前加一个BN。

实验复现

TODO:

  1. 构建正确的模型结构,train方法,到底怎么写。
  2. 如何正确地挂上hook,以及正确算出想要的值(这些值是每层的?还是整个网络的?怎么计算)
  3. 画图和可视化
  4. 对比,以及使用两个工具。

备注

TODO:

  1. BP推导 和 3.2节理解
  2. 3.1节数学推导
  3. Pytorch torch.hook 学习 + 两节实验复现
  4. .lua文件学习
  5. Structure tensor? second moment matrix

Unsolved problem

  1. Eqn 1,2, 4到底怎么来的
  2. condition number到底有什么用
  3. 3.1节高效计算,理解proposition1
  4. 3.2节 BP,弄清楚输出梯度、输入梯度、权重梯度什么的。
  5. 第四节的Theorem1和2, 以及weight domination的definition
  6. 残差网络:定义、性质和训练。
  7. 如何理解第五节5.1之前那段,解释loss变化的Analysis of Learning Dynamics.

PDF with note


如果您喜欢此博客或发现它对您有用,则欢迎对此发表评论。 也欢迎您共享此博客,以便更多人可以参与。 如果博客中使用的图像侵犯了您的版权,请与作者联系以将其删除。 谢谢 !