找回密码
 立即注册
首页 业界区 业界 stable diffusion论文解读

stable diffusion论文解读

缑莺韵 2025-6-3 00:20:14
High-Resolution Image Synthesis with Latent Diffusion Models

论文背景

LDM是Stable Diffusion模型的奠基性论文
于2022年6月在CVPR上发表
1.png

传统生成模型具有局限性:

  • 扩散模型(DM)通过逐步去噪生成图像,质量优于GAN,但直接在像素空间操作导致高计算开销。
  • 随着分辨率提升,扩散模型的优化和推理成本呈指数级增长,限制了实际应用
如DDPM生成的图像分辨率普遍不超过256×256,而LDM生成的图像分辨率可以超过1024×1024.
而LDM通过将扩散过程迁移至潜在空间,解决了传统模型的计算瓶颈,同时保持生成质量与灵活性
论文框架方法

论文中框架示意图如图所示:
2.png

在训练阶段:

  • 预训练自动编码器(AE)和条件生成编码器(如clip)
  • 输入图片x,经过自动编码器压缩到隐空间ε(x)=z
  • 随机采样时间步T,对Z进行加噪到\(Z_{T}\)
  • 对右边框里的条件进行条件编码\(\tau_\theta(y)\)和\(Z_{T}\)一起输入UNet网络中
  • 进行交叉注意力计算,其中\(Z_{T}\)作为Q向量,\(\tau_\theta(y)\)作为K,V向量计算注意力,这样做是让图像的每个位置根据文本的语义来决定关注哪些部分
  • 最后Unet输出两个向量,一个是无条件预测噪声,一个是文本预测噪声。
无条件预测噪声输入是空字符串


  • 使用CFG计算最终预测噪声\(\epsilon_{\text{guided}}(z_t, t, \tau_\theta(y)) = \epsilon_\theta(z_t, t,\tau_\theta(y)) + s \cdot (\epsilon_\theta(z_t, t, \tau_\theta(y)) - \epsilon_\theta(z_t, t, \varnothing))\)
  • 使用损失函数进行反向传播计算
    3.png

在生成阶段:

  • 以随机噪声\(Z_{T}\)作为起点
  • 输入文本作为条件,编码后一起进入Unet进行交叉注意力计算
  • 输出预测噪声\(\epsilon_{\text{guided}}(z_t, t, \tau_\theta(y))\)
  • 使用调度器进行逐步去噪计算(如DDPM,DDIM)成为\(Z_{T-1}\)
  • 重复以上过程,直到Z
  • 通过自动编码器的解码器部分把Z迁移到像素空间,D(z),即生成图像
交叉注意力机制中的维度变换
图像编码后变成 C=4, H'=64, W'=64,展平后作为Q(\(z_t \Rightarrow Q \in \mathbb{R}^{(H'W') \times d}\)),文本通过编码器的编码表示为\(c = [t_1, t_2, ..., t_L] \Rightarrow \text{Embedding} \in \mathbb{R}^{L \times d}\),K和V表示为\(K, V \in \mathbb{R}^{L \times d}\),计算注意力权重\(A = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d}} \right) \in \mathbb{R}^{(H'W') \times L}\),输出为\(\text{Attention}(Q, K, V) = A \cdot V \in \mathbb{R}^{(H'W') \times d}\)
  1.         # 潜在空间输入(prepare_latents生成)
  2. latents.shape = (batch_size * num_images_per_prompt, 4, H//8, W//8)
  3. # 文本嵌入处理(encode_prompt输出)
  4. prompt_embeds.shape = (batch_size, max_sequence_length, embedding_dim)
  5. # IP适配器图像嵌入处理
  6. image_embeds[0].shape = (batch_size * num_images_per_prompt, num_images, emb_dim)
  7. # UNet输入/输出维度
  8. latent_model_input.shape = [batch*2, 4, H//8, W//8]  # 当启用CFG时
  9. noise_pred.shape = [batch*2, 4, H//8, W//8]          # UNet输出噪声预测
  10. 假设参数设置
  11. prompt = "一只坐在月球上的猫"
  12. height = 512
  13. width = 512
  14. num_images_per_prompt = 1
  15. guidance_scale = 7.5
  16. batch_size = 1  # 根据prompt长度自动确定
  17. # 关键计算步骤演示
  18. # ---------------------------
  19. # 步骤1:潜在空间(latents)维度计算
  20. latents_shape = (
  21.     batch_size * num_images_per_prompt,  # 1*1=1
  22.     4,  # UNet输入通道数
  23.     height // 8,  # 512/8=64
  24.     width // 8    # 512/8=64
  25. )
  26. print(f"潜在空间维度: {latents_shape}")  # -> (1, 4, 64, 64)
  27. # 步骤2:文本编码维度(假设使用CLIP模型)
  28. prompt_embeds_shape = (
  29.     batch_size,
  30.     77,  # CLIP最大序列长度
  31.     768  # CLIP文本编码维度
  32. )
  33. print(f"文本嵌入维度: {prompt_embeds_shape}")  # -> (1, 77, 768)
  34. # 步骤3:CFG处理后的嵌入
  35. if guidance_scale > 1:
  36.     prompt_embeds = torch.cat([negative_embeds, positive_embeds])
  37.     print(f"CFG嵌入维度: {prompt_embeds.shape}")  # -> (2, 77, 768)
  38. # 步骤4:UNet输入维度(假设启用CFG)
  39. latent_model_input = torch.cat([latents] * 2)
  40. print(f"UNet输入维度: {latent_model_input.shape}")  # -> (2, 4, 64, 64)
  41. # 步骤5:噪声预测输出
  42. noise_pred = unet(latent_model_input, ...)[0]
  43. print(f"噪声预测维度: {noise_pred.shape}")  # -> (2, 4, 64, 64)
  44. # 步骤6:CFG调整后的噪声
  45. noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
  46. noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
  47. print(f"调整后噪声维度: {noise_pred.shape}")  # -> (1, 4, 64, 64)
  48. # 最终输出图像
  49. image = vae.decode(latents / vae.config.scaling_factor)[0]
  50. print(f"输出图像维度: {image.shape}")  # -> (1, 3, 512, 512)
  51.     def cross_attention(query, key, value):
  52.     # 输入维度说明
  53.     # query: 来自潜在噪声 [batch=2, 4*64*64=16384] → 投影为 [2, 16384, 768]
  54.     # key/value: 来自文本嵌入 [2, 77, 768]
  55.    
  56.     # 步骤1:计算注意力分数
  57.     attention_scores = torch.matmul(
  58.         query,  # [2, 16384, 768]
  59.         key.transpose(-1, -2)  # [2, 768, 77] → 转置后维度
  60.     )  # 矩阵乘法结果 → [2, 16384, 77]
  61.    
  62.     # 步骤2:计算注意力权重
  63.     attention_probs = torch.softmax(
  64.         attention_scores,  # [2, 16384, 77]
  65.         dim=-1  # 对最后一个维度(文本标记维度)做归一化
  66.     )  # 保持维度 [2, 16384, 77]
  67.    
  68.     # 步骤3:应用注意力到value
  69.     output = torch.matmul(
  70.         attention_probs,  # [2, 16384, 77]
  71.         value  # [2, 77, 768]
  72.     )  # 结果维度 → [2, 16384, 768]
  73.    
  74.     # 步骤4:重塑为潜在空间维度
  75.     output = output.view(2, 4, 64, 64, 768)  # 恢复空间结构
  76.     output = output.permute(0, 4, 1, 2, 3)  # [2, 768, 4, 64, 64]
  77.     output = self.to_out(output)  # 通过最后的线性层投影回4通道
  78.     return output  # [2, 4, 64, 64]
复制代码
数据集以及指标介绍

数据集介绍

4.png

CelebA-HQ 256 × 256数据集,是一个大规模的人脸属性数据集,拥有超过200K张名人图片,每张图片都有40个属性注释(如身份,年龄、表情、发型等)。
5.png

从Flickr网站爬取的人脸数据集集合,涵盖多样化的年龄、种族、表情、配饰(如眼镜、帽子)等属性
6.png

这两个数据集都是LSUN(大规模场景理解)数据集的子集,两个数据集分别表示教堂和卧室场景的数据,包含教堂建筑的不同视角、结构和环境条件,覆盖多样化的卧室场景,包括不同装修风格、家具布局和光照条件
指标介绍

IS分数介绍

Inception Score 的定义为:
\(IS(G) = \exp \left( \mathbb{E}_{x \sim p_g} \left[ D_{KL} ( p(y|x) \| p(y) ) \right] \right)\)
x~pg:生成图像样本来自生成模型的分布 。
p(y|x):通过预训练分类器(如Inception v3)对生成图像的类别预测概率分布。
p(y):预测类别的边缘分布。类别可以是猫,狗,猪等诸如此类的动物。
其中如果生成图像明确、质量高,则p(y|x)的熵就会比较低,如果生成图像比较多样,则p(y)的熵就会较高,体现在公式中则IS分数会较高。
使用 class-conditional 采样:遍历多个类别生成图像;用预训练的 Inception v3 网络对生成图像进行分类;计算 p(y∣x)p(y∣x)(每张图的预测概率),p(y)p(y)(所有图的平均预测分布);使用 KL 散度公式计算 IS 分数;通常生成 50,000 张图像用于评估。
FID分数介绍

主要是计算生成图像分布和真实图像分布在特征空间中的距离
公式\(\text{FID} = \| \mu_r - \mu_g \|_2^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2 (\Sigma_r \Sigma_g)^{\frac{1}{2}})\)
\(\mu_r,\Sigma_r\):真实图像分布的均值和协方差矩阵。
\(\mu_g,\Sigma_g\):生成图像分布的均值和协方差矩阵。
\(\| \mu_r - \mu_g \|_2^2\):欧几里得距离的平方。
\(\text{Tr}\):矩阵的迹。
\((\Sigma_r \Sigma_g)^{\frac{1}{2}}\):协方差矩阵的乘积的平方根。
两个分布的均值和协方差越低,FID越低,生成图像质量越接近生成的图像
prec和recall

这里的指标和一般理解的不一样。
会先用Inception网络分别提取真实图像和生成图像的特征点
用集合的角度解释:
Precision ≈ 生成图像中,有多少落在真实图像分布的“支持区域”里(真实性)
Recall    ≈ 真实图像中,有多少被生成图像的“支持区域”覆盖(多样性)
注:支持区域的计算
基于 kNN — 提取 Inception 特征后,设置一个半径 ε
看有多少点落入对方的球内;
也有基于球体体积估计的方法。
实验分析

7.png

研究不同下采样因子f对生成图像质量和训练效率的影响
下采样因子:指的是自动编码器中的参数。
可以看到下采样因子为4或8时表现最好。因为如果因子过小,会导致维度高,计算缓慢,因子过大,会损失很多信息,导致最后生成图像生成质量较差
后续的实验将基于此展开
8.png

在这个实验里可以看到,LDM在CelebA-HQ中取得了最优的FID分数,在其他数据集上的表现也是中规中矩。
9.png

实验使用的是ImageNet数据集
这个实验里展示了LDM在类别生成任务中的表现,可以看到使用cfg引导的LDM展现出了非常优秀的性能,在FID和IS分数上表现优异,虽然recall略低,但是使用的参数量也大幅减少了。

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册