Tái hiện Stable Diffusion 3.5 từ đầu bằng PyTorch thuần túy
(github.com/yousef-rafat)- Dự án miniDiffusion là một mã nguồn mở tái hiện mô hình Stable Diffusion 3.5 từ đầu chỉ bằng PyTorch
- Cấu trúc của dự án này nổi bật ở chỗ tập trung vào mục đích giáo dục và phục vụ thử nghiệm, hack
- Toàn bộ codebase chỉ khoảng 2800 dòng, được cấu thành bằng lượng mã tối thiểu, từ VAE đến DiT, cùng các script huấn luyện và dataset
- Các thành phần chính gồm VAE, bộ mã hóa văn bản CLIP, T5, transformer khuếch tán đa phương thức, joint attention
- Hiện vẫn bao gồm các tính năng thử nghiệm và cần thêm nhiều kiểm thử
Giới thiệu dự án miniDiffusion
miniDiffusion là một dự án mã nguồn mở tái hiện các chức năng cốt lõi của Stable Diffusion 3.5 chỉ bằng PyTorch
So với Stable Diffusion 3.5 hiện có, dự án này có những ưu điểm sau
- Codebase chỉ khoảng 2.800 dòng, rất nhỏ gọn nên rất phù hợp để trực tiếp phân tích cấu trúc và học tập
- Có thể được tận dụng hữu ích cho nhiều thử nghiệm machine learning và hack mô hình
- Có rất ít phụ thuộc và chỉ sử dụng số lượng thư viện tối thiểu
Cấu trúc cốt lõi và các tệp thành phần
- dit.py : phần hiện thực mô hình Stable Diffusion chính
- dit_components.py : các thành phần như embedding, chuẩn hóa, patch embedding và hàm phụ trợ cho DiT
- attention.py : phần hiện thực thuật toán Joint Attention
- noise.py : bao gồm bộ lập lịch Euler ODE cho Rectified Flow
- t5_encoder.py, clip.py : hiện thực bộ mã hóa văn bản T5 và CLIP
- tokenizer.py : hiện thực tokenizer Byte-Pair và Unigram
- metrics.py : hiện thực chỉ số đánh giá FID (Fréchet inception distance)
- common.py : cung cấp các hàm phụ trợ cần cho huấn luyện
- common_ds.py : hiện thực dataset iterable chuyển đổi hình ảnh thành dữ liệu huấn luyện cho DiT
- thư mục model : lưu checkpoint mô hình và log sau huấn luyện
- thư mục encoders : lưu checkpoint của các mô-đun riêng như VAE, CLIP
⚠️ Tính năng thử nghiệm và nhu cầu kiểm thử miniDiffusion hiện vẫn bao gồm các tính năng thử nghiệm và cần thêm nhiều kiểm thử
Cấu hình chi tiết theo từng tính năng chính
Core Image Generation Modules
- Hiện thực VAE, bộ mã hóa văn bản CLIP, T5
- Hiện thực tokenizer Byte-Pair, Unigram
SD3 Components
- Multi-Modal Diffusion Transformer Model
- Hiện thực Flow-Matching Euler Scheduler
- Logit-Normal Sampling
- Áp dụng thuật toán Joint Attention
Script huấn luyện và suy luận mô hình
- Cung cấp script huấn luyện và suy luận cho SD3 (Stable Diffusion 3.5)
Giấy phép
- Được phát hành theo giấy phép MIT và được tạo ra cho mục đích giáo dục và thử nghiệm
Ý nghĩa và ưu điểm của dự án mã nguồn mở này
- Có thể trực tiếp huấn luyện và hack cấu trúc mô hình tạo ảnh hiện đại ở cấp độ Stable Diffusion 3.5 chỉ bằng PyTorch thuần túy
- Mã nguồn ngắn gọn và độc lập, được tối ưu cho phân tích cấu trúc / tinh chỉnh mô hình / nghiên cứu thuật toán mới
- Có thể trực tiếp thực hành các kỹ thuật đa phương thức, transformer, attention hiện đại
- Cung cấp nền tảng để thử nghiệm an toàn, tách biệt với các dự án thương mại
1 bình luận
Ý kiến trên Hacker News
Bản triển khai tham chiếu của Flux thực sự có cấu trúc rất tối giản, nên nếu ai quan tâm thì rất đáng xem thử
GitHub Flux
Dự án minRF có ưu điểm là dùng rectified flow nên khá dễ để bắt đầu khi huấn luyện các mô hình diffusion nhỏ
GitHub minRF
Bản triển khai tham chiếu của Stable Diffusion 3.5 cũng được viết khá gọn gàng, phù hợp để tham khảo
GitHub SD 3.5
Các bản triển khai tham chiếu thường không được bảo trì tốt và có khá nhiều lỗi
Tôi tự hỏi liệu có phải dự án miniDiffusion đang dùng mô hình Stable Diffusion 3.5 hay không
Đoạn mã liên quan
Bộ dữ liệu huấn luyện rất nhỏ và chỉ gồm ảnh liên quan đến thời trang
Bộ dữ liệu thời trang
Bộ dữ liệu đó dùng để thực hành fine-tune mô hình diffusion
Tôi tò mò liệu dùng PyTorch thuần có mang lại lợi thế hiệu năng trên GPU không phải của NVIDIA hay không, hay PyTorch đã được tối ưu cho CUDA quá mạnh đến mức các hãng GPU khác không thể cạnh tranh
PyTorch chạy khá ổn trên Apple Silicon
Trên các thiết bị không phải NVIDIA như AMD, vẫn có thể chạy workload ML qua Vulkan
Hỗ trợ ROCm của PyTorch tiến triển rất chậm, và kể cả khi chạy được thì tốc độ cũng chậm
PyTorch đúng là có chạy được trên ROCm, nhưng tôi không rõ nó có thực sự tốt đến mức hoàn toàn "ngang hàng" hay không
Trong mã PyTorch, thay vì
có đề xuất rằng nên thử
như vậy
Làm như vậy thì thay vì các tham số của q, k, v ban đầu được nối độc lập, các tham số giữa q, k, v sẽ được nối với nhau
Có vẻ đây là tài liệu tốt cho người học
Tôi tự hỏi liệu có tutorial hoặc tài liệu giải thích nào mà người mới bắt đầu cũng có thể làm theo không
Có khóa học của fast.ai hướng dẫn tự triển khai Stable Diffusion
Tôi thắc mắc liệu điều này có nghĩa là có thể dùng Stable Diffusion mà không bị ràng buộc bởi giấy phép hay không
Thành thật mà nói thì hơi ngại, nhưng tôi muốn biết chúng ta thực sự có thêm được điều gì mới trước và sau khi kho mã này xuất hiện
Cá nhân tôi vẫn tránh tự làm mô hình và chủ yếu chỉ đứng ngoài quan sát kết quả
Tôi vốn mơ hồ cho rằng từ trước đã có sẵn script suy luận/huấn luyện dựa trên PyTorch được công khai rồi
Ít nhất tôi nghĩ script suy luận sẽ đi kèm khi phát hành mô hình, và chắc cũng sẽ có script fine-tune/huấn luyện
Tôi không chắc dự án này là kiểu "clean room" hay "dirty room" viết lại từ cái có sẵn, hay là vì ngay cả mã PyTorch hiện có cũng quá rối do dính CUDA/C nên một bản PyTorch thuần mới thực sự có ý nghĩa lớn
Dù sao tôi cũng không rõ, nên sẽ rất tốt nếu ai đó có thể giải thích
Giá trị cốt lõi của dự án này là một bản triển khai "ít phụ thuộc nhất có thể"
Stability AI phát hành các mô hình Stable Diffusion theo Stability AI Community License, nên không phải là "hoàn toàn tự do" như MIT
Khi nghĩ về SD 3.5, hoặc bất kỳ phiên bản nào, tôi xem phần cốt lõi nằm ở các trọng số được tạo ra trong quá trình huấn luyện
Tôi tò mò về mức độ có thể dùng trong thực tế của mã nguồn học thuật gốc do nhóm CompViz của Ludwig Maximilian University công bố
Tôi muốn biết liệu phần triển khai diffusion transformer (DiT) ở đây có thực sự triển khai đúng cross-token attention như bản SD 3.5 đầy đủ hay không, hay đã được đơn giản hóa để mã dễ đọc hơn