抱歉,您的浏览器无法访问本站
本页面需要浏览器支持(启用)JavaScript
了解详情 >

DiT

  • DiT的主要结构

    DiT是从LDM改进而来的,将LDM中的U-Net替换为Transformer,VAE仍然使用Stable Diffusion的。

    最后一个DiT block之后,将tokens decode为output noise prediction和covariance prediction。

  • DiT的Patchify过程

    与ViT的一样,使用卷积将压缩后的z切分成一个一个的token,并嵌入位置信息。

  • DiT的不同条件注入机制

    • In-context conditioning:将去噪时间步和条件c作为额外两个token附在输入序列前;最后一个block之后再将它们移除。
    • Cross-attnetion:将去噪时间步和条件c连接成一个长度为2的序列,在原始transformer block中插入cross attention并注入条件。
    • adaLN:将原始layerNorm替换为adaLN,即从t和c中回归得到scale和shift。
    • adaLN-Zero:将每个ResBlock初始化为identity function。同时除了回归scale和shift外,同时在每个Residual connection前回归dimension-wise scale的参数。
  • AdaLN的实现

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    class AdaLayerNorm(nn.Module):

    def __init__(self, embedding_dim: int, num_embeddings: int):
    super().__init__()
    self.emb = nn.Embedding(num_embeddings, embedding_dim)
    self.silu = nn.SiLU()
    self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
    self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)

    def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
    emb = self.linear(self.silu(self.emb(timestep)))
    scale, shift = torch.chunk(emb, 2)
    x = self.norm(x) * (1 + scale) + shift # 用 timestep 进行 affine
    return x

SiT