*본 포스팅은 자연어처리 스터디를 진행하던 중 생겼던 궁금증을 해결한 후 내용 정리를 하고자 작성하였습니다.
1. Multi-head Attention이란?
트랜스포머 원 논문(Attention is all you need)에서는 한 번의 어텐션을 하는 것보다 여러 번의 어텐션을 병렬로 사용하는 것이 더 효과적이라고 한다.
그래서 d_model의 차원을 num_heads개로 나누어 d_model/num_heads의 차원을 가지는 Q, K, V에 대해서 num_heads개의 병렬 어텐션을 수행한다.
논문에서는 하이퍼파라미터인 num_heads의 값을 8로 지정하였고, 8개의 병렬 어텐션이 이루어지게 된다.
다시 말해 위에서 설명한 어텐션이 8개로 병렬로 이루어지게 되는데, 이때 각각의 어텐션 값 행렬을 어텐션 헤드라고 부른다.
이때 가중치 행렬 $W^q$, $W^k$, $W^v$의 값은 8개의 어텐션 헤드마다 전부 다르다.
2. Multi-head Attention의 목적
왜 어텐션 헤드가 하나 이상 필요할까?
-> 한 헤드의 소프트맥스가 유사도의 한 측면에만 초점을 맞추는 경향이 있기 때문이다.
또한 입력 벡터가 굉장히 크면 softmax 과정에서 큰 값들이 반영이 잘 안된다. 따라서 구간을 나누어 각각을 softmax 해준다.
여러 개의 헤드가 있으면 모델은 동시에 여러 측면에 초점을 맞춘다.
예를 들어 한 헤드는 주어-동사 상호작용에 초점을 맞추고, 다른 헤드는 인접한 형용사를 찾는 식이다.
이러한 관계는 모델에 수동으로 입력되지 않고 모델이 직접 데이터에서 학습한다.
(Computer Vision에서 합성곱 신경망의 필터와 유사하다.)
3. 그래서 Multi-head Attention을 통해서 어떻게 다른 관점의 정보를 반영하나?
여러 개의 헤드를 통해 다양한 측면의 정보를 받아들이는 건 이해를 했다.
그럼 이를 어떻게 구현하나?
이를 설명하기 위해 Multi-head Attention을 직접 구현하는 코드를 작성해보자.
class AttentionHead(nn.Module):
def __init__(self, embed_dim, head_dim):
super().__init__()
self.q = nn.Linear(embed_dim, head_dim)
self.k = nn.Linear(embed_dim, head_dim)
self.v = nn.Linear(embed_dim, head_dim)
def forward(self, hidden_state):
attn_outputs = scaled_dot_product_attention(
self.q(hidden_state), self.k(hidden_state), self.v(hidden_state))
return attn_outputs
세 개의 독립된 선형 층을 만들었다.
head_dim은 투영하려는 차원의 크기이다.
head_dim이 토큰의 임베딩 차원인 embed_dim 보다 더 작을 필요는 없지만, 실전에서는 헤드마다 계산이 일정하도록 embed_dim과 배수가 되게 설정한다.
어텐션 헤드를 준비했으니 각 헤드의 출력을 연결해서 완전한 Multi-head Attention 층을 만들어보자.
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config.hidden_size # 768
num_heads = config.num_attention_heads # 12
head_dim = embed_dim // num_heads
self.heads = nn.ModuleList(
[AttentionHead(embed_dim, head_dim) for _ in range(num_heads)]
)
self.output_linear = nn.Linear(embed_dim, embed_dim)
def forward(self, hidden_state):
x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
x = self.output_linear(x)
return x
어텐션 헤드의 출력을 연결한 다음 최종 선형 층으로 주입해서 [batch_size, seq_len, hidden_dim] 크기의 출력 텐서를 만든다.
(이 형태는 이후 활용할 피드 포워드 신경망에 사용하기 적절하다.)
그럼 위의 코드들을 통해 어떻게 여러 개의 헤드들을 반영할 수 있는지 확인해 보자.
먼저 MultiHeadAttention 클래스의 self.heads를 보면 헤드의 개수만큼 AttentionHead 클래스를 담고 있는 것을 알 수 있다.
이 AttentionHead 클래스로 거슬러 올라가 보면 self.q, self.k, self.v 가 nn.Linear 함수로 구성되어 있는 것을 확인할 수 있는데,
이 Linear 함수에 대해 개념이 정립되지 않은 분들은 다음 포스팅을 참고하시길 바란다.
https://seungseop.tistory.com/28
다시 돌아와서, 이 Linear() 함수는 weight와 가중치가 갱신이 된다.
따라서 갱신된 가중치로 다시 q, k, v를 구성하면 각 헤드마다 다른 q,k,v를 가지게 되는 것이다.
그렇기 때문에 결과적으로 Multi-head Attention을 통해 다른 관점의 정보들을 받을 수 있다.