- diffusion models are generative
- random noise -> denoised output over several iterations
- at each step, estimate how to go from the current step to a completely denoised version
- since each step has small changes, error in early steps can be corrected by later steps
Training Procedure 1. Load in images from the training data
2. Add noise in different amounts (both noisy and close-to-perfect)
3. Feed noisy versions of the inputs into the model
4. Evaluate how well the model does denoising inputs
5. Update weights
Forward Diffusion Process (Noising) q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t}x_{t-1}, \beta_t I) q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) noise schedule (β t \beta_t β t ) - controls how noise is introduced into the process
- 0 < β t < 1 0 < \beta_t < 1 0 < β t < 1 - typically, β t \beta_t β t starts small (don't destroy too much info early) - β t \beta_t β t values are usually small overall, between the 0.0001 0.0001 0.0001 to 0.02 0.02 0.02 range💡
N ( 0 , 1 ) \mathcal{N}(0, 1) N ( 0 , 1 ) - a scalar (single number) drawn from a standard normal
N ( 0 , I ) \mathcal{N}(0, I) N ( 0 , I ) - a vector where each component is independently sampled from N ( 0 , 1 ) \mathcal{N}(0, 1) N ( 0 , 1 )
For normal distributions, we have
x ∼ N ( μ , σ 2 ) = N ( x ; μ , σ 2 ) x \sim \mathcal{N}(\mu, \sigma^2) = \mathcal{N}(x;\mu, \sigma^2) x ∼ N ( μ , σ 2 ) = N ( x ; μ , σ 2 ) To sample from this, we take
x = μ + σ ⋅ ϵ , ϵ ∼ N ( 0 , I ) x = \mu + \sigma \cdot \epsilon, \epsilon \sim \mathcal{N}(0, I) x = μ + σ ⋅ ϵ , ϵ ∼ N ( 0 , I ) Therefore, our noise updates become
x t = 1 − β t x t − 1 + β t ⋅ ϵ , ϵ ∼ N ( 0 , I ) x_t = \sqrt{1-\beta_t}x_{t-1} + \sqrt{\beta_t} \cdot \epsilon, \epsilon \sim \mathcal{N}(0, I) x t = 1 − β t x t − 1 + β t ⋅ ϵ , ϵ ∼ N ( 0 , I ) - as we can see, each weight update scales down the current vector/image x t − 1 x_{t-1} x t − 1 before adding random noise - we use 1 − β t \sqrt{1 - \beta_t} 1 − β t and β t \sqrt{\beta_t} β t so that the variance of the image doesn't explode Mathematical derivation for mean/variance of noising process at t t t :
x t = 1 − β t x t − 1 + β t ⋅ ϵ , ϵ ∼ N ( 0 , I ) ⇒ x 1 = 1 − β 1 x 0 + β 1 ⋅ ϵ 1 , x 2 = 1 − β 2 x 1 + β 2 ⋅ ϵ 2 , ⇒ x 2 = 1 − β 2 ( 1 − β 1 x 0 + β 1 ⋅ ϵ 1 ) + β 2 ⋅ ϵ 2 ⇒ x 2 = ( 1 − β 2 ) ( 1 − β 1 ) x 0 + ( 1 − β 2 ) β 1 ⋅ ϵ 1 + β 2 ⋅ ϵ 2
\begin{align*}
x_t &= \sqrt{1-\beta_t}x_{t-1} + \sqrt{\beta_t} \cdot \epsilon, \epsilon \sim \mathcal{N}(0, I)\\\\
\Rightarrow x_1 &= \sqrt{1-\beta_1}x_{0} + \sqrt{\beta_1} \cdot \epsilon_1,\\\\
x_2 &= \sqrt{1-\beta_2}x_{1} + \sqrt{\beta_2} \cdot \epsilon_2, \\\\
\Rightarrow x_2 &= \sqrt{1-\beta_2}(\sqrt{1-\beta_1}x_{0} + \sqrt{\beta_1} \cdot \epsilon_1) + \sqrt{\beta_2} \cdot \epsilon_2 \\\\
\Rightarrow x_2 &= (\sqrt{1-\beta_2})(\sqrt{1-\beta_1})x_{0} + (\sqrt{1-\beta_2})\sqrt{\beta_1} \cdot \epsilon_1 + \sqrt{\beta_2} \cdot \epsilon_2
\end{align*} x t ⇒ x 1 x 2 ⇒ x 2 ⇒ x 2 = 1 − β t x t − 1 + β t ⋅ ϵ , ϵ ∼ N ( 0 , I ) = 1 − β 1 x 0 + β 1 ⋅ ϵ 1 , = 1 − β 2 x 1 + β 2 ⋅ ϵ 2 , = 1 − β 2 ( 1 − β 1 x 0 + β 1 ⋅ ϵ 1 ) + β 2 ⋅ ϵ 2 = ( 1 − β 2 ) ( 1 − β 1 ) x 0 + ( 1 − β 2 ) β 1 ⋅ ϵ 1 + β 2 ⋅ ϵ 2 Notice the pattern: the coefficient of x 0 x_0 x 0 (signal term, from original image) becomes ∏ i = 1 t 1 − β i x 0 = ∏ i = 1 t ( 1 − β i ) \prod_{i=1}^{t} \sqrt{1 - \beta_i} x_0 = \sqrt{\prod_{i=1}^t (1 - \beta_i)} ∏ i = 1 t 1 − β i x 0 = ∏ i = 1 t ( 1 − β i ) .
Furthermore, we notice that
( 1 − β 2 ) β 1 ⋅ ϵ 1 + β 2 ⋅ ϵ 2 ∼ N ( 0 , ( 1 − β 2 ) β 1 ) + N ( 0 , β 2 ) , ( 1 − β 2 ) β 1 ⋅ ϵ 1 + β 2 ⋅ ϵ 2 ∼ N ( 0 , ( 1 − β 2 ) β 1 + β 2 ) , ∼ N ( 0 , 1 − ( 1 − β 2 ) ( 1 − β 1 ) ) , ∼ 1 − ∏ i = 1 t ( 1 − β i ) ⋅ N ( 0 , 1 ) . \begin{align*}
(\sqrt{1-\beta_2})\sqrt{\beta_1} \cdot \epsilon_1 + \sqrt{\beta_2} \cdot \epsilon_2 &\sim \mathcal{N}(0, (1-\beta_2)\beta_1) + \mathcal{N}(0, \beta_2),\\\\
(\sqrt{1-\beta_2})\sqrt{\beta_1} \cdot \epsilon_1 + \sqrt{\beta_2} \cdot \epsilon_2 &\sim \mathcal{N}(0, (1-\beta_2)\beta_1 + \beta_2),\\\\
&\sim \mathcal{N}(0, 1 - (1 - \beta_2)(1 - \beta_1)),\\\\
&\sim \sqrt{1 - \prod_{i=1}^t (1 - \beta_i)} \cdot \mathcal{N}(0, 1).
\end{align*} ( 1 − β 2 ) β 1 ⋅ ϵ 1 + β 2 ⋅ ϵ 2 ( 1 − β 2 ) β 1 ⋅ ϵ 1 + β 2 ⋅ ϵ 2 ∼ N ( 0 , ( 1 − β 2 ) β 1 ) + N ( 0 , β 2 ) , ∼ N ( 0 , ( 1 − β 2 ) β 1 + β 2 ) , ∼ N ( 0 , 1 − ( 1 − β 2 ) ( 1 − β 1 )) , ∼ 1 − i = 1 ∏ t ( 1 − β i ) ⋅ N ( 0 , 1 ) . by the Sum of Gaussians rule above. Now, letting α t = ∏ i = 1 t ( 1 − β i ) \alpha_t = \prod_{i=1}^t (1 - \beta_i) α t = ∏ i = 1 t ( 1 − β i ) , we have
x t = α t x 0 + 1 − α t ⋅ ϵ , ϵ ∼ N ( 0 , I ) x_t = \sqrt{\alpha_t} x_0 + \sqrt{1 - \alpha_t} \cdot \epsilon, \epsilon \sim \mathcal{N}(0, I) x t = α t x 0 + 1 − α t ⋅ ϵ , ϵ ∼ N ( 0 , I ) Since 0 < β i < 1 0 < \beta_i < 1 0 < β i < 1 , for sufficiently large k k k , we have lim t → k α k = 0 \lim_{t \rightarrow k} \sqrt{\alpha_k} = 0 lim t → k α k = 0 while lim t → k 1 − α k = 1 \lim_{t \rightarrow k} \sqrt{1 - \alpha_k} = 1 lim t → k 1 − α k = 1 .
Thus, as the number of iterations (noising steps) increases, the original signal x 0 x_0 x 0 dies while the noise takes over the entire image. Furthermore, the entire system is distributed normally.