Simple Guidance Mechanisms for Discrete Diffusion Models (Schriff et al., 2025) という ICLR2025 に採択された論文で離散拡散モデルに分類器ガイダンス(CG)および分類器フリーガイダンス(CFG)を適用する手法が提案されていたので読んでみました。

連続拡散モデルでのCG/CFG

連続拡散モデルでは $x_{t-1}$ の条件付き確率分布が以下のように表されます。

$$p^\gamma(x_{t-1} \mid y, x_t) \propto p(y \mid x_{t-1})^\gamma \, p_\theta(x_{t-1} \mid x_t)$$

対数勾配をとると以下のようになります。(分類器ガイダンス)

$$\nabla_{x_{t-1}} \log p^\gamma(x_{t-1} \mid y, x_t) = \gamma \nabla_{x_{t-1}}\log p(y \mid x_{t-1}) + \nabla_{x_{t-1}}\log p_\theta(x_{t-1} \mid x_t)$$

ここでベイズの定理より

$$\log p_\theta(x_{t-1} \mid y, x_t) = \log p(y \mid x_{t-1}) + \log p_\theta​(x_{t−1}\mid x_t​)−\log p(y\mid x_t​)$$

が成り立ちます。 両辺の勾配をとって

$$\nabla_{x_{t-1}}\log p_\theta(x_{t-1}\mid y,x_t) = \nabla_{x_{t-1}}\log p(y\mid x_{t-1}) + \nabla_{x_{t-1}}\log p_\theta(x_{t-1}\mid x_t)$$

これを先ほどの分類器ガイダンスの式に代入して

$$\nabla_{x_{t-1}} \log p^\gamma(x_{t-1} \mid y, x_t) = \gamma \nabla_{x_{t-1}}\log p_\theta(x_{t-1}\mid y,x_t) + (1-\gamma) \nabla_{x_{t-1}}\log p_\theta(x_{t-1}\mid x_t)$$

が得られます。(分類器フリーガイダンス)

提案手法

離散拡散モデルでは潜在変数 $x_t$ が離散であるため対数尤度の勾配が計算できず、上のような式を直接使うことができません。これを解決するための手法として論文で提案されていた手法が以下の通りです。

分類器フリーガイダンス

連続拡散モデルの分類器フリーガイダンスの式

$$\log p^\gamma(x_{t-1} \mid y, x_t) = \gamma \log p(y \mid x_{t-1}) + \log p_\theta(x_{t-1} \mid x_t) + c$$

は、ベイズの定理より以下のように変形できました。

$$ \log p^\gamma(x_{t-1} \mid y, x_t) = \gamma \log p_\theta(x_{t-1}\mid y,x_t) + (1-\gamma) \log p_\theta(x_{t-1}\mid x_t) + c$$

離散拡散モデルではこれがトークン単位の計算となります。これを配列長 $L$ のトークン列に拡張すると以下のようになります。

$$\log p^\gamma\!\left(x_{t-1}^{(1:L)} \mid y, x_t^{(1:L)}\right) =\gamma \log p\!\left(x_{t-1}^{(1:L)} \mid y, x_t^{(1:L)}\right)+(1-\gamma)\log p\!\left(x_{t-1}^{(1:L)} \mid x_t^{(1:L)}\right)+ c$$

さらに確率の積の形に直すとこのようになります。

$$p_\theta^\gamma\!\left(x_{t-1}^{(1:L)} \mid y, x_t^{(1:L)} \right)=\prod_{\ell=1}^{L}\frac{1}{Z^{(\ell)}}\, p_\theta\!\left(x_{t-1}^{(\ell)} \mid y, x_t^{(1:L)} \right)^{\gamma}\, p_\theta\!\left(x_{t-1}^{(\ell)} \mid x_t^{(1:L)}\right)^{1-\gamma}$$

ただし

$$Z^{(\ell)}=\sum_{x_{t-1}'}p_\theta\!\left(x_{t-1}' \mid y, x_t^{(1:L)}\right)^{\gamma}p_\theta\!\left(x_{t-1}' \mid x_t^{(1:L)}\right)^{1-\gamma}$$

※ここで関数 $Z^{(\ell)}$ は分配関数であり、正規化定数の役割を果たしています。また、$x'_{t-1}$ は語彙の集合 $\nu$ に含まれる全ての種類のトークンを表しています。

分類器ガイダンス

トークン単位の計算は以下のようになります。

$$p^\gamma(x_{t-1} \mid x_t, y) = \frac{p(y \mid x_{t-1}, x_t)^\gamma p(x_{t-1} \mid x_t)}{Z^{(\ell)}}$$

ただし

$$Z^{(\ell)} = \sum_{x'_{t-1}} p(y \mid x'_{t-1}, x_t)^\gamma p(x'_{t-1} \mid x_t)$$

これを配列全体に拡張すると以下のようになります。

$$p_{\phi, \theta}^\gamma(x_{t-1}^{(1:L)} \mid x_t^{(1:L)}, y) = \prod_{\ell=1}^{L} \frac{p_\phi(y \mid \tilde{x}^{(1:L)})^\gamma p_\theta(x_{t-1}^{(\ell)} \mid x_t^{(1:L)})}{\sum_{\hat{x} \in \tilde{x}_\ell(x_t^{(1:L)})} p_\phi(y \mid \hat{x})^\gamma p_\theta(\hat{x}^{(\ell)} \mid x_t^{(1:L)})}$$

ただし

  • $\tilde{x}^{(1:L)}$: $x_t$ の $l$ 目のトークンを $x_{t-1}$ の $l$ 番目のトークンに置き換えたもの
  • $\hat{x}$: $x_t$ の $l$ 番目のトークンを $\nu$ に含まれるトークンのうちの1つに入れ替えたもの

実装

実際の実装方法を確認しました。公式の実装はGitHubで公開されています。
https://github.com/kuleshov-group/discrete-diffusion-guidance

分類器フリーガイダンス