扩散模型完全指南:从数学原理到工程实践
| 章节 | 核心问题 | 关键点 |
|---|---|---|
| 1. 基础篇 | 为什么要用扩散模型? | 理解"渐进变换"的核心思想 |
| 2. 数学篇 | 加噪/去噪的数学本质? | 掌握条件分布的高斯近似证明 |
| 3. 算法篇 | DDPM和DDIM如何工作? | 实现两种采样器 + 方差缩减技巧 |
| 4. 理论篇 | 为什么扩散模型有效? | 理解投影解释与流匹配统一框架 |
| 5. 实践篇 | 如何训练实用的扩散模型? | 掌握噪声调度、参数化、潜空间等技巧 |
| 6. 进阶篇 | 扩散模型的理论边界? | 了解误差分析、泛化能力与前沿方向 |
1. 基础篇:生成建模的核心问题
1.1 问题形式化
生成建模的目标:给定未知分布 $p^*(x)$ 的 i.i.d. 样本 $\{x_0^{(1)}, \ldots, x_0^{(n)}\}$,构造一个采样器,能够生成近似相同分布的新样本。
$$ \text{Input: } \{x_0^{(i)}\}_{i=1}^n \sim p^*(x) \quad \xrightarrow{\text{训练}} \quad \text{Output: } \text{Sample}() \to \hat{x}_0 \sim p^*(x) $$1.2 扩散模型的核心洞察
关键思想:将"从噪声一步生成数据"的困难问题,分解为"多步小变换"的简单问题序列。
┌─────────────────────────────────────────────────┐
│ 传统生成模型: │
│ N(0, I) ──[复杂非线性映射]──► p*(x) ❌ 难学 │
│ │
│ 扩散模型: │
│ N(0, I) ──[小步1]─[小步2]─...─[小步T]──► p*(x) ✅ 易学 │
└─────────────────────────────────────────────────┘
1.3 高斯扩散前向过程(直观版)
对于目标分布 $p^*$ 中的样本 $x_0 \in \mathbb{R}^d$,定义高斯扩散前向过程:
$$ x_{t+\Delta t} := x_t + \eta_t, \quad \eta_t \sim \mathcal{N}(0, \sigma_q^2 \Delta t) \tag{1} $$其中:
- $\Delta t = 1/T$ 是离散化步长,$T$ 是总步数
- $\sigma_q^2$ 是期望的终端方差
关键性质:可直接采样任意时刻
$$ x_t \sim \mathcal{N}(x_0, \sigma_t^2), \quad \sigma_t := \sigma_q \sqrt{t} \tag{2} $$# 代码:前向加噪过程(可直接采样任意时刻)
def forward_diffusion(x0, t, sigma_q):
"""
Args:
x0: 原始数据 [B, d]
t: 时间步 [B] 或标量,范围 [0, 1]
sigma_q: 终端噪声标准差
Returns:
x_t: 加噪后的数据 [B, d]
noise: 添加的噪声 [B, d]
"""
sigma_t = sigma_q * torch.sqrt(t) # σ_t = σ_q * √t
noise = torch.randn_like(x0)
return x0 + sigma_t * noise, noise
1.4 抽象视角:扩散的通用框架
扩散模型可抽象为以下通用流程 [Nakkiran et al., 2024, Sec 1.2]:
1. 选择目标分布: p₀ = p* (如图像分布)
2. 选择基分布: pₜ = q (如标准高斯,易于采样)
3. 构造插值序列: p₀, p₁, ..., pₜ,相邻分布"足够接近"
4. 学习反向采样器: Fₜ: pₜ → pₜ₋₁
5. 生成: 从 pₜ 采样 → 迭代应用 Fₜ → 得到 p₀ 的样本
定义 1(反向采样器):给定边际分布序列 $\{p_t\}$,步 $t$ 的反向采样器 $F_t$ 是(可能随机的)函数,满足:
$$ > \{F_t(z): z \sim p_t\} \equiv p_{t-1} > $$即:若输入服从 $p_t$,则输出边际分布恰好为 $p_{t-1}$。
2. 数学篇:高斯扩散的严格推导
2.1 反向过程的核心事实
$$ p(x_{t-\Delta t} | x_t = z) \approx \mathcal{N}(x_{t-\Delta t}; \mu_z, \sigma_q^2 \Delta t) \tag{3} $$Fact 1(扩散反向过程) [Nakkiran et al., 2024, Claim 1]:
对于足够小的 $\sigma$,高斯扩散的条件分布 $p(x_{t-\Delta t} | x_t = z)$ 近似为高斯分布:
其中均值 $\mu_z$ 仅依赖于 $z$,具体为:
$$ \mu_z = \mathbb{E}[x_{t-\Delta t} | x_t = z] = z + \sigma_q^2 \Delta t \cdot \nabla \log p_t(z) \tag{4} $$直观理解:为什么小噪声下近似高斯?
图:不同噪声水平下的反向条件分布
p(x_{t-1}) p(x_{t-1} | x_t=z)
│ │
┌────┴────┐ ┌──────┴──────┐
│ ╱╲ │ σ大 │ ~~~~~~ │ ← 非高斯,多峰
│ ╱ ╲ │ ───► │ ~ ~ │
└─┴────┴─┘ └────────────┘
┌────┴────┐ ┌──────┴──────┐
│ ╱╲ │ σ小 │ /‾‾\ │ ← 近似高斯 ✓
│ ╱ ╲ │ ───► │ / \ │
└─┴────┴─┘ └───┴────┴───┘
μ_z → 均值
启发式推导(贝叶斯 + Taylor展开)
从贝叶斯规则开始:
p(x_{t-Δt}|x_t) = p(x_t|x_{t-Δt})·p_{t-Δt}(x_{t-Δt}) / p_t(x_t)
取对数,忽略仅含 x_t 的常数项:
log p(x_{t-Δt}|x_t)
= log p(x_t|x_{t-Δt}) + log p_t(x_{t-Δt}) + O(Δt)
= -‖x_{t-Δt} - x_t‖²/(2σ_q²Δt) + log p_t(x_{t-Δt}) + O(Δt)
对 log p_t(x_{t-Δt}) 在 x_t 处 Taylor 展开:
≈ -‖x_{t-Δt} - x_t‖²/(2σ_q²Δt)
+ log p_t(x_t) + ⟨∇log p_t(x_t), x_{t-Δt}-x_t⟩
配方完成平方:
= -‖x_{t-Δt} - [x_t + σ_q²Δt·∇log p_t(x_t)]‖²/(2σ_q²Δt) + C
⇒ 近似为均值 μ = x_t + σ_q²Δt·∇log p_t(x_t) 的高斯分布 ✓
技术细节:Lemma 1 [Appendix B.1] 证明,当 $\sigma^2$ 为步方差时,每步高斯近似的 KL 误差为 $O(\sigma^4)$,足够快以保证累积误差可控。
2.2 学习均值:回归问题简化
均值函数可通过标准回归学习:
$$ \mu_{t-\Delta t}(z) = \mathbb{E}[x_{t-\Delta t} | x_t = z] = \arg\min_f \mathbb{E}\|f(x_t) - x_{t-\Delta t}\|_2^2 \tag{5} $$训练流程(简化版)
# 代码:扩散训练损失(基础版)
def diffusion_loss(model, x0, schedule, batch_size):
"""
Args:
model: 神经网络 f_θ(x_t, t) → μ_t(x_t)
schedule: 噪声调度器,提供 σ_t
x0: 原始数据批次 [B, d]
"""
# 1. 随机采样时间和噪声
t = torch.rand(batch_size, device=x0.device) # t ~ Uniform[0,1]
sigma_t = schedule(t) # σ_t
eps = torch.randn_like(x0) # ε ~ N(0, I)
# 2. 前向加噪:利用可直接采样性质 (2)
x_t = x0 + sigma_t.unsqueeze(1) * eps # [B, d]
# 3. 预测去噪后的值(即条件期望)
mu_pred = model(x_t, t) # f_θ(x_t, t) ≈ E[x_{t-Δt}|x_t]
# 4. 回归损失(L2)
# 注意:实际训练中通常预测的是 x_{t+Δt} 给定 x_t
loss = torch.mean((mu_pred - x0)**2)
return loss
2.3 离散化细节:方差缩放
为确保终端分布 $p_T$ 的方差与步数 $T$ 无关,需缩放每步噪声方差:
$$ \sigma = \sigma_q \sqrt{\Delta t}, \quad \Delta t = 1/T \tag{6} $$此时前向过程更新为:
$$ x_{t+\Delta t} := x_t + \eta_t, \quad \eta_t \sim \mathcal{N}(0, \sigma_q^2 \Delta t) \tag{7} $$且边际分布满足:
$$ x_t \sim \mathcal{N}(x_0, \sigma_t^2), \quad \sigma_t = \sigma_q \sqrt{t} \tag{8} $$3. 算法篇:从DDPM到DDIM
3.1 DDPM:随机反向采样器
采样算法 [Algorithm 1, Nakkiran et al.]
输入:噪声样本 x₁ ~ N(0, σ_q²), 训练好的模型 f_θ ≈ μ_{t-Δt}
输出:生成样本 x̂₀
x ← Sample from N(0, σ_q²) # 从纯噪声开始
for t = 1, 1-Δt, 1-2Δt, ..., Δt:
η ~ N(0, σ_q²Δt) # 添加随机噪声
x ← f_θ(x, t) + η # 去噪 + 随机扰动
return x
# 代码:DDPM采样(随机版本)
@torch.no_grad()
def ddpm_sample(model, schedule, steps=1000, batch_size=1):
"""
DDPM随机采样器实现
"""
device = next(model.parameters()).device
sigma_q = schedule.sigma_max
# 1. 从纯噪声开始: x₁ ~ N(0, σ_q²)
x = torch.randn(batch_size, *model.input_shape, device=device) * sigma_q
# 2. 时间步调度
sigmas = schedule.get_sigmas(steps) # [steps+1]
# 3. 反向迭代
for i in range(steps):
t = (steps - i) / steps # 当前时间
sigma_t = sigmas[i]
sigma_next = sigmas[i+1]
# 预测条件期望: μ_{t-Δt}(x_t)
mu_pred = model(x, torch.full((batch_size,), t, device=device))
# 添加随机噪声(关键!保证采样多样性)
noise = torch.randn_like(x) * torch.sqrt(sigma_t**2 - sigma_next**2)
x = mu_pred + noise
return x
正确性证明要点
通过贝叶斯规则 + Taylor展开可证明 [Claim 1]:
$$ \log p(x_{t-\Delta t}|x_t) = -\frac{1}{2\sigma_q^2\Delta t}\|x_{t-\Delta t} - \mu\|^2 + C $$其中 $\mu = x_t + \sigma_q^2\Delta t \nabla \log p_t(x_t)$,即score function自然出现!
3.2 DDIM:确定性反向采样器
核心思想差异
| 特性 | DDPM(随机) | DDIM(确定性) |
|---|---|---|
| 输出分布 | 点态:$F_t(x_t) \sim p(x_{t-1} | x_t)$ |
| 采样结果 | 同一 $x_1$ → 不同 $x_0$(随机) | 同一 $x_1$ → 相同 $x_0$(确定) |
| 物理类比 | 布朗运动粒子 | 流体速度场传输 |
采样算法 [Algorithm 2, Nakkiran et al.]
$$ \hat{x}_{t-\Delta t} = x_t + \lambda \left( \mu_{t-\Delta t}(x_t) - x_t \right), \quad \lambda = \frac{\sigma_t}{\sigma_t - \sigma_{t-\Delta t}} \tag{9} $$# 代码:DDIM采样(确定性版本)
@torch.no_grad()
def ddim_sample(model, schedule, steps=50, batch_size=1):
"""
DDIM确定性采样器实现
"""
device = next(model.parameters()).device
sigma_q = schedule.sigma_max
# 1. 从纯噪声开始
x = torch.randn(batch_size, *model.input_shape, device=device) * sigma_q
sigmas = schedule.get_sigmas(steps)
for i in range(steps):
t = (steps - i) / steps
sigma_t = sigmas[i]
sigma_next = sigmas[i+1]
# 预测条件期望
mu_pred = model(x, torch.full((batch_size,), t, device=device))
# DDIM更新:确定性传输
lambda_factor = sigma_t / (sigma_t - sigma_next + 1e-8)
x = x + lambda_factor * (mu_pred - x)
return x
速度场解释(物理直觉)
定义速度场:
$$ v_t(x_t) = \frac{\lambda}{\Delta t} \left( \mathbb{E}[x_{t-\Delta t}|x_t] - x_t \right) \tag{10} $$则DDIM更新可写为:
$$ \hat{x}_{t-\Delta t} = x_t + v_t(x_t) \Delta t \tag{11} $$物理类比:气体粒子流动
t=1.0 t=0.5 t=0.0
┌─────┐ ┌─────┐ ┌─────┐
│ · │ ──► │ ↗↖ │ ──► │ ★ │
│ ·· │ │↗ ↖│ │ ★★ │
│·····│ │ →v→ │ │★★★★★│
└─────┘ └─────┘ └─────┘
高斯分布 速度场引导 目标分布
(噪声) (传输方向) (数据)
3.3 方差缩减技巧:预测 $x_0$ vs $\epsilon$
关键关系 [Claim 2, Nakkiran et al.]
$$ \mathbb{E}[x_{t-\Delta t} - x_t | x_t] = \frac{\Delta t}{t} \mathbb{E}[x_0 - x_t | x_t] \tag{12} $$直观解释:给定 $x_t$,最后一步噪声 $\eta_{t-\Delta t}$ 与之前所有噪声步骤"看起来一样"(对称性),因此可用平均噪声估计单步噪声,显著降低方差。
三种参数化对比
| 预测目标 | 损失函数 | 优点 | 适用场景 |
|---|---|---|---|
| $x_{t-1}$ | $\|f_\theta - x_{t-1}\|^2$ | 直接对应采样 | 理论分析 |
| $x_0$ ⭐ | $\|f_\theta - x_0\|^2$ | 方差小,训练稳定 | 实际使用 |
| $\epsilon$ | $\|f_\theta - \epsilon\|^2$ | 与噪声直接相关 | 高噪声阶段 |
| $v$ [Salimans & Ho] | $\|f_\theta - (\alpha_t\epsilon - \sigma_t x_0)\|^2$ | 高低噪声平衡 | 高级调优 |
# 代码:统一参数化框架
class DiffusionModel(nn.Module):
def __init__(self, prediction_type='epsilon'):
"""
prediction_type: 'x0', 'epsilon', 'v', or 'x_prev'
"""
super().__init__()
self.prediction_type = prediction_type
self.net = UNet(...) # 实际网络架构
def compute_loss(self, x0, t, schedule):
"""计算训练损失"""
sigma_t = schedule(t)
eps = torch.randn_like(x0)
# 前向加噪
x_t = x0 + sigma_t.unsqueeze(1) * eps
# 根据参数化类型确定目标
if self.prediction_type == 'x0':
target = x0
elif self.prediction_type == 'epsilon':
target = eps
elif self.prediction_type == 'v':
# v-prediction: v = α_t·ε - σ_t·x₀
alpha_t = torch.sqrt(1 - sigma_t**2)
target = alpha_t.unsqueeze(1) * eps - sigma_t.unsqueeze(1) * x0
elif self.prediction_type == 'x_prev':
# 预测 x_{t-Δt}
sigma_prev = schedule(t - 1/1000) # 简化
target = (sigma_prev/sigma_t) * x0 + ... # 略
# 网络预测
pred = self.net(x_t, t)
return torch.mean((pred - target)**2)
4. 理论篇:优化视角与流匹配
4.1 去噪作为近似投影(优化视角)[Yuan & Permenter]
核心观点
学习到的去噪器 $\epsilon_\theta(x, \sigma)$ 可解释为数据流形 $\mathcal{K}$ 的近似投影梯度
📐 距离函数与投影
- 距离函数:$\text{dist}_\mathcal{K}(x) = \min_{x_0 \in \mathcal{K}} \|x - x_0\|$
- 投影:$\text{proj}_\mathcal{K}(x) = \{x_0 \in \mathcal{K}: \|x-x_0\| = \text{dist}_\mathcal{K}(x)\}$
关键定理
$$ \frac{1}{2}\nabla_x \text{dist}_\mathcal{K}^2(x, \sigma) = \sigma \cdot \epsilon^*(x, \sigma) \tag{13} $$其中 $\text{dist}_\mathcal{K}^2(x, \sigma)$ 是平滑后的平方距离函数。
几何解释:
数据流形 𝒦
╱╲
╱ ╲
╱ ★ ╲ ← 投影点 proj_𝒦(x)
╱ ╱││╲ ╲
x ─●──┼┼─────► ∇dist²(x,σ) 方向
╲ ││╱
╲╱
去噪器输出:x - σ·ε_θ(x,σ) ≈ proj_𝒦(x)
相对误差模型(实践指导)
假设当 $\frac{1}{\nu}\text{dist}_\mathcal{K}(x) \leq \sqrt{n}\sigma \leq \nu \cdot \text{dist}_\mathcal{K}(x)$ 时:
$$ \|x - \sigma\epsilon_\theta(x,\sigma) - \text{proj}_\mathcal{K}(x)\| \leq \eta \cdot \text{dist}_\mathcal{K}(x) \tag{14} $$工程意义:
- $\eta$ 越小,去噪器越接近理想投影
- 可通过验证集估计 $\eta$,指导模型选择
- 高维数据中 $\eta$ 通常随维度增长,需更大模型
4.2 Flow Matching:扩散的推广
什么是流(Flow)?
流 = 时间索引的向量场集合 $\{v_t\}_{t \in [0,1]}$,定义粒子轨迹:
$$ \frac{dx_t}{dt} = -v_t(x_t), \quad x_1 \sim q \to x_0 \sim p \tag{15} $$点wise流 → 边际流
- 点wise流:对任意点对 $(x_1, x_0)$,定义连接它们的流 $v^{[x_1,x_0]}$
- 边际流:通过加权平均组合所有点wise流:
线性流(最简单选择)[Liu et al., 2022]
$$ v_t^{[x_1,x_0]}(x_t) = x_0 - x_1 \quad \Rightarrow \quad x_t = t x_1 + (1-t) x_0 \tag{17} $$# 代码:Flow Matching训练(线性流)
def flow_matching_loss(model, x0, x1, t):
"""
Args:
x0: 目标数据 ~ p
x1: 源噪声 ~ q (如高斯)
t: 时间 ~ Uniform[0,1]
"""
# 线性插值轨迹
x_t = t * x1 + (1 - t) * x0
# 线性流的速度场(常数)
v_target = x0 - x1
# 回归损失
v_pred = model(x_t, t)
return torch.mean((v_pred - v_target)**2)
DDIM作为Flow Matching的特例
Claim 4 [Nakkiran et al., Appendix B.7]: DDIM等价于使用"扩散耦合"的线性流匹配,仅时间参数化不同($\sqrt{t}$ vs $t$)
┌────────────────────────────────────┐
│ DDIM轨迹: x_t = x₀ + (x₁-x₀)√t │
│ 线性流轨迹: x_t = x₀ + (x₁-x₀)·t │
│ │
│ 关系: DDIM在时间t处 = 线性流在时间√t处 │
└────────────────────────────────────┘
4.3 概率流ODE(连续时间视角)
从离散到连续
当 $\Delta t \to 0$,DDIM收敛到概率流ODE [Song et al., 2020]:
$$ \frac{dx_t}{dt} = -\frac{1}{2t} \mathbb{E}[x_0 - x_t | x_t] \tag{18} $$SDE ↔ ODE 对应关系
| 形式 | 方程 | 特点 |
|---|---|---|
| 前向SDE | $dx = \sigma_q dw$ | 纯扩散,零漂移 |
| 反向SDE | $dx = -\sigma_q^2 \nabla \log p_t(x) dt + \sigma_q dw$ | 随机,需学习score |
| 概率流ODE | $dx = -\frac{1}{2}\sigma_q^2 \nabla \log p_t(x) dt$ | 确定性,相同边际分布 |
# 代码:ODE求解器采样(使用torchdiffeq)
from torchdiffeq import odeint
class ProbabilityFlowODE:
def __init__(self, score_model):
self.score_model = score_model # 学习 ∇log p_t(x)
def reverse_ode(self, t, x):
"""ODE的右端函数: dx/dt = f(x,t)"""
score = self.score_model(x, t) # ≈ ∇log p_t(x)
return -0.5 * score # 简化版本(σ_q=1)
def sample(self, x1, t_span):
"""从t=1积分到t=0"""
return odeint(self.reverse_ode, x1, t_span, method='dopri5')
5. 实践篇:工程实现与设计选择
5.1 噪声调度(Noise Schedule)
常见调度对比与实现
# 代码:多种噪声调度实现
class NoiseSchedule:
"""噪声调度基类"""
def __call__(self, t):
raise NotImplementedError
class ScheduleLogLinear(NoiseSchedule):
"""对数线性调度:σ ∈ [σ_min, σ_max] 对数均匀 [Karras et al., 2022]"""
def __init__(self, N, sigma_min=0.02, sigma_max=10.0):
self.sigmas = torch.logspace(
math.log10(sigma_min),
math.log10(sigma_max),
N
)
def __call__(self, t):
idx = (t * (len(self.sigmas)-1)).long().clamp(0, len(self.sigmas)-2)
return self.sigmas[idx]
class ScheduleDDPM(NoiseSchedule):
"""DDPM方差保持调度:Var(x_t) ≈ 常数 [Ho et al., 2020]"""
def __init__(self, N, beta_min=0.0001, beta_max=0.02):
betas = torch.linspace(beta_min, beta_max, N)
alphas = 1 - betas
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
def __call__(self, t):
idx = (t * (len(self.alphas_cumprod)-1)).long()
return torch.sqrt(1 - self.alphas_cumprod[idx])
class ScheduleKarras(NoiseSchedule):
"""Karras调度:σ(t) = t, 配合整体缩放 [Karras et al., 2022]"""
def __init__(self, N, sigma_min=0.002, sigma_max=80.0):
self.sigmas = ((sigma_max**(1/3) +
torch.linspace(0, 1, N) * (sigma_min**(1/3) - sigma_max**(1/3)))**3)
def __call__(self, t):
idx = (t * (len(self.sigmas)-1)).long()
return self.sigmas[idx]
调度选择建议
图像生成推荐:
├── 像素空间低分辨率:ScheduleDDPM(方差保持,训练稳定)
├── 潜空间/高分辨率:ScheduleLogLinear 或 ScheduleKarras(灵活,适合多尺度)
├── 快速采样:自定义稀疏调度(如 [1.0, 0.5, 0.2, 0.05, 0])
└── 理论研究:固定σ(简化分析)
调度可视化:
log(σ)
↑
2 ┤ ╱╲╱╲╱╲ ← LogLinear(对数均匀)
│ ╱ ╲
1 ┤ ╱ ╲╱╲ ← DDPM(前期缓,后期陡)
│ ╱
0 ┼╱─────────────→ t
0 0.5 1.0
5.2 潜空间扩散(Latent Diffusion)[Sander Dieleman]
为什么需要潜空间?
┌─────────────────────────────────────┐
│ 问题:高分辨率图像直接在像素空间扩散 │
│ • 计算成本高:256×256×3 = 196,608维 │
│ • 迭代次数多:每步需网络前向传播 │
│ • 细节建模难:纹理/高频信息难学习 │
└─────────────────────────────────────┘
解决方案:两阶段架构
┌─────────┐ ┌─────────┐ ┌─────────┐
│ 编码器 │────►│ 潜空间 │────►│ 扩散模型│
│ E: x→z │ │ z∈ℝ^d │ │ p(z) │
└─────────┘ └─────────┘ └─────────┘
│
▼
┌─────────┐
│ 解码器 │
│ D: z→x │
└─────────┘
Stable Diffusion架构示意
# 概念代码:潜空间扩散训练流程
class LatentDiffusion:
def __init__(self, autoencoder, unet, schedule):
self.autoencoder = autoencoder # 预训练,冻结
self.unet = unet # 可训练
self.schedule = schedule
def train_step(self, x0, text_condition=None):
# 1. 编码到潜空间(冻结)
with torch.no_grad():
z0 = self.autoencoder.encode(x0)
# 2. 潜空间扩散训练
t = torch.rand(z0.shape[0], device=z0.device)
sigma_t = self.schedule(t)
eps = torch.randn_like(z0)
z_t = z0 + sigma_t.unsqueeze(1) * eps
# 3. UNet预测(可加入text condition)
eps_pred = self.unet(z_t, t, context=text_condition)
return torch.mean((eps_pred - eps)**2)
@torch.no_grad()
def sample(self, condition=None, steps=50):
# 1. 潜空间采样
z_1 = torch.randn(1, *self.latent_shape)
z_0 = ddim_sample(self.unet, self.schedule, steps)(z_1, condition)
# 2. 解码回像素空间
x_0 = self.autoencoder.decode(z_0)
return x_0
5.3 完整实现示例:2D螺旋数据集
# 完整可运行示例:2D螺旋数据扩散
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# 1. 螺旋数据生成
def swiss_roll(n_samples=1000):
"""生成2D螺旋数据集"""
n = np.sqrt(np.random.rand(n_samples)) * 5 * np.pi
d = np.random.rand(n_samples) * 2 - 1
x = n * np.cos(n) + 0.5 * d * np.sin(n)
y = n * np.sin(n) + 0.5 * d * np.cos(n)
return torch.stack([torch.tensor(x), torch.tensor(y)], dim=1).float()
# 2. 时间嵌入 + MLP去噪器
class TimeEmbedding(nn.Module):
def __call__(self, sigma):
"""简单的2维时间嵌入"""
sigma = sigma.unsqueeze(1)
return torch.cat([
torch.sin(torch.log(sigma) / 2),
torch.cos(torch.log(sigma) / 2)
], dim=1)
class DenoiserMLP(nn.Module):
def __init__(self, dim=2, hidden_dims=[16, 128, 128, 16]):
super().__init__()
layers = []
dims = [dim + 2] + hidden_dims # +2 for time embedding
for i in range(len(dims)-1):
layers.extend([
nn.Linear(dims[i], dims[i+1]),
nn.GELU()
])
layers.append(nn.Linear(hidden_dims[-1], dim))
self.net = nn.Sequential(*layers)
self.time_emb = TimeEmbedding()
def forward(self, x, sigma):
t_emb = self.time_emb(sigma) # [B, 2]
inp = torch.cat([x, t_emb], dim=1) # [B, dim+2]
return self.net(inp)
# 3. 训练循环
def train_diffusion(model, data_loader, schedule, epochs=5000, lr=1e-3):
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
for x0 in data_loader:
optimizer.zero_grad()
# 采样训练样本
t = torch.rand(x0.shape[0])
sigma_t = schedule(t)
eps = torch.randn_like(x0)
x_t = x0 + sigma_t.unsqueeze(1) * eps
# 预测噪声(这里用epsilon参数化)
eps_pred = model(x_t, sigma_t)
loss = torch.mean((eps_pred - eps)**2)
loss.backward()
optimizer.step()
if epoch % 500 == 0:
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
# 4. 采样与可视化
@torch.no_grad()
def sample_and_plot(model, schedule, n_samples=200, steps=20):
model.eval()
sigmas = schedule.get_sigmas(steps)
# 从噪声开始
x = torch.randn(n_samples, 2) * sigmas[0]
# DDIM采样轨迹记录
trajectory = [x.clone()]
for i in range(steps):
t = (steps - i) / steps
sigma_t = sigmas[i]
sigma_next = sigmas[i+1]
# 预测x0(epsilon参数化转x0)
eps_pred = model(x, torch.full((n_samples,), t))
x0_pred = (x - sigma_t * eps_pred) / torch.sqrt(1 - sigma_t**2 + 1e-8)
# DDIM更新
direction = (x - sigma_t * x0_pred) / torch.sqrt(1 - sigma_t**2 + 1e-8)
x = sigma_next * x0_pred + torch.sqrt(1 - sigma_next**2 + 1e-8) * direction
trajectory.append(x.clone())
# 可视化
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for idx, t_idx in enumerate([0, 7, 14, 20]):
ax = axes[idx]
ax.scatter(trajectory[t_idx][:,0], trajectory[t_idx][:,1],
s=2, alpha=0.5, c='blue')
ax.set_title(f'Step {t_idx}/{steps} (t={1-t_idx/steps:.2f})')
ax.set_xlim(-10, 10)
ax.set_ylim(-10, 10)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()
训练效果可视化
采样轨迹演变(20步DDIM):
Step 0 (t=1.0) Step 7 (t=0.65)
┌─────────┐ ┌─────────┐
│ · · · │ │ ╱╲ │
│ · · · · │ ──► │ ╱ ╲ │
│ · · · │ │╱ ╲╱ │
└─────────┘ └─────────┘
高斯噪声 开始形成螺旋
Step 14 (t=0.3) Step 20 (t=0.0)
┌─────────┐ ┌─────────┐
│ ╱╲╱╲ │ │ ★★★★★ │
│ ╱ ╲ │ ──► │★ ★│
│╱ ╲│ │ ★★★★★ │
└─────────┘ └─────────┘
螺旋轮廓清晰 收敛到数据分布
5.4 常见陷阱与调试技巧
⚠️ 训练不稳定?
# 技巧1:梯度裁剪(防止梯度爆炸)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 技巧2:学习率调度(余弦退火)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs, eta_min=1e-5
)
# 技巧3:损失加权(不同噪声水平不同权重)
def weighted_loss(eps_pred, eps, sigma_t, weighting='uniform'):
if weighting == 'uniform':
weights = torch.ones_like(sigma_t)
elif weighting == 'snr': # Signal-to-Noise Ratio weighting
weights = 1 / (sigma_t**2 + 1e-6)
elif weighting == 'min_snr': # Min-SNR weighting [Hang et al.]
snr = 1 / sigma_t**2
weights = torch.minimum(snr, torch.ones_like(snr) * 5)
return torch.mean(weights * (eps_pred - eps)**2)
⚠️ 采样质量差?
# 技巧1:增加采样步数(但计算成本↑)
# 技巧2:使用更高级的ODE求解器(DPM-Solver, Heun, etc.)
# 技巧3:添加引导(classifier-free guidance)
def classifier_free_guidance(model, x_t, t, cond, uncond, guidance_scale=7.5):
"""CFG: 增强条件生成的质量"""
# 条件预测
eps_cond = model(x_t, t, context=cond)
# 无条件预测(训练时随机drop condition)
eps_uncond = model(x_t, t, context=uncond)
# 引导组合
return eps_uncond + guidance_scale * (eps_cond - eps_uncond)
# 技巧4:动态阈值(防止采样发散)
def dynamic_thresholding(x, percentile=99.5):
"""防止采样值过大导致图像异常"""
s = torch.quantile(x.abs().reshape(x.shape[0], -1), percentile/100, dim=1)
s = torch.clamp(s, min=1.0)
return x.clamp(-s.unsqueeze(1), s.unsqueeze(1))
⚠️ 过拟合/记忆训练数据?
🔍 诊断方法:
1. 检查生成样本与训练样本的L2距离(近邻搜索)
2. 可视化训练/验证损失曲线(验证损失上升→过拟合)
3. 使用FID/IS等指标评估泛化能力
4. 人工检查生成样本多样性
✅ 缓解策略:
• 增加数据增强(随机裁剪、翻转、颜色抖动)
• 使用更大的模型(隐式正则化)
• 早停(early stopping)+ 验证集监控
• 添加dropout/weight decay
• 增加训练数据量(最有效)
📊 记忆化现象可视化 [Nakkiran et al., Fig 6]:
N=10样本:轨迹坍缩到训练点 → 记忆化
N=40样本:轨迹学习螺旋流形 → 泛化
6. 进阶篇:学习理论与前沿方向
6.1 误差来源分析 [Nakkiran et al., Sec 5]
扩散模型的生成误差可分解为:
总误差 = 训练误差 + 采样误差 + 模型偏差
1️⃣ 训练误差(统计误差)
• 有限样本:回归目标估计不准
• 有限模型:函数类表达能力不足
• 优化误差:未收敛到全局最优
2️⃣ 采样误差(数值误差)
• 离散化:Δt不够小,高斯近似失效
• ODE/SDE求解器:数值积分误差
• 截断:早期停止采样
3️⃣ 模型偏差(结构性误差)
• 高斯假设:真实反向条件分布非高斯
• 马尔可夫假设:忽略长程依赖
• 参数化选择:不同预测目标影响收敛
6.2 泛化与记忆化的平衡
关键点:完美拟合训练数据 ≠ 好的生成模型
# 概念验证:记忆化检测
def check_memorization(generated_samples, train_samples, threshold=0.1):
"""检测生成样本是否记忆训练数据"""
from scipy.spatial.distance import cdist
# 计算生成样本与训练样本的最小L2距离
distances = cdist(
generated_samples.reshape(len(generated_samples), -1),
train_samples.reshape(len(train_samples), -1),
metric='euclidean'
)
min_distances = distances.min(axis=1)
# 统计"太接近"训练样本的比例
memorized_ratio = (min_distances < threshold).mean()
return memorized_ratio, min_distances
# 实践建议:
# • 记忆化比例 < 1% 通常可接受
# • 高记忆化 → 增加数据/正则化/早停
6.3 前沿方向速览
| 方向 | 核心思想 | 代表工作 | 实用价值 |
|---|---|---|---|
| 一致性模型 | 学习一步映射,蒸馏多步扩散 | [Song et al., 2023] | ⭐⭐⭐⭐⭐ 快速采样 |
| 整流流 | 学习直线轨迹,减少采样步数 | [Liu et al., 2022] | ⭐⭐⭐⭐ 训练稳定 |
| 潜空间优化 | 更高效的潜空间编码/解码 | [Esser et al., 2024] | ⭐⭐⭐⭐ 高分辨率 |
| 理论分析 | 泛化边界、收敛速率 | [Chen et al., 2022-2024] | ⭐⭐⭐ 指导设计 |
| 离散扩散 | 文本/图/离散结构的扩散 | [Austin et al., 2021] | ⭐⭐⭐⭐ 多模态 |
总结
扩散模型
│
┌─────────┬───────┴───────┬─────────┐
▼ ▼ ▼ ▼
数学基础 核心算法 理论解释 工程实践
│ │ │ │
▼ ▼ ▼ ▼
• 高斯分布 • DDPM • 投影解释 • 噪声调度
• 条件期望 • DDIM • 流匹配 • 潜空间
• 贝叶斯规则• Flow Matching• SDE/ODE • 参数化选择
• Taylor展开• 概率流ODE • 相对误差 • 采样加速