2 điểm bởi GN⁺ 2024-03-12 | 1 bình luận | Chia sẻ qua WhatsApp
  • Mô hình khuếch tán được dùng không chỉ cho tạo ảnh mà còn cho các bài toán cần lấy mẫu từ phân phối đa mode như âm thanh, video, 3D, thiết kế protein và lập kế hoạch đường đi cho robot; bài hướng dẫn này kết nối huấn luyện và lấy mẫu từ góc nhìn tối ưu hóa
  • Quá trình huấn luyện tạo dữ liệu nhiễu (x_\sigma=x_0+\sigma\epsilon), rồi tối thiểu hóa sai số bình phương trung bình để mạng nơ-ron (\epsilon_\theta(x,\sigma)) dự đoán hướng nhiễu
  • Denoiser đã huấn luyện có thể được diễn giải như một phép chiếu xấp xỉ lên tập dữ liệu (\mathcal{K}), và denoiser lý tưởng liên hệ với gradient của hàm khoảng cách bình phương đã được làm mượt theo (\sigma)
  • Lấy mẫu DDIM có thể được xem là gradient descent xấp xỉ trên (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2), còn lịch (\sigma_t) quyết định số vòng lặp và chi phí đánh giá denoiser
  • Khi kết hợp cập nhật ước lượng gradient với việc thêm nhiễu, có thể mô tả chung DDIM, DDPM và sampler cải tiến của tác giả bằng các tham số gammu, rồi mở rộng sang ví dụ toy model và latent diffusion

Mô hình khuếch tán dưới góc nhìn tối ưu hóa

  • Mô hình khuếch tán có thế mạnh trong việc sinh mẫu từ phân phối đa mode, và được áp dụng không chỉ cho công cụ tạo văn bản-thành-ảnh như Stable Diffusion mà còn cho âm thanh, video, tạo 3D, thiết kế protein và lập kế hoạch đường đi cho robot
  • Nền tảng lý thuyết của bài hướng dẫn là diễn giải theo tối ưu hóa từ bài báo ICML 2024bài báo liên quan
  • Phần triển khai chủ yếu tham chiếu smalldiffusion, còn mã trong bài được đơn giản hóa hơn thư viện gốc để phục vụ mục đích học tập

Huấn luyện: dự đoán hướng nhiễu

  • Mô hình khuếch tán học tập dữ liệu (\mathcal{K}) từ các ví dụ huấn luyện và hướng tới việc sinh mẫu từ tập đó
    • Với ảnh, (\mathcal{K} \subset \mathbb{R}^{c\times h \times w}) là tập các giá trị pixel tương ứng với những ảnh thực tế
    • Cùng khuôn khổ này cũng áp dụng cho các miền rời rạc như âm thanh, video, quỹ đạo robot và văn bản
  • Quy trình huấn luyện có thể xem theo ba bước
    • Lấy mẫu (x_0 \sim \mathcal{K}), (\sigma), và (\epsilon \sim N(0,I))
    • Tạo dữ liệu có trộn nhiễu bằng (x_\sigma=x_0+\sigma\epsilon)
    • Tối thiểu hóa hàm mất mát bình phương để (\epsilon_\theta(x_\sigma,\sigma)) dự đoán (\epsilon)
  • Trong mã, training_loop với mỗi batch x0 sẽ tạo sigmaeps qua generate_train_sample, rồi tối ưu MSE giữa đầu ra của model(x0 + sigma * eps, sigma)eps
  • Thay vì lấy mẫu (\sigma) đều trên một khoảng liên tục, người ta lấy từ lịch (\sigma) đã rời rạc hóa thành (N) giá trị
    • Lớp Schedule bao bọc danh sách sigmas khả dụng và lấy mẫu giá trị theo từng batch trong lúc huấn luyện
    • Ví dụ trong bài dùng ScheduleLogLinear(N, sigma_min=0.02, sigma_max=10)
    • ScheduleDDPM dành cho mô hình khuếch tán trong không gian pixel, còn ScheduleLDM dành cho latent diffusion như Stable Diffusion

Ví dụ toy Swissroll

  • Bộ dữ liệu toy là tập điểm xoắn ốc dùng trong một trong những bài báo khuếch tán đầu tiên, Sohl-Dickstein et al. 2015, với (\mathcal{K}\subset\mathbb{R}^2)
  • Với bộ dữ liệu đơn giản này, denoiser được cài đặt bằng MLP
    • Đầu vào là phép nối giữa (x\in\mathbb{R}^2) và embedding 2 chiều của (\sigma)
    • Đầu ra là giá trị dự đoán nhiễu (\epsilon\in\mathbb{R}^2)
    • Nhiều mô hình khuếch tán dùng sinusoidal positional embedding cho (\sigma), nhưng ví dụ này cho thấy embedding 2 chiều đơn giản cũng hoạt động tốt
  • Thiết lập huấn luyện của ví dụ dùng ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)epochs=15000
  • Denoiser đã huấn luyện có thể được trực quan hóa thành trường vector bằng cách vẽ (x-\sigma\epsilon_\theta(x,\sigma))
    • Khi (\sigma) lớn, denoiser có xu hướng dự đoán trung bình dữ liệu
    • Khi (\sigma) nhỏ và đầu vào (x) gần dữ liệu, nó dự đoán các điểm dữ liệu thực tế

Diễn giải Denoising như phép chiếu

  • Hàm khoảng cách tới tập dữ liệu (\mathcal{K}) được định nghĩa là (\mathrm{dist}_{\mathcal{K}}(x)=\min{|x-x_0|:x_0\in\mathcal{K}})
  • Phép chiếu của (x) lên (\mathcal{K}), (\mathrm{proj}_{\mathcal{K}}(x)), là tập các điểm trong (\mathcal{K}) đạt được khoảng cách đó
  • Nếu (\mathcal{K}) là tập đóng, (x\notin\mathcal{K}) và phép chiếu là duy nhất, thì gradient của hàm khoảng cách bình phương là (x-\mathrm{proj}_{\mathcal{K}}(x))
  • Vì hàm khoảng cách (\mathrm{dist}_{\mathcal{K}}) không khả vi ở mọi nơi, bài viết thay min bằng softmin để đưa vào hàm khoảng cách bình phương đã được làm mượt theo (\sigma)
  • Gradient của hàm khoảng cách đã làm mượt hướng về phía trung bình có trọng số của các điểm trong (\mathcal{K}), với trọng số do (x) quyết định

Denoiser lý tưởng và mô hình sai số tương đối

  • Denoiser lý tưởng (\epsilon^*) là denoiser tối thiểu hóa chính xác hàm mất mát huấn luyện tại một (\sigma) nhất định
  • Nếu dữ liệu là phân phối đều rời rạc trên tập hữu hạn (\mathcal{K}), denoiser lý tưởng có thể được biểu diễn bằng công thức dạng đóng
    • Trọng số của mỗi điểm dữ liệu được xác định theo khoảng cách giữa (x_\sigma) và điểm đó
    • Với bộ dữ liệu nhỏ, có thể tính trực tiếp bằng IdealDenoiser
  • Trên dữ liệu toy, denoiser lý tưởng hướng về trung bình dữ liệu khi (\sigma) lớn, và hướng về điểm dữ liệu gần nhất khi (\sigma) nhỏ
  • Định lý cốt lõi thiết lập quan hệ (\frac{1}{2}\nabla_x \mathrm{dist}^2_{\mathcal{K}}(x,\sigma)=\sigma\epsilon^*(x,\sigma)) với mọi (\sigma>0), (x\in\mathbb{R}^n)
  • Mô hình sai số tương đối dùng điều kiện rằng (x-\sigma\epsilon_\theta(x,\sigma)) xấp xỉ tốt (\mathrm{proj}_{\mathcal{K}}(x))
    • Áp dụng khi (\sqrt{n}\sigma) ước lượng tốt (\mathrm{dist}_{\mathcal{K}}(x)) trong một hệ số hằng số
    • Giả định sai số bị chặn bởi (\eta\mathrm{dist}_{\mathcal{K}}(x))
    • Ở mức nhiễu thấp, theo manifold hypothesis, phần lớn nhiễu bổ sung vuông góc với đa tạp dữ liệu nên denoising xấp xỉ phép chiếu
    • Ở mức nhiễu cao, nếu (\sigma) lớn hơn đường kính của (\mathcal{K}), một denoiser dự đoán trung bình có trọng số của dữ liệu vẫn có sai số tương đối nhỏ
  • CIFAR-10 có quy mô đủ nhỏ để tính denoiser lý tưởng, và trong thực nghiệm sai số tương đối giữa phép chiếu chính xác và đầu ra denoiser lý tưởng trên quỹ đạo lấy mẫu được ghi nhận là nhỏ

Lấy mẫu: denoising lặp và DDIM

  • Khi đã có denoiser được huấn luyện, từ dữ liệu nhiễu (x_t) và mức nhiễu (\sigma_t) ta dự đoán (x_0) bằng (\hat{x}0^t=x_t-\sigma_t\epsilon\theta(x_t,\sigma_t))
  • Điểm khởi đầu được chọn sao cho (\sigma_T) lớn hơn nhiều so với đường kính của (\mathcal{K}), rồi lấy mẫu độc lập (x_T) từ (N(0,\sigma_T)) để nó nằm xa (\mathcal{K})
  • Ở mức nhiễu cao, một lần gọi denoiser có thể vẫn có sai số tuyệt đối lớn dù sai số tương đối nhỏ, và dự đoán của denoiser lý tưởng thường gần trung bình dữ liệu
  • Vì vậy, quá trình lấy mẫu lặp lại việc gọi denoiser theo lịch (\sigma_t) để tạo chuỗi (x_T,\ldots,x_0)
  • Cập nhật (x_{t-1}=x_t-(\sigma_t-\sigma_{t-1})\epsilon_\theta(x_t,\sigma_t)) tương đương với thuật toán lấy mẫu DDIM tất định sau một phép đổi tọa độ
    • Chứng minh tính tương đương với DDIM nằm ở Appendix A của bài báo

DDIM dưới góc nhìn tối thiểu hóa khoảng cách

  • DDIM có thể được diễn giải như gradient descent xấp xỉ trên (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2)
    • Kích thước bước là (1-\sigma_{t-1}/\sigma_t)
    • (\nabla f(x_t)) được ước lượng bằng (\epsilon_\theta(x_t,\sigma_t))
  • Lịch (\sigma_t) quyết định số lượng và độ lớn của các bước gradient trong lúc lấy mẫu
    • Nếu quá ít bước, (\mathrm{dist}_{\mathcal{K}}(x_t)) có thể không giảm nên không hội tụ
    • Nếu dùng nhiều bước nhỏ, số lần đánh giá denoiser tăng lên và chi phí tính toán cao hơn
  • admissible schedule là lịch mà ở mỗi vòng lặp, (\sqrt{n}\sigma_t) khớp với (\mathrm{dist}_{\mathcal{K}}(x_t)) trong một hệ số hằng số
    • Chuỗi (\sigma_t) log-linear giảm theo cấp số nhân là một admissible schedule
  • Theo định lý, nếu trên các (x_t) do DDIM sinh ra, (\nabla\mathrm{dist}{\mathcal{K}}(x)) tồn tại và (\mathrm{dist}{\mathcal{K}}(x_T)=\sqrt{n}\sigma_T), thì (x_t) được sinh bởi gradient descent trên hàm khoảng cách bình phương và duy trì (\mathrm{dist}_{\mathcal{K}}(x_t)/\sqrt{n}\approx\sigma_t)
  • Trong ví dụ toy, bài viết cài đặt sampler DDIM 20 bước bằng cách lấy mẫu con từ lịch log-linear ban đầu; đa số mẫu sinh ra gần với dữ liệu gốc nhưng vẫn còn chỗ để cải thiện

Sampler cải tiến dựa trên ước lượng gradient

  • Dựa trên việc (\nabla\mathrm{dist}{\mathcal{K}}(x)) là bất biến giữa (x) và (\mathrm{proj}{\mathcal{K}}(x)), tác giả dùng cập nhật trộn giữa ước lượng hiện tại và ước lượng trước đó
  • Cập nhật (\bar{\epsilon}t=\gamma\epsilon\theta(x_t,\sigma_t)+(1-\gamma)\epsilon_\theta(x_{t+1},\sigma_{t+1})) là cách hiệu chỉnh sai số của bước trước bằng ước lượng hiện tại
  • Trên các mẫu toy model, cách này hội tụ nhanh hơn DDIM và các mẫu gần dữ liệu gốc hơn
  • So với DDIM, sampler này có thể được diễn giải như thêm momentum; quỹ đạo có thể overshoot nhưng cũng có thể hội tụ nhanh hơn
  • Việc thêm nhiễu trong quá trình sinh mẫu cải thiện chất lượng lấy mẫu theo kinh nghiệm
    • Để giữ nguyên lịch (\sigma_t) ban đầu, ta denoise đến (\sigma_{t'}) nhỏ hơn rồi thêm lại nhiễu (w_t\sim N(0,I))
    • Khi (\mu=\frac{1}{2}), nó khôi phục chính xác DDPM sampler
  • Cập nhật đầy đủ (x_{t-1}=x_t-(\sigma_t-\sigma_{t'})\bar{\epsilon}_t+\eta w_t) tổng quát hóa ba sampler
    • DDIM: gam=1, mu=0
    • DDPM: gam=1, mu=0.5
    • Sampler ước lượng gradient: gam=2, mu=0

Mô hình lớn hơn và tài liệu tham khảo

  • Mã huấn luyện ở trên không chỉ dùng cho dữ liệu toy mà còn có thể dùng để huấn luyện mô hình khuếch tán ảnh từ đầu
  • Ví dụ FashionMNIST huấn luyện trên bộ dữ liệu FashionMNIST và được đưa ra như một ví dụ đạt điểm FID đứng thứ 2 trên bảng xếp hạng Papers with Code
  • Mã lấy mẫu cũng có thể dùng nguyên trạng cho các mô hình latent diffusion đã huấn luyện sẵn
    • Ví dụ dùng ScheduleLDM(1000)ModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base')
    • Điều kiện văn bản được đặt là An astronaut riding a horse, lấy mẫu với 50 bước (\sigma) rồi giải mã latent
  • Tác động của hạng tử momentum (\gamma) được minh họa qua các hình so sánh trong bài toán sinh ảnh văn bản-thành-ảnh độ phân giải cao
  • Một số tài liệu đáng xem thêm

1 bình luận

 
GN⁺ 2024-03-12
Ý kiến trên Hacker News
  • Tôi là tác giả. Trong lúc cố gắng hiểu mô hình khuếch tán, tôi nhận ra có thể đơn giản hóa đáng kể cả mã lẫn toán học, nên đã viết bài blog này và tạo thư viện khuếch tán này.
    Nếu có câu hỏi thì tôi có thể trả lời.
    • Từ góc nhìn của một nhà nghiên cứu, tôi không thích nhiều bài blog về mô hình khuếch tán, nhưng bài này thật sự rất hay. Nó đi thẳng vào cốt lõi mà vẫn chỉ ra những phần phức tạp thường gặp, không khiến người đọc lạc hướng hay bị phân tán.
      Tôi đặc biệt thích phần bàn về quỹ đạo, vì nó tạo động lực để hiểu những chỗ mà nhiều người gặp khó ở các chủ đề như scheduler. Dù không đầy đủ như bài của Song hay Lilian, nó dễ tiếp cận hơn nhiều nên tôi định giới thiệu cho người khác.
      Nhân tiện, một người bạn của tôi từng viết một triển khai khuếch tán tối giản; xét theo góc nhìn DDPM thì nó “đầy đủ” hơn một chút và khá hữu ích: https://github.com/VSehwag/minimal-diffusion/
    • Ở ảnh ví dụ cuối, hạng mục momentum có vẻ gây tác động bất lợi cho bức vẽ kỹ thuật số ngôi nhà. Trong ảnh gamma = 2.0, cánh cửa biến mất, nên để hiểu trực quan hiệu quả của DDIM sampler dùng thông tin gradient, tôi muốn biết thêm chi tiết về ví dụ đó.
      Với tư cách người từng thử nghiệm một chút quy trình lấy mẫu trong Stable Diffusion, tôi cũng muốn xem so sánh thời gian hội tụ và số bước so với DDIM. Tôi tò mò liệu có mối liên hệ nào giữa momentum, hội tụ và sai số không. Chẳng hạn, sẽ rất hay nếu có so sánh kiểu sampler có momentum 16 bước gần như tương đương DDIM 20 bước ± hạng sai số hay không.
    • Có vẻ get_sigma_embeds(batches, sigma) không dùng đầu vào thứ nhất. Tôi tự hỏi có phải ý định là broadcast sigma thành dạng (batches, 1) không.
    • Tôi tò mò liệu một số khái niệm này có xuất phát từ nguyên lý vật lý không. Có giống kiểu nói mạng nơ-ron được phỏng theo mạng nơ-ron sinh học không, và liệu có insight nào từ góc nhìn đó không.
  • Một bài hay khác cũng có tiêu đề Diffusion Models From Scratch: https://www.tonyduan.com/diffusion/index.html
    Bài này đi sâu hơn nhiều vào chi tiết toán học, đồng thời đi kèm một triển khai tối giản rất dễ hiểu dưới 500 dòng.
  • Thật tốt vì có mã. Các bài báo về khuếch tán nổi tiếng là có nhiều phương trình (https://twitter.com/cto_junior/status/1766518604395155830), nhưng với phần còn lại trong chúng ta, mã dễ đọc hơn nhiều và có thể còn chính xác hơn. Tôi cho rằng mọi bài báo lý thuyết đều nên đi kèm mã triển khai tham chiếu.
    Sẽ rất hay nếu mở rộng sang phiên bản diffusion transformer đang vận hành Sora và các mô hình tạo video khác. Có lẽ có thể kết hợp bài này với https://jaykmody.com/blog/gpt-from-scratch/ để tạo một bài nhập môn “diffusion transformer từ con số không”.
    • Dù các bài báo về khuếch tán đúng là nổi tiếng nhiều phương trình, nhưng thành thật mà nói, hầu hết các nhà nghiên cứu khuếch tán mà tôi biết cũng phản ứng y như vậy. Nhiều người lặp lại cùng những phương trình, và tôi nghĩ các phương trình đó thực chất gần với mục đích ôn lại hơn.
      Ngược lại, nếu thật sự muốn đào sâu, tôi khuyên đọc các công trình của Kingma, Gao, Ricky Tian Qi Chen và các học trò của Max Welling (Tomczak là postdoc, Hoogeboom, v.v.), cùng với người có công thầm lặng Aapo Hyvärinen. Một ví dụ thuộc nhánh tương đối nhẹ nhàng hơn của Kingma & Gao, đồng thời cũng liên quan đến bài báo SD3, ở đây: https://arxiv.org/abs/2303.00848
      Điểm đáng tiếc là khả năng tiếp cận thấp hơn vì phụ thuộc nhiều vào việc biết và hiểu các nghiên cứu trước đó, nhưng cũng khó gọi đây là một phê bình có ý nghĩa. Vì đó là nghiên cứu, không phải tài liệu giáo dục cho đại chúng.
    • Chỉ cần thay U-net bằng transformer encoder là được. Bỏ embedding và chiếu các patch ảnh thành vector kích thước n_embd; bản thân quá trình khuếch tán có thể giữ nguyên.
  • Bài hay, nhưng tôi có cảm giác thiếu một tính chất quan trọng[1]: mô hình khuếch tán mô hình hóa hàm score (đạo hàm của log xác suất), cũng như việc lấy mẫu khuếch tán tương tự động lực học Langevin[2]. Tôi nghĩ các góc nhìn này giải thích tốt vì sao chúng dễ huấn luyện hơn GAN. Vì mục tiêu mô hình hóa dễ hơn.
    [1] https://yang-song.net/blog/2021/score/
    [2] https://lilianweng.github.io/posts/2021-07-11-diffusion-mode...
    • Đúng vậy. Các bài blog này đưa ra những cách diễn giải mô hình khuếch tán khác với góc nhìn “chiếu về dữ liệu” được mô tả trong bài. Có thể xem đó là nhiều cách diễn giải khác nhau cho cùng mục tiêu huấn luyện và quy trình lấy mẫu.
      Theo góc nhìn của chúng tôi, lý do mô hình khuếch tán dễ huấn luyện là vì chúng dùng mục tiêu huấn luyện dự đoán gradient của hàm khoảng cách đã được làm trơn, thay vì dự đoán gradient của hàm khoảng cách chính xác. Lấy mẫu bằng mô hình khuếch tán giống như thực hiện nhiều bước gradient xấp xỉ.
      Để hiểu sâu hơn về mô hình khuếch tán, tôi khuyên nên đọc tất cả các bài blog này và học các cách diễn giải khác nhau.
  • Rất thú vị. Tôi lập tức nghĩ đến Iterative alpha-(de)Blending[1]. Công trình này cũng cố xây dựng một mô hình khuếch tán đơn giản hơn về mặt khái niệm, và đi đến kết luận là có thể chính thức hóa nó như một quá trình chiếu lặp xấp xỉ.
    Tuy nhiên, cách tiếp cận trong bài này có vẻ cho phép các thí nghiệm thú vị hơn, chẳng hạn phân tích sai số của bộ khử nhiễu.
    [1] https://arxiv.org/pdf/2305.03486.pdf
  • Phần giải thích lý thuyết rất hay. Có vẻ đây là lời giải thích độc lập với dataset, nhưng tôi tò mò về các chi tiết cụ thể của tạo ảnh thực tế.
    Ví dụ, vì sao trình tạo ảnh lại khó tạo phím đàn piano? Có vẻ để tạo được cấu trúc các phím đen xen kẽ theo nhóm hai và ba, nó cần biểu diễn tốt hơn các ràng buộc khoảng cách trung gian.
    • Cái này giống vấn đề ngón tay. Mỗi lần phải làm đúng cả số lượng, kích thước, góc, vị trí, v.v.; chỉ cần sai một thứ là mọi người nhận ra rất nhanh. Khác với những đối tượng như cành cây, nơi vị trí phân nhánh dù “sai” thì người ta cũng khó nhận ra.
  • Có phải một phần ý tưởng của khuếch tán là tăng dữ liệu huấn luyện lên rất nhiều không? Kiểu như có thể đối chiếu các ảnh đã được khuếch tán ngẫu nhiên với ảnh gốc chưa khuếch tán?
  • Mọi mô hình học máy đều là tích chập. Cứ chờ mà xem.
    • Hình như bạn đã đăng câu này vài lần rồi; có thể giải thích kỹ hơn không? Ví dụ, tôi thấy khó coi học tăng cường là tích chập.