7 điểm bởi GN⁺ 2025-02-07 | 1 bình luận | Chia sẻ qua WhatsApp
  • Tối ưu hiệu năng deep learning ở quy mô lớn có thể trông như một kiểu “giả kim thuật”, nhưng trên thực tế có thể cải thiện hiệu quả mô hình bằng những nguyên tắc đơn giản, có thể hiểu được
  • Từ một bộ tăng tốc đơn lẻ đến hàng chục nghìn bộ tăng tốc, các nguyên tắc tương đối đơn giản đều áp dụng được ở mọi nơi; hiểu được chúng sẽ giúp thực hiện các công việc hữu ích sau:
    • Ước lượng sơ bộ mỗi phần của mô hình đang tiến gần mức tối ưu lý thuyết đến đâu
    • Có cơ sở để lựa chọn các kỹ thuật song song hóa khác nhau ở nhiều quy mô
    • Ước tính chi phí và thời gian cần thiết để huấn luyện và chạy các mô hình Transformer lớn
    • Thiết kế thuật toán tận dụng được đặc tính của phần cứng cụ thể
    • Thiết kế phần cứng với hiểu biết rõ ràng về giới hạn của hiệu năng thuật toán hiện tại
  • Kiến thức nền cần có
    • Cần hiểu các khái niệm cơ bản về LLM và kiến trúc Transformer
    • Không bắt buộc phải hiểu cách vận hành ở quy mô lớn
    • Nếu có kiến thức cơ bản về huấn luyện LLM và kinh nghiệm dùng JAX thì càng tốt
    • Khuyến nghị tham khảo các bài blog về kiến trúc Transformer và slide về scaling LLM trong JAX
  • Mục tiêu
    • Rèn luyện khả năng ước lượng nên song song hóa mô hình như thế nào trên phần cứng được cấp
    • Rèn luyện khả năng tính toán gần đúng thời gian và chi phí cho huấn luyện và suy luận

Vì sao nên quan tâm

  • Chỉ khoảng 3~4 năm trước, phần lớn nhà nghiên cứu ML vẫn chưa cần hiểu sâu về kiểu tối ưu quy mô lớn này
    • Hiện nay, ngay cả các mô hình “nhỏ” cũng vận hành sát giới hạn phần cứng, nên việc hiểu cách làm việc hiệu quả ở quy mô lớn đã trở thành điều thiết yếu
    • Lịch sử ML có thể được xem là quá trình phát triển đan xen giữa đổi mới hệ thống và cải tiến phần mềm
    • Khi các mô hình Transformer gần đây đã khai thác tới sát giới hạn phần cứng, nếu không hiểu hiệu quả mô hình thì kiến trúc hay nghiên cứu mới rất dễ thất bại khi triển khai thực tế
    • Dù đạt cải thiện 20% trên benchmark, nếu hiệu quả phần cứng giảm 20% thì cuối cùng tính thực dụng vẫn thấp
  • Mục tiêu cốt lõi của scaling mô hình là làm cho thông lượng tăng tuyến tính khi tăng số lượng chip (bộ tăng tốc)
    • Điều này được gọi là "strong scaling"
    • Thêm chip giúp giảm thời gian tính toán nhưng phát sinh chi phí giao tiếp giữa các chip
    • Nếu giao tiếp mất nhiều thời gian hơn tính toán, hệ thống sẽ rơi vào trạng thái "communication bound" và không thể strong scaling
    • Nếu hiểu đủ rõ phần cứng để dự đoán nơi xuất hiện các điểm nghẽn này, ta có thể thiết kế hoặc tái cấu trúc mô hình để tránh chúng
  • Mục tiêu của cuốn sách này là giải thích cách phần cứng TPU (và GPU) hoạt động, cũng như cách kiến trúc Transformer đã phát triển để vận hành tốt trên phần cứng hiện tại
    • Tác giả hy vọng nội dung này sẽ hữu ích cho cả các nhà nghiên cứu thiết kế kiến trúc mới lẫn các kỹ sư đang cố gắng chạy thật nhanh các thế hệ LLM hiện nay

Tổng quan toàn bộ

  • Bài viết này được cấu trúc như sau
  • Phần 1 giải thích các yếu tố quyết định giới hạn hiệu năng của mô hình (giao tiếp, tính toán, bộ nhớ) thông qua phân tích roofline
  • Phần 2, Phần 3 trình bày cấu trúc bên trong của TPU và GPU cũng như cách kết nối giữa các chip
    • Qua đó trả lời các câu hỏi như sau
      • Về mặt lý thuyết, một phép nhân ma trận với kích thước nhất định có thể được thực hiện nhanh đến mức nào
      • Tại điểm nào việc tính toán sẽ bị ràng buộc bởi băng thông bộ nhớ hoặc băng thông giao tiếp
      • Cụm TPU được kết nối theo cấu trúc nào, và mất khoảng bao lâu để chuyển dữ liệu từ chip này sang chip khác
      • Làm thế nào để nhân các ma trận phân tán một cách hiệu quả
  • Phần 4 đi sâu vào công thức của kiến trúc Transformer (kích thước ma trận, số lượng tham số, FLOPs)
  • Phần 5Phần 7 là trọng tâm, giới thiệu nhiều cách khác nhau để song song hóa mô hình trên nhiều chip
    • Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
    • Cũng đề cập các kỹ thuật tiết kiệm bộ nhớ như ZeRO, Rematerialisation, Host offload, Gradient accumulation
  • Phần 6, Phần 8 lấy ví dụ huấn luyện và suy luận mô hình LLaMA-3 trên TPU để đưa ra chi phí, thời gian và cách cấu hình thực tế
  • Cuối cùng, Phần 9, Phần 10 trình bày cách thực tế để profile, debug và áp dụng xử lý song song cho mô hình trong JAX

Chi tiết hơn: tóm tắt các phần chính của sách

  • Phần 1: Preliminaries

  • Phần 2: Transformers

    • Phần 4: Tổng hợp các công thức Transformer cần thiết

      • Cụ thể các phép nhân ma trận trong Transformer có dạng như thế nào
      • Cách tính số lượng tham số, FLOPs, kích thước KV cache, v.v.
      • Xác định attention đòi hỏi lượng tính toán nhiều hơn bao nhiêu so với khối Feed-Forward
    • Phần 5: Chiến lược song song hóa huấn luyện Transformer

      • Giới thiệu các kỹ thuật Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
      • Các biện pháp tiết kiệm bộ nhớ như ZeRO(FSDP), Rematerialisation, Gradient accumulation, Host offload
      • Xây dựng khái niệm về cách cấu hình song song hóa phù hợp với kích thước mô hình và số lượng chip cụ thể
    • Phần 6: Ứng dụng huấn luyện LLaMA 3 trên TPU

      • Giả sử huấn luyện mô hình LLaMA 3 trong môi trường TPU thực tế, ước tính thời gian và chi phí cần thiết
      • Đưa ra ví dụ cụ thể về batch size, cách song song hóa, mức sử dụng bộ nhớ, v.v.
    • Phần 7: Mọi thứ về suy luận Transformer

      • Khi suy luận, độ trễ (latency) trở thành một yếu tố mới quan trọng
      • Vấn đề bộ nhớ và giao tiếp do KV cache gây ra
      • Thảo luận cách phân bổ và kết nối nhiều chip để phục vụ mô hình
    • Phần 8: Ứng dụng phục vụ LLaMA 3 trên TPU

      • Giả sử phục vụ LLaMA 3 trên TPU v5e, phân tích gần đúng về chi phí, độ trễ và đánh đổi thông lượng
  • Phần 3: Practical Tutorials

1 bình luận

 
GN⁺ 2025-02-07
Ý kiến trên Hacker News
  • Có kỳ vọng rằng JAX sẽ thay thế pytorch/cuda trong vài năm tới. Vấn đề PTX với nhóm Deepseek cho thấy giá trị của việc đầu tư vào cách tiếp cận ở mức thấp hơn để khai thác tối đa hiệu năng phần cứng
    • Tài liệu này từng được dùng nội bộ tại Google như một cẩm nang cho công việc tối ưu hiệu năng. Việc nó được công khai là điều đáng ngạc nhiên, nhưng có vẻ các chi tiết liên quan đến Gemini đã bị loại bỏ
    • Điểm hay của hướng dẫn này là nhờ JAX/XLA nên có thể chuyển thẳng sang GPU
    • Có ý kiến thắc mắc vì sao JAX lại dùng tracing thay vì AST
    • Có chia sẻ liên kết tới chuỗi tweet của tác giả
    • Có người đang tìm cách chuyển một trang Jekyll thành PDF
    • Có lời khen đây là một bài viết rất hay cùng lời cảm ơn
    • Có ý kiến tò mò không biết họ làm các hoạt ảnh đẹp mắt đó như thế nào