- Mô hình khuếch tán được dùng không chỉ cho tạo ảnh mà còn cho các bài toán cần lấy mẫu từ phân phối đa mode như âm thanh, video, 3D, thiết kế protein và lập kế hoạch đường đi cho robot; bài hướng dẫn này kết nối huấn luyện và lấy mẫu từ góc nhìn tối ưu hóa
- Quá trình huấn luyện tạo dữ liệu nhiễu (x_\sigma=x_0+\sigma\epsilon), rồi tối thiểu hóa sai số bình phương trung bình để mạng nơ-ron (\epsilon_\theta(x,\sigma)) dự đoán hướng nhiễu
- Denoiser đã huấn luyện có thể được diễn giải như một phép chiếu xấp xỉ lên tập dữ liệu (\mathcal{K}), và denoiser lý tưởng liên hệ với gradient của hàm khoảng cách bình phương đã được làm mượt theo (\sigma)
- Lấy mẫu DDIM có thể được xem là gradient descent xấp xỉ trên (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2), còn lịch (\sigma_t) quyết định số vòng lặp và chi phí đánh giá denoiser
- Khi kết hợp cập nhật ước lượng gradient với việc thêm nhiễu, có thể mô tả chung DDIM, DDPM và sampler cải tiến của tác giả bằng các tham số
gamvàmu, rồi mở rộng sang ví dụ toy model và latent diffusion
Mô hình khuếch tán dưới góc nhìn tối ưu hóa
- Mô hình khuếch tán có thế mạnh trong việc sinh mẫu từ phân phối đa mode, và được áp dụng không chỉ cho công cụ tạo văn bản-thành-ảnh như Stable Diffusion mà còn cho âm thanh, video, tạo 3D, thiết kế protein và lập kế hoạch đường đi cho robot
- Nền tảng lý thuyết của bài hướng dẫn là diễn giải theo tối ưu hóa từ bài báo ICML 2024 và bài báo liên quan
- Phần triển khai chủ yếu tham chiếu
smalldiffusion, còn mã trong bài được đơn giản hóa hơn thư viện gốc để phục vụ mục đích học tập
Huấn luyện: dự đoán hướng nhiễu
- Mô hình khuếch tán học tập dữ liệu (\mathcal{K}) từ các ví dụ huấn luyện và hướng tới việc sinh mẫu từ tập đó
- Với ảnh, (\mathcal{K} \subset \mathbb{R}^{c\times h \times w}) là tập các giá trị pixel tương ứng với những ảnh thực tế
- Cùng khuôn khổ này cũng áp dụng cho các miền rời rạc như âm thanh, video, quỹ đạo robot và văn bản
- Quy trình huấn luyện có thể xem theo ba bước
- Lấy mẫu (x_0 \sim \mathcal{K}), (\sigma), và (\epsilon \sim N(0,I))
- Tạo dữ liệu có trộn nhiễu bằng (x_\sigma=x_0+\sigma\epsilon)
- Tối thiểu hóa hàm mất mát bình phương để (\epsilon_\theta(x_\sigma,\sigma)) dự đoán (\epsilon)
- Trong mã,
training_loopvới mỗi batchx0sẽ tạosigmavàepsquagenerate_train_sample, rồi tối ưu MSE giữa đầu ra củamodel(x0 + sigma * eps, sigma)vàeps - Thay vì lấy mẫu (\sigma) đều trên một khoảng liên tục, người ta lấy từ lịch (\sigma) đã rời rạc hóa thành (N) giá trị
- Lớp
Schedulebao bọc danh sáchsigmaskhả dụng và lấy mẫu giá trị theo từng batch trong lúc huấn luyện - Ví dụ trong bài dùng
ScheduleLogLinear(N, sigma_min=0.02, sigma_max=10) ScheduleDDPMdành cho mô hình khuếch tán trong không gian pixel, cònScheduleLDMdành cho latent diffusion như Stable Diffusion
- Lớp
Ví dụ toy Swissroll
- Bộ dữ liệu toy là tập điểm xoắn ốc dùng trong một trong những bài báo khuếch tán đầu tiên, Sohl-Dickstein et al. 2015, với (\mathcal{K}\subset\mathbb{R}^2)
- Với bộ dữ liệu đơn giản này, denoiser được cài đặt bằng MLP
- Đầu vào là phép nối giữa (x\in\mathbb{R}^2) và embedding 2 chiều của (\sigma)
- Đầu ra là giá trị dự đoán nhiễu (\epsilon\in\mathbb{R}^2)
- Nhiều mô hình khuếch tán dùng sinusoidal positional embedding cho (\sigma), nhưng ví dụ này cho thấy embedding 2 chiều đơn giản cũng hoạt động tốt
- Thiết lập huấn luyện của ví dụ dùng
ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)vàepochs=15000 - Denoiser đã huấn luyện có thể được trực quan hóa thành trường vector bằng cách vẽ (x-\sigma\epsilon_\theta(x,\sigma))
- Khi (\sigma) lớn, denoiser có xu hướng dự đoán trung bình dữ liệu
- Khi (\sigma) nhỏ và đầu vào (x) gần dữ liệu, nó dự đoán các điểm dữ liệu thực tế
Diễn giải Denoising như phép chiếu
- Hàm khoảng cách tới tập dữ liệu (\mathcal{K}) được định nghĩa là (\mathrm{dist}_{\mathcal{K}}(x)=\min{|x-x_0|:x_0\in\mathcal{K}})
- Phép chiếu của (x) lên (\mathcal{K}), (\mathrm{proj}_{\mathcal{K}}(x)), là tập các điểm trong (\mathcal{K}) đạt được khoảng cách đó
- Nếu (\mathcal{K}) là tập đóng, (x\notin\mathcal{K}) và phép chiếu là duy nhất, thì gradient của hàm khoảng cách bình phương là (x-\mathrm{proj}_{\mathcal{K}}(x))
- Vì hàm khoảng cách (\mathrm{dist}_{\mathcal{K}}) không khả vi ở mọi nơi, bài viết thay
minbằng softmin để đưa vào hàm khoảng cách bình phương đã được làm mượt theo (\sigma) - Gradient của hàm khoảng cách đã làm mượt hướng về phía trung bình có trọng số của các điểm trong (\mathcal{K}), với trọng số do (x) quyết định
Denoiser lý tưởng và mô hình sai số tương đối
- Denoiser lý tưởng (\epsilon^*) là denoiser tối thiểu hóa chính xác hàm mất mát huấn luyện tại một (\sigma) nhất định
- Nếu dữ liệu là phân phối đều rời rạc trên tập hữu hạn (\mathcal{K}), denoiser lý tưởng có thể được biểu diễn bằng công thức dạng đóng
- Trọng số của mỗi điểm dữ liệu được xác định theo khoảng cách giữa (x_\sigma) và điểm đó
- Với bộ dữ liệu nhỏ, có thể tính trực tiếp bằng
IdealDenoiser
- Trên dữ liệu toy, denoiser lý tưởng hướng về trung bình dữ liệu khi (\sigma) lớn, và hướng về điểm dữ liệu gần nhất khi (\sigma) nhỏ
- Định lý cốt lõi thiết lập quan hệ (\frac{1}{2}\nabla_x \mathrm{dist}^2_{\mathcal{K}}(x,\sigma)=\sigma\epsilon^*(x,\sigma)) với mọi (\sigma>0), (x\in\mathbb{R}^n)
- Mô hình sai số tương đối dùng điều kiện rằng (x-\sigma\epsilon_\theta(x,\sigma)) xấp xỉ tốt (\mathrm{proj}_{\mathcal{K}}(x))
- Áp dụng khi (\sqrt{n}\sigma) ước lượng tốt (\mathrm{dist}_{\mathcal{K}}(x)) trong một hệ số hằng số
- Giả định sai số bị chặn bởi (\eta\mathrm{dist}_{\mathcal{K}}(x))
- Ở mức nhiễu thấp, theo manifold hypothesis, phần lớn nhiễu bổ sung vuông góc với đa tạp dữ liệu nên denoising xấp xỉ phép chiếu
- Ở mức nhiễu cao, nếu (\sigma) lớn hơn đường kính của (\mathcal{K}), một denoiser dự đoán trung bình có trọng số của dữ liệu vẫn có sai số tương đối nhỏ
- CIFAR-10 có quy mô đủ nhỏ để tính denoiser lý tưởng, và trong thực nghiệm sai số tương đối giữa phép chiếu chính xác và đầu ra denoiser lý tưởng trên quỹ đạo lấy mẫu được ghi nhận là nhỏ
Lấy mẫu: denoising lặp và DDIM
- Khi đã có denoiser được huấn luyện, từ dữ liệu nhiễu (x_t) và mức nhiễu (\sigma_t) ta dự đoán (x_0) bằng (\hat{x}0^t=x_t-\sigma_t\epsilon\theta(x_t,\sigma_t))
- Điểm khởi đầu được chọn sao cho (\sigma_T) lớn hơn nhiều so với đường kính của (\mathcal{K}), rồi lấy mẫu độc lập (x_T) từ (N(0,\sigma_T)) để nó nằm xa (\mathcal{K})
- Ở mức nhiễu cao, một lần gọi denoiser có thể vẫn có sai số tuyệt đối lớn dù sai số tương đối nhỏ, và dự đoán của denoiser lý tưởng thường gần trung bình dữ liệu
- Vì vậy, quá trình lấy mẫu lặp lại việc gọi denoiser theo lịch (\sigma_t) để tạo chuỗi (x_T,\ldots,x_0)
- Cập nhật (x_{t-1}=x_t-(\sigma_t-\sigma_{t-1})\epsilon_\theta(x_t,\sigma_t)) tương đương với thuật toán lấy mẫu DDIM tất định sau một phép đổi tọa độ
- Chứng minh tính tương đương với DDIM nằm ở Appendix A của bài báo
DDIM dưới góc nhìn tối thiểu hóa khoảng cách
- DDIM có thể được diễn giải như gradient descent xấp xỉ trên (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2)
- Kích thước bước là (1-\sigma_{t-1}/\sigma_t)
- (\nabla f(x_t)) được ước lượng bằng (\epsilon_\theta(x_t,\sigma_t))
- Lịch (\sigma_t) quyết định số lượng và độ lớn của các bước gradient trong lúc lấy mẫu
- Nếu quá ít bước, (\mathrm{dist}_{\mathcal{K}}(x_t)) có thể không giảm nên không hội tụ
- Nếu dùng nhiều bước nhỏ, số lần đánh giá denoiser tăng lên và chi phí tính toán cao hơn
- admissible schedule là lịch mà ở mỗi vòng lặp, (\sqrt{n}\sigma_t) khớp với (\mathrm{dist}_{\mathcal{K}}(x_t)) trong một hệ số hằng số
- Chuỗi (\sigma_t) log-linear giảm theo cấp số nhân là một admissible schedule
- Theo định lý, nếu trên các (x_t) do DDIM sinh ra, (\nabla\mathrm{dist}{\mathcal{K}}(x)) tồn tại và (\mathrm{dist}{\mathcal{K}}(x_T)=\sqrt{n}\sigma_T), thì (x_t) được sinh bởi gradient descent trên hàm khoảng cách bình phương và duy trì (\mathrm{dist}_{\mathcal{K}}(x_t)/\sqrt{n}\approx\sigma_t)
- Trong ví dụ toy, bài viết cài đặt sampler DDIM 20 bước bằng cách lấy mẫu con từ lịch log-linear ban đầu; đa số mẫu sinh ra gần với dữ liệu gốc nhưng vẫn còn chỗ để cải thiện
Sampler cải tiến dựa trên ước lượng gradient
- Dựa trên việc (\nabla\mathrm{dist}{\mathcal{K}}(x)) là bất biến giữa (x) và (\mathrm{proj}{\mathcal{K}}(x)), tác giả dùng cập nhật trộn giữa ước lượng hiện tại và ước lượng trước đó
- Cập nhật (\bar{\epsilon}t=\gamma\epsilon\theta(x_t,\sigma_t)+(1-\gamma)\epsilon_\theta(x_{t+1},\sigma_{t+1})) là cách hiệu chỉnh sai số của bước trước bằng ước lượng hiện tại
- Trên các mẫu toy model, cách này hội tụ nhanh hơn DDIM và các mẫu gần dữ liệu gốc hơn
- So với DDIM, sampler này có thể được diễn giải như thêm momentum; quỹ đạo có thể overshoot nhưng cũng có thể hội tụ nhanh hơn
- Việc thêm nhiễu trong quá trình sinh mẫu cải thiện chất lượng lấy mẫu theo kinh nghiệm
- Để giữ nguyên lịch (\sigma_t) ban đầu, ta denoise đến (\sigma_{t'}) nhỏ hơn rồi thêm lại nhiễu (w_t\sim N(0,I))
- Khi (\mu=\frac{1}{2}), nó khôi phục chính xác DDPM sampler
- Cập nhật đầy đủ (x_{t-1}=x_t-(\sigma_t-\sigma_{t'})\bar{\epsilon}_t+\eta w_t) tổng quát hóa ba sampler
- DDIM:
gam=1, mu=0 - DDPM:
gam=1, mu=0.5 - Sampler ước lượng gradient:
gam=2, mu=0
- DDIM:
Mô hình lớn hơn và tài liệu tham khảo
- Mã huấn luyện ở trên không chỉ dùng cho dữ liệu toy mà còn có thể dùng để huấn luyện mô hình khuếch tán ảnh từ đầu
- Ví dụ FashionMNIST huấn luyện trên bộ dữ liệu FashionMNIST và được đưa ra như một ví dụ đạt điểm FID đứng thứ 2 trên bảng xếp hạng Papers with Code
- Mã lấy mẫu cũng có thể dùng nguyên trạng cho các mô hình latent diffusion đã huấn luyện sẵn
- Ví dụ dùng
ScheduleLDM(1000)vàModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base') - Điều kiện văn bản được đặt là
An astronaut riding a horse, lấy mẫu với 50 bước (\sigma) rồi giải mã latent
- Ví dụ dùng
- Tác động của hạng tử momentum (\gamma) được minh họa qua các hình so sánh trong bài toán sinh ảnh văn bản-thành-ảnh độ phân giải cao
- Một số tài liệu đáng xem thêm
- What are diffusion models: giới thiệu mô hình khuếch tán theo góc nhìn thời gian rời rạc đảo ngược Markov process
- Generative modeling by estimating gradients of the data distribution: giới thiệu mô hình khuếch tán theo góc nhìn thời gian liên tục đảo ngược phương trình vi phân ngẫu nhiên
- The annotated diffusion model: giải thích chi tiết cách triển khai mô hình khuếch tán bằng PyTorch
1 bình luận
Ý kiến trên Hacker News
Nếu có câu hỏi thì tôi có thể trả lời.
Tôi đặc biệt thích phần bàn về quỹ đạo, vì nó tạo động lực để hiểu những chỗ mà nhiều người gặp khó ở các chủ đề như scheduler. Dù không đầy đủ như bài của Song hay Lilian, nó dễ tiếp cận hơn nhiều nên tôi định giới thiệu cho người khác.
Nhân tiện, một người bạn của tôi từng viết một triển khai khuếch tán tối giản; xét theo góc nhìn DDPM thì nó “đầy đủ” hơn một chút và khá hữu ích: https://github.com/VSehwag/minimal-diffusion/
Với tư cách người từng thử nghiệm một chút quy trình lấy mẫu trong Stable Diffusion, tôi cũng muốn xem so sánh thời gian hội tụ và số bước so với DDIM. Tôi tò mò liệu có mối liên hệ nào giữa momentum, hội tụ và sai số không. Chẳng hạn, sẽ rất hay nếu có so sánh kiểu sampler có momentum 16 bước gần như tương đương DDIM 20 bước ± hạng sai số hay không.
get_sigma_embeds(batches, sigma)không dùng đầu vào thứ nhất. Tôi tự hỏi có phải ý định là broadcastsigmathành dạng(batches, 1)không.Bài này đi sâu hơn nhiều vào chi tiết toán học, đồng thời đi kèm một triển khai tối giản rất dễ hiểu dưới 500 dòng.
Sẽ rất hay nếu mở rộng sang phiên bản diffusion transformer đang vận hành Sora và các mô hình tạo video khác. Có lẽ có thể kết hợp bài này với https://jaykmody.com/blog/gpt-from-scratch/ để tạo một bài nhập môn “diffusion transformer từ con số không”.
Ngược lại, nếu thật sự muốn đào sâu, tôi khuyên đọc các công trình của Kingma, Gao, Ricky Tian Qi Chen và các học trò của Max Welling (Tomczak là postdoc, Hoogeboom, v.v.), cùng với người có công thầm lặng Aapo Hyvärinen. Một ví dụ thuộc nhánh tương đối nhẹ nhàng hơn của Kingma & Gao, đồng thời cũng liên quan đến bài báo SD3, ở đây: https://arxiv.org/abs/2303.00848
Điểm đáng tiếc là khả năng tiếp cận thấp hơn vì phụ thuộc nhiều vào việc biết và hiểu các nghiên cứu trước đó, nhưng cũng khó gọi đây là một phê bình có ý nghĩa. Vì đó là nghiên cứu, không phải tài liệu giáo dục cho đại chúng.
n_embd; bản thân quá trình khuếch tán có thể giữ nguyên.[1] https://yang-song.net/blog/2021/score/
[2] https://lilianweng.github.io/posts/2021-07-11-diffusion-mode...
Theo góc nhìn của chúng tôi, lý do mô hình khuếch tán dễ huấn luyện là vì chúng dùng mục tiêu huấn luyện dự đoán gradient của hàm khoảng cách đã được làm trơn, thay vì dự đoán gradient của hàm khoảng cách chính xác. Lấy mẫu bằng mô hình khuếch tán giống như thực hiện nhiều bước gradient xấp xỉ.
Để hiểu sâu hơn về mô hình khuếch tán, tôi khuyên nên đọc tất cả các bài blog này và học các cách diễn giải khác nhau.
Tuy nhiên, cách tiếp cận trong bài này có vẻ cho phép các thí nghiệm thú vị hơn, chẳng hạn phân tích sai số của bộ khử nhiễu.
[1] https://arxiv.org/pdf/2305.03486.pdf
Ví dụ, vì sao trình tạo ảnh lại khó tạo phím đàn piano? Có vẻ để tạo được cấu trúc các phím đen xen kẽ theo nhóm hai và ba, nó cần biểu diễn tốt hơn các ràng buộc khoảng cách trung gian.