Anti Math Math Club

Attention is not all you need: pure attention losses rank doubly exponentially with depth 본문

Machine Learning & Deep Learning/Natural Language Processing

Attention is not all you need: pure attention losses rank doubly exponentially with depth

seewoo5 2021. 4. 24. 23:58

이번에는 제목을 보자마자 어그로가 끌려서 읽을 수 밖에 없었던 Attention is not all you need: pure attention losses rank doubly exponentially with depth라는 논문에 대해서 리뷰하겠습니다.

 

Attention is all you need라는 제목으로 NeurIPS에 발표된 논문은 자연어 처리 뿐만 아니라 요즘에는 컴퓨터 비전까지 넘보고 있는 Transformer를 제시한 역사적인 논문입니다. Transformer는 self-attention만을 이용하여 기존의 RNN 기반 모델들에 비해서 자연어 처리에서의 월등한 성능과 훈련 속도를 보여주었고, 이후 GPT와 BERT를 필두로 한 자연어 처리에서의 pre-training & fine-tuning 파라다임의 de facto standard가 됩니다. 하지만, 이 논문의 제목은 뭔가 Transformer를 정면으로 반박하는 것 처럼 보입니다. 결론부터 말하자면 그렇지 않고, 실제 논문의 내용과 좀 더 가까운 제목은 Attention and Skip Connection (and Feed-Forward Network?) is All You Need라고 할 수 있습니다. 내용을 찬찬히 살펴보도록 하겠습니다. 

 

이 논문의 가장 중요한 요지는 다음과 같습니다. Transformer의 구성 요소로는 Multi-head Attention, Skip Connection, Layer Normalization, Feed-Forward Network등이 있는데, 논문 제목에서 말한대로 정말 이 중에서 Multi-head Attention만 사용한 네트워크를 만들면 매우 안좋은 상황이 발생하게 됩니다. 그건 바로, layer가 많아질수록 output matrix가 query에 상관없이 모든 row가 동일해지는, 즉 rank 1 matrix로 수렴하게 됩니다. 이는 당연히 모델의 expressibility가 크게 줄어드는것이고 성능도 매우 나빠질 것을 암시합니다. 이를 이론적으로 설명하는데, 이에 대해 좀 더 자세히 설명하겠습니다.

 

먼저, H개의 head로 이루어진 multi-head (self-)attention layer가 L개가 쌓여있을 때, input X에 대한 output SA(X)는 아래와 같이 쓸 수 있습니다.

여기서 P_h는 attention weight matrix로, row의 합이 1인 row-wise stochastic matrix입니다.

그렇다면, bias를 (잠깐) 무시했을 때, 여러개의 attention layer를 통과한 output은 다음과 같이 쓸 수 있습니다.

근데, row-wise stochastic matrix들을 곱하면 결과도 역시 row-wise stochastic이기 때문에 (연습문제!!) 최종 output을 다음과 같이 쓸 수 있습니다.

여기서 path란 아래의 그림에 잘 나타나 있는데, 각 layer별 몇번째 head를 지나는지 그 tuple이라고 보면 됩니다. 즉, 오로지 attention layer로만 이루어진 network는 아래 그림과 같은 여러 가능한 조합의 path들의 결과를 합한 path decomposition으로써 나타낼 수 있다는 것입니다.

Path decomposition. Dong et al. 2021

이제 논문의 부제인 pure attention losses rank doubly exponentially with depth라는 말을 이해할 준비가 거의 다 되었습니다. 먼저, 어떤 행렬이 (row가 모두 동일한) rank 1 행렬에 얼마나 가까운지를 나타내는 척도를 다음과 같은 residual이라는 값으로 정의합니다.

여기서 norm은 (1, infinity)-norm, 즉 L^1-norm과 L^infinity-norm의 기하평균을 사용하고, res(X)값이 매우 작다는 것은 X가 1x^T, 즉 모든 row vector가 x인 행렬에 매우 가깝다는 것을 의미합니다. 이렇게 residual을 정의하게 되면 논문의 핵심인 다음 부등식을 이해햘 수 있습니다.

여기서는 head가 하나인, single-head attention에서의 경우를 나타낸 것이며 beta는 아래 식과 같이 query, key, value embedding의 weight을 bound하는 상수입니다.

위의 식에서, beta의 값이 충분히 작아서 괄호 안의 값이 1보다 작아지게 되면 L이 커질수록 곱해지는 값은 굉장히 빠르게 0으로 수렴하게 됩니다. 여기서의 굉장히 빠른 정도를 doubly exponential이라고 지칭합니다 (exponential에 exponential이 있기 때문입니다). 일반적으로, head가 하나가 아닌 여러개인 경우에는

의 식으로 일반화가 됩니다. 증명은 L=1인 경우만 보인 뒤 귀납적으로 부등식을 적용하면 되는데, L=1인 경우는 R = res(X)에 대해 P_h를

로 쓸 수 있고, 여기서 softmax 안에 있는 첫번째 항을 Holder inequality 등을 이용해서 bound를 잘 시켜서 증명을 합니다. 이 항이 작은 경우 P_h는

가 되어 rank 1에 가깝다는것을 보일 수 있게 됩니다.

 

그렇다면, transformer의 다른 component들은 이러한 rank collapse 관접에서 어떤 역할을 하고있을까요? 먼저, 가장 중요한 부분은 Skip connection입니다. 이 부분은 input을 그대로 output에 더해주는 역할을 하는데, 위의 path decomposition에서 skip connection이 추가가 되면 layer가 아무리 많아도 모든 layer를 skip하는, 즉 input을 보존하는 path가 존재합니다. 그렇기 때문에, 직관적으로는 attention layer를 지나면서 rank가 1에 가까워진다고 해도 결국 원래 input값이 최종 output에 더해짐으로써 rank가 collapse되는 것을 막아줍니다. 실제로 다음이 성립합니다.

그 다음, Feed-Forward Network는 무슨 역할을 할까요? 결론부터 말하자면, Feed-Forward Network는 (l^1,\infty norm에 대해서) Lipschitz가 되는데, 이때의 Lipschitz constant를 lambda라고 하면 위의 residual에 대한 부등식에 아래와 같이 lambda가 추가가 됩니다.

즉, lambda가 커질수록 rank가 collapse하는 속도가 느려지는 효과가 생깁니다. 

 

마지막으로, Layer Normalization은 rank를 적어도 증가시키지는 않는데, 그 이유는 위의 path decompositon 식에서 W와 b만 바꾸는 역할을 하기 때문입니다. (라고 되어있는데...새로 바뀐 W와 b는 input dependent하기 때문에 단순히 이렇게 얘기하도 되는건진 모르겠습니다.) 참고로 행렬을 곱하는것은 rank를 증가시키지 않습니다.

 

논문에서는 3가지 실험을 통해서 위의 주장들을 검증을 합니다. 먼저 첫번째 실험은 실제로 BERT, ALBERT, XLNet 3가지의 transformer 기반 모델들을 pure-attention based architecture로 바꾸었을 때, 혹은 여기에 skip connection이나 MLP를 추가했을 때 rank degeneration이 실제로 일어나는지를 확인합니다. 아래 그림에는 relative residual (self attention layer의 output의 residual의 norm과 output의 norm의 비율, 이 비율이 작을수록 rank가 작은 matrix에 가깝다고 할 수 있습니다)를 layer별로 나타나 있습니다. 그림에서 볼 수 있듯이 self attention만 사용하게 되면 relative residual이 위쪽 layer로 갈 수록 급격히 감소하는 반면, 기존의 transformer 구조나 attention + skip connection만으로 이루어진 구조에서는 그런 현상이 덜함을 알 수 있습니다. (그럼에도 불구하고 relative residual이 점점 감소하는 경향이 나타나긴 합니다.)

Relative residual of BERT, ALBERT, XLNet for each layer, Dong et al. 2021.

두번째 실험은 각 component의 inductive bias를 알기 위한 실험으로, 1개의 self attention + skip connection layer를 recursive하게 여러번 통과시킨, Universal Transformer와 같은 형태의 모델을 이용합니다. Task는 원 위의 반시계 방향으로 향하는 두개의 반원을 예측하는 synthetic task인데, 모델은 autoregressive하게 이전의 (예측된) 점들의 좌표를 바탕으로 다음 점의 좌표를 예측합니다. 

아래 Figure3의 실험 결과를 보면, MLP나 skip connection 없이 self attention만 이용했을 때 모델의 dimension이 작으면 rank collapse가 일어나서 starting point에 상관없이 같은 점으로 수렴하지만, dimension 128에서는 그런 현상이 발생하지 않는 것을 볼 수 있습니다. 이는 위에서 언급한 논문의 핵심적인 부등식에 부합하는 결과인데, 차원이 커질수록 beta값이 커진다고 생각할 수 있기 때문에 rank collapse하는 속도(정도?)역시 느려진다고 볼 수 있고, 두번째 행의 결과는 MLP(Feed-Forward Network)가 추가되면 rank collapse가 늦춰진다는 주장을 뒷받침합니다. 또한, skip connection이 추가되면 점이 움직이는 정도(?)가 줄어드는것을 볼 수 있는데 (dimension 32일때 특히 두드러지게 나타납니다), 이는 skip connection 자체가 input을 그대로 더하기 때문이라고 생각할 수 있습니다.

Visualizing bias of different architectures, Dong et al. 2021.

 

마지막 세번째 실험은 path effectiveness에 대한 실험입니다. 위에서 기존의 여러 layer의 multi-head self attention을 path들로 쪼갬으로써 일종의 shallow network의 ensemble같은 관점에서 볼 수 있다고 했는데, 실제로 이 관점에서 바라봤을 때 layer의 수 (path의 길이)와 path의 갯수에 따른 self attention network의 효과가 어느정도인지를 확인하는 실험입니다. Task는 다음과 같은 3개의 synthetic task를 가지고 실험을 했습니다.

  1. Sequence memoization: natural language token -> random binary label의 형태의 데이터를 "외우는" task이며, training에서 얼마나 학습이 잘 되는지의 여부만 확인한다고 보면 됩니다. (training accuracy가 얼마나 높아질 수 있는지). 
  2. Learning to sort: 주어진 문자열을 알파벳순으로 정렬하는 task입니다. 길이는 8이고 알파벳은 10개만 사용했습니다.
  3. Convex hull prediction: [0, 1] x [0, 1]의 영역에서 랜덤으로 뽑은 10000개의 점들을 input으로 받아서 각 점이 이 점들의 convex hull에 포함되는지 여부를 예측하는 task입니다.

아래 그림(Figure 4)에 실험 결과가 나타나 있습니다. 먼저, 논문에서 이야기한대로 path length가 길어질수록 skip connection이 있음에도 불구하고 성능이 낮아지는 경향이 나타납니다 (이는 아래 figure에서 실험한 모델 크기와 상관없이 항상 같은 경향을 보였다고 합니다). 또한, 전체 모델에서 특정 몇개의 path만 임의로 골라서 이들의 output의 average를 바탕으로 예측을 했을 때, path의 숫자가 늘어날수록 정확도가 높아집니다. (5 paths < 20 paths)

Path effectiveness for 3 synthetic tasks, Dong et al. 2021.

논문에 대한 개인적인 소감(?)은, 논문의 임팩트 자체가 크지 않다는 생각이 듭니다. 결국에는 self attention만 사용한 transformer는 성능이 좋지 않다는것인데, 애초에 transformer 기반 모델 중에서 self attention만 사용하던 모델을 본 적이 없고 (누군가 해봤는데 이 논문에서처럼 잘 안되었기 때문에 언급이 되지 않았을 수도 있습니다) 실험도 너무 synthetic한 task 위주로 진행된게 아닌가 싶습니다. 또한, layer normalization은 rank collapse에 영향을 끼치지 않는다고 하지만, 실제로 몇몇 transformer 관련 논문들에서 언급하듯이 layer normalization의 순서(skip connection과 layer normalization중 무엇을 먼저 할 것인지 등...)가 성능에 꽤 큰 영향을 미친다고 알려져 있는데 이러한 관점에서의 이론적인 분석은 없는 것이 아쉬웠습니다. (물론, 논문의 주제 자체가 rank collapse에 초점이 맞추어져 있기 때문에 없는게 당연하긴 합니다.) 그래도 "하면 안좋은 것"에 대한 결과를 제시했다는 점에서 어느정도 중요한 논문이라고 생각이 듭니다. (대부분의 논문들이 잘 되는것만 이야기하고 잘 안되는 것은 언급조차 하지 않는 경우가 대부분이라서 이런 종류의 논문도 많이 있어야 한다고 생각합니다.)