Fine-Tuning Language Models with Just Forward Passes

预备知识

核心思想

20260125165102

20260125171920

关于学习率 $\eta$

proof:

$$ \begin{align*} \mathbb{E} \left[ \|\hat{\nabla}\mathcal{L}(\theta; \mathcal{B}) \|^2 \right] &\overset{\epsilon \rightarrow 0}{=} \frac{1}{B^2n^2} \sum_{x_1, x_2 \in \mathcal{B}} \sum_{i, j=1}^n \mathbb{E}\left[ (\bm{z}_i \bm{z}_i^T g_1)^T (\bm{z}_j \bm{z}_j^T g_2) \right] \\ &= \frac{1}{B^2n^2} \sum_{x_1, x_2 \in \mathcal{B}} \sum_{i, j=1}^n \text{Tr}\left( \mathbb{E}\left[ \bm{z}_i \bm{z}_i^T \bm{z}_j \bm{z}_j^T \right] \mathbb{E}\left[ g_2 g_1^T \right] \right) \\ &= \frac{1}{B^2n^2} \text{Tr}\left( \sum_{x_1, x_2 \in \mathcal{B}} \left (n (n-1) \bm{I}_d + \sum_{i = j} \mathbb{E}\left[ \bm{z} \bm{z}^T \bm{z} \bm{z}^T \right] \right ) \mathbb{E}\left[ g_2 g_1^T \right] \right ) \\ &= \frac{1}{B^2n^2} \text{Tr}\left( \sum_{x_1, x_2 \in \mathcal{B}} \left (n (n-1) \bm{I}_d + n \mathbb{E}\left[ \bm{z} \bm{z}^T \bm{z} \bm{z}^T \right] \right ) \mathbb{E}\left[ g_2 g_1^T \right] \right ) \\ &= \frac{1}{B^2n} \text{Tr}\left( \sum_{x_1, x_2 \in \mathcal{B}} \left ((n-1) \bm{I}_d + \mathbb{E}\left[ \bm{z} \bm{z}^T \bm{z} \bm{z}^T \right] \right ) \mathbb{E}\left[ g_2 g_1^T \right] \right ) \\ &= \frac{1}{B^2n} \text{Tr}\left( \sum_{x_1, x_2 \in \mathcal{B}} \left ((n-1) \bm{I}_d + (d + 2) \bm{I}_d \right ) \mathbb{E}\left[ g_2 g_1^T \right] \right ) \\ &= \frac{d + n + 1}{B^2n} \sum_{x_1, x_2 \in \mathcal{B}} \mathbb{E}\left[ g_1^Tg_2 \right] \\ \end{align*} $$

类似的, 我们有

$$ \begin{align*} \mathbb{E}\left[ \|\nabla\mathcal{L}(\theta; \mathcal{B}) \|^2 \right] &= \frac{1}{B^2} \sum_{x_1, x_2 \in \mathcal{B}} \mathbb{E}[g_1^Tg_2] \end{align*} $$

注: 这里证明出来的 scale factor 实际上是 $\frac{d+n+1}{n}$, 和文中的结果不同, 我检查了这里的证明似乎没有什么问题, 如果有人发现纰漏, 请随时指出.

参考文献

  1. Malladi S., Gao T., Nichani E., Damian A., Lee J. D., Chen D. and Arora S. Fine-Tuning Language Models with Just Forward Passes. NeurIPS, 2023. [PDF] [Code]