-
Tầm quan trọng của Attention
- Attention là lớp cốt lõi trong kiến trúc Transformer, đồng thời gây ra nút thắt cổ chai trong các mô hình ngôn ngữ lớn và các ứng dụng ngữ cảnh dài.
- FlashAttention và FlashAttention-2 đã tiên phong trong cách tiếp cận tăng tốc Attention bằng cách giảm thiểu việc đọc/ghi bộ nhớ trên GPU.
- Nhờ đó, độ dài ngữ cảnh của LLM đã tăng lên đáng kể.
-
Các kỹ thuật chính của FlashAttention-3
- Khai thác tính bất đồng bộ: Tận dụng tính bất đồng bộ của Tensor Cores và TMA để chồng lấp toàn bộ quá trình tính toán và di chuyển dữ liệu.
- Tính toán theo khối: Thực hiện xen kẽ phép nhân ma trận và phép toán softmax theo từng khối.
- Xử lý độ chính xác thấp: Tận dụng hỗ trợ độ chính xác thấp FP8 để cải thiện hiệu năng.
-
Cải thiện hiệu năng của FlashAttention-3
- Hiệu quả khai thác GPU: Tận dụng tới 75% hiệu năng tối đa của GPU H100, nhanh hơn 1,5-2 lần so với phiên bản trước.
- Hiệu năng với độ chính xác thấp: Dùng FP8 để tăng tốc độ xử lý và giảm mức sử dụng bộ nhớ.
- Xử lý ngữ cảnh dài: Tăng tốc cơ chế Attention để có thể xử lý văn bản dài hơn một cách hiệu quả.
-
Tóm tắt về FlashAttention
- FlashAttention sắp xếp lại phép tính Attention và tận dụng tiling cùng tính toán lại để tăng tốc đáng kể và giảm mức sử dụng bộ nhớ.
- Thông qua tiling, hệ thống nạp các khối đầu vào, thực hiện Attention trên các khối đó rồi cập nhật đầu ra.
- Bằng cách không ghi ma trận Attention trung gian vào bộ nhớ, lượng đọc/ghi bộ nhớ được giảm xuống.
-
Các tính năng phần cứng mới của GPU Hopper
- WGMMA: Tận dụng Tensor Cores mới để cung cấp thông lượng cao.
- TMA: Đơn vị phần cứng tăng tốc truyền dữ liệu giữa bộ nhớ toàn cục và bộ nhớ chia sẻ.
- FP8 độ chính xác thấp: Sử dụng FP8 để tăng gấp đôi thông lượng của Tensor Core.
-
Tính bất đồng bộ: chồng lấp GEMM và Softmax
- Sự cần thiết của chồng lấp: Thực hiện GEMM và softmax song song để tối đa hóa hiệu năng.
- Lập lịch ping-pong: Hai nhóm warp luân phiên thực hiện GEMM và softmax để cải thiện hiệu năng.
- Chồng lấp trong nhóm warp: Thực hiện GEMM và softmax song song trong cùng một nhóm warp để tăng thông lượng.
-
Độ chính xác thấp: giảm lỗi lượng tử hóa bằng xử lý incoherent
- Xử lý incoherent: Sử dụng phép biến đổi Hadamard để giảm lỗi lượng tử hóa.
- Kết quả thực nghiệm: Xử lý incoherent giúp giảm lỗi lượng tử hóa 2,6 lần.
-
Benchmark Attention
- FP16: Nhanh hơn khoảng 1,6-1,8 lần so với FlashAttention-2.
- FP8: Đạt tối đa 1.2 PFLOPS.
Tổng kết của GN⁺
- FlashAttention-3 tận dụng các tính năng phần cứng mới của GPU để cải thiện mạnh mẽ hiệu năng của cơ chế Attention.
- Nó có thể xử lý ngữ cảnh dài một cách hiệu quả, từ đó tối đa hóa hiệu năng của các mô hình ngôn ngữ lớn.
- Khả năng cao sẽ được tích hợp vào các framework lớn như PyTorch, qua đó tạo ảnh hưởng lớn đến nghiên cứu và ứng dụng AI trong tương lai.
- Các dự án cung cấp tính năng tương tự gồm có Triton và cuDNN.
1 bình luận
Ý kiến trên Hacker News
Có vẻ như Tri Dao đã bắt đầu làm FA3 từ tháng 4 năm 2022
Tò mò không biết thuật toán Flash Attention phụ thuộc vào phần cứng đến mức nào
Tò mò liệu trình biên dịch có thể tự tìm ra các tối ưu hóa như FlashAttention hay không
Người nào muốn port sang ROCm/AMD MI300x thì hãy liên hệ
TMA (Tensor Memory Accelerator) là một đơn vị phần cứng tăng tốc truyền dữ liệu giữa bộ nhớ toàn cục và bộ nhớ chia sẻ
FlashAttention-3 được tối ưu cho GPU Hopper (ví dụ: H100)
Có nhắc đến việc các hàm kích hoạt như sigmoid rất chậm trong LLM hiện đại
Tò mò vì sao Flash Attention lại chậm hơn 5 lần khi có masking biến thiên so với khi không có
Tò mò liệu FlashAttention có thể thay thế phép toán attention trong LLM hay không
Cần phần cứng đắt tiền