2 điểm bởi GN⁺ 2025-02-27 | Chưa có bình luận nào. | Chia sẻ qua WhatsApp

DeepGEMM

DeepGEMM là một thư viện dành cho phép nhân ma trận tổng quát FP8 (GEMM), hỗ trợ cơ chế scale chi tiết được đề xuất trong DeepSeek-V3. Thư viện này hỗ trợ GEMM thông thường và GEMM được nhóm cho Mix-of-Experts (MoE), được viết bằng CUDA nên không cần biên dịch khi cài đặt. Nó hỗ trợ NVIDIA Hopper tensor core và sử dụng cơ chế tích lũy 2 giai đoạn bằng CUDA core để giải quyết vấn đề thiếu chính xác của tích lũy tensor core FP8. Thư viện có áp dụng một phần các khái niệm từ CUTLASS và CuTe, nhưng giữ sự đơn giản bằng cách giảm thiểu phụ thuộc vào template hay đại số. Với một hàm kernel cốt lõi duy nhất chỉ khoảng 300 dòng mã, đây là tài nguyên phù hợp để học về phép nhân ma trận FP8 trên Hopper và các kỹ thuật tối ưu hóa. Dù có thiết kế gọn nhẹ, nó vẫn đạt hiệu năng ngang bằng hoặc vượt qua các thư viện được tinh chỉnh bởi chuyên gia trên nhiều dạng ma trận khác nhau.

Hiệu năng

Tất cả các dạng có thể được dùng trong suy luận DeepSeek-V3/R1 đã được kiểm thử trên H800 SXM5 với NVCC 12.8. Mọi chỉ số tăng tốc đều được tính bằng cách so sánh với một triển khai được tối ưu hóa nội bộ dựa trên CUTLASS 3.6. Một số dạng có thể cho hiệu năng chưa tốt, và các PR tối ưu hóa luôn được hoan nghênh.

GEMM thông thường (mô hình dense)

  • Kết quả đo hiệu năng của DeepGEMM trên nhiều kích thước ma trận cho thấy mức tăng tốc tối đa lên tới 2.7 lần ở một số kích thước nhất định.

GEMM được nhóm cho mô hình MoE (bố cục liên tục)

  • Tùy theo số lượng nhóm và kích thước ma trận của từng nhóm, mức tăng tốc tối đa đạt 1.2 lần.

GEMM được nhóm cho mô hình MoE (bố cục mask)

  • Sử dụng bố cục mask cho mức tăng tốc tối đa 1.2 lần.

Bắt đầu nhanh

Yêu cầu

  • GPU kiến trúc Hopper, cần hỗ trợ sm_90a
  • Python 3.8 trở lên
  • CUDA 12.3 trở lên (khuyến nghị 12.8 trở lên để có hiệu năng tốt nhất)
  • PyTorch 2.1 trở lên
  • CUTLASS 3.6 trở lên

Phát triển

  • Mô tả quy trình phát triển bao gồm clone submodule, tạo symbolic link, biên dịch JIT và kiểm thử toàn bộ các triển khai GEMM.

Cài đặt

  • Có thể import deep_gemm vào dự án Python để sử dụng.

Giao diện

Lưu ý

  • Thư viện này chỉ bao gồm các kernel GEMM và chỉ hỗ trợ định dạng NT. Các thao tác chuyển vị hoặc ép kiểu FP8 khác cần được triển khai riêng.

GEMM dense thông thường (không nhóm)

  • Cung cấp hàm để thực hiện FP8 GEMM cơ bản không nhóm.

GEMM được nhóm (bố cục liên tục)

  • Được thiết kế cho các kịch bản trong mô hình MoE nơi các expert chia sẻ cùng một dạng.

GEMM được nhóm (bố cục mask)

  • Trong giai đoạn giải mã suy luận, cung cấp tensor mask để chỉ tính toán phần hợp lệ.

Tiện ích

  • Cung cấp nhiều hàm tiện ích và biến môi trường khác nhau để hỗ trợ tối ưu hiệu năng.

Tối ưu hóa

Warp specialization liên tục

  • Tuân theo thiết kế của CUTLASS, chồng lấp việc di chuyển dữ liệu, lệnh tensor core MMA và việc nâng cấp bằng CUDA core.

Tính năng Hopper TMA

  • Tận dụng TMA để tăng tốc di chuyển dữ liệu.

Các tối ưu chi tiết dùng chung

  • Cải thiện hiệu năng thông qua nhiều kỹ thuật tối ưu hóa khác nhau.

Bộ lập lịch block hợp nhất và tối ưu hóa

  • Cung cấp bộ lập lịch cho tất cả các kernel không nhóm và có nhóm.

Thiết kế JIT hoàn chỉnh

  • Cải thiện hiệu năng nhờ thiết kế JIT không yêu cầu biên dịch khi cài đặt.

Kích thước block không căn chỉnh

  • Hỗ trợ kích thước block không căn chỉnh để tối đa hóa mức sử dụng SM trong một số dạng nhất định.

FFMA SASS interleaving

  • Điều chỉnh lệnh FFMA để tăng mức song song ở cấp độ warp nhằm cải thiện hiệu năng.

Lời cảm ơn

  • DeepGEMM lấy cảm hứng từ dự án CUTLASS và bày tỏ sự biết ơn cũng như tôn trọng tới các nhà phát triển của dự án.

Giấy phép

  • Được phát hành theo giấy phép MIT.

Chưa có bình luận nào.

Chưa có bình luận nào.