Distributional Robustness Loss for Long-tail Learning
- 概
- 符号说明
- 主要内容
- Representation-learning loss
- Robust loss
- 可行的上界
- Joint loss
- 细节
- 代码
Samuel D. and Chechik G. Distributional robustness loss for long-tail learning. In International Conference on Computer Vision (ICCV), 2021.
概
本文利用 Distributionally Robust Optimization (DRO) 来试图解决长尾问题, 出发点是, 小样本的类内中心由于缺乏数据, 和真实的类内中心往往有很大差距, 故作者用 DRO 来优化一定区域内最坏的情况来缓解这一问题.
符号说明
- \((x_i, y_i), i=1,2,\cdots, n\), 共 \(n\) 组数据;
- \(y_i \in \{c_1, c_2, \cdots, c_k\}\), 共 \(k\) 个类别;
- \(f_{\theta}: x \rightarrow z\), 将样本转换为特征 \(z\);
- \(Z := \{z_1, z_2, \cdots, z_n\}\) 为训练样本特征的集合;
- \(S_c := \{z_i | y_i = c\}\), 为某一类特征的集合;
- \(\hat{\mu}_c := \frac{1}{|S_c|} \sum_{z_i \in S_c} z_i\) 为一类的经验类内中心;
- \(\mu_c := \mathbb{E}_{x \sim P|y=c} [z]\) 为真实的类内样本的中心.
主要内容
Representation-learning loss
启发自对比损失, 我们可以定义
\[P(z_i | \mu_c) := \frac{\exp(-d(\mu_c, z_i))}{\sum_{z' \in Z} e^{-d(\mu_c, z')}}, \]这里 \(d(\cdot, \cdot)\) 可以是常见的欧式距离或者 cos 相似度, 看代码应该选择的是前者.
我们可以通过如下损失进行训练:
\[\mathcal{L}_{NLL}(Z; P; \theta) = \sum_{c \in C} w(c) (-\log P(S_c|\mu_c)) = -\sum_{c \in C} w(c) \sum_{z_i \in S_c} \log \frac{e^{-d(\mu_c, z_i)}}{\sum_{z' \in Z} e^{-d(\mu_c, z')}}. \]通常设定 \(w(c) = \frac{1}{|S_c|}\) 来缓解头部类别的主宰效应.
Robust loss
但是上面的损失有个问题, 在实际中, 我们无法预先知道类内中心 \(\mu_c\), 所以, 我们只能通过 \(\hat{\mu}_c\) 来估计, 但是这个效果的好坏取决于该类的样本的个数. 对于小样本来说, 肯定是没法很好满足的.
我们定义 \(\hat{p}_c = \mathcal{N}(\hat{\mu}_c, \sigma^2 I)\), 表示对条件分布 \(p(x|y=c)\)的一个经验估计.
\[U_c := \{q| D(q\|\hat{p}_c) \le \epsilon_c\}, \]其中 \(D\) 是两个分布的距离度量, 比如常见的 KL 的散度 (本文的选择). 倘若我们仅在服从正态分布 \(\mathcal{N}(\mu, \sigma_c^2 I)\)上进行讨论. 则 \(\mathcal{N}(\mu_q, \sigma_c^2I), \mathcal{N}(\hat{\mu}_c, \sigma_c^2 I)\) 之间的 KL 散度容易证得为:
\[\frac{1}{2\sigma_c^2} d(\mu_q, \hat{\mu}_c)^2. \]我们希望优化
\[\min_{\theta} \sum_{c \in C} \sup_{q_c \in U_c} \mathbb{E}_{x \sim q_c} [\ell (z; Q_c;\theta)], \]其在 \(U\) 内的最坏的情况.
可行的上界
在推导上界之前, 我们注意到一个性质:
\[D(q\|\hat{p}_c) = \frac{d(\mu_q, \hat{\mu}_c)^2}{2\sigma_c^2} \le \epsilon_c \rightarrow d(\mu_q, \hat{\mu}_c) \le \sqrt{2\epsilon_c} \sigma_c =: \Delta_c. \]于是有:
\[d(\mu_q, z) \le d(\hat{\mu}_c, z) + d(\hat{\mu}_c, \mu_q), \\ d(\hat{\mu}_c, z) \le d(\mu_q, z) + d(\hat{\mu}_c, \mu_q). \\ \]于是
\[\begin{array}{ll} P(z | \mu_q) :=Q_c(z) &= \frac{e^{-d(\mu_q, z)}}{\sum_{z' \in Z} e^{-d(\mu_q, z')}} \\ &= \frac{e^{-d(\mu_q, z)}}{\sum_{z_+ \in S_c} e^{-d(\mu_q, z_+)} + \sum_{z_- \not \in S_c} e^{-d(\mu_q, z_-)}} \\ &\ge \frac{e^{-d(\hat{\mu}_c, z) - \Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - \Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\mu_q, z_-)}} \\ &\ge \frac{e^{-d(\hat{\mu}_c, z) - \Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - \Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\hat{\mu}_c, z_-) + \Delta_c}} \\ &= \frac{e^{-d(\hat{\mu}_c, z) - 2\Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - 2\Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\hat{\mu}_c, z_-)}}. \\ \end{array} \]相应的
\[\sup_{q_c \in U_c} \ell(z; Q_c; \theta) \le -\log \frac{e^{-d(\hat{\mu}_c, z) - 2\Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - 2\Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\hat{\mu}_c, z_-)}}. \\ \]于是我们可以优化此上界, 定义为:
\[\tag{1} \mathcal{L}_{Robust} = -\sum_{c \in C} w(c) \sum_{z \in S_c}\log \frac{e^{-d(\hat{\mu}_c, z) - 2\Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - 2\Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\hat{\mu}_c, z_-)}}. \\ \]Joint loss
最后, 作者采用的是如下的一个联合损失:
\[\mathcal{L} = \lambda \mathcal{L}_{CE} + (1 - \lambda) \mathcal{L}_{Robust}. \]细节
-
注意到 (1) 中的分母部分是遍历 \(Z\) 的, 实际中是采取一个 batch 的特征;
-
为了 \(\hat{mu}_c\), 作者选择在每个 epoch 开始前, 遍历数据以估计 \(\hat{\mu}_c\);
-
实际训练采取的是长尾分布中常见的两阶段训练;
-
关于 \(\Delta_c\) 的选取, 可以有
- 不同类别共享超参数 \(\Delta\);
- 按照 \(\Delta / \sqrt{n}\) 的方式定义的超参数;
- 可学习的 \(\Delta_c\)
通过实现来看, 似乎可学习的 \(\Delta\) 的效果是最好的;
-
\(Z\) 以及 \(\hat{\mu}_c\) 会首先通过标准训练进行一个初始化.
代码
[official]