離散拡散モデルについての理解を深めるために既存の実装を参考にしながらタンパク質配列生成を試してみました。

参考にした論文

1. 離散拡散モデルについて

本論文(D3PMと呼びます)で紹介されている離散拡散モデルの仕組みについて紹介します。

順過程

離散拡散モデルの拡散過程では以下のようにノイズを載せて拡散していきます。

$$ q(x_t \mid x_{t-1}) = \text{Cat}(\boldsymbol{x_t}; \boldsymbol{p}=\boldsymbol{x_{t-1}}\boldsymbol{Q_t}) $$

ここで $x_t$ は時刻 $t$ におけるタンパク質配列のワンホットベクトル、$Q_t$ は時刻 $t$ における遷移行列を表します。遷移行列の作成方法は様々ですが、今回は論文中で紹介されている Absorbing-state (吸収状態) と呼ばれるものを使います。吸収状態とはアミノ酸が [MASK] トークンで置換された状態を指します。$Q_t$ は語彙サイズを $K$ として $(K, K)$ の行列であり各要素は以下のようにして決定されます。$m$ は既に [MASK] トークンに置換されている状態を指します。

$$ [\boldsymbol{Q}_t]_{ij} =\begin{cases}1 & \text{if } i = j = m \quad \\1 - \beta_t & \text{if } i = j \neq m \quad \\\beta_t & \text{if } j = m, i \neq m \quad \end{cases} $$

$\beta_t$ の決め方はノイズスケジューリングと呼ばれ様々な方法がありますが、D3PMでは

$$\beta_t = \frac{1}{T - t + 1}$$

とされています。

逆過程

D3PMの逆拡散過程では以下のようにして $x_{t-1}$ を計算します。

$$ q(x_{t-1} | x_t, x_0) = \text{Cat}\left( \boldsymbol{x}_{t-1}; \boldsymbol{p} = \frac{\boldsymbol{x}_t \boldsymbol{Q}_t^\top \odot \boldsymbol{x}_0 \bar{\boldsymbol{Q}}_{t-1}}{\boldsymbol{x}_0 \bar{\boldsymbol{Q}}_t \boldsymbol{x}_t^\top} \right) $$

ただし、ここでの $x_0$ は $x_t$ を入力とするニューラルネットワークによって推定したものです。今回はこのニューラルネットワークとして ESM-2 の事前学習済みモデルに時間情報を加えたものを学習させて用いました。

損失関数

DDPMでは変分下限を式変形していくことで以下のような損失関数となります。

$$ \begin{aligned} & L_{\text{vb}} = L_T + \sum_{t=2}^{T} L_{t-1} + L_0 \\ & \begin{cases} L_T &= \mathbb{E}_{q(x_{1:T}|x_0)} \left( \log \frac{q(x_T|x_0)}{p(x_T)} \right), \\ L_{t-1} &= \mathbb{E}_{q(x_{1:T}|x_0)} \left( \log \frac{q(x_{t-1}|x_t, x_0)}{p_\theta(x_{t-1}|x_t)} \right), \\ L_0 &= -\mathbb{E}_{q(x_{1:T}|x_0)} \left( \log p_\theta(x_0|x_1) \right). \end{cases} \end{aligned} $$

D3PMではこれを踏襲し、変分下限項とクロスエントロピー項の重み付け和を損失関数としています。

$$L_\lambda = L_{\text{vb}} + \lambda \, \mathbb{E}_{q(\boldsymbol{x}_0)} \mathbb{E}_{q(\boldsymbol{x}_t \mid \boldsymbol{x}_0)} \left[ - \log \widetilde{p}_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_t) \right]$$

2. 実装

今回実装を行うのにあたり、以下の資料を参考にしました。

ニューラルネットワーク

D3PMの逆拡散過程では $\tilde{p}_\theta(x_0 \mid x_t)$ で表されるニューラルネットワークが必要となります。今回は以下のように実装しました。

class DiffusedESM(nn.Module):
    def __init__(self, pretrained_model_name="facebook/esm2_t6_8M_UR50D"):
        super().__init__()
        self.esm = EsmForMaskedLM.from_pretrained(pretrained_model_name)
        self.config = self.esm.config
        self.hidden_dim = self.config.hidden_size

        time_emb_dim = self.hidden_dim
        self.time_embed = nn.Sequential(
            PosEncoding(self.hidden_dim),  
            nn.Linear(self.hidden_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, self.hidden_dim),
        )


    def forward(self, x, t, cond=None):
        inputs_embeds = self.esm.esm.embeddings.word_embeddings(x)
        t_emb = self.time_embed(t.float())  
        t_emb = t_emb.unsqueeze(1)

        inputs_embeds = inputs_embeds + t_emb

        outputs = self.esm(inputs_embeds=inputs_embeds)
        return outputs.logits

モデルのバックボーンには、大量のタンパク質配列で訓練されたマスク言語モデルであるESM-2を用いました。ESM-2には EsmForMaskedLM というクラスがあり、引数にアミノ酸配列の埋め込み表現を入力することでロジットが手に入ります。

合わせて、時刻 $t$ を正弦波位置エンコーディングにより変換したベクトルを導入するため PosEncoding クラスを準備します。

class PosEncoding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

拡散過程・逆拡散過程

元のコードは画像データ向け・一様拡散のものであったため、部分的に変更しています。

class D3PM(nn.Module):
    def __init__(
        self,
        x0_model: nn.Module,
        n_T: int,
        num_classes: int = 33, 
        mask_token_id: int = 32,  
        forward_type="absorbing",  
        hybrid_loss_coeff=0.01, 
    ) -> None:
        super(D3PM, self).__init__()
        self.x0_model = x0_model

        self.n_T = n_T
        self.hybrid_loss_coeff = hybrid_loss_coeff
        self.num_classses = num_classes
        self.eps = 1e-8  

        t_steps = torch.arange(1, n_T + 1, dtype=torch.float64)
        self.beta_t = 1.0 / (n_T - t_steps + 1)

        q_onestep_mats = []
        q_mats = []  
        

        for beta in self.beta_t:
            if forward_type == "absorbing":
                # 遷移行列 Q_t を作成する
                mat = torch.eye(num_classes) * (1 - beta)
                mat[:, mask_token_id] += beta
                mat[mask_token_id, :] = 0
                mat[mask_token_id, mask_token_id] = 1.0

                q_onestep_mats.append(mat)
            else:
                raise NotImplementedError

        q_one_step_mats = torch.stack(q_onestep_mats, dim=0)
        q_one_step_transposed = q_one_step_mats.transpose(
            1, 2
        )
        q_mat_t = q_onestep_mats[0].double()
        q_mats = [q_mat_t.float()]

        # 遷移行列の積 \bar{Q} = Q_{T-1}・Q_{T-2} ... Q_2・Q_1 の計算 
        for idx in range(1, self.n_T):
            q_mat_t = q_mat_t @ q_onestep_mats[idx].double()
            q_mats.append(q_mat_t.float())
        q_mats = torch.stack(q_mats, dim=0)
        self.logit_type = "logit"

        self.register_buffer("q_one_step_transposed", q_one_step_transposed)
        self.register_buffer("q_mats", q_mats)

        assert self.q_mats.shape == (
            self.n_T,
            num_classes,
            num_classes,
        ), self.q_mats.shape

    def _at(self, a, t, x):
        bs = t.shape[0]
        t = t.reshape((bs, *[1] * (x.dim() - 1)))
        return a[t - 1, x, :]

    # q(x_{t-1} | x_t, x_0)  
    def q_posterior_logits(self, x_0, x_t, t):
        if x_0.dtype == torch.int64 or x_0.dtype == torch.int32:
            x_0_logits = torch.log(
                torch.nn.functional.one_hot(x_0.long(), self.num_classses) + self.eps
            )
        else:
            x_0_logits = x_0.clone()

        fact1 = self._at(self.q_one_step_transposed, t, x_t)

        softmaxed = torch.softmax(x_0_logits, dim=-1)  
        t_idx = torch.clamp(t - 2, min=0)
        qmats2 = self.q_mats[t_idx].to(dtype=softmaxed.dtype)

        fact2 = torch.einsum("b...c,bcd->b...d", softmaxed, qmats2)
        out = torch.log(fact1 + self.eps) + torch.log(fact2 + self.eps)
        t_broadcast = t.reshape((t.shape[0], *[1] * (x_t.dim())))
        bc = torch.where(t_broadcast == 1, x_0_logits, out)

        return bc

    # 変分下限の計算
    def vb(self, dist1, dist2):
        dist1 = dist1.flatten(start_dim=0, end_dim=-2)
        dist2 = dist2.flatten(start_dim=0, end_dim=-2)

        out = torch.softmax(dist1 + self.eps, dim=-1) * (
            torch.log_softmax(dist1 + self.eps, dim=-1)
            - torch.log_softmax(dist2 + self.eps, dim=-1)
        )
        return out.sum(dim=-1).mean()

    # q(x_t | x_0) の計算 (拡散過程)
    def q_sample(self, x_0, t, noise):
        logits = torch.log(self._at(self.q_mats, t, x_0) + self.eps)
        noise = torch.clip(noise, self.eps, 1.0)
        gumbel_noise = -torch.log(-torch.log(noise))
        return torch.argmax(logits + gumbel_noise, dim=-1)

    # \tilde{p}(x_0 | x_t) の計算 (逆拡散過程)
    def model_predict(self, x_t, t, cond):
        predicted_x0_logits = self.x0_model(x_t, t.float(), cond)
        return predicted_x0_logits

    def forward(self, x: torch.Tensor, cond: torch.Tensor = None) -> torch.Tensor:
        t = torch.randint(1, self.n_T, (x.shape[0],), device=x.device)
        x_t = self.q_sample(
            x, t, torch.rand((*x.shape, self.num_classses), device=x.device)
        )
        # x_{t_1} を学習データ中の x_0 から計算
        true_q_posterior_logits = self.q_posterior_logits(x, x_t, t)

        # x_{t-1} の分布をニューラルネットワークで推測した x_0 から計算
        predicted_x0_logits = self.model_predict(x_t, t, cond)
        pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x_t, t)

        # 2つの x_{t-1} 分布から変分下限を計算 
        vb_loss = self.vb(true_q_posterior_logits, pred_q_posterior_logits)

        predicted_x0_logits = predicted_x0_logits.flatten(start_dim=0, end_dim=-2)
        x = x.flatten(start_dim=0, end_dim=-1)

        # 正しい x_0 と推測された x_0 のクロスエントロピー誤差
        ce_loss = torch.nn.CrossEntropyLoss()(predicted_x0_logits, x)

        # 重み付け和
        return vb_loss + self.hybrid_loss_coeff * ce_loss, {
            "vb_loss": vb_loss.detach().item(),
            "ce_loss": ce_loss.detach().item(),
        }

    # 逆拡散操作 
    def p_sample(self, x, t, cond, noise):
        predicted_x0_logits = self.model_predict(x, t, cond)
        pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x, t)

        noise = torch.clip(noise, self.eps, 1.0)

        not_first_step = (t != 1).float().reshape((x.shape[0], *[1] * (x.dim())))

        gumbel_noise = -torch.log(-torch.log(noise))
        sample = torch.argmax(
            pred_q_posterior_logits + gumbel_noise * not_first_step, dim=-1
        )
        return sample

    # 推論
    def sample(self, x, cond=None):
        for t in reversed(range(1, self.n_T)):
            t = torch.tensor([t] * x.shape[0], device=x.device)
            x = self.p_sample(
                x, t, cond, torch.rand((*x.shape, self.num_classses), device=x.device)
            )
        return x

訓練

SwissProt と呼ばれるタンパク質のアミノ酸配列のデータセット(の一部)を使って学習しました。

if __name__ == "__main__":
    backbone = DiffusedESM()
    d3pm = D3PM(
        x0_model=backbone,
        n_T=1000,
        num_classes=N,
        mask_token_id=tokenizer.mask_token_id,
        forward_type="absorbing",
        hybrid_loss_coeff=0.01,
    ).cuda()

    dataset = SwissProtDataset(tokenizer, max_len=128, max_samples=50000)
    dataloader = DataLoader(dataset, batch_size=1600, shuffle=True, num_workers=4)

    optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=5e-4)
    d3pm.train()

    n_epoch = 100 
    device = "cuda"

    global_step = 0
    for i in range(n_epoch):
        pbar = tqdm(dataloader)
        loss_ema = None
        for x in pbar:
            optim.zero_grad()
            x = x.to(device)
            cond = None

            loss, info = d3pm(x, cond)

            loss.backward()
            norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 1.0)

            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.99 * loss_ema + 0.01 * loss.item()
            pbar.set_description(
                f"loss: {loss_ema:.4f}, norm: {norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}"
            )
            optim.step()
            global_step += 1

3. 推論

最後に推論を試してみます。完全に拡散された状態から無条件で配列を生成しました。

d3pm.eval()

with torch.no_grad():
    B_sample = 3
    L_sample = 256
    init_noise = torch.full(
        (B_sample, L_sample), tokenizer.mask_token_id, device=device
    )

    generated_ids = d3pm.sample(init_noise, cond=None)
    for seq_ids in generated_ids:
        seq = tokenizer.decode(seq_ids, skip_special_tokens=True)
        print(f"Seq: {seq[:50].replace(" ", "")}...") 

結果

<学習初期>
Seq: IDQAKKKEIGLKSYRAFELPQPTES
Seq: CKGPYDLIKNERKKATEIVYKIEQN
Seq: VLINITAQLNDDASINAEGNVGVTD

<10エポック後>
Seq: MSILNGKDVKEATERVSEMLEKDTG
Seq: MAREHIKALLEILGTLVLLFSAGAV
Seq: MTDLGIDPHVIVCPIGINRIVIGGP

4. 最後に

  • 条件付き生成
  • 部分生成 (inpainting)
  • ガイダンス

なども試してみたいと思っています。