Distributional Robustness Loss for Long-tail Learning


目录
  • 符号说明
  • 主要内容
    • Representation-learning loss
    • Robust loss
    • 可行的上界
    • Joint loss
    • 细节
  • 代码

Samuel D. and Chechik G. Distributional robustness loss for long-tail learning. In International Conference on Computer Vision (ICCV), 2021.

本文利用 Distributionally Robust Optimization (DRO) 来试图解决长尾问题, 出发点是, 小样本的类内中心由于缺乏数据, 和真实的类内中心往往有很大差距, 故作者用 DRO 来优化一定区域内最坏的情况来缓解这一问题.

符号说明

  • \((x_i, y_i), i=1,2,\cdots, n\), 共 \(n\) 组数据;
  • \(y_i \in \{c_1, c_2, \cdots, c_k\}\), 共 \(k\) 个类别;
  • \(f_{\theta}: x \rightarrow z\), 将样本转换为特征 \(z\);
  • \(Z := \{z_1, z_2, \cdots, z_n\}\) 为训练样本特征的集合;
  • \(S_c := \{z_i | y_i = c\}\), 为某一类特征的集合;
  • \(\hat{\mu}_c := \frac{1}{|S_c|} \sum_{z_i \in S_c} z_i\) 为一类的经验类内中心;
  • \(\mu_c := \mathbb{E}_{x \sim P|y=c} [z]\) 为真实的类内样本的中心.

主要内容

Representation-learning loss

启发自对比损失, 我们可以定义

\[P(z_i | \mu_c) := \frac{\exp(-d(\mu_c, z_i))}{\sum_{z' \in Z} e^{-d(\mu_c, z')}}, \]

这里 \(d(\cdot, \cdot)\) 可以是常见的欧式距离或者 cos 相似度, 看代码应该选择的是前者.

我们可以通过如下损失进行训练:

\[\mathcal{L}_{NLL}(Z; P; \theta) = \sum_{c \in C} w(c) (-\log P(S_c|\mu_c)) = -\sum_{c \in C} w(c) \sum_{z_i \in S_c} \log \frac{e^{-d(\mu_c, z_i)}}{\sum_{z' \in Z} e^{-d(\mu_c, z')}}. \]

通常设定 \(w(c) = \frac{1}{|S_c|}\) 来缓解头部类别的主宰效应.

Robust loss

但是上面的损失有个问题, 在实际中, 我们无法预先知道类内中心 \(\mu_c\), 所以, 我们只能通过 \(\hat{\mu}_c\) 来估计, 但是这个效果的好坏取决于该类的样本的个数. 对于小样本来说, 肯定是没法很好满足的.

我们定义 \(\hat{p}_c = \mathcal{N}(\hat{\mu}_c, \sigma^2 I)\), 表示对条件分布 \(p(x|y=c)\)的一个经验估计.

\[U_c := \{q| D(q\|\hat{p}_c) \le \epsilon_c\}, \]

其中 \(D\) 是两个分布的距离度量, 比如常见的 KL 的散度 (本文的选择). 倘若我们仅在服从正态分布 \(\mathcal{N}(\mu, \sigma_c^2 I)\)上进行讨论. 则 \(\mathcal{N}(\mu_q, \sigma_c^2I), \mathcal{N}(\hat{\mu}_c, \sigma_c^2 I)\) 之间的 KL 散度容易证得为:

\[\frac{1}{2\sigma_c^2} d(\mu_q, \hat{\mu}_c)^2. \]

我们希望优化

\[\min_{\theta} \sum_{c \in C} \sup_{q_c \in U_c} \mathbb{E}_{x \sim q_c} [\ell (z; Q_c;\theta)], \]

其在 \(U\) 内的最坏的情况.

可行的上界

在推导上界之前, 我们注意到一个性质:

\[D(q\|\hat{p}_c) = \frac{d(\mu_q, \hat{\mu}_c)^2}{2\sigma_c^2} \le \epsilon_c \rightarrow d(\mu_q, \hat{\mu}_c) \le \sqrt{2\epsilon_c} \sigma_c =: \Delta_c. \]

于是有:

\[d(\mu_q, z) \le d(\hat{\mu}_c, z) + d(\hat{\mu}_c, \mu_q), \\ d(\hat{\mu}_c, z) \le d(\mu_q, z) + d(\hat{\mu}_c, \mu_q). \\ \]

于是

\[\begin{array}{ll} P(z | \mu_q) :=Q_c(z) &= \frac{e^{-d(\mu_q, z)}}{\sum_{z' \in Z} e^{-d(\mu_q, z')}} \\ &= \frac{e^{-d(\mu_q, z)}}{\sum_{z_+ \in S_c} e^{-d(\mu_q, z_+)} + \sum_{z_- \not \in S_c} e^{-d(\mu_q, z_-)}} \\ &\ge \frac{e^{-d(\hat{\mu}_c, z) - \Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - \Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\mu_q, z_-)}} \\ &\ge \frac{e^{-d(\hat{\mu}_c, z) - \Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - \Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\hat{\mu}_c, z_-) + \Delta_c}} \\ &= \frac{e^{-d(\hat{\mu}_c, z) - 2\Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - 2\Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\hat{\mu}_c, z_-)}}. \\ \end{array} \]

相应的

\[\sup_{q_c \in U_c} \ell(z; Q_c; \theta) \le -\log \frac{e^{-d(\hat{\mu}_c, z) - 2\Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - 2\Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\hat{\mu}_c, z_-)}}. \\ \]

于是我们可以优化此上界, 定义为:

\[\tag{1} \mathcal{L}_{Robust} = -\sum_{c \in C} w(c) \sum_{z \in S_c}\log \frac{e^{-d(\hat{\mu}_c, z) - 2\Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - 2\Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\hat{\mu}_c, z_-)}}. \\ \]

Joint loss

最后, 作者采用的是如下的一个联合损失:

\[\mathcal{L} = \lambda \mathcal{L}_{CE} + (1 - \lambda) \mathcal{L}_{Robust}. \]

细节

  1. 注意到 (1) 中的分母部分是遍历 \(Z\) 的, 实际中是采取一个 batch 的特征;

  2. 为了 \(\hat{mu}_c\), 作者选择在每个 epoch 开始前, 遍历数据以估计 \(\hat{\mu}_c\);

  3. 实际训练采取的是长尾分布中常见的两阶段训练;

  4. 关于 \(\Delta_c\) 的选取, 可以有

    • 不同类别共享超参数 \(\Delta\);
    • 按照 \(\Delta / \sqrt{n}\) 的方式定义的超参数;
    • 可学习的 \(\Delta_c\)
      通过实现来看, 似乎可学习的 \(\Delta\) 的效果是最好的;
  5. \(Z\) 以及 \(\hat{\mu}_c\) 会首先通过标准训练进行一个初始化.

代码

[official]