2 điểm bởi GN⁺ 2024-09-24 | 1 bình luận | Chia sẻ qua WhatsApp

Felafax BlogTune Llama3 405B on AMD MI300x (hành trình của chúng tôi)

Giới thiệu

  • Khi các mô hình mã nguồn mở ngày càng lớn hơn, nhu cầu về hạ tầng mạnh mẽ để xử lý huấn luyện AI quy mô lớn cũng tăng lên
  • Felafax đã tinh chỉnh mô hình LLaMA 3.1 405B trên GPU AMD để chứng minh hiệu quả của phần cứng AMD
  • Toàn bộ công việc đã được công khai mã nguồn mở trên GitHub
  • GPU AMD MI300X mang lại hiệu năng cao so với phần cứng AI của NVIDIA
  • Dự án có thể thực hiện được nhờ sự hỗ trợ của TensorWave

JAX là gì và tại sao lại chọn nó

  • JAX là một thư viện machine learning mạnh mẽ kết hợp API tương tự NumPy, tính vi phân tự động và trình biên dịch XLA của Google
  • Nó cung cấp API xuất sắc cho xử lý song song mô hình, rất phù hợp cho huấn luyện mô hình quy mô lớn

Ưu điểm của JAX

  • Hàm thuần: JAX khuyến khích viết các hàm thuần, giúp mã dễ cấu trúc, gỡ lỗi và đọc hơn
  • Xử lý song song nâng cao: API JIT linh hoạt của JAX hỗ trợ xử lý song song dữ liệu và mô hình ở mức nâng cao, điều thiết yếu cho huấn luyện quy mô lớn
  • Codebase gọn gàng: Triết lý thiết kế của JAX khuyến khích viết mã có thể di chuyển giữa các nền tảng phần cứng

Vì sao JAX vượt trội trên phần cứng không phải NVIDIA

  • Cách tiếp cận độc lập phần cứng: JAX tận dụng trình biên dịch XLA để biên dịch phép tính sang biểu diễn trung gian độc lập phần cứng
  • Tối ưu hóa độc lập nền tảng: Trình biên dịch XLA thực hiện tối ưu hóa mà không phụ thuộc vào phần cứng
  • Khả năng di chuyển dễ dàng: Khi dùng JAX, việc chuyển từ NVIDIA sang AMD chỉ cần thay đổi mã ở mức tối thiểu

Thiết lập JAX trên GPU AMD

  • Kéo Docker image, khởi động container rồi xác minh cài đặt
  • Huấn luyện mô hình LLaMA 405B bằng 8 GPU AMD MI300x

Huấn luyện LLaMA 405B: hiệu năng và khả năng mở rộng

  • Huấn luyện mô hình LLaMA 405B trên GPU AMD bằng JAX
  • Thông qua tinh chỉnh LoRA, trọng số mô hình và tham số LoRA được điều chỉnh với độ chính xác bfloat16
  • Kích thước mô hình: chiếm khoảng 800GB VRAM
  • Trọng số LoRA và trạng thái optimizer: chiếm khoảng 400GB VRAM
  • Tổng mức dùng VRAM: khoảng 1200GB
  • Tốc độ huấn luyện: khoảng 35 token mỗi giây
  • Hiệu quả bộ nhớ: duy trì khoảng 70%
  • Khả năng mở rộng: với JAX, hệ thống mở rộng gần như tuyến tính trên 8 GPU

Cấu hình huấn luyện của chúng tôi

  • Chuyển đổi LLaMA 3.1 từ PyTorch sang JAX
  • Phân phối hiệu quả thông qua nạp mô hình và sharding tham số

Sharding tham số trong JAX

  • Dùng tính năng device mesh của JAX để phân phối mô hình hiệu quả trên 8 GPU AMD
  • Định nghĩa quy tắc sharding tham số để shard các chiều của từng tensor theo các trục của mesh

Triển khai huấn luyện LoRA

  • LoRA giảm số lượng tham số có thể huấn luyện bằng cách phân rã cập nhật trọng số thành các ma trận hạng thấp
  • Triển khai lớp LoRADense để bao gồm các tham số LoRA
  • Phân phối tham số LoRA một cách hiệu quả để tối ưu mức dùng bộ nhớ và hiệu quả tính toán

Kết luận

  • Trải nghiệm tinh chỉnh mô hình LLaMA 3.1 405B bằng GPU AMD và JAX là rất tích cực
  • Mô hình được phân phối hiệu quả nhờ tận dụng khả năng xử lý song song mạnh mẽ và cách tiếp cận độc lập phần cứng của JAX
  • Điều này chứng minh GPU AMD là một lựa chọn thay thế mạnh mẽ cho huấn luyện AI quy mô lớn
  • Có thể xem toàn bộ mã nguồn trong kho GitHub và tự chạy thử

Tóm tắt của GN⁺

  • Bài viết này giải thích cách huấn luyện hiệu quả các mô hình AI quy mô lớn bằng GPU AMD và JAX
  • Bài viết nhấn mạnh rằng phần cứng AMD là một lựa chọn thay thế hiệu quả về chi phí so với NVIDIA
  • Cách tiếp cận độc lập phần cứng của JAX giúp tăng tính di động của mã và đơn giản hóa bảo trì
  • Cung cấp thông tin hữu ích và mã thực hành cho những ai quan tâm đến huấn luyện mô hình quy mô lớn
  • Các dự án có chức năng tương tự gồm CUDA và PyTorch của NVIDIA

1 bình luận

 
GN⁺ 2024-09-24
Ý kiến Hacker News
  • Chia sẻ kết quả tinh chỉnh mô hình Llama3.1 405B trên 8 GPU AMD MI300x bằng JAX

    • Đạt được hiệu năng rất tốt nhờ API sharding nâng cao của JAX
    • Cung cấp liên kết tới bài viết blog và mã nguồn mở: liên kết GitHub
    • Đây là một startup xây dựng hạ tầng AI để tinh chỉnh và phục vụ LLM trên TPU, AMD và Trainium thay vì phần cứng NVIDIA
    • Nhiều công ty đang cố chạy PyTorch trên GPU AMD, nhưng họ cho rằng đó là một con đường khó khăn
    • PyTorch gắn bó rất sâu với hệ sinh thái NVIDIA, nên cần rất nhiều chỉnh sửa để chạy trên phần cứng không phải NVIDIA
    • Họ tin rằng JAX phù hợp hơn với phần cứng không phải NVIDIA
    • Trong JAX, mã mô hình ML được biên dịch thành đồ thị HLO độc lập với phần cứng, còn trình biên dịch XLA sẽ thực hiện tối ưu hóa đặc thù cho từng phần cứng
    • Có thể chạy cùng một mã JAX trên Google TPU và GPU AMD mà không cần thay đổi
    • Chiến lược của công ty là port mô hình sang JAX và tận dụng kernel XLA để khai thác hiệu năng tối đa trên các backend không phải NVIDIA
    • Họ đã port Llama 3.1 từ PyTorch sang JAX lần đầu tiên, và giờ cùng một mô hình JAX hoạt động tốt trên cả TPU và GPU AMD
    • Họ muốn lắng nghe ý kiến về tầm nhìn và kho mã của mình
  • Đề xuất tìm cách vượt qua giới hạn bộ nhớ và chạy phiên bản được biên dịch JIT

    • Điều này có thể mang lại thêm cải thiện về hiệu năng
  • Chia sẻ trải nghiệm về GPU AMD và hỗ trợ ROCm

    • Một năm trước đã thử dùng GPU AMD và hỗ trợ ROCm, nhưng cảm thấy AMD vẫn còn rất xa mới bắt kịp NVIDIA
    • Việc chọn JAX là một hướng tiếp cận thú vị, nhưng họ tò mò không biết đã gặp những khó khăn gì khi rời khỏi PyTorch
  • Chia sẻ kinh nghiệm thử nghiệm ở khía cạnh suy luận của mô hình 405B

    • Họ cho rằng torch.cuda không tệ đến vậy
    • Vì phiên bản PyTorch cho AMD sẽ dịch phần này, nên họ cho rằng đây chỉ là vấn đề về tên gọi
    • Dùng container rocm:pytorch cũng dễ như dùng container rocm:jax
    • Họ chỉ ra rằng chưa có nhiều dữ liệu hiệu năng được công bố
    • Họ muốn biết chỉ số MFU (mức độ tận dụng mô hình)
  • Câu hỏi về việc thiếu dữ liệu hiệu năng

    • Đặt nghi vấn về khả năng khai thác giá trị từ các đơn hàng GPU AMD số lượng lớn
    • Họ có ấn tượng rằng câu trả lời là “không”
  • Thắc mắc vì sao Obsidian (ứng dụng ghi chú) lại làm việc này

    • Ban đầu họ tưởng đây là bài đăng của Obsidian
    • Họ thắc mắc vì sao vẫn chưa phân biệt GitHub.com và GitHub.io
  • Yêu cầu @dang thêm tên người dùng vào URL

    • Bài đăng này nói về một blog do người dùng tạo ra, không phải chính Obsidian**