Felafax BlogTune Llama3 405B on AMD MI300x (hành trình của chúng tôi)
Giới thiệu
- Khi các mô hình mã nguồn mở ngày càng lớn hơn, nhu cầu về hạ tầng mạnh mẽ để xử lý huấn luyện AI quy mô lớn cũng tăng lên
- Felafax đã tinh chỉnh mô hình LLaMA 3.1 405B trên GPU AMD để chứng minh hiệu quả của phần cứng AMD
- Toàn bộ công việc đã được công khai mã nguồn mở trên GitHub
- GPU AMD MI300X mang lại hiệu năng cao so với phần cứng AI của NVIDIA
- Dự án có thể thực hiện được nhờ sự hỗ trợ của TensorWave
JAX là gì và tại sao lại chọn nó
- JAX là một thư viện machine learning mạnh mẽ kết hợp API tương tự NumPy, tính vi phân tự động và trình biên dịch XLA của Google
- Nó cung cấp API xuất sắc cho xử lý song song mô hình, rất phù hợp cho huấn luyện mô hình quy mô lớn
Ưu điểm của JAX
- Hàm thuần: JAX khuyến khích viết các hàm thuần, giúp mã dễ cấu trúc, gỡ lỗi và đọc hơn
- Xử lý song song nâng cao: API JIT linh hoạt của JAX hỗ trợ xử lý song song dữ liệu và mô hình ở mức nâng cao, điều thiết yếu cho huấn luyện quy mô lớn
- Codebase gọn gàng: Triết lý thiết kế của JAX khuyến khích viết mã có thể di chuyển giữa các nền tảng phần cứng
Vì sao JAX vượt trội trên phần cứng không phải NVIDIA
- Cách tiếp cận độc lập phần cứng: JAX tận dụng trình biên dịch XLA để biên dịch phép tính sang biểu diễn trung gian độc lập phần cứng
- Tối ưu hóa độc lập nền tảng: Trình biên dịch XLA thực hiện tối ưu hóa mà không phụ thuộc vào phần cứng
- Khả năng di chuyển dễ dàng: Khi dùng JAX, việc chuyển từ NVIDIA sang AMD chỉ cần thay đổi mã ở mức tối thiểu
Thiết lập JAX trên GPU AMD
- Kéo Docker image, khởi động container rồi xác minh cài đặt
- Huấn luyện mô hình LLaMA 405B bằng 8 GPU AMD MI300x
Huấn luyện LLaMA 405B: hiệu năng và khả năng mở rộng
- Huấn luyện mô hình LLaMA 405B trên GPU AMD bằng JAX
- Thông qua tinh chỉnh LoRA, trọng số mô hình và tham số LoRA được điều chỉnh với độ chính xác bfloat16
- Kích thước mô hình: chiếm khoảng 800GB VRAM
- Trọng số LoRA và trạng thái optimizer: chiếm khoảng 400GB VRAM
- Tổng mức dùng VRAM: khoảng 1200GB
- Tốc độ huấn luyện: khoảng 35 token mỗi giây
- Hiệu quả bộ nhớ: duy trì khoảng 70%
- Khả năng mở rộng: với JAX, hệ thống mở rộng gần như tuyến tính trên 8 GPU
Cấu hình huấn luyện của chúng tôi
- Chuyển đổi LLaMA 3.1 từ PyTorch sang JAX
- Phân phối hiệu quả thông qua nạp mô hình và sharding tham số
Sharding tham số trong JAX
- Dùng tính năng device mesh của JAX để phân phối mô hình hiệu quả trên 8 GPU AMD
- Định nghĩa quy tắc sharding tham số để shard các chiều của từng tensor theo các trục của mesh
Triển khai huấn luyện LoRA
- LoRA giảm số lượng tham số có thể huấn luyện bằng cách phân rã cập nhật trọng số thành các ma trận hạng thấp
- Triển khai lớp
LoRADense để bao gồm các tham số LoRA
- Phân phối tham số LoRA một cách hiệu quả để tối ưu mức dùng bộ nhớ và hiệu quả tính toán
Kết luận
- Trải nghiệm tinh chỉnh mô hình LLaMA 3.1 405B bằng GPU AMD và JAX là rất tích cực
- Mô hình được phân phối hiệu quả nhờ tận dụng khả năng xử lý song song mạnh mẽ và cách tiếp cận độc lập phần cứng của JAX
- Điều này chứng minh GPU AMD là một lựa chọn thay thế mạnh mẽ cho huấn luyện AI quy mô lớn
- Có thể xem toàn bộ mã nguồn trong kho GitHub và tự chạy thử
Tóm tắt của GN⁺
- Bài viết này giải thích cách huấn luyện hiệu quả các mô hình AI quy mô lớn bằng GPU AMD và JAX
- Bài viết nhấn mạnh rằng phần cứng AMD là một lựa chọn thay thế hiệu quả về chi phí so với NVIDIA
- Cách tiếp cận độc lập phần cứng của JAX giúp tăng tính di động của mã và đơn giản hóa bảo trì
- Cung cấp thông tin hữu ích và mã thực hành cho những ai quan tâm đến huấn luyện mô hình quy mô lớn
- Các dự án có chức năng tương tự gồm CUDA và PyTorch của NVIDIA
1 bình luận
Ý kiến Hacker News
Chia sẻ kết quả tinh chỉnh mô hình Llama3.1 405B trên 8 GPU AMD MI300x bằng JAX
Đề xuất tìm cách vượt qua giới hạn bộ nhớ và chạy phiên bản được biên dịch JIT
Chia sẻ trải nghiệm về GPU AMD và hỗ trợ ROCm
Chia sẻ kinh nghiệm thử nghiệm ở khía cạnh suy luận của mô hình 405B
torch.cudakhông tệ đến vậyrocm:pytorchcũng dễ như dùng containerrocm:jaxCâu hỏi về việc thiếu dữ liệu hiệu năng
Thắc mắc vì sao Obsidian (ứng dụng ghi chú) lại làm việc này
Yêu cầu @dang thêm tên người dùng vào URL