NeurIPS 2024 に採択された論文 “Gradient Guidance for Diffusion Models: An Optimization Perspective” (Guo et al. 2024) を読んでみたので紹介します。

拡散モデルにおいて Classifier Guidance / Classifier-free Guidance などを用いたガイダンス手法が広く用いられているなか、この論文はユーザーが定義した任意の目的関数 $f$ を最大化するようにガイダンスを行うための手法について検討しています。

この論文では単純な勾配 $\nabla f$ を用いた勾配法による最適化では生成物が崩壊してしまうことを示し、“Look-ahead Loss” という損失を定義し、生成されたデータによるスコア関数のファインチューニングを繰り返しながら最小化するようなガイダンス手法を提案しています。

1. 背景と課題

拡散モデルにおいて、拡散過程を $\mathrm{d}\mathbf{x} = f( t)\mathbf{x}\mathrm{d}t + g(t)\mathrm{d}\mathbf{w}$ と表す場合の逆拡散過程は次のような確率微分方程式(SDE)で表されます。
($\mathrm{d}t$: 無限小ステップ, $\bar{\mathbf{w}}$: 時刻 $T \rightarrow 0$ まで逆向きにたどった際の標準ウィーナー過程)

$$\mathrm{d}\mathbf{x} = [f(t)\mathbf{x}-g(t)^2\nabla_\mathbf{x}\log p_t(\mathbf{x})]\mathrm{d}t + g(t) \mathrm{d}\bar{\mathbf{w}}$$

論文では拡散過程・逆拡散過程を以下のように表記しています。

$$\mathrm{d}{X}_t = -\frac12 q(t){X}_t\mathrm{d}t + \sqrt{q(t)}\mathrm{d}{W}_t\tag{1}$$

$$\mathrm{d}X_t^{\leftarrow} = \left[ \frac{1}{2}X_t^{\leftarrow} + \nabla \log p_{T-t}(X_t^{\leftarrow}) \right] \mathrm{d}t + \mathrm{d}\overline{W}_t \tag{2}$$

ここで、ベイズの定理より条件付きスコアが以下のように展開できることを考えます。

$$\nabla_{x_t} \log p_t({x_t} \mid y) = \nabla_{x_t} \log p_t(y \mid {x_t}) + \nabla_{x_t} \log p_t({x_t}) \tag{4}$$

これが条件なしスコアと $\nabla_{x_t} \log p_t(y \mid {x_t})$ の和であることから、逆拡散過程で条件をつけるためにはガイダンス項 $\mathrm{G} = \nabla_{x_t} \log p_t(y \mid {x_t})$ を加えればよいと分かります。そしてこれは Classifier Guidance に相当します。

$$\mathrm{d}X_t^{\leftarrow} = \left[ \frac{1}{2}X_t^{\leftarrow} + s_\theta(X_t^{\leftarrow}, T-t) + \mathbf{G}(X_t^{\leftarrow}, T-t) \right] \mathrm{d}t + \mathrm{d}\overline{W}_t$$

2. 提案手法

目的関数 $f$ の最大化

目的関数 $f$ を最大化するような逆拡散過程を考える際、単純に $\nabla f(\mathbf{x}_t)$ に比例するようなガイダンス項を用いると、データが本来の多様体から外れてしまいため生成品質が著しく低下し、うまくいかないことが論文中で示されています。

そこで論文中で提案されたのが、逆拡散時の各タイムステップにおける $\mathbf{x}_t$ ではなく、現在の $\mathbf{x}_t$ から推定されるノイズ除去後のデータ $\mathbf{x}_0$ を「先読み (look-ahead)」し、それに基づいた勾配を計算する手法です。

先読みによるガイダンス

まず以下のように仮定(1),(2)を置きます。

  • 仮定(1)
    • データ $X \in \mathbb{R}^D$ は、未知の行列 $A \in \mathbb{R}^{D\times d}$ と潜在変数 $U \in \mathbb{R}^d$ を用いて $X = AU$ と表せる。ただし、$d \ll D$ であり、行列 $A$ の列ベクトルは正規直交であるとする。
  • 仮定(2)
    • データ $x$ はガウス分布に従う
    • 目的関数 $f(x) = g^\top x$
    • $y = f(x) + \epsilon$

論文によれば、これらの仮定のもとで条件付きスコアには次のような比例関係が成り立ちます。

$$\nabla_{x_t} \log p_t(y\mid x_t) = -(2\sigma_y^2(x_t))^{-1}\cdot\nabla_{x_t}(y - g^\top \mathbb{E}[x_0\mid x_t])^2 \tag{6}$$

これに基づき提案手法のガイダンス項 $\mathrm{G}_{loss}$ は以下のように定義されます。
($\beta(t) > 0, y \in \mathbb{R}$ はハイパーパラメーターとなります)

$$\mathrm{G}_{loss}(x_t, t) := -\beta(t) \cdot \nabla_{x_t}(y - g^\top \mathbb{E}[x_0\mid x_t])^2 \tag{7}$$

初めに置いた仮定(1)のもとでは、以下のような包含関係が成立することが保証されます。これにより、最適化を進めてもデータ多様体から逸脱することを防ぐことができます。

$$\mathrm{G}_{loss}(x_t, t) \in \text{Span}(A) \tag{8}$$

この証明を行うための $\mathbb{E}[x_0\mid x_t]$ の計算には Tweedie の公式が用いられています。

更新アルゴリズム

各種パラメータ

  • $\beta(t)$: 大きさのパラメータ
  • $\{y_k\} \space (k=0,...,K-1)$: 最大化の目的値
  • $K$: イテレーション数
  • $B_k$: バッチサイズ

alt text Source: [Guo et al. 2024]

この Algorithm 1 を用いた生成プロセスは、$f$ が凹関数でかつL平滑性を満たすとき、以下の正則化付き最適化問題と等価であることが Theorem 2 で示されています。

$$x^*_{A, \lambda} = \underset{x \in \text{Span}(A)}{\text{argmax}} \left\{ f(x) - \frac{\lambda}{2} \|x - \bar{\mu}\|^2_{\bar{\Sigma}^{-1}} \right\} \tag{13}$$

ファインチューニングによる大域的最適化

ここまでの Algorithm 1 が抱える問題点として、事前学習済みのスコア関数によって正則化されてしまい、事前学習されたデータ分布の範囲内での最適解しか得ることができません。これを解決するために、生成されたデータによるスコア関数のファインチューニングを行います。 alt text Source: [Guo et al. 2024]

ファインチューニングは、以下の目的関数を最小化するようにスコア関数を最小化することで行います。

$$\min_{s \in \mathcal{S}} \int_0^T \sum_{i=0}^k w_{k,i} \mathbb{E}_{x_0 \in \mathcal{D}_i} \mathbb{E}_{x_t | x_0} \left[ \| \nabla_{x_t} \log \phi_t(x_t | x_0) - s(x_t, t) \|_2^2 \right] \mathrm{d}t \tag{14}$$

このような更新を繰り返すことで最終的に得られる分布の平均 $\mu_K$ が正則化なしの大域的最適解 $f^*_A$ に収束することが Theorem 3 によって保証されています。

3. 実装

公式実装のコードを見てみました。
https://github.com/yukang123/GGDMOptim/tree/main

$\mathrm{G}_{loss}$ の実装

https://github.com/yukang123/GGDMOptim/blob/main/simulations/diffusion.py

def cond_score_v2(self, x, t, classes):
    r"""
    Gradient Guidance G_loss
    Following Definiton 1/ Lemma 1 in the original paper

    G_loss = - beta_t * \nabla_{x_t} (y - g^T * E[x_0 | x_t])^2
    """

    alpha_t = extract(self.sqrt_alphas_cumprod, t, (x.shape[0], 1))
    h_t = extract(self.sqrt_one_minus_alphas_cumprod, t, (x.shape[0], 1)) ** 2

    #### compute the gradient of g^T * E[x_0 | x_t] w.r.t. x_t
    with torch.enable_grad():
        x_in = x.detach().requires_grad_(True) # x_t, batched samples
        pred_noise = self.model(x_in, t, classes=None) # predict noise
        noise_svd = extract(
            self.sqrt_one_minus_alphas_cumprod, t, pred_noise.shape)
        uncond_score = -1 * pred_noise * 1 / noise_svd
        x_0_hat = (x_in + h_t * uncond_score) / alpha_t # get E[x_0 | x_t]
        value = torch.sum(x_0_hat * self.g, dim=1, keepdim=True)
        gradient = torch.autograd.grad(value.sum(), x_in)[0] # Get graidents for each sample in a batch parallelly

    ## 
    # beta_t 
    ## theoretical value: a variant of the beta_t shown in Lemma 1
    ## beta_t := alpha_t^2/2 / (alpha_t^2 * \sigma^2 + h(t) ||\nabla_{x_t} (g^T * E[x_0 | x_t])||^2)
    beta_t = (alpha_t ** 2 / 2) / (
        alpha_t ** 2 * self.sigma ** 2 + h_t * torch.norm(gradient, dim=1, keepdim=True) ** 2
        ) if self.beta is None else torch.full((x.shape[0], 1), self.beta, device=x.device)
    beta_t = beta_t * self.beta_coef 

    ## gradient of the loss (y - g^T * E[x_0 | x_t]) w.r.t. x_t
    gradient_2 = 2 * (value - classes.view(-1, 1)) * gradient 
    
    G = uncond_score - beta_t * gradient_2 # get the conditional score
    pred_noise = -1 * G * noise_svd # transform the conditional score to noise
    return pred_noise

最適化ループ

https://github.com/yukang123/GGDMOptim/blob/main/simulations/main.py

Algorithm 1 と 2 はif args.score_matching_finetune:のブロック部分以外は共通しているようです。

def optimize(args, model, diffusion, generator, func, func_type, query_samples, interval, opt_logger, reward_K):
    '''
    Reward optimization with guided diffusion model
    '''
    print("====================Optimization Starts====================")
    if args.score_matching_finetune:
        optimizer_ft_sm = torch.optim.Adam(diffusion.model.parameters(), lr=args.sm_ft_lr, betas=(0.9, 0.99))
        sm_ft_round = args.sm_ft_round if args.sm_ft_round is not None and args.sm_ft_round > 0 else args.opt_rounds - args.sm_ft_start_round
        finetune_rounds = 0
        
    samples_list = []
    reward_list = []
    top_k_samples_list = []

    sample_time = 0
    opt_start_time = time.time()
    for i in range(args.opt_rounds):
        ## 1. get y_k (B_k, d_outer)
        # Query gradient g_k (get gradient for each sample z_(k,i))
        bs_gradient, bs_value = BatchGrad(query_samples, func)         
        diffusion.update_grad(bs_gradient) #, batch_size=bs_gradient.shape[0])
        # y_k = delta + g_k^T * z_k
        guidance = interval + torch.sum(bs_gradient * query_samples, dim=1)
        opt_logger.log(guidance=guidance[0].item(), mean_guidance=torch.mean(guidance).item())

        if args.score_matching_finetune:
            ## Alg.2: Fine-tuning Diffusion Model with Score Matching Loss
            if i >= args.sm_ft_start_round and finetune_rounds < sm_ft_round:
                model.train()
                steps = args.sm_ft_step_per_round # [Deprecated] default: 1
                batch_size = len(query_samples) // steps
                for j in range(steps):
                    loss_diffusion = diffusion(query_samples[j*batch_size: (j+1)*batch_size, :])
                    optimizer_ft_sm.zero_grad()
                    loss_diffusion.backward()
                    optimizer_ft_sm.step()
                    opt_logger.log(loss_diffusion=loss_diffusion.item())
                finetune_rounds += 1
                print("finetune the pretrained score model | round: ", finetune_rounds)
                model.eval()

                if finetune_rounds == sm_ft_round:
                    save_checkpoint(...)
        
        query_samples = diffusion.sample(
            num_samples=args.generate_bs,
            classes= guidance if not args.disable_guidance else None,
            )

        reward = func(query_samples)
        mean_reward = torch.mean(reward)
        ratio = generator.off_support_ratio(query_samples)
        reward_list.append(reward.cpu().numpy())

    return samples_list, reward_list, top_k_samples_list

感想

条件ベクトル $y \in \mathbb{R}^d$ による制御ではなく、目的関数 $f$ を最大化するような方向性のガイダンス手法としてとても良いやり方であると感じた。本論文以前のガイダンス手法の中には、論文中で多様体からの逸脱を理由に否定されていた $\nabla f$ を更新する方向性の研究 (例: Gruver et al. 2023) もあり、面白かった。

参考文献

  • Guo, Y., Yuan, H., Yang, Y., Chen, M., & Wang, M. (2024). Gradient Guidance for Diffusion Models: An Optimization Perspective (No. arXiv:2404.14743). arXiv. https://doi.org/10.48550/arXiv.2404.14743
  • Gruver, N., Stanton, S., Frey, N. C., Rudner, T. G. J., Hotzel, I., Lafrance-Vanasse, J., Rajpal, A., Cho, K., & Wilson, A. G. (2023). Protein Design with Guided Discrete Diffusion (No. arXiv:2305.20009). arXiv. https://doi.org/10.48550/arXiv.2305.20009