扩散模型完全指南:从数学原理到工程实践


章节 核心问题 关键点
1. 基础篇 为什么要用扩散模型? 理解"渐进变换"的核心思想
2. 数学篇 加噪/去噪的数学本质? 掌握条件分布的高斯近似证明
3. 算法篇 DDPM和DDIM如何工作? 实现两种采样器 + 方差缩减技巧
4. 理论篇 为什么扩散模型有效? 理解投影解释与流匹配统一框架
5. 实践篇 如何训练实用的扩散模型? 掌握噪声调度、参数化、潜空间等技巧
6. 进阶篇 扩散模型的理论边界? 了解误差分析、泛化能力与前沿方向

1. 基础篇:生成建模的核心问题

1.1 问题形式化

生成建模的目标:给定未知分布 $p^*(x)$ 的 i.i.d. 样本 $\{x_0^{(1)}, \ldots, x_0^{(n)}\}$,构造一个采样器,能够生成近似相同分布的新样本。

$$ \text{Input: } \{x_0^{(i)}\}_{i=1}^n \sim p^*(x) \quad \xrightarrow{\text{训练}} \quad \text{Output: } \text{Sample}() \to \hat{x}_0 \sim p^*(x) $$

1.2 扩散模型的核心洞察

关键思想:将"从噪声一步生成数据"的困难问题,分解为"多步小变换"的简单问题序列。

┌─────────────────────────────────────────────────┐
│  传统生成模型:                                  │
│  N(0, I) ──[复杂非线性映射]──► p*(x)  ❌ 难学    │
│                                                 │
│  扩散模型:                                      │
│  N(0, I) ──[小步1]─[小步2]─...─[小步T]──► p*(x)  ✅ 易学  │
└─────────────────────────────────────────────────┘

1.3 高斯扩散前向过程(直观版)

对于目标分布 $p^*$ 中的样本 $x_0 \in \mathbb{R}^d$,定义高斯扩散前向过程

$$ x_{t+\Delta t} := x_t + \eta_t, \quad \eta_t \sim \mathcal{N}(0, \sigma_q^2 \Delta t) \tag{1} $$

其中:

关键性质:可直接采样任意时刻

$$ x_t \sim \mathcal{N}(x_0, \sigma_t^2), \quad \sigma_t := \sigma_q \sqrt{t} \tag{2} $$
# 代码:前向加噪过程(可直接采样任意时刻)
def forward_diffusion(x0, t, sigma_q):
    """
    Args:
        x0: 原始数据 [B, d]
        t: 时间步 [B] 或标量,范围 [0, 1]
        sigma_q: 终端噪声标准差
    Returns:
        x_t: 加噪后的数据 [B, d]
        noise: 添加的噪声 [B, d]
    """
    sigma_t = sigma_q * torch.sqrt(t)  # σ_t = σ_q * √t
    noise = torch.randn_like(x0)
    return x0 + sigma_t * noise, noise

1.4 抽象视角:扩散的通用框架

扩散模型可抽象为以下通用流程 [Nakkiran et al., 2024, Sec 1.2]:

1. 选择目标分布: p₀ = p* (如图像分布)
2. 选择基分布:   pₜ = q (如标准高斯,易于采样)
3. 构造插值序列: p₀, p₁, ..., pₜ,相邻分布"足够接近"
4. 学习反向采样器: Fₜ: pₜ → pₜ₋₁
5. 生成: 从 pₜ 采样 → 迭代应用 Fₜ → 得到 p₀ 的样本

定义 1(反向采样器):给定边际分布序列 $\{p_t\}$,步 $t$ 的反向采样器 $F_t$ 是(可能随机的)函数,满足:

$$ > \{F_t(z): z \sim p_t\} \equiv p_{t-1} > $$

即:若输入服从 $p_t$,则输出边际分布恰好为 $p_{t-1}$。


2. 数学篇:高斯扩散的严格推导

2.1 反向过程的核心事实

Fact 1(扩散反向过程) [Nakkiran et al., 2024, Claim 1]:
对于足够小的 $\sigma$,高斯扩散的条件分布 $p(x_{t-\Delta t} | x_t = z)$ 近似为高斯分布:

$$ p(x_{t-\Delta t} | x_t = z) \approx \mathcal{N}(x_{t-\Delta t}; \mu_z, \sigma_q^2 \Delta t) \tag{3} $$

其中均值 $\mu_z$ 仅依赖于 $z$,具体为:

$$ \mu_z = \mathbb{E}[x_{t-\Delta t} | x_t = z] = z + \sigma_q^2 \Delta t \cdot \nabla \log p_t(z) \tag{4} $$

直观理解:为什么小噪声下近似高斯?

图:不同噪声水平下的反向条件分布

    p(x_{t-1})          p(x_{t-1} | x_t=z)
         │                      │
    ┌────┴────┐          ┌──────┴──────┐
    │  ╱╲    │   σ大    │  ~~~~~~    │  ← 非高斯,多峰
    │ ╱  ╲   │  ───►   │ ~        ~ │
    └─┴────┴─┘          └────────────┘
    
    ┌────┴────┐          ┌──────┴──────┐
    │  ╱╲    │   σ小    │    /‾‾\    │  ← 近似高斯 ✓
    │ ╱  ╲   │  ───►   │   /      \  │
    └─┴────┴─┘          └───┴────┴───┘
                         μ_z → 均值

启发式推导(贝叶斯 + Taylor展开)

从贝叶斯规则开始:
p(x_{t-Δt}|x_t) = p(x_t|x_{t-Δt})·p_{t-Δt}(x_{t-Δt}) / p_t(x_t)

取对数,忽略仅含 x_t 的常数项:
log p(x_{t-Δt}|x_t) 
  = log p(x_t|x_{t-Δt}) + log p_t(x_{t-Δt}) + O(Δt)
  = -‖x_{t-Δt} - x_t‖²/(2σ_q²Δt) + log p_t(x_{t-Δt}) + O(Δt)

对 log p_t(x_{t-Δt}) 在 x_t 处 Taylor 展开:
  ≈ -‖x_{t-Δt} - x_t‖²/(2σ_q²Δt) 
     + log p_t(x_t) + ⟨∇log p_t(x_t), x_{t-Δt}-x_t⟩

配方完成平方:
  = -‖x_{t-Δt} - [x_t + σ_q²Δt·∇log p_t(x_t)]‖²/(2σ_q²Δt) + C

⇒ 近似为均值 μ = x_t + σ_q²Δt·∇log p_t(x_t) 的高斯分布 ✓

技术细节:Lemma 1 [Appendix B.1] 证明,当 $\sigma^2$ 为步方差时,每步高斯近似的 KL 误差为 $O(\sigma^4)$,足够快以保证累积误差可控。

2.2 学习均值:回归问题简化

均值函数可通过标准回归学习:

$$ \mu_{t-\Delta t}(z) = \mathbb{E}[x_{t-\Delta t} | x_t = z] = \arg\min_f \mathbb{E}\|f(x_t) - x_{t-\Delta t}\|_2^2 \tag{5} $$

训练流程(简化版)

# 代码:扩散训练损失(基础版)
def diffusion_loss(model, x0, schedule, batch_size):
    """
    Args:
        model: 神经网络 f_θ(x_t, t) → μ_t(x_t)
        schedule: 噪声调度器,提供 σ_t
        x0: 原始数据批次 [B, d]
    """
    # 1. 随机采样时间和噪声
    t = torch.rand(batch_size, device=x0.device)  # t ~ Uniform[0,1]
    sigma_t = schedule(t)  # σ_t
    eps = torch.randn_like(x0)  # ε ~ N(0, I)
    
    # 2. 前向加噪:利用可直接采样性质 (2)
    x_t = x0 + sigma_t.unsqueeze(1) * eps  # [B, d]
    
    # 3. 预测去噪后的值(即条件期望)
    mu_pred = model(x_t, t)  # f_θ(x_t, t) ≈ E[x_{t-Δt}|x_t]
    
    # 4. 回归损失(L2)
    # 注意:实际训练中通常预测的是 x_{t+Δt} 给定 x_t
    loss = torch.mean((mu_pred - x0)**2)
    return loss

2.3 离散化细节:方差缩放

为确保终端分布 $p_T$ 的方差与步数 $T$ 无关,需缩放每步噪声方差:

$$ \sigma = \sigma_q \sqrt{\Delta t}, \quad \Delta t = 1/T \tag{6} $$

此时前向过程更新为:

$$ x_{t+\Delta t} := x_t + \eta_t, \quad \eta_t \sim \mathcal{N}(0, \sigma_q^2 \Delta t) \tag{7} $$

且边际分布满足:

$$ x_t \sim \mathcal{N}(x_0, \sigma_t^2), \quad \sigma_t = \sigma_q \sqrt{t} \tag{8} $$

3. 算法篇:从DDPM到DDIM

3.1 DDPM:随机反向采样器

采样算法 [Algorithm 1, Nakkiran et al.]

输入:噪声样本 x₁ ~ N(0, σ_q²), 训练好的模型 f_θ ≈ μ_{t-Δt}
输出:生成样本 x̂₀

x ← Sample from N(0, σ_q²)  # 从纯噪声开始
for t = 1, 1-Δt, 1-2Δt, ..., Δt:
    η ~ N(0, σ_q²Δt)                    # 添加随机噪声
    x ← f_θ(x, t) + η                  # 去噪 + 随机扰动
return x
# 代码:DDPM采样(随机版本)
@torch.no_grad()
def ddpm_sample(model, schedule, steps=1000, batch_size=1):
    """
    DDPM随机采样器实现
    """
    device = next(model.parameters()).device
    sigma_q = schedule.sigma_max
    
    # 1. 从纯噪声开始: x₁ ~ N(0, σ_q²)
    x = torch.randn(batch_size, *model.input_shape, device=device) * sigma_q
    
    # 2. 时间步调度
    sigmas = schedule.get_sigmas(steps)  # [steps+1]
    
    # 3. 反向迭代
    for i in range(steps):
        t = (steps - i) / steps  # 当前时间
        sigma_t = sigmas[i]
        sigma_next = sigmas[i+1]
        
        # 预测条件期望: μ_{t-Δt}(x_t)
        mu_pred = model(x, torch.full((batch_size,), t, device=device))
        
        # 添加随机噪声(关键!保证采样多样性)
        noise = torch.randn_like(x) * torch.sqrt(sigma_t**2 - sigma_next**2)
        x = mu_pred + noise
    
    return x

正确性证明要点

通过贝叶斯规则 + Taylor展开可证明 [Claim 1]:

$$ \log p(x_{t-\Delta t}|x_t) = -\frac{1}{2\sigma_q^2\Delta t}\|x_{t-\Delta t} - \mu\|^2 + C $$

其中 $\mu = x_t + \sigma_q^2\Delta t \nabla \log p_t(x_t)$,即score function自然出现!

3.2 DDIM:确定性反向采样器

核心思想差异

特性 DDPM(随机) DDIM(确定性)
输出分布 点态:$F_t(x_t) \sim p(x_{t-1} x_t)$
采样结果 同一 $x_1$ → 不同 $x_0$(随机) 同一 $x_1$ → 相同 $x_0$(确定)
物理类比 布朗运动粒子 流体速度场传输

采样算法 [Algorithm 2, Nakkiran et al.]

$$ \hat{x}_{t-\Delta t} = x_t + \lambda \left( \mu_{t-\Delta t}(x_t) - x_t \right), \quad \lambda = \frac{\sigma_t}{\sigma_t - \sigma_{t-\Delta t}} \tag{9} $$
# 代码:DDIM采样(确定性版本)
@torch.no_grad()
def ddim_sample(model, schedule, steps=50, batch_size=1):
    """
    DDIM确定性采样器实现
    """
    device = next(model.parameters()).device
    sigma_q = schedule.sigma_max
    
    # 1. 从纯噪声开始
    x = torch.randn(batch_size, *model.input_shape, device=device) * sigma_q
    sigmas = schedule.get_sigmas(steps)
    
    for i in range(steps):
        t = (steps - i) / steps
        sigma_t = sigmas[i]
        sigma_next = sigmas[i+1]
        
        # 预测条件期望
        mu_pred = model(x, torch.full((batch_size,), t, device=device))
        
        # DDIM更新:确定性传输
        lambda_factor = sigma_t / (sigma_t - sigma_next + 1e-8)
        x = x + lambda_factor * (mu_pred - x)
    
    return x

速度场解释(物理直觉)

定义速度场:

$$ v_t(x_t) = \frac{\lambda}{\Delta t} \left( \mathbb{E}[x_{t-\Delta t}|x_t] - x_t \right) \tag{10} $$

则DDIM更新可写为:

$$ \hat{x}_{t-\Delta t} = x_t + v_t(x_t) \Delta t \tag{11} $$
物理类比:气体粒子流动

    t=1.0          t=0.5          t=0.0
    ┌─────┐       ┌─────┐       ┌─────┐
    │  ·  │  ──►  │ ↗↖ │  ──►  │  ★  │
    │ ··  │       │↗  ↖│       │ ★★  │
    │·····│       │ →v→ │       │★★★★★│
    └─────┘       └─────┘       └─────┘
    
    高斯分布    速度场引导    目标分布
    (噪声)     (传输方向)    (数据)

3.3 方差缩减技巧:预测 $x_0$ vs $\epsilon$

关键关系 [Claim 2, Nakkiran et al.]

$$ \mathbb{E}[x_{t-\Delta t} - x_t | x_t] = \frac{\Delta t}{t} \mathbb{E}[x_0 - x_t | x_t] \tag{12} $$

直观解释:给定 $x_t$,最后一步噪声 $\eta_{t-\Delta t}$ 与之前所有噪声步骤"看起来一样"(对称性),因此可用平均噪声估计单步噪声,显著降低方差。

三种参数化对比

预测目标 损失函数 优点 适用场景
$x_{t-1}$ $\|f_\theta - x_{t-1}\|^2$ 直接对应采样 理论分析
$x_0$ ⭐ $\|f_\theta - x_0\|^2$ 方差小,训练稳定 实际使用
$\epsilon$ $\|f_\theta - \epsilon\|^2$ 与噪声直接相关 高噪声阶段
$v$ [Salimans & Ho] $\|f_\theta - (\alpha_t\epsilon - \sigma_t x_0)\|^2$ 高低噪声平衡 高级调优
# 代码:统一参数化框架
class DiffusionModel(nn.Module):
    def __init__(self, prediction_type='epsilon'):
        """
        prediction_type: 'x0', 'epsilon', 'v', or 'x_prev'
        """
        super().__init__()
        self.prediction_type = prediction_type
        self.net = UNet(...)  # 实际网络架构
    
    def compute_loss(self, x0, t, schedule):
        """计算训练损失"""
        sigma_t = schedule(t)
        eps = torch.randn_like(x0)
        
        # 前向加噪
        x_t = x0 + sigma_t.unsqueeze(1) * eps
        
        # 根据参数化类型确定目标
        if self.prediction_type == 'x0':
            target = x0
        elif self.prediction_type == 'epsilon':
            target = eps
        elif self.prediction_type == 'v':
            # v-prediction: v = α_t·ε - σ_t·x₀
            alpha_t = torch.sqrt(1 - sigma_t**2)
            target = alpha_t.unsqueeze(1) * eps - sigma_t.unsqueeze(1) * x0
        elif self.prediction_type == 'x_prev':
            # 预测 x_{t-Δt}
            sigma_prev = schedule(t - 1/1000)  # 简化
            target = (sigma_prev/sigma_t) * x0 + ...  # 略
        
        # 网络预测
        pred = self.net(x_t, t)
        return torch.mean((pred - target)**2)

4. 理论篇:优化视角与流匹配

4.1 去噪作为近似投影(优化视角)[Yuan & Permenter]

核心观点

学习到的去噪器 $\epsilon_\theta(x, \sigma)$ 可解释为数据流形 $\mathcal{K}$ 的近似投影梯度

📐 距离函数与投影

关键定理

$$ \frac{1}{2}\nabla_x \text{dist}_\mathcal{K}^2(x, \sigma) = \sigma \cdot \epsilon^*(x, \sigma) \tag{13} $$

其中 $\text{dist}_\mathcal{K}^2(x, \sigma)$ 是平滑后的平方距离函数。

几何解释:

              数据流形 𝒦
              ╱╲
            ╱    ╲
          ╱   ★    ╲    ← 投影点 proj_𝒦(x)
        ╱  ╱││╲   ╲
      x ─●──┼┼─────►  ∇dist²(x,σ) 方向
          ╲ ││╱
            ╲╱
             
    去噪器输出:x - σ·ε_θ(x,σ) ≈ proj_𝒦(x)

相对误差模型(实践指导)

假设当 $\frac{1}{\nu}\text{dist}_\mathcal{K}(x) \leq \sqrt{n}\sigma \leq \nu \cdot \text{dist}_\mathcal{K}(x)$ 时:

$$ \|x - \sigma\epsilon_\theta(x,\sigma) - \text{proj}_\mathcal{K}(x)\| \leq \eta \cdot \text{dist}_\mathcal{K}(x) \tag{14} $$

工程意义

4.2 Flow Matching:扩散的推广

什么是流(Flow)?

= 时间索引的向量场集合 $\{v_t\}_{t \in [0,1]}$,定义粒子轨迹:

$$ \frac{dx_t}{dt} = -v_t(x_t), \quad x_1 \sim q \to x_0 \sim p \tag{15} $$

点wise流 → 边际流

  1. 点wise流:对任意点对 $(x_1, x_0)$,定义连接它们的流 $v^{[x_1,x_0]}$
  2. 边际流:通过加权平均组合所有点wise流:
$$ v_t^*(x_t) = \mathbb{E}_{x_0,x_1|x_t}\left[ v_t^{[x_1,x_0]}(x_t) \mid x_t \right] \tag{16} $$

线性流(最简单选择)[Liu et al., 2022]

$$ v_t^{[x_1,x_0]}(x_t) = x_0 - x_1 \quad \Rightarrow \quad x_t = t x_1 + (1-t) x_0 \tag{17} $$
# 代码:Flow Matching训练(线性流)
def flow_matching_loss(model, x0, x1, t):
    """
    Args:
        x0: 目标数据 ~ p
        x1: 源噪声 ~ q (如高斯)
        t: 时间 ~ Uniform[0,1]
    """
    # 线性插值轨迹
    x_t = t * x1 + (1 - t) * x0
    
    # 线性流的速度场(常数)
    v_target = x0 - x1
    
    # 回归损失
    v_pred = model(x_t, t)
    return torch.mean((v_pred - v_target)**2)

DDIM作为Flow Matching的特例

Claim 4 [Nakkiran et al., Appendix B.7]: DDIM等价于使用"扩散耦合"的线性流匹配,仅时间参数化不同($\sqrt{t}$ vs $t$)

┌────────────────────────────────────┐
│  DDIM轨迹:     x_t = x₀ + (x₁-x₀)√t  │
│  线性流轨迹:   x_t = x₀ + (x₁-x₀)·t  │
│                                     │
│  关系: DDIM在时间t处 = 线性流在时间√t处 │
└────────────────────────────────────┘

4.3 概率流ODE(连续时间视角)

从离散到连续

当 $\Delta t \to 0$,DDIM收敛到概率流ODE [Song et al., 2020]:

$$ \frac{dx_t}{dt} = -\frac{1}{2t} \mathbb{E}[x_0 - x_t | x_t] \tag{18} $$

SDE ↔ ODE 对应关系

形式 方程 特点
前向SDE $dx = \sigma_q dw$ 纯扩散,零漂移
反向SDE $dx = -\sigma_q^2 \nabla \log p_t(x) dt + \sigma_q dw$ 随机,需学习score
概率流ODE $dx = -\frac{1}{2}\sigma_q^2 \nabla \log p_t(x) dt$ 确定性,相同边际分布
# 代码:ODE求解器采样(使用torchdiffeq)
from torchdiffeq import odeint

class ProbabilityFlowODE:
    def __init__(self, score_model):
        self.score_model = score_model  # 学习 ∇log p_t(x)
    
    def reverse_ode(self, t, x):
        """ODE的右端函数: dx/dt = f(x,t)"""
        score = self.score_model(x, t)  # ≈ ∇log p_t(x)
        return -0.5 * score  # 简化版本(σ_q=1)
    
    def sample(self, x1, t_span):
        """从t=1积分到t=0"""
        return odeint(self.reverse_ode, x1, t_span, method='dopri5')

5. 实践篇:工程实现与设计选择

5.1 噪声调度(Noise Schedule)

常见调度对比与实现

# 代码:多种噪声调度实现
class NoiseSchedule:
    """噪声调度基类"""
    def __call__(self, t):
        raise NotImplementedError

class ScheduleLogLinear(NoiseSchedule):
    """对数线性调度:σ ∈ [σ_min, σ_max] 对数均匀 [Karras et al., 2022]"""
    def __init__(self, N, sigma_min=0.02, sigma_max=10.0):
        self.sigmas = torch.logspace(
            math.log10(sigma_min), 
            math.log10(sigma_max), 
            N
        )
    
    def __call__(self, t):
        idx = (t * (len(self.sigmas)-1)).long().clamp(0, len(self.sigmas)-2)
        return self.sigmas[idx]

class ScheduleDDPM(NoiseSchedule):
    """DDPM方差保持调度:Var(x_t) ≈ 常数 [Ho et al., 2020]"""
    def __init__(self, N, beta_min=0.0001, beta_max=0.02):
        betas = torch.linspace(beta_min, beta_max, N)
        alphas = 1 - betas
        self.alphas_cumprod = torch.cumprod(alphas, dim=0)
    
    def __call__(self, t):
        idx = (t * (len(self.alphas_cumprod)-1)).long()
        return torch.sqrt(1 - self.alphas_cumprod[idx])

class ScheduleKarras(NoiseSchedule):
    """Karras调度:σ(t) = t, 配合整体缩放 [Karras et al., 2022]"""
    def __init__(self, N, sigma_min=0.002, sigma_max=80.0):
        self.sigmas = ((sigma_max**(1/3) + 
                       torch.linspace(0, 1, N) * (sigma_min**(1/3) - sigma_max**(1/3)))**3)
    
    def __call__(self, t):
        idx = (t * (len(self.sigmas)-1)).long()
        return self.sigmas[idx]

调度选择建议

图像生成推荐:
├── 像素空间低分辨率:ScheduleDDPM(方差保持,训练稳定)
├── 潜空间/高分辨率:ScheduleLogLinear 或 ScheduleKarras(灵活,适合多尺度)
├── 快速采样:自定义稀疏调度(如 [1.0, 0.5, 0.2, 0.05, 0])
└── 理论研究:固定σ(简化分析)

调度可视化:
    log(σ)
       ↑
    2 ┤    ╱╲╱╲╱╲     ← LogLinear(对数均匀)
      │   ╱    ╲
    1 ┤  ╱      ╲╱╲   ← DDPM(前期缓,后期陡)
      │ ╱
    0 ┼╱─────────────→ t
      0    0.5    1.0

5.2 潜空间扩散(Latent Diffusion)[Sander Dieleman]

为什么需要潜空间?

┌─────────────────────────────────────┐
│ 问题:高分辨率图像直接在像素空间扩散   │
│ • 计算成本高:256×256×3 = 196,608维  │
│ • 迭代次数多:每步需网络前向传播      │
│ • 细节建模难:纹理/高频信息难学习    │
└─────────────────────────────────────┘

解决方案:两阶段架构
┌─────────┐     ┌─────────┐     ┌─────────┐
│ 编码器  │────►│ 潜空间  │────►│ 扩散模型│
│ E: x→z  │     │ z∈ℝ^d  │     │ p(z)    │
└─────────┘     └─────────┘     └─────────┘
                     │
                     ▼
              ┌─────────┐
              │ 解码器  │
              │ D: z→x  │
              └─────────┘

Stable Diffusion架构示意

# 概念代码:潜空间扩散训练流程
class LatentDiffusion:
    def __init__(self, autoencoder, unet, schedule):
        self.autoencoder = autoencoder  # 预训练,冻结
        self.unet = unet                 # 可训练
        self.schedule = schedule
    
    def train_step(self, x0, text_condition=None):
        # 1. 编码到潜空间(冻结)
        with torch.no_grad():
            z0 = self.autoencoder.encode(x0)
        
        # 2. 潜空间扩散训练
        t = torch.rand(z0.shape[0], device=z0.device)
        sigma_t = self.schedule(t)
        eps = torch.randn_like(z0)
        z_t = z0 + sigma_t.unsqueeze(1) * eps
        
        # 3. UNet预测(可加入text condition)
        eps_pred = self.unet(z_t, t, context=text_condition)
        
        return torch.mean((eps_pred - eps)**2)
    
    @torch.no_grad()
    def sample(self, condition=None, steps=50):
        # 1. 潜空间采样
        z_1 = torch.randn(1, *self.latent_shape)
        z_0 = ddim_sample(self.unet, self.schedule, steps)(z_1, condition)
        
        # 2. 解码回像素空间
        x_0 = self.autoencoder.decode(z_0)
        return x_0

5.3 完整实现示例:2D螺旋数据集

# 完整可运行示例:2D螺旋数据扩散
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 1. 螺旋数据生成
def swiss_roll(n_samples=1000):
    """生成2D螺旋数据集"""
    n = np.sqrt(np.random.rand(n_samples)) * 5 * np.pi
    d = np.random.rand(n_samples) * 2 - 1
    x = n * np.cos(n) + 0.5 * d * np.sin(n)
    y = n * np.sin(n) + 0.5 * d * np.cos(n)
    return torch.stack([torch.tensor(x), torch.tensor(y)], dim=1).float()

# 2. 时间嵌入 + MLP去噪器
class TimeEmbedding(nn.Module):
    def __call__(self, sigma):
        """简单的2维时间嵌入"""
        sigma = sigma.unsqueeze(1)
        return torch.cat([
            torch.sin(torch.log(sigma) / 2),
            torch.cos(torch.log(sigma) / 2)
        ], dim=1)

class DenoiserMLP(nn.Module):
    def __init__(self, dim=2, hidden_dims=[16, 128, 128, 16]):
        super().__init__()
        layers = []
        dims = [dim + 2] + hidden_dims  # +2 for time embedding
        for i in range(len(dims)-1):
            layers.extend([
                nn.Linear(dims[i], dims[i+1]),
                nn.GELU()
            ])
        layers.append(nn.Linear(hidden_dims[-1], dim))
        self.net = nn.Sequential(*layers)
        self.time_emb = TimeEmbedding()
    
    def forward(self, x, sigma):
        t_emb = self.time_emb(sigma)  # [B, 2]
        inp = torch.cat([x, t_emb], dim=1)  # [B, dim+2]
        return self.net(inp)

# 3. 训练循环
def train_diffusion(model, data_loader, schedule, epochs=5000, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        for x0 in data_loader:
            optimizer.zero_grad()
            
            # 采样训练样本
            t = torch.rand(x0.shape[0])
            sigma_t = schedule(t)
            eps = torch.randn_like(x0)
            x_t = x0 + sigma_t.unsqueeze(1) * eps
            
            # 预测噪声(这里用epsilon参数化)
            eps_pred = model(x_t, sigma_t)
            loss = torch.mean((eps_pred - eps)**2)
            
            loss.backward()
            optimizer.step()
        
        if epoch % 500 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# 4. 采样与可视化
@torch.no_grad()
def sample_and_plot(model, schedule, n_samples=200, steps=20):
    model.eval()
    sigmas = schedule.get_sigmas(steps)
    
    # 从噪声开始
    x = torch.randn(n_samples, 2) * sigmas[0]
    
    # DDIM采样轨迹记录
    trajectory = [x.clone()]
    for i in range(steps):
        t = (steps - i) / steps
        sigma_t = sigmas[i]
        sigma_next = sigmas[i+1]
        
        # 预测x0(epsilon参数化转x0)
        eps_pred = model(x, torch.full((n_samples,), t))
        x0_pred = (x - sigma_t * eps_pred) / torch.sqrt(1 - sigma_t**2 + 1e-8)
        
        # DDIM更新
        direction = (x - sigma_t * x0_pred) / torch.sqrt(1 - sigma_t**2 + 1e-8)
        x = sigma_next * x0_pred + torch.sqrt(1 - sigma_next**2 + 1e-8) * direction
        trajectory.append(x.clone())
    
    # 可视化
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    for idx, t_idx in enumerate([0, 7, 14, 20]):
        ax = axes[idx]
        ax.scatter(trajectory[t_idx][:,0], trajectory[t_idx][:,1], 
                  s=2, alpha=0.5, c='blue')
        ax.set_title(f'Step {t_idx}/{steps} (t={1-t_idx/steps:.2f})')
        ax.set_xlim(-10, 10)
        ax.set_ylim(-10, 10)
        ax.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()

训练效果可视化

采样轨迹演变(20步DDIM):

Step 0 (t=1.0)          Step 7 (t=0.65)      
┌─────────┐            ┌─────────┐
│  · · ·  │            │  ╱╲    │
│ · · · · │   ──►     │ ╱  ╲   │
│  · · ·  │            │╱    ╲╱ │
└─────────┘            └─────────┘
  高斯噪声              开始形成螺旋

Step 14 (t=0.3)         Step 20 (t=0.0)
┌─────────┐            ┌─────────┐
│  ╱╲╱╲  │            │ ★★★★★ │
│ ╱    ╲ │   ──►     │★     ★│
│╱      ╲│            │ ★★★★★ │
└─────────┘            └─────────┘
  螺旋轮廓清晰          收敛到数据分布

5.4 常见陷阱与调试技巧

⚠️ 训练不稳定?

# 技巧1:梯度裁剪(防止梯度爆炸)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# 技巧2:学习率调度(余弦退火)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=epochs, eta_min=1e-5
)

# 技巧3:损失加权(不同噪声水平不同权重)
def weighted_loss(eps_pred, eps, sigma_t, weighting='uniform'):
    if weighting == 'uniform':
        weights = torch.ones_like(sigma_t)
    elif weighting == 'snr':  # Signal-to-Noise Ratio weighting
        weights = 1 / (sigma_t**2 + 1e-6)
    elif weighting == 'min_snr':  # Min-SNR weighting [Hang et al.]
        snr = 1 / sigma_t**2
        weights = torch.minimum(snr, torch.ones_like(snr) * 5)
    return torch.mean(weights * (eps_pred - eps)**2)

⚠️ 采样质量差?

# 技巧1:增加采样步数(但计算成本↑)
# 技巧2:使用更高级的ODE求解器(DPM-Solver, Heun, etc.)
# 技巧3:添加引导(classifier-free guidance)

def classifier_free_guidance(model, x_t, t, cond, uncond, guidance_scale=7.5):
    """CFG: 增强条件生成的质量"""
    # 条件预测
    eps_cond = model(x_t, t, context=cond)
    # 无条件预测(训练时随机drop condition)
    eps_uncond = model(x_t, t, context=uncond)
    # 引导组合
    return eps_uncond + guidance_scale * (eps_cond - eps_uncond)

# 技巧4:动态阈值(防止采样发散)
def dynamic_thresholding(x, percentile=99.5):
    """防止采样值过大导致图像异常"""
    s = torch.quantile(x.abs().reshape(x.shape[0], -1), percentile/100, dim=1)
    s = torch.clamp(s, min=1.0)
    return x.clamp(-s.unsqueeze(1), s.unsqueeze(1))

⚠️ 过拟合/记忆训练数据?

🔍 诊断方法:
1. 检查生成样本与训练样本的L2距离(近邻搜索)
2. 可视化训练/验证损失曲线(验证损失上升→过拟合)
3. 使用FID/IS等指标评估泛化能力
4. 人工检查生成样本多样性

✅ 缓解策略:
• 增加数据增强(随机裁剪、翻转、颜色抖动)
• 使用更大的模型(隐式正则化)
• 早停(early stopping)+ 验证集监控
• 添加dropout/weight decay
• 增加训练数据量(最有效)

📊 记忆化现象可视化 [Nakkiran et al., Fig 6]:
    N=10样本:轨迹坍缩到训练点 → 记忆化
    N=40样本:轨迹学习螺旋流形 → 泛化

6. 进阶篇:学习理论与前沿方向

6.1 误差来源分析 [Nakkiran et al., Sec 5]

扩散模型的生成误差可分解为:

总误差 = 训练误差 + 采样误差 + 模型偏差

1️⃣ 训练误差(统计误差)
   • 有限样本:回归目标估计不准
   • 有限模型:函数类表达能力不足
   • 优化误差:未收敛到全局最优

2️⃣ 采样误差(数值误差)
   • 离散化:Δt不够小,高斯近似失效
   • ODE/SDE求解器:数值积分误差
   • 截断:早期停止采样

3️⃣ 模型偏差(结构性误差)
   • 高斯假设:真实反向条件分布非高斯
   • 马尔可夫假设:忽略长程依赖
   • 参数化选择:不同预测目标影响收敛

6.2 泛化与记忆化的平衡

关键点:完美拟合训练数据 ≠ 好的生成模型

# 概念验证:记忆化检测
def check_memorization(generated_samples, train_samples, threshold=0.1):
    """检测生成样本是否记忆训练数据"""
    from scipy.spatial.distance import cdist
    
    # 计算生成样本与训练样本的最小L2距离
    distances = cdist(
        generated_samples.reshape(len(generated_samples), -1),
        train_samples.reshape(len(train_samples), -1),
        metric='euclidean'
    )
    min_distances = distances.min(axis=1)
    
    # 统计"太接近"训练样本的比例
    memorized_ratio = (min_distances < threshold).mean()
    return memorized_ratio, min_distances

# 实践建议:
# • 记忆化比例 < 1% 通常可接受
# • 高记忆化 → 增加数据/正则化/早停

6.3 前沿方向速览

方向 核心思想 代表工作 实用价值
一致性模型 学习一步映射,蒸馏多步扩散 [Song et al., 2023] ⭐⭐⭐⭐⭐ 快速采样
整流流 学习直线轨迹,减少采样步数 [Liu et al., 2022] ⭐⭐⭐⭐ 训练稳定
潜空间优化 更高效的潜空间编码/解码 [Esser et al., 2024] ⭐⭐⭐⭐ 高分辨率
理论分析 泛化边界、收敛速率 [Chen et al., 2022-2024] ⭐⭐⭐ 指导设计
离散扩散 文本/图/离散结构的扩散 [Austin et al., 2021] ⭐⭐⭐⭐ 多模态

总结

                    扩散模型
                       │
     ┌─────────┬───────┴───────┬─────────┐
     ▼         ▼               ▼         ▼
  数学基础   核心算法       理论解释   工程实践
     │         │               │         │
     ▼         ▼               ▼         ▼
 • 高斯分布  • DDPM        • 投影解释  • 噪声调度
 • 条件期望  • DDIM        • 流匹配   • 潜空间
 • 贝叶斯规则• Flow Matching• SDE/ODE  • 参数化选择
 • Taylor展开• 概率流ODE   • 相对误差  • 采样加速