Giải thích nguyên lý hoạt động của Transformer: hiểu toán học đằng sau nó
(osanseviero.github.io)- Rút gọn quá trình suy luận của Transformer bằng ví dụ dịch Hello World → Hola Mundo, để có thể lần theo bằng tay từ token hóa, encoder·decoder cho đến tính xác suất token tiếp theo
- Thay vì cấu hình lớn trong bài báo gốc, ví dụ dùng embedding 4 chiều, 2 attention head và tầng feedforward 8 chiều để thu nhỏ luồng nhân ma trận và softmax
- Encoder cộng mã hóa vị trí vào embedding token, rồi đi qua multi-head self-attention và tầng feedforward để tạo biểu diễn ngữ cảnh của chuỗi đầu vào
- Decoder bắt đầu từ
SOS, dùng đồng thời các token đã sinh trước đó và đầu ra của encoder; trong encoder-decoder attention, query được tính từ decoder, còn key/value được tính từ đầu ra encoder - Embedding cuối cùng của decoder đi qua tầng tuyến tính và softmax để thành xác suất token tiếp theo, nhưng ví dụ dùng trọng số ngẫu nhiên nên không kỳ vọng chất lượng dịch thực tế
Mục tiêu và giả định
- Xem một ví dụ end-to-end về cách toán học khi suy luận được nối tiếp bên trong mô hình Transformer
- Giảm mạnh kích thước mô hình để dễ lần theo phép tính bằng tay
- Thay vì chiều embedding 512 trong bài báo gốc, ví dụ dùng 4 chiều
- Thay vì 8 attention head trong bài báo gốc, ví dụ dùng 2 head
- Thay vì chiều feedforward 2048 trong bài báo gốc, ví dụ dùng 8 chiều
- Giả định cần có là đại số tuyến tính cơ bản; phần lớn phép tính được thực hiện bằng nhân ma trận
- Tập trung vào việc phép tính thực tế diễn ra như thế nào hơn là Transformer “là gì”
- Phần giải thích trực quan nên đọc cùng The Illustrated Transformer, còn bài báo gốc là Attention is all you need
Tạo đầu vào cho encoder
-
Token hóa
- Vì mô hình machine learning xử lý số chứ không phải văn bản, nên văn bản đầu vào được chuyển thành ID token
- Để đơn giản hóa, ví dụ tách
"Hello World"thành hai token dạng từ là"Hello"và"World" - Trong thực tế, cách token hóa có thể là theo từ, theo ký tự, hoặc theo subword
- Cách theo từ cần vocabulary lớn và coi
"dog"với"dogs"là hai token khác nhau - Cách theo ký tự có vocabulary nhỏ nhưng có thể chứa ít thông tin ngữ nghĩa hơn
- Token hóa subword là điểm trung gian giữa cách theo từ và theo ký tự, và tokenizer được học bằng một quá trình thống kê
-
Embedding token
- Bản thân ID token không có ý nghĩa, nên mỗi token được chuyển thành embedding là một vector kích thước cố định
- Embedding trong ví dụ dùng các giá trị tùy ý
Hello -> [1, 2, 3, 4]World -> [2, 3, 4, 5]
- Trong Transformer thực tế, ánh xạ embedding cũng được học để mô hình học ra biểu diễn token phù hợp với tác vụ
- Hai embedding được gộp thành một ma trận để dùng cho các phép nhân ma trận về sau
-
Mã hóa vị trí
- Chỉ embedding thôi thì không biết được vị trí trong câu của từ, nên cần cộng thêm mã hóa vị trí
- Bài báo gốc dùng mã hóa vị trí sine/cosine cố định, và ví dụ cũng làm theo cách đó
- Mã hóa vị trí trong ví dụ được tính như sau
Hello -> [0, 1, 0, 1]World -> [0.84, 0.99, 0, 1]
- Cộng embedding token với mã hóa vị trí để tạo ma trận đầu vào encoder
Hello -> [1, 3, 3, 5]World -> [2.84, 3.99, 4, 6]
Tính self-attention
-
Tạo Q, K, V
- self-attention tính query(Q), key(K), value(V) từ embedding đầu vào
- Ví dụ dùng 2 attention head, mỗi head có các ma trận
WQ,WK,WVriêng - Mỗi ma trận trọng số chuyển embedding 4 chiều thành query/key/value 3 chiều
- Ở head đầu tiên, nhân ma trận đầu vào với
WK1,WV1,WQ1để thu đượcK1,V1,Q1
-
Công thức Attention
- Điểm attention được tính qua bốn bước
- Tính tích vô hướng giữa query và từng key
- Chia cho căn bậc hai của số chiều key
- Dùng softmax để đổi thành các trọng số dương có tổng bằng 1
- Dùng các trọng số đó để lấy tổng có trọng số của các vector value
- Quá trình này được nén lại trong công thức của bài báo gốc
- \
- Attention(Q,K,V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V
- \
- Trong ví dụ, vì số chiều nhỏ và giá trị khởi tạo ngẫu nhiên, kết quả softmax gần như lệch hẳn về 0 và 1
- Giá trị dot product lớn có thể bị khuếch đại mạnh hơn trong softmax, nên cần phép scale bằng căn bậc hai của số chiều key
- Để giải thích, tác giả tạm thời còn dùng biến thể chia cho 30 thay vì
sqrt(3), nhưng đó không phải lời giải lâu dài
- Điểm attention được tính qua bốn bước
-
Đầu ra multi-head attention
- Kết quả attention của từng head được concatenate rồi nhân với một ma trận trọng số học được để đưa trở lại chiều embedding
- Trong ví dụ, hai kết quả head được ghép thành ma trận 6 chiều, sau đó chuyển thành đầu ra 4 chiều
- Đầu ra này được chuyển sang tầng tiếp theo của khối encoder là tầng feedforward
Tầng feedforward và khối encoder
-
Tầng feedforward
- Sau self-attention là mạng nơ-ron feedforward (FFN)
- FFN gồm hai phép biến đổi tuyến tính và kích hoạt ReLU ở giữa
- Tầng tuyến tính thứ nhất mở rộng số chiều, còn tầng tuyến tính thứ hai giảm số chiều về kích thước ban đầu
- ReLU biến số âm thành 0 và giữ nguyên số dương để thêm tính phi tuyến
- Trong ví dụ, đầu vào 4 chiều được mở rộng thành 8 chiều rồi thu lại còn 4 chiều
- \
- \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2
- \
-
Khối encoder
- Một khối encoder gồm multi-head attention và FFN
- Bài báo gốc xếp chồng 6 encoder, và mã ví dụ cũng lặp encoder với
n=6 - Nếu chỉ đơn giản cho dữ liệu đi qua nhiều khối encoder, giá trị có thể tăng quá lớn khiến phép tính softmax bị overflow và sinh ra
nan
Residual connection và layer normalization
-
Vấn đề giá trị bùng nổ
- Trong ví dụ, khi đi qua 6 encoder thì xuất hiện cảnh báo
overflow encountered in expvàinvalid value encountered in divide, đầu ra trở thànhnan - Hiện tượng giá trị trở nên quá lớn rồi còn tiếp tục tăng ở tầng sau là vấn đề thường gặp trong mạng nơ-ron sâu
- Khi gradient trở nên quá lớn trong quá trình backpropagation, hiện tượng này được gọi là gradient explosion
- Trong ví dụ, khi đi qua 6 encoder thì xuất hiện cảnh báo
-
Residual connection
- residual connection là cách cộng đầu vào của tầng vào đầu ra của chính tầng đó
- \
- \text{Residual}(x) = x + \text{Layer}(x)
- \
- Trong ví dụ, residual connection được áp dụng riêng cho đầu ra attention và đầu ra FFN
- residual connection được dùng để giảm nhẹ vấn đề vanishing gradient
-
Layer normalization
- layer normalization chuẩn hóa để mỗi chiều embedding có trung bình 0 và độ lệch chuẩn 1
- Công thức như sau
- \
- \text{LayerNorm}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \times \gamma + \beta
- \
- (\epsilon) là giá trị nhỏ để tránh chia cho 0 khi độ lệch chuẩn bằng 0
- (\gamma) và (\beta) là các tham số học để điều khiển scaling và shifting
- Sau khi thêm residual connection và layer normalization, dữ liệu đi qua 6 encoder vẫn cho ra giá trị bình thường mà không có
nan
Cấu trúc decoder
-
Đầu vào decoder và cách sinh
- Decoder nhận đầu ra encoder và chuỗi đầu ra đã được sinh ra cho đến thời điểm hiện tại làm đầu vào
- Khi suy luận, nó bắt đầu từ token SOS(start-of-sequence)
- Decoder sinh từng token một theo cách autoregressive
- Lần 1: nhận
SOSlàm đầu vào và sinh"hola" - Lần 2: nhận
SOS + holalàm đầu vào và sinh"mundo" - Lần 3: nhận
SOS + hola + mundolàm đầu vào và sinhEOS
- Lần 1: nhận
- Khi token
EOS(end-of-sequence)được sinh ra, quá trình giải mã dừng lại - Encoder có thể tạo biểu diễn chỉ với một forward pass, nhưng decoder phải thực hiện nhiều forward pass nên chậm hơn
-
Thành phần khối decoder
- Khối decoder phức tạp hơn khối encoder và gồm các bước theo thứ tự sau
- masked self-attention
- residual connection và layer normalization
- encoder-decoder attention
- residual connection và layer normalization
- tầng feedforward
- residual connection và layer normalization
- Trong ví dụ suy luận, embedding
SOSđược cộng với mã hóa vị trí để dùng giá trị[1, 1, 0, 1] - Trong quá trình huấn luyện, masked self-attention được dùng để che các điểm attention bằng
-infnhằm không nhìn thấy các token tương lai
- Khối decoder phức tạp hơn khối encoder và gồm các bước theo thứ tự sau
Encoder-decoder attention
- encoder-decoder attention là bước giúp decoder tập trung vào những phần liên quan của câu đầu vào
- Cách tính giống self-attention, nhưng đầu vào dùng để tạo Q/K/V thì khác nhau
- query được tính từ đầu ra của tầng decoder trước đó
- key và value được tính từ đầu ra encoder
- Nhờ cấu trúc này, mỗi vị trí trong decoder có thể tham chiếu đến mọi vị trí trong chuỗi đầu vào
- Điều này hữu ích cho các tác vụ như dịch máy, nơi token đầu ra phải phụ thuộc vào những vị trí liên quan trong câu đầu vào
Sinh token đầu ra
-
Tầng Linear và softmax
- Đầu ra decoder chưa phải là từ, nên embedding cuối cùng được đưa qua một tầng tuyến tính để biến thành vector logits có kích thước bằng vocabulary
- Vocabulary trong ví dụ có kích thước 10, và các ứng viên token tiếp theo như sau
hello,mundo,world,how,?,EOS,SOS,a,hola,c
- logits đi qua softmax để trở thành phân phối xác suất của từng token
- Trong xác suất ví dụ,
"hola"có xác suất cao nhất nên được chọn làm token tiếp theo - Cách luôn chọn token có xác suất cao nhất được gọi là greedy decoding, nhưng không phải lúc nào cũng là tốt nhất
- Có thể xem thêm về các kỹ thuật sinh trong bài viết của Hugging Face
-
Toàn bộ vòng lặp sinh
- Toàn bộ quy trình sinh đi theo luồng sau
- Chuyển chuỗi đầu vào thành embedding
- Encoder tạo biểu diễn ngữ cảnh cho toàn bộ đầu vào
- Decoder bắt đầu từ
SOSvà dùng đồng thời các token đã sinh trước đó với đầu ra encoder - Áp dụng tầng linear và softmax lên embedding cuối của decoder
- Chọn token tiếp theo có khả năng cao nhất và thêm vào chuỗi
- Lặp lại cho đến khi xuất hiện
EOShoặc đạt độ dài tối đa
- Khi chạy ví dụ, với đầu vào
hello worldmô hình sinh raSOS hola mundo world - Vì mọi trọng số và embedding đều dùng ngẫu nhiên, kết quả không phải là bản dịch tốt, và đó là hành vi được kỳ vọng
- Toàn bộ quy trình sinh đi theo luồng sau
Kết luận và phạm vi
- Ví dụ này nối các thành phần cốt lõi của Transformer gồm embedding, mã hóa vị trí, self-attention, multi-head attention, FFN, residual connection, layer normalization, encoder-decoder attention và đầu ra softmax thành một luồng thống nhất
- Các kiến trúc Transformer hiện đại có thể bổ sung nhiều kỹ thuật khác, nhưng toán học cốt lõi vẫn dựa trên cấu trúc được trình bày trong ví dụ này
- Tùy loại tác vụ mà stack sử dụng có thể khác nhau
- Với tác vụ thiên về hiểu như phân loại, có thể đặt một tầng linear trên stack encoder
- Với tác vụ thiên về sinh như dịch máy, có thể dùng đồng thời stack encoder và decoder
- Với tác vụ sinh tự do như ChatGPT hay Mistral, có thể chỉ dùng stack decoder
- Bài viết không bàn về quá trình huấn luyện mà tập trung vào việc hiểu toán học suy luận khi sử dụng một mô hình có sẵn
- Có thể tham khảo PDF này nếu muốn tài liệu toán học chính quy hơn
1 bình luận
Ý kiến trên Hacker News
“Điều bí ẩn” của Transformer nằm ở chỗ, thay vì nhân các trọng số và giá trị tĩnh theo thứ tự tuyến tính ở từng tầng, nó tạo ra 3 ma trận thu được bằng cách nhân cùng một đầu vào với các trọng số đã học, rồi nhân các ma trận đó với nhau
Cách này hoạt động tốt nhờ tăng tính song song, nhưng bản thân công thức attention là cố định nên rất hạn chế
Để tiến xa hơn, có vẻ cần một cách tổng quát hóa chính đồ thị tính toán thành các tham số có thể học được. Tôi không biết liệu các phương pháp gradient truyền thống có làm được không, do hiệu ứng hỗn loạn khi những thay đổi nhỏ dẫn đến biến đổi lớn về hiệu năng; có thể bên trong sẽ cần một dạng nào đó như thuật toán di truyền hoặc tối ưu hóa bầy đàn hạt
Lợi thế lý thuyết lớn so với RNN là nó hỗ trợ việc này mà không mất thông tin. Đó là vì mỗi phần tử có thể truy cập toàn bộ thông tin của mọi phần tử khác trong chuỗi, hoặc trong trường hợp theo thứ tự thời gian là toàn bộ thông tin của các phần tử trước đó
Ngược lại, RNN và “Transformer tuyến tính” nén các giá trị quá khứ, nên phần tử cuối của một chuỗi dài thường khó truy cập toàn bộ thông tin của phần tử đầu tiên; điều đó là bất khả thi trừ khi trạng thái nội bộ rất lớn đến mức không bỏ đi bất kỳ thông tin nào
Vấn đề là ta không thu được nhiều từ đó. Các phép toán không phải nhân ma trận nhiều khả năng sẽ chậm hơn hoặc có tốc độ tương tự
Tuy nhiên, nếu đưa điều khiển luồng vào thì có nguy cơ về cơ bản trở thành Turing machine, và khi đó, như đã nói, việc huấn luyện sẽ thành vấn đề. Dù vậy, có thể đây không phải là vấn đề hoàn toàn bất khả trị
Nếu muốn một lời giải thích khô khan hơn, hình thức hơn và súc tích hơn, có bài “The Transformer Model in Equations” [0] của John Thickstun
Toàn bộ nội dung nằm gọn trong một trang bằng ký hiệu toán học chuẩn
[0] https://johnthickstun.com/docs/transformers.pdf
Nhiều lúc các nhà nghiên cứu machine learning trông như chưa từng học toán vậy
Theo hiểu biết của tôi, cách giải thích “NaN xuất hiện, giá trị quá lớn nên phát nổ khi chuyển sang encoder tiếp theo, đây là gradient exploding” là sai
Ở đây không tính gradient tại bất kỳ điểm nào, nên đó không phải là gradient exploding
Vấn đề có vẻ nằm ở cách triển khai softmax, và cách triển khai softmax ổn định về mặt số học được giải thích ở đây [0]
[0]: https://jaykmody.com/blog/stable-softmax/
Tuy nhiên, vì toàn bộ mạng nơ-ron đều nhạy với các giá trị lớn, chỉ softmax ổn định về mặt số học là chưa đủ để giải quyết. Để mạng hoạt động, chuẩn hóa là yếu tố cốt lõi
Các tutorial về Transformer có lẽ sẽ trở thành tutorial Monad mới. Đó là một khái niệm khó hiểu, nhưng thuộc loại phải vật lộn thực hành với ví dụ mới hiểu được
Cũng như rất nhiều thứ trong khoa học máy tính
Tôi mới đọc sáu đoạn mà đã có câu hỏi
Trong
Hello -> [1,2,3,4] World -> [2,3,4,5], dù nói vector là ngẫu nhiên nhưng trông vẫn có mẫu hình. Tôi thắc mắc liệu số2xuất hiện trong cả hai vector có ý nghĩa gì không, hay toàn bộ tập hợp mới tạo nên tính duy nhấtỞ đây chúng cách nhau khoảng 60 độ và phần nào cùng hướng; việc tác giả không muốn đưa số âm vào ví dụ khiến các vector trông giống nhau hơn thực tế
Bản thân việc các con số được dùng lại không có ý nghĩa gì. Số
1ở vị trí đầu tiên gần như không liên quan đến số1ở vị trí thứ hai. Vì ta cũng không thực hiện tích chập trên các vector nàySau khi huấn luyện, các từ tương tự nhau sẽ có một mức độ cosine similarity nhất định, nhưng gần như không có trường hợp nào cosine similarity cao như
[1,2,3,4]và[2,3,4,5]Không hoàn toàn liên quan, nhưng tôi đang tìm bài viết hoặc paper giải thích vì sao Transformer, dù hoạt động đơn giản như một “bộ dự đoán token tiếp theo”, vẫn có thể xử lý những câu hỏi như sau
"sdsfs_ff","fsdf_value"Có vẻ đây là câu hỏi phổ biến, nhưng tôi không tìm được từ khóa để tìm kiếm. Nếu có liên kết nào đi sâu vào positional embedding thì cũng tốt; tôi cũng vẫn chưa có câu trả lời thỏa đáng về lý do dùng sin/cos và về phép nhân so với phép cộng
Nếu mô hình cho là cần thiết, nó có thể sao chép các token ký tự đơn lẻ để tái tạo một chuỗi chưa biết, hoặc nếu hợp ngữ cảnh thì có thể tạo mới
P(X_1=x_1, X_2=x_2, X_3=x_3) = P(X_3=x_3 | X_1=X_1, X_2=x_2) • P(X_1=x_1, X_2=x_2)= P(X_3=x_3 | X_1=X_1, X_2=x_2) • P(X_2=x_2 | X_1=x_1) • P(X_1=x_1)Tức là nếu có phân phối xác suất có điều kiện đúng cho token tiếp theo khi đã cho các token trước đó, thì cũng tạo được phân phối xác suất đúng cho toàn bộ chuỗi token
“Phân phối xác suất đúng cho chuỗi token”, hoặc phân phối xác suất có điều kiện đúng của chuỗi token khi có một điều kiện nào đó, trên thực tế có thể dùng cách nói đó để mô tả gần như mọi loại hành vi đầu vào/đầu ra
Vì vậy, “hoạt động bằng cách dự đoán token tiếp theo” về nguyên tắc không phải là một ràng buộc lớn đối với việc nó có thể thực hiện hành vi đầu vào/đầu ra nào
Dù nó làm một việc ấn tượng đến đâu, điều đó cũng không mâu thuẫn với việc đầu ra đó đến từ
P(X_{n+1}=x_{n+1} | X_1=x_1, ..., X_n=x_n), tức là “dự đoán token tiếp theo”Dự đoán token tiếp theo là một tác vụ thông minh hơn nhiều so với cảm giác khi nghe qua
Tôi đồng ý với câu “độ phức tạp đến từ số bước và số tham số”
Một mô hình Transformer đủ đơn giản để chúng ta hiểu thì không làm được việc thú vị, còn một Transformer đủ phức tạp để làm việc thú vị thì có vẻ quá phức tạp để chúng ta hiểu
Tôi muốn nghiên cứu các mô hình quy mô trung gian vừa đủ đơn giản để hiểu, vừa đủ phức tạp để làm được việc thú vị
Nếu dùng khái niệm mà không định nghĩa hoặc giới thiệu thì rất khó hiểu. Phần Encoder bắt đầu ngay mà không giải thích nó là gì, nằm ở đâu trong toàn bộ quy trình
Tôi hiểu tác giả muốn làm gì, nhưng bài thiếu cấu trúc viết cơ bản: giới thiệu ý tưởng trước, giải thích rồi mới sử dụng
Nếu không phải là người đang học chủ đề này và đã hiểu khoảng một nửa, toàn bộ bài sẽ tạo cảm giác rối rắm
Dù tôi từng viết ANN từ đầu và không dùng TensorFlow, phần giải thích này vẫn gây rối
Tôi đã nhờ ChatGPT giải thích cách sửa một ANN cơ bản để triển khai self-attention mà không dùng các từ
MatrixhayVector, và nó đưa ra một giải thích khá đơn giản. Tôi vẫn chưa thử triển khaiTôi thích nghĩ mọi thứ theo góc nhìn node, trọng số và layer hơn. Ma trận và vector làm cho việc liên hệ với những gì thực sự diễn ra bên trong ANN trở nên khó hơn
Trong cách viết ANN quen thuộc, mỗi node đầu vào là một scalar, nhưng thuật toán forward propagation nhân trọng số với mọi node đầu vào rồi cộng lại, nên trông giống phép nhân vector-ma trận. Tôi có cảm giác mình đang tiếp cận các giải thích này với một tư duy sai, và cũng có thể là tôi thiếu kiến thức nền cần thiết