Cách mở rộng mô hình của bạn: Góc nhìn hệ thống về LLM trên TPU
(jax-ml.github.io)- Tối ưu hiệu năng deep learning ở quy mô lớn có thể trông như một kiểu “giả kim thuật”, nhưng trên thực tế có thể cải thiện hiệu quả mô hình bằng những nguyên tắc đơn giản, có thể hiểu được
- Từ một bộ tăng tốc đơn lẻ đến hàng chục nghìn bộ tăng tốc, các nguyên tắc tương đối đơn giản đều áp dụng được ở mọi nơi; hiểu được chúng sẽ giúp thực hiện các công việc hữu ích sau:
- Ước lượng sơ bộ mỗi phần của mô hình đang tiến gần mức tối ưu lý thuyết đến đâu
- Có cơ sở để lựa chọn các kỹ thuật song song hóa khác nhau ở nhiều quy mô
- Ước tính chi phí và thời gian cần thiết để huấn luyện và chạy các mô hình Transformer lớn
- Thiết kế thuật toán tận dụng được đặc tính của phần cứng cụ thể
- Thiết kế phần cứng với hiểu biết rõ ràng về giới hạn của hiệu năng thuật toán hiện tại
- Kiến thức nền cần có
- Cần hiểu các khái niệm cơ bản về LLM và kiến trúc Transformer
- Không bắt buộc phải hiểu cách vận hành ở quy mô lớn
- Nếu có kiến thức cơ bản về huấn luyện LLM và kinh nghiệm dùng JAX thì càng tốt
- Khuyến nghị tham khảo các bài blog về kiến trúc Transformer và slide về scaling LLM trong JAX
- Mục tiêu
- Rèn luyện khả năng ước lượng nên song song hóa mô hình như thế nào trên phần cứng được cấp
- Rèn luyện khả năng tính toán gần đúng thời gian và chi phí cho huấn luyện và suy luận
Vì sao nên quan tâm
- Chỉ khoảng 3~4 năm trước, phần lớn nhà nghiên cứu ML vẫn chưa cần hiểu sâu về kiểu tối ưu quy mô lớn này
- Hiện nay, ngay cả các mô hình “nhỏ” cũng vận hành sát giới hạn phần cứng, nên việc hiểu cách làm việc hiệu quả ở quy mô lớn đã trở thành điều thiết yếu
- Lịch sử ML có thể được xem là quá trình phát triển đan xen giữa đổi mới hệ thống và cải tiến phần mềm
- Khi các mô hình Transformer gần đây đã khai thác tới sát giới hạn phần cứng, nếu không hiểu hiệu quả mô hình thì kiến trúc hay nghiên cứu mới rất dễ thất bại khi triển khai thực tế
- Dù đạt cải thiện 20% trên benchmark, nếu hiệu quả phần cứng giảm 20% thì cuối cùng tính thực dụng vẫn thấp
- Mục tiêu cốt lõi của scaling mô hình là làm cho thông lượng tăng tuyến tính khi tăng số lượng chip (bộ tăng tốc)
- Điều này được gọi là "strong scaling"
- Thêm chip giúp giảm thời gian tính toán nhưng phát sinh chi phí giao tiếp giữa các chip
- Nếu giao tiếp mất nhiều thời gian hơn tính toán, hệ thống sẽ rơi vào trạng thái "communication bound" và không thể strong scaling
- Nếu hiểu đủ rõ phần cứng để dự đoán nơi xuất hiện các điểm nghẽn này, ta có thể thiết kế hoặc tái cấu trúc mô hình để tránh chúng
- Mục tiêu của cuốn sách này là giải thích cách phần cứng TPU (và GPU) hoạt động, cũng như cách kiến trúc Transformer đã phát triển để vận hành tốt trên phần cứng hiện tại
- Tác giả hy vọng nội dung này sẽ hữu ích cho cả các nhà nghiên cứu thiết kế kiến trúc mới lẫn các kỹ sư đang cố gắng chạy thật nhanh các thế hệ LLM hiện nay
Tổng quan toàn bộ
- Bài viết này được cấu trúc như sau
- Phần 1 giải thích các yếu tố quyết định giới hạn hiệu năng của mô hình (giao tiếp, tính toán, bộ nhớ) thông qua phân tích roofline
- Phần 2, Phần 3 trình bày cấu trúc bên trong của TPU và GPU cũng như cách kết nối giữa các chip
- Qua đó trả lời các câu hỏi như sau
- Về mặt lý thuyết, một phép nhân ma trận với kích thước nhất định có thể được thực hiện nhanh đến mức nào
- Tại điểm nào việc tính toán sẽ bị ràng buộc bởi băng thông bộ nhớ hoặc băng thông giao tiếp
- Cụm TPU được kết nối theo cấu trúc nào, và mất khoảng bao lâu để chuyển dữ liệu từ chip này sang chip khác
- Làm thế nào để nhân các ma trận phân tán một cách hiệu quả
- Qua đó trả lời các câu hỏi như sau
- Phần 4 đi sâu vào công thức của kiến trúc Transformer (kích thước ma trận, số lượng tham số, FLOPs)
- Phần 5 và Phần 7 là trọng tâm, giới thiệu nhiều cách khác nhau để song song hóa mô hình trên nhiều chip
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- Cũng đề cập các kỹ thuật tiết kiệm bộ nhớ như ZeRO, Rematerialisation, Host offload, Gradient accumulation
- Phần 6, Phần 8 lấy ví dụ huấn luyện và suy luận mô hình LLaMA-3 trên TPU để đưa ra chi phí, thời gian và cách cấu hình thực tế
- Cuối cùng, Phần 9, Phần 10 trình bày cách thực tế để profile, debug và áp dụng xử lý song song cho mô hình trong JAX
Chi tiết hơn: tóm tắt các phần chính của sách
-
Phần 1: Preliminaries
-
Phần 1: Giới thiệu ngắn gọn về phân tích Roofline
- Ba yếu tố ràng buộc thuật toán: tính toán, giao tiếp, bộ nhớ
- Từ đó học cách ước lượng giới hạn trên của tốc độ tính toán
-
- TPU thực hiện tính toán theo cách nào
- Cấu trúc systolic array là gì
- Hiểu biết cơ bản về cách TPU cung cấp băng thông bộ nhớ và băng thông giao tiếp
-
Phần 3: Ma trận phân tán và phép nhân phân tán
- Kỹ thuật chia nhỏ (sharding) để lưu tham số mô hình trên nhiều chip
- Cách xử lý giao tiếp và điểm nghẽn phát sinh trong các phép toán ma trận phân tán
-
-
Phần 2: Transformers
-
Phần 4: Tổng hợp các công thức Transformer cần thiết
- Cụ thể các phép nhân ma trận trong Transformer có dạng như thế nào
- Cách tính số lượng tham số, FLOPs, kích thước KV cache, v.v.
- Xác định attention đòi hỏi lượng tính toán nhiều hơn bao nhiêu so với khối Feed-Forward
-
Phần 5: Chiến lược song song hóa huấn luyện Transformer
- Giới thiệu các kỹ thuật Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- Các biện pháp tiết kiệm bộ nhớ như ZeRO(FSDP), Rematerialisation, Gradient accumulation, Host offload
- Xây dựng khái niệm về cách cấu hình song song hóa phù hợp với kích thước mô hình và số lượng chip cụ thể
-
Phần 6: Ứng dụng huấn luyện LLaMA 3 trên TPU
- Giả sử huấn luyện mô hình LLaMA 3 trong môi trường TPU thực tế, ước tính thời gian và chi phí cần thiết
- Đưa ra ví dụ cụ thể về batch size, cách song song hóa, mức sử dụng bộ nhớ, v.v.
-
Phần 7: Mọi thứ về suy luận Transformer
- Khi suy luận, độ trễ (latency) trở thành một yếu tố mới quan trọng
- Vấn đề bộ nhớ và giao tiếp do KV cache gây ra
- Thảo luận cách phân bổ và kết nối nhiều chip để phục vụ mô hình
-
Phần 8: Ứng dụng phục vụ LLaMA 3 trên TPU
- Giả sử phục vụ LLaMA 3 trên TPU v5e, phân tích gần đúng về chi phí, độ trễ và đánh đổi thông lượng
-
-
Phần 3: Practical Tutorials
-
- Hiểu stack JAX+XLA
- Xác định các vấn đề suy giảm hiệu năng thực tế và cách giải quyết
- Cách dùng profiler của JAX/TensorBoard
-
Phần 10: Lập trình TPU với JAX
- Cách tận dụng API song song hóa (primitives) của JAX
- Học các khái niệm tính toán song song qua ví dụ và bài tập
-
Phần 11: Kết luận và tài liệu bổ sung
- Tài liệu đọc thêm về TPU và LLM
- Khép lại ngắn gọn toàn bộ nội dung và đề cập triển vọng tương lai
-
1 bình luận
Ý kiến trên Hacker News