PyTorch đã chết. JAX muôn năm
(neel04.github.io)- Lý do PyTorch gây ra tổn thất năng suất và lãng phí thời gian phát triển là vì "không phải bản thân framework tệ, mà là nó không được thiết kế phù hợp với các use case hiện đang được áp dụng"
Triết lý của PyTorch
- Triết lý của PyTorch là động, dễ debug và đậm chất Python
- Trong khi đó, TensorFlow 1.x cố gắng trở thành một framework tĩnh nhưng có hiệu năng tốt bằng cách tận dụng mạnh mẽ trình biên dịch XLA
- Các nhà phát triển TensorFlow nhận ra rằng cộng đồng không thích API 1.x, nên quyết định dùng Keras làm giao diện chính và giảm vai trò của trình biên dịch XLA
- PyTorch giữ nguyên gốc rễ của mình, và khác với cách tiếp cận tĩnh, trì hoãn của TensorFlow, đã áp dụng cách tiếp cận "eager execution" năng động hơn, trong đó
torch.Tensorđược đánh giá ngay lập tức - Điều này đã mang lại kết quả, khiến nhiều nghiên cứu chuyển sang PyTorch
- Khi GPT-3 xuất hiện vào năm 2021, hiệu năng và khả năng mở rộng trở thành mối quan tâm chính
- PyTorch đã đáp ứng tương đối tốt các nhu cầu này, nhưng vì không được thiết kế với triết lý đó ngay từ đầu, ngày càng tích tụ nợ kỹ thuật và nền tảng bắt đầu lung lay
- Các nhà phát triển PyTorch không muốn bất kỳ điểm thỏa hiệp nào và đã chọn theo đuổi đồng thời hai hướng
- Dùng trình biên dịch XLA làm backend mặc định có hiệu năng và độ ổn định cao
- Xây dựng stack
torch.compileđể người dùng có quyền tự do gọi trình biên dịch khi cần
- Việc thiếu một chiến lược dài hạn là vấn đề nghiêm trọng
- PyTorch không muốn cam kết với triết lý lấy trình biên dịch làm trung tâm như JAX, nhưng cũng không thấy lựa chọn thay thế tốt
- Các đối thủ giải quyết vấn đề này như thế nào?
Phát triển dựa trên trình biên dịch của JAX
- JAX tận dụng XLA, stack trình biên dịch mạnh mẽ của TensorFlow
- XLA là một trình biên dịch mạnh, nhưng tất cả đều được trừu tượng hóa đối với người dùng cuối
- Chỉ cần hàm là pure, bạn có thể dùng decorator
@jax.jitđể JIT compile hàm và cho phép nó chạy trên XLA - XLA sẽ xử lý phía sau toàn bộ việc xác minh đồ thị được tạo ra có chính xác hay không, GSPMD partitioner để tự động song song hóa bằng sharding trong JAX, tối ưu hóa đồ thị, fusion giữa operator và kernel, lập lịch che giấu độ trễ, chồng lấp truyền thông bất đồng bộ, sinh mã cho các backend khác như triton, v.v.
- Chỉ cần tuân thủ các ràng buộc của JAX, XLA sẽ tự động lo phần còn lại
- Ví dụ, khi song song hóa, không cần các communication primitive như
torch.distributed.barrier() - Hỗ trợ DDP có thể đạt được với mã đơn giản
- Cách tiếp cận của XLA là tính toán đi theo sharding. Vì vậy, nếu mảng đầu vào được shard theo một trục nào đó, XLA sẽ tự động xử lý cho các phép tính con
- Ý tưởng "phát triển dựa trên trình biên dịch" tương tự cách trình biên dịch Rust hoạt động
- Giới hạn của PyTorch
- Tác giả không hài lòng với việc các nhà phát triển PyTorch chọn tích hợp và phụ thuộc vào stack trình biên dịch cho các tính năng mới, thay vì giữ vững triết lý cốt lõi là tính linh hoạt và tự do
- Theo roadmap chính thức của PyTorch 2.x, có nêu rõ kế hoạch dài hạn là tích hợp hoàn toàn XLA vào Torch
- Đây là một ý tưởng tồi tệ. Cũng giống như nói rằng nhét ép mã C++ vào trình biên dịch Rust sẽ cho trải nghiệm tốt hơn dùng chính Rust
- Torch, khác với JAX, không được thiết kế xoay quanh XLA
- Nếu PyTorch đã quyết định dùng stack trình biên dịch dựa trên XLA, chẳng phải framework lý tưởng nên là thứ được thiết kế và xây dựng riêng xoay quanh nó sao?
- Ngay cả khi PyTorch theo đuổi cách tiếp cận "multi-backend" cho phép chọn backend trình biên dịch mong muốn, liệu điều đó có làm trầm trọng thêm vấn đề phân mảnh và phá nát API khi phải cố tôn trọng các giới hạn của mọi stack trình biên dịch hay không?
- Bất kỳ ai từng dùng Torch/XLA trên TPU đều mắc PTSD nghiêm trọng
Multi-Backend đã thất bại
- PyTorch cố làm mọi thứ cùng lúc và đã thất bại thảm hại
- Quyết định thiết kế "multi-backend" khiến vấn đề này tệ hơn theo cấp số nhân
- Về lý thuyết thì nghe như có thể chọn stack mình muốn, nhưng trên thực tế lại là một mớ hỗn độn rối rắm của traceback khó hiểu và các vấn đề không tương thích
- Xung đột giữa ràng buộc của các backend và API PyTorch
- Điều khó không phải là làm cho các backend này chạy được, mà là các ràng buộc mà chúng yêu cầu không phù hợp với API linh hoạt, Pythonic của PyTorch
- Có sự đánh đổi giữa việc giữ API nhất quán và tuân theo các giới hạn của backend
- Kết quả là các nhà phát triển có xu hướng dựa nhiều hơn vào sinh mã thay vì thực sự tích hợp/cam kết với một backend duy nhất
- Sự thiếu chiến lược của PyTorch
- Vì PyTorch từ chối các đánh đổi có ý nghĩa, mọi quyết định đều giống như thỏa hiệp
- Không có tính nhất quán, cũng không có chiến lược tổng thể
- Cuối cùng gây ra rất nhiều thất vọng cho người dùng và tạo cảm giác như một mớ chắp vá các tính năng không ăn nhập với nhau
- Không có cách nào giết chết hệ sinh thái nhanh hơn thế
- Tại sao không nên đi theo cách tiếp cận của JAX
- PyTorch không nên đi theo cách tiếp cận "trình biên dịch và backend tích hợp" của JAX
- Vì JAX được thiết kế rõ ràng để hoạt động cùng XLA
- Thay frontend của PyTorch bằng frontend của JAX không thể là chiến lược
- Gần như không thể nghĩ ra một API tốt hơn JAX trên nền XLA
- Tác giả không trách các nhà phát triển vì muốn thử các ý tưởng mới và khác biệt
- Nhưng nếu PyTorch muốn đứng vững trước thử thách của thời gian, nó cần tập trung hơn vào việc củng cố nền tảng thay vì tung ra các tính năng mới bóng bẩy nhưng sụp đổ ngay ngoài các điều kiện tutorial lý tưởng
Sự phân mảnh của PyTorch và lập trình hàm của JAX
- API hàm của JAX
- Hàm JAX phải là pure, tức là không có side effect toàn cục
- Giống như hàm toán học, với cùng dữ liệu đầu vào thì luôn phải trả về cùng một đầu ra, bất kể ngữ cảnh thực thi
- Nhờ triết lý thiết kế này, các hàm JAX có thể được cấu thành và tương tác với nhau rất tốt
- Độ phức tạp phát triển giảm xuống, và hàm được định nghĩa với signature cụ thể cùng công việc cụ thể được xác định rõ
- Chỉ cần type được giữ đúng thì hàm được đảm bảo sẽ hoạt động ngay
- Điều này phù hợp với tính toán khoa học, đặc biệt là kiểu công việc cần thiết trong deep learning
- Ví dụ về API optax
- Nhờ cách tiếp cận hàm, optax có một khái niệm gọi là "chain"
- Nó bao gồm nhiều hàm được áp dụng tuần tự lên gradient
- Thành phần cấu tạo cơ bản là
GradientTransformation - Điều này tạo nên một API mạnh mẽ nhưng vẫn giàu khả năng biểu đạt
- Ví dụ, các việc như clipping gradient, lấy EMA của gradient, hay kết hợp optimizer đều trở nên rất đơn giản
- Ưu điểm của thiết kế hàm
- Một kết quả tuyệt vời khác của thiết kế hàm là
vmap - Nó có nghĩa là map "vectorized" và mô tả chính xác chức năng đó
- Bạn có thể map mọi thứ, và chỉ cần là
vmapthì XLA sẽ tự động fusion và tối ưu - Khi viết hàm, không cần phải nghĩ về chiều batch
- Chỉ cần
vmaptoàn bộ mã - Điều này có nghĩa là ít cần tới các thao tác ein-* hơn
- Việc nắm bắt thao tác tensor 2D/3D trở nên trực quan hơn và khả năng đọc cũng tốt hơn nhiều
- Vì chỉ cần cô lập từng thành phần riêng lẻ để suy luận, nên việc viết mã phức tạp nhưng hoạt động tốt trở nên dễ hơn
- Chỉ cần tôn trọng các ràng buộc về tính pure và có đúng signature, bạn sẽ nhận được mọi lợi ích khác như khả năng kết hợp
- Một kết quả tuyệt vời khác của thiết kế hàm là
- Vấn đề của hệ sinh thái PyTorch
- Trong
torch, bất kể dùng stack nào (FSDP+ đa nút +torch.compile, v.v.), luôn có khả năng một thứ gì đó sẽ hỏng - Nhiều thứ phải hoạt động đúng cùng nhau, và nếu bất kỳ thành phần nào thất bại thì bạn phải debug đến 3 giờ sáng
- Vì không thể kiểm thử mọi tổ hợp của hàng chục tính năng mà PyTorch cung cấp, nên luôn sẽ có lỗi chưa được phát hiện trong quá trình phát triển
- Gần như không thể viết được mã hoạt động tốt nếu không bỏ ra nỗ lực đáng kể
- Hệ sinh thái
torchđã trở nên cực kỳ phình to và đầy lỗi - Vì không có trừu tượng dùng chung, các thư viện và framework mới xuất hiện nhưng không được thiết kế để giao tiếp với các "giải pháp" khác
- Điều này nhanh chóng thoái hóa thành mớ hỗn độn phụ thuộc và
requirements.txt - 70-80% các issue trên GitHub hoặc thảo luận trên diễn đàn chỉ đơn giản là do lỗi phát sinh giữa các thư viện khác nhau
- Hầu như không có cách nào để giải quyết điều này
- Trong
- Sự thiếu vắng lời giải
- Đây là vấn đề của OOP và thiết kế
- Tác giả cho rằng các đối tượng cơ bản, đậm chất PyTorch như
PyTreelẽ ra có thể giúp xây dựng một nền tảng chung cho trừu tượng hóa - Cũng không thể chuyển sang mô hình lập trình hàm
- Làm vậy sẽ hội tụ thành một phiên bản JAX kém hiệu năng hơn, đồng thời phá vỡ khả năng tương thích ngược của toàn bộ codebase
torchhiện có - PyTorch có vẻ đã hỏng hoàn toàn ở khía cạnh này
Ưu thế về tính tái lập của JAX
- Xử lý seed
- Việc xử lý seed của PyTorch không lý tưởng
- Thông thường phải chạy nhiều dòng mã
- Rất dễ quên hoặc cấu hình sai
- JAX buộc bạn phải tạo key tường minh và truyền nó vào mọi hàm cần randomness
- Cách tiếp cận này loại bỏ hoàn toàn vấn đề vì RNG luôn được seed theo cách tĩnh
- JAX có phiên bản NumPy riêng (
jax.numpy), nên không cần seed riêng biệt - Những quyết định QoL nhỏ như vậy có thể khiến trải nghiệm người dùng của toàn bộ framework tốt hơn rất nhiều
- Tính di động
- Một trong những vấn đề lớn nhất khi dùng codebase PyTorch là thiếu tính di động
- Các codebase viết cho CUDA/GPU thường không chạy tốt khi chuyển sang phần cứng không phải Nvidia như TPU, NPU, AMD GPU, v.v.
- Rất khó port mã PyTorch viết cho 1 nút sang đa nút
- Đa nút thường đòi hỏi hàng chục giờ phát triển và thay đổi mã đáng kể
- Cách tiếp cận lấy trình biên dịch làm trung tâm của JAX có lợi thế ở điểm này
- XLA xử lý việc chuyển đổi giữa các backend thiết bị và hoạt động tốt trên GPU/TPU/đa nút/đa slice với thay đổi mã tối thiểu
- Điều này giúp các nhà cung cấp phần cứng dễ hỗ trợ thiết bị của họ hơn và giúp việc chuyển đổi giữa các thiết bị trở nên dễ dàng hơn
- Không phải ai cũng có quyền truy cập cùng một phần cứng, nên codebase có thể di động trên nhiều loại phần cứng sẽ là một bước nhỏ giúp deep learning dễ tiếp cận hơn với người mới bắt đầu/trình độ trung cấp
- Tự động scale
- Một codebase có thể tự động scale tốt là yếu tố rất hữu ích cho việc tái lập
- Trường hợp lý tưởng là điều đó phải xảy ra tự động với thay đổi mã tối thiểu, không bị giới hạn bởi các ranh giới mạng
- JAX làm điều này rất tốt
- Khi viết mã JAX, không cần chỉ định communication primitive hay rải
torch.distributed.barrier()khắp nơi - XLA sẽ tự động chèn chúng dựa trên phần cứng khả dụng
- Mọi thiết bị mà JAX có thể phát hiện sẽ được dùng tự động, bất kể mạng, topology, cấu hình, v.v.
- Nó tự động đồng bộ hóa và chuẩn bị tính toán, áp dụng các pass tối ưu để tối đa hóa thực thi bất đồng bộ của kernel và giảm độ trễ xuống mức thấp nhất
- Điều duy nhất con người phải làm là chỉ định sharding của tensor muốn phân tán lên các thiết bị, chẳng hạn như chiều batch của mảng đầu vào
- Nhờ cách tiếp cận "tính toán đi theo sharding" của XLA, nó sẽ tự động suy ra phần còn lại
- Điều này cho phép dễ dàng chạy các thí nghiệm đã được kiểm chứng ở quy mô lớn như một thú vui để khám phá và có khả năng lặp lại chúng
- Nó giúp việc khám phá lại những ý tưởng bị lãng quên dễ dàng hơn, và có thể khuyến khích các thí nghiệm như vậy vì chúng có thể được kiểm thử như một hàm ở quy mô lớn hơn với nỗ lực tối thiểu
Nhược điểm của JAX
- Cấu trúc quản trị
- Hiện tại XLA nằm dưới cơ chế quản trị của TensorFlow
- Đã có thảo luận về việc thành lập một tổ chức quản trị riêng tương tự PyTorch, nhưng chưa có nhiều nỗ lực cụ thể
- Mức độ tin tưởng vào Google không cao do danh tiếng khai tử các sản phẩm không còn được ưa chuộng
- Về mặt kỹ thuật, JAX là một dự án của DeepMind và có ý nghĩa cốt lõi đối với nỗ lực AI tổng thể của Google, nhưng một cơ chế quản trị riêng có vẻ sẽ mang lại lợi ích dài hạn lớn cho toàn bộ hệ sinh thái
- Một tổ chức quản trị riêng sẽ cung cấp định hướng cho sự phát triển của dự án
- Nó sẽ mang lại cấu trúc cụ thể và tách khỏi bộ máy quan liêu khét tiếng của Google, từ đó tránh được nhiều vấn đề cùng lúc
- Không phải JAX nhất thiết cần loại cấu trúc chính thức này, nhưng sẽ tốt hơn nếu có đảm bảo rằng việc phát triển JAX sẽ tiếp tục lâu dài bất kể quyết định của ban lãnh đạo Google
- Điều này rõ ràng sẽ giúp việc được các công ty và phòng thí nghiệm lớn chấp nhận dễ hơn, vì họ ngần ngại bỏ tài nguyên để tích hợp một công cụ có thể một ngày nào đó không còn được bảo trì
- Quá trình mở nguồn của XLA
- Trong một thời gian dài, XLA là một dự án mã nguồn đóng
- Tuy nhiên, đã có những nỗ lực để mở nguồn nó, và hiện tại OpenXLA cho hiệu năng tốt hơn nhiều so với bản dựng XLA nội bộ
- Nhưng tài liệu về nội bộ của XLA vẫn còn thiếu
- Phần lớn tài nguyên chỉ là các buổi nói chuyện trực tiếp và đôi khi là các bài báo, mà thường cũng đã cũ
- Nếu có một roadmap công khai về các tính năng sắp tới, mọi người sẽ dễ theo dõi tiến độ và đóng góp vào những phần họ đặc biệt hứng thú hơn
- Sẽ rất tốt nếu có các bài blog mini theo phong cách Edward Yang phân tích từng bước của stack trình biên dịch XLA và giải thích chi tiết, để người thực hành có thể đánh giá rõ hơn XLA làm được gì và không làm được gì
- Tác giả hiểu rằng điều này tốn tài nguyên và có thể nên được truyền đạt tốt hơn ở nơi khác, nhưng mọi người tin tưởng công cụ hơn khi họ hiểu nó, và điều đó tạo ra hiệu ứng lan tỏa tích cực cho toàn bộ hệ sinh thái, mang lợi ích cho tất cả
- Tích hợp hệ sinh thái
flaxlà nỗi đau đầu của hệ sinh thái JAX- Nó có API không trực quan, cú pháp ngắn gọn khó hiểu và là địa ngục tuyệt đối cho người mới chuyển từ PyTorch sang
- Tác giả khuyên nên dùng
equinox - Đã có những nỗ lực từ đội ngũ phát triển nhằm khắc phục nhược điểm của
flax, nhưng cuối cùng đó chỉ là lãng phí thời gian - Nếu muốn API kiểu
equinox, tốt nhất là dùngequinox - Không có nhiều thứ mà
flaxthực sự làm tốt hơn một cách đặc biệt, và cũng không khó để tái tạo chúng bằngequinox - Hiện tại, phần lớn hệ sinh thái JAX được thiết kế xoay quanh
flax - Vì
equinoxvề cơ bản giao tiếp vớiPyTree, nó tương thích chéo với mọi thư viện, dù có cần một chúteqx.partitionvàfilter - Tác giả muốn thay đổi hiện trạng.
equinoxnên được hỗ trợ hạng nhất ở khắp nơi - Đây là một ý kiến gây tranh cãi, nhưng đó là lỗi ngụy biện chi phí chìm kinh điển
equinoxhoạt động tốt hơn theo đúng cách mà framework JAX lẽ ra luôn nên vận hành- Như đã được tóm tắt trong tài liệu
equinox, nếu so sánhequinoxvàflaxthìequinoxtốt hơn - Việc những người quản lý hệ sinh thái JAX nhận ra mức độ phổ biến của
equinoxvà điều chỉnh theo là điều tốt, nhưng tác giả cũng muốn Google và nhómflaxchính thức dành nhiều sự ủng hộ hơn - Nếu muốn thử JAX, tác giả khuyên nên dùng
equinox
- Các góc cạnh sắc bén
- Do các quyết định thiết kế API và các giới hạn của XLA, JAX có những "góc cạnh sắc bén" mà bạn cần lưu ý
- Điều này được giải thích rất súc tích trong tài liệu được viết tốt
- Tác giả khuyên nên đọc ít nhất một lần trước khi dùng JAX
- Cũng như mọi khi, RTFM sẽ giúp tiết kiệm rất nhiều thời gian và công sức
Kết luận
- Bài blog này nhằm chỉnh lại huyền thoại được lặp đi lặp lại rằng PyTorch là lựa chọn phù hợp nhất cho workload nghiên cứu thực tế, đặc biệt là trên GPU. Điều đó không còn đúng nữa
- Thực tế, tác giả cực đoan đến mức cho rằng việc port toàn bộ mã PyTorch sang JAX sẽ cực kỳ có lợi cho cả lĩnh vực
- Tự động song song hóa, tính tái lập, API hàm gọn gàng, v.v. không phải là các tính năng nhỏ nhặt và sẽ rất hữu ích cho nhiều codebase nghiên cứu
- Nếu bạn muốn làm cho lĩnh vực này tốt hơn dù chỉ một chút, hãy cân nhắc viết lại codebase của mình bằng JAX
8 bình luận
Thế sự vẫn cứ trôi đi. haha
So sánh PyTorch và TensorFlow năm 2022
Tôi sẽ trụ lại với
torchvàonnx.Bài này do một sinh viên đại học viết.. ghê thật
PyTorch mà không có Huggingface thì đúng là chết thật haha
JAX muôn năm! Gần đây tôi đã thử dùng, và tôi rất thích API NNX.
Vấn đề lớn nhất của JAX là nó thuộc về Google. Google rất nổi tiếng với việc bỏ rơi mã nguồn mở (Tflite, android things, dart, angular, bazel, v.v.). Ngay cả TensorFlow từ một thời điểm nào đó cũng bắt đầu không còn được cập nhật tốt. Trong khi đó, torch bắt đầu từ Facebook, nơi vận hành một hệ sinh thái mã nguồn mở đồ sộ, nên được quản lý rất tốt và hiện đã do quỹ torch vận hành. Những nhược điểm của torch rõ ràng là có thật, nhưng xét ở khía cạnh ai là người vận hành bền vững dự án mã nguồn mở đó, thì có vẻ JAX đã khởi đầu với một rủi ro lớn.
Ít nhất thì có vẻ Dart vẫn sẽ sống ổn thêm một thời gian nhờ Flutter.
Facebook thì có vẻ vẫn liên tục đóng góp một cách khá “có tình có nghĩa” (?) cho các công nghệ trong stack mà họ sử dụng, như React hay Django, nhưng Google thì dường như chỉ cần hơi lỗi thời một chút là vứt bỏ như đôi dép rách...