Score-based Generative Model through SDE
1. Notations
The meaning of ∇⋅[⋅]:
For scalar function ϕ(x):Rd→R, we know ∇ϕ(x) is the gradient.
For vector function ϕ(x):Rd→Rd, ∇⋅ϕ(x)∈R is the divergence, defined as the sum of the partial derivatives of each component:
∇⋅ϕ(x)=i=1∑d∂xi∂ϕi
For matrix function ϕ(x):Rd→Rd×d, its divergence is a vector, defined as:
∇⋅ϕ(x)=(j=1∑d∂xj∂ϕ1j,…,j=1∑d∂xj∂ϕdj)⊤
or written in a more general form:
∇⋅ϕ(x):=(∇⋅ϕ1(x),…,∇⋅ϕd(x))⊤
where ϕi(x) is the ith row of ϕ.
2. Score-based Diffusion through SDE (Unconditional)
SDE of forward process(noise):
dx=f(x,t)dt+g(x,t)dwf:Rd×[0,T]→Rd,g:Rd×[0,T]→Rd×d
SDE of reverse process(denoise):
dx=(f(x,t)−∇⋅[g(x,t)g(x,t)T]−g(x,t)g(x,t)T∇xlnpt(x))dt+g(x,t)dw
probability flow ODE:
dx=(f(x,t)−21∇⋅[g(x,t)g(x,t)T]−21g(x,t)g(x,t)T∇xlnpt(x))dt
3. Training (Unconditional)
θ∗=argθminEt{λ(t)Ex(0)Ex(t)Ev∼pv[v⊤sθ(x(t),t)v+21(vTsθ(x(t),t))2]}
- Denoising Score Matching:
θ∗=argθminEt∼U(0,T)λ(t)Ex0∼p0[Ex∼p0t[∥sθ(x,t)−∇xlnp0t(x∣x0)∥22]]
4. Score-based Diffusion through SDE (Conditional)
dx=(f(x,t)−∇⋅[g(x,t)g(x,t)T]−g(x,t)g(x,t)T∇xlnpt(x∣y))dt+g(x,t)dw
∇xlnpt(x∣y)=∇xln(pt(y)pt(y∣x)pt(x))=∇xlnpt(y∣x)+∇xlnpt(x)
dx=(f(x,t)−∇⋅[g(x,t)g(x,t)T]−g(x,t)g(x,t)T(∇xlnpt(y∣x)+∇xlnpt(x)))dt+g(x,t)dw
We can train a neural network c(x,t) to learn lnpt(y∣x)
We can also use some prior knowledge to directly determine ∇xlnpt(y∣x)
5. Special Case
If ∃g~(t)∈R s.t. g(x,t)=g~(t)⋅I, then the reverse process can be written as:
dx=(f(x,t)−g~(t)2∇xlnpt(x))dt+g~(t)dw
6. Denoising Score Matching
Assume forward diffusion process can be written as xt=atx0+btϵ,ϵ∼N(0,I), then
xt∼N(atx0,bt2I)
minimize
t,x0,xt∼p0t(xt∣x0)E∥sθ(xt,t)−∇xtlnp0t(xt∣x0)∥22
- Equivalence of Epsilon Model and Score Model
score=∇xtlnp0t(xt∣x0)=−bt21(xt−atx0)=−btϵ
- Unconditional Score Matching
argmin∥bt−1ϵθ(xt,t)−bt−1ϵ∥22⟺argmin∥ϵθ(xt,t)−ϵ∥22
- Conditional (Loss Guidance) Score Matching
argmin∥bt−1ϵθ(xt,t)−bt−1ϵ−∇x−l(xt,x0)∥22⟺argmin∥ϵθ(xt,t)−ϵ−∇xbtl(xt,x0)∥22
7. VPSDE(Continuous DDPM) (Unconditional)
P(xt∣xt−1)=N(xt;1−βt⋅xt−1,βt⋅I)
xt=1−βt⋅xt−1+βt⋅ϵ,ϵ∼N(0,I)
xt+Δt−xt=1−βt+Δt⋅xt+βt+Δt⋅ϵ−xt
Because 1−x=1−2x+o(x), we have:
xt+Δt−xt=−2βt+Δt⋅xt+βt+Δt⋅ϵ+o(βt+Δt)⋅xt
SDE of forward process(noise):
dx=−2βt⋅xdt+βtdw
SDE of reverse process(denoise):
dx=(−2βt⋅x−βt⋅∇xlnpt(x))dt+βtdw
参数对比
| model | beta | n_step |
|---|
| DDPM | 0.0001-0.02 | 1000 |
| VPSDE | 0.1-20 | 1000 |
8. VESDE
xt∼N(x0,σt2I),σt=σmin(σmax/σmin)t
dxxtVar[xt]g(t)=g(t)dw=x0+∫0tg(s)ds=∫0tg(s)2ds=:σt2=dtdσt2=σt2ln(σmax/σmin)
离散情况下:
xtσt2g(t)=xt−1+g(t)ϵt,ϵt∼N(0,I)=σt−12+g(t)2=σt2−σt−12