Tinh chỉnh Llama 405B bằng GPU AMD
(publish.obsidian.md)Felafax BlogTune Llama3 405B on AMD MI300x (hành trình của chúng tôi)
Giới thiệu
- Khi các mô hình mã nguồn mở ngày càng lớn hơn, nhu cầu về hạ tầng mạnh mẽ để xử lý huấn luyện AI quy mô lớn cũng tăng lên
- Felafax đã tinh chỉnh mô hình LLaMA 3.1 405B trên GPU AMD để chứng minh hiệu quả của phần cứng AMD
- Toàn bộ công việc đã được công khai mã nguồn mở trên GitHub
- GPU AMD MI300X mang lại hiệu năng cao so với phần cứng AI của NVIDIA
- Dự án có thể thực hiện được nhờ sự hỗ trợ của TensorWave
JAX là gì và tại sao lại chọn nó
- JAX là một thư viện machine learning mạnh mẽ kết hợp API tương tự NumPy, tính vi phân tự động và trình biên dịch XLA của Google
- Nó cung cấp API xuất sắc cho xử lý song song mô hình, rất phù hợp cho huấn luyện mô hình quy mô lớn
Ưu điểm của JAX
- Hàm thuần: JAX khuyến khích viết các hàm thuần, giúp mã dễ cấu trúc, gỡ lỗi và đọc hơn
- Xử lý song song nâng cao: API JIT linh hoạt của JAX hỗ trợ xử lý song song dữ liệu và mô hình ở mức nâng cao, điều thiết yếu cho huấn luyện quy mô lớn
- Codebase gọn gàng: Triết lý thiết kế của JAX khuyến khích viết mã có thể di chuyển giữa các nền tảng phần cứng
Vì sao JAX vượt trội trên phần cứng không phải NVIDIA
- Cách tiếp cận độc lập phần cứng: JAX tận dụng trình biên dịch XLA để biên dịch phép tính sang biểu diễn trung gian độc lập phần cứng
- Tối ưu hóa độc lập nền tảng: Trình biên dịch XLA thực hiện tối ưu hóa mà không phụ thuộc vào phần cứng
- Khả năng di chuyển dễ dàng: Khi dùng JAX, việc chuyển từ NVIDIA sang AMD chỉ cần thay đổi mã ở mức tối thiểu
Thiết lập JAX trên GPU AMD
- Kéo Docker image, khởi động container rồi xác minh cài đặt
- Huấn luyện mô hình LLaMA 405B bằng 8 GPU AMD MI300x
Huấn luyện LLaMA 405B: hiệu năng và khả năng mở rộng
- Huấn luyện mô hình LLaMA 405B trên GPU AMD bằng JAX
- Thông qua tinh chỉnh LoRA, trọng số mô hình và tham số LoRA được điều chỉnh với độ chính xác bfloat16
- Kích thước mô hình: chiếm khoảng 800GB VRAM
- Trọng số LoRA và trạng thái optimizer: chiếm khoảng 400GB VRAM
- Tổng mức dùng VRAM: khoảng 1200GB
- Tốc độ huấn luyện: khoảng 35 token mỗi giây
- Hiệu quả bộ nhớ: duy trì khoảng 70%
- Khả năng mở rộng: với JAX, hệ thống mở rộng gần như tuyến tính trên 8 GPU
Cấu hình huấn luyện của chúng tôi
- Chuyển đổi LLaMA 3.1 từ PyTorch sang JAX
- Phân phối hiệu quả thông qua nạp mô hình và sharding tham số
Sharding tham số trong JAX
- Dùng tính năng device mesh của JAX để phân phối mô hình hiệu quả trên 8 GPU AMD
- Định nghĩa quy tắc sharding tham số để shard các chiều của từng tensor theo các trục của mesh
Triển khai huấn luyện LoRA
- LoRA giảm số lượng tham số có thể huấn luyện bằng cách phân rã cập nhật trọng số thành các ma trận hạng thấp
- Triển khai lớp
LoRADenseđể bao gồm các tham số LoRA - Phân phối tham số LoRA một cách hiệu quả để tối ưu mức dùng bộ nhớ và hiệu quả tính toán
Kết luận
- Trải nghiệm tinh chỉnh mô hình LLaMA 3.1 405B bằng GPU AMD và JAX là rất tích cực
- Mô hình được phân phối hiệu quả nhờ tận dụng khả năng xử lý song song mạnh mẽ và cách tiếp cận độc lập phần cứng của JAX
- Điều này chứng minh GPU AMD là một lựa chọn thay thế mạnh mẽ cho huấn luyện AI quy mô lớn
- Có thể xem toàn bộ mã nguồn trong kho GitHub và tự chạy thử
Tóm tắt của GN⁺
- Bài viết này giải thích cách huấn luyện hiệu quả các mô hình AI quy mô lớn bằng GPU AMD và JAX
- Bài viết nhấn mạnh rằng phần cứng AMD là một lựa chọn thay thế hiệu quả về chi phí so với NVIDIA
- Cách tiếp cận độc lập phần cứng của JAX giúp tăng tính di động của mã và đơn giản hóa bảo trì
- Cung cấp thông tin hữu ích và mã thực hành cho những ai quan tâm đến huấn luyện mô hình quy mô lớn
- Các dự án có chức năng tương tự gồm CUDA và PyTorch của NVIDIA
1 bình luận
Các ý kiến trên Hacker News
Gần đây đã tinh chỉnh mô hình llama3.1 405B trên 8x GPU AMD MI300x bằng JAX thay vì PyTorch
Nhờ API sharding nâng cao của JAX nên đạt hiệu năng tốt, và kỹ thuật sharding đã dùng được tổng hợp trong blog. Mã nguồn cũng đã được công khai: https://github.com/felafax/felafax
Chúng tôi là một startup nhỏ xây dựng hạ tầng AI cho việc tinh chỉnh và phục vụ LLM trên phần cứng không phải NVIDIA (TPU, AMD, Trainium)
Nhiều công ty đang cố chạy PyTorch trên GPU AMD, nhưng PyTorch gắn rất sâu với hệ sinh thái NVIDIA, như
torch.cudahayscaled_dot_product_attention, nên tôi nghĩ cần rất nhiều công sức để “thoát NVIDIA”Tôi cho rằng JAX phù hợp hơn với phần cứng không phải NVIDIA, vì mã mô hình được biên dịch thành đồ thị HLO độc lập phần cứng, sau đó trình biên dịch XLA tối ưu hóa rồi áp dụng các tối ưu hóa riêng cho từng phần cứng. Cùng một mã LLaMA3 bằng JAX đã chạy trên Google TPU và GPU AMD mà không cần chỉnh sửa
Chiến lược của công ty là trước hết port mô hình sang JAX, rồi tận dụng framework JAX và kernel XLA để khai thác hiệu năng tối đa trên các backend không phải NVIDIA. Vì vậy chúng tôi đã chuyển Llama 3.1 từ PyTorch sang JAX trước, và cùng một mô hình JAX chạy tốt trên TPU lẫn GPU AMD
Cá nhân tôi dùng PyTorch chủ yếu vì mô hình gốc được tạo bằng PyTorch. Dù logic giữa các phiên bản mô hình khác nhau trông có vẻ giống nhau, ở quy mô dữ liệu khổng lồ, những sai số dấu phẩy động rất nhỏ có thể tích lũy và gây drift mô hình
Việc debug các sai lệch độ chính xác như vậy trên mô hình lớn gần như còn khổ hơn vòng địa ngục thứ 10
hipblaslt, Composable Kernel FATôi không rành JAX lắm, nhưng tôi cho rằng một phần đáng kể lý do hiệu năng huấn luyện PyTorch trên MI300x tệ hại là do các thư viện ROCm dùng bên trong chậm
“Chạy được” ở đây không có nghĩa là mất 2 tuần vật lộn với driver rồi sau đó không bao giờ dám cập nhật server nữa
Tôi cũng tò mò về các vấn đề kỹ thuật đã gặp phải
Nói thẳng thì hiệu năng này khá tệ. Có lẽ là do chưa làm cho biên dịch hoạt động đúng
Mô hình 405B đạt 35 token/giây, tương đương khoảng 85 teraflops. 8 GPU MI300x ở mức khoảng 10,4 petaflops, nên MFU chỉ khoảng 0,8%
Con số này thấp hơn 40–50 lần so với hiệu năng huấn luyện decent là 30–40% MFU, nên có lẽ AMD sẽ hy vọng nút thắt nằm ở stack phần mềm
Trang GitHub nói rằng có thể tinh chỉnh LLaMa3.1 trên Google Cloud TPU với chi phí thấp hơn 30%, nhưng không nhắc đến hiệu năng
Công việc rất tuyệt. Khoảng một năm trước tôi có nghịch một chút với GPU AMD và hỗ trợ ROCm, và rõ ràng AMD vẫn còn một chặng đường dài để bắt kịp Nvidia
Cách tiếp cận chọn JAX khá thú vị; tôi tò mò việc rời xa PyTorch, vốn gần như là thư viện tiêu chuẩn cho machine learning, đã gặp những khó khăn gì
Ban đầu mục tiêu là tinh chỉnh LLaMA 3 trên TPU, nhưng PyTorch XLA khá thô, nên chúng tôi quyết định viết lại mô hình bằng JAX
Như đã nói ở trên, chúng tôi xem JAX là nền tảng tốt hơn cho GPU không phải NVIDIA, và muốn xây dựng hạ tầng cho GPU không phải NVIDIA trên JAX+openXLA
Làm tốt đấy. Cuối tuần trước tôi cũng đang nghịch phần suy luận của 405B [0]
Tôi không chắc
torch.cudatệ đến thế. PyTorch cho AMD chuyển đổi thay nó mà. Đây có vẻ là vấn đề tên gọi hơn là vấn đề bản chấtThực tế, kéo container
rocm:pytorchcũng dễ như kéo containerrocm:jaxKhông có nhiều số liệu được công bố, tôi tò mò MFU đạt bao nhiêu
[0] https://x.com/HotAisle/status/1837580046732874026
MFU cần phải tính. Chi tiết GPU và VRAM có trong repository: https://dub.sh/amd-405b-res
Cuối tuần tới tôi định thử chạy huấn luyện lại, đồng thời JIT compile toàn bộ bước huấn luyện, và lúc đó sẽ tính MFU
Khi chúng tôi đo ở ZML, MI300X nhanh hơn H100 30%. Những con chip này rất tuyệt
Tôi tò mò có nhà cung cấp cloud nào cho thuê host 8xAMD MI300 không
Công việc của tôi dùng AWS nhiều, nhưng tôi từng muốn thử GPU AMD
Dữ liệu hiệu năng ở đâu?
Do mã và giới hạn VRAM, chúng tôi không chạy được phiên bản JIT compile của mô hình 405B. Phần này cần điều tra thêm
Toàn bộ lần chạy huấn luyện được thực hiện ở chế độ eager execution của JAX, nên còn nhiều dư địa cải thiện hiệu năng
Ngay cả ở chế độ eager execution, mức sử dụng GPU nhìn chung khoảng 30–40%, khá ổn. Tôi nghĩ dùng JIT có thể dễ dàng đẩy mức sử dụng GPU lên 50–60%
Nếu có thể, sẽ rất thú vị nếu khám phá cách vượt qua giới hạn bộ nhớ để chạy phiên bản JIT compile. Việc này có thể mang lại cải thiện hiệu năng thêm
Cần bước huấn luyện được JIT compile, nạp dữ liệu và sharding tối ưu hơn, gradient accumulation, activation checkpointing
Chúng tôi sẽ tiếp tục xây dựng, triển khai mọi cải tiến, rồi sớm đăng blog lại
Tôi tò mò liệu AMD đã tiến gần hơn chút nào tới việc khai thác giá trị ở đây thông qua các đơn hàng GPU số lượng lớn và tình trạng thiếu nguồn cung chưa
Ấn tượng của tôi thì gần như là “chưa”
Đối thủ có lợi thế đi trước khổng lồ, và rõ ràng còn rất nhiều việc phải làm ở phía phần mềm. Cần thời gian
Tại sao ứng dụng ghi chú Obsidian lại làm việc này?