(이 글은 Attention, transformer, 텍스트 토큰 전처리에 대한 사전 지식을 요구합니다.)
Positional Embedding이란?
글을 읽을 때 왼쪽에서 오른쪽으로 읽듯이, 텍스트를 처리함에 있어 순서는 글을 이해하는데 중요한 정보입니다. 이는 인공지능에게도 마찬가지입니다. 그래서 과거에는 텍스트를 앞에서부터 순차적으로 처리하는 RNN 기반의 레이어를 사용해 언어 모델을 구현했었습니다.
하지만 RNN 레이어는 텍스트 시퀀스를 병렬적으로 처리할 수 없어 문장이 길어질수록 속도가 느려지는 단점이 있었고, 이로 인해 대량의 데이터를 학습하기 어려웠습니다. 그래서 텍스트를 병렬적으로 처리할 수 있는 attention 레이어 기반의 transformer 구조가 LLM(Large Language Model)에 사용되기 시작했습니다.
Attention 레이어는 텍스트 시퀀스를 병렬적으로 처리할 수 있지만, RNN과 같이 자연스럽게 순서 정보를 익힐 방법이 없습니다. 따라서 모델을 학습시킬 때 직접 위치 정보를 따로 입력해 줘야 합니다. 이 때 사용되는 것이 'Positional Embedding'입니다. 오늘은 Positional Embedding의 종류들에 대해서 간단히 알아보려고 합니다.
1. Absolute Positional Embedding
2. Relative Positional Embedding
3. Rotary Positional Embedding
4. 정리
1. Absolute Positional Embedding
Absolute Positional Embedding은 각 텍스트 토큰의 위치 정보를 절대적인 고유값으로 표현하는 방법입니다. 예를 들면 문장의 앞에서부터 숫자를 1, 2, 3... 과 같이 할당하는 것이죠.
Absolute positional embedding을 구현하는 방법은 2가지가 있습니다.
1. Sinusodial function
Sinusoidal function은 위치 토큰 벡터들을 sin과 cos 함수를 이용해 나타내는 방법입니다.
$$ \left\{\begin{matrix}p_{i,2t}=sin(i/10000^{2t/d}) \\ p_{i,2t+1}=cos(i/10000^{2t/d})\end{matrix}\right.$$
def sinusoidal_positional_embedding(seq_len, dim):
"""
Sinusoidal positional encoding을 생성하는 함수
"""
position = torch.arange(seq_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))
pe = torch.zeros(seq_len, dim)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
pe = sinusoidal_positional_embedding(128, 256)
Sinusoidal function으로 생성된 positional embedding을 시각화 해보면 아래와 같습니다.
위와 같이 생성된 positional embedding을 토큰 벡터에 더해 위치 정보를 모델에 입력하게 됩니다.
# pseudo code
input_emb = embedding() # 토큰 임베딩
position_emb = sinusoidal_positional_embedding() # 위치 임베딩
input_emb = input_emb + position_emb # 두 임베딩을 더한 뒤,
out = attention(input_emb) # attention 레이어에 입력
2. 학습 가능한 파라미터를 사용
다른 방법은 sinusoidal function 대신 학습 가능한 파라미터를 이용해 위치 정보를 나타내는 것입니다. 임베딩 레이어를 이용해 아래와 같이 구현할 수 있습니다.
import torch.nn as nn
token_embedding = nn.Embedding(512)
position_embedding = nn.Embedding(512)
tokens = [37, 26, 817, 92, 1]
positions = [1, 2, 3, 4, 5]
token_emb = token_embedding(tokens)
pos_emb = position_embedding(positions)
input_emb = token_emb + pos_emb
out = attention(input_emb)
Absolute Positional Embedding의 단점
- 각 위치마다 고정된 값을 사용하기 때문에 학습 시에 사용된 문장 길이보다 긴 문장을 입력하려 할 경우 추가 학습이 필요할 수 있다.
예를 들어 512 길이의 문장들로 학습을 한 모델의 경우, 512번째 토큰까지는 학습된 임베딩 벡터를 사용할 수 있지만, 513번째 토큰부터는 임베딩 벡터가 학습되어 있지 않을 것입니다. 따라서 모델이 제대로 작동하지 않을 수 있습니다.
- 단어의 위치는 모델에 입력되지만 단어 사이의 거리는 고려되지 않는다.
대체로 문장이 길어질수록 앞쪽의 단어들과 뒤쪽의 단어들은 서로 관련이 적어질 것입니다. 하지만 absolute positional embedding은 학습할 때 이런 상대적인 거리는 고려되지 않기 때문에 이런 부분에 대한 고려가 부족할 수 있습니다.
2. Relative Positional Embedding
Relative Positional Embedding은 토큰 간의 '상대적인' 거리 정보를 임베딩 벡터로 입력하는 방법입니다. 단순히 앞에서부터 1, 2, 3...과 같은 고정값을 부여하는 것이 아니라, 각 단어 토큰마다 다른 단어 토큰과의 상대적인 위치 정보를 행렬 형태로 표현합니다.
Relative Positional Embedding을 적용하는 방법은 다양합니다. 대표적으로 T5(Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer, 5페이지)에서는 아래와 같이 attention score에 relative attention bias를 더하는 식으로 계산됩니다.
$$\text{Attention}(Q,K,V)=\text{softmax}({QK^T\over\sqrt{d_k}}+B)V$$
$B$는 아래와 같이 구현할 수 있습니다.
1. relative position id 계산하기
seq_len=128
position_ids = torch.arange(seq_len)
relative_position = position_ids.unsqueeze(0) - position_ids.unsqueeze(1)
'''
[[0, 1, 2, 3 ... , 127]
...
[-127, -126... , 1, 0]]
'''
2. relative position을 bucket 단위로 변환
def compute_relative_position_buckets(relative_position, num_buckets=32, max_distance=128):
"""
상대적 위치 정보를 버킷(bucket)으로 변환하는 함수.
가까운 거리(작은 상대 위치)는 정밀하게 구분하고, 먼 거리(큰 상대 위치)는 대략적으로 표현함.
"""
sign = (relative_position < 0).long() # 음수 여부 (왼쪽에 있는 단어인지 체크)
relative_position = torch.abs(relative_position)
# 작은 거리(근처 토큰)는 정밀하게 표현
min_range = (relative_position < num_buckets // 2).long() * relative_position
# 먼 거리는 로그 스케일로 압축
max_range = (relative_position >= num_buckets // 2).long() * (
num_buckets // 2 +
(torch.log(relative_position.float() / (num_buckets // 2)) /
torch.log(torch.tensor(max_distance / (num_buckets // 2))))
* (num_buckets - num_buckets // 2)
).long()
# 최종 버킷 ID 결정
buckets = min_range + max_range
buckets = torch.clamp(buckets, 0, num_buckets - 1)
return buckets * (1 - 2 * sign) # 방향 고려
# -127~127의 상대 위치를 -31~31 범위로 표현
num_buckets = 32
relative_bias = compute_relative_position_buckets(relative_position, num_buckets, seq_len)
'''
[[0, 1, 2, ..., 31, 31],
...,
[-31, -31, ..., -1, 0]]
'''
이를 통해 중요도가 높은 가까운 거리는 정밀하게 구분하고, 중요도가 떨어지는 먼 거리는 대략적으로 구분합니다.
3. Embedding 레이어 적용
pos_emb = nn.Embedding(num_buckets*2, n_heads)
relative_position_buckets += num_buckets # -31~31을 1~63 범위로 변환
relative_bias = pos_emb(relative_position_buckets)
relative_bias = relative_bias.permute(2, 0, 1) # (n_heads, seq_len, seq_len)
4. attention score에 더함.
attn_scores = (Q @ K.transpose(-2, -1)) / (d_model ** 0.5) # Query와 Key로 attn_score 계산.
attn_scores += relative_bias # relative bias를 더해줌.
attn_weights = torch.softmax(attn_scores, dim=-1)
output = attn_weights @ V
Relative Positional Embedding의 단점
- 계산이 오래걸림
Attention 연산을 할 때마다 positional embedding을 계산해야 하기 때문에 속도가 느려집니다. 속도를 개선하기 위해 positional embedding 값을 저장하여 재사용하기도 하지만 이 경우엔 메모리에 부하가 걸릴 수도 있습니다.
3. Rotary Positional Embedding (RoPE)
Rotary Positional Embedding도 relative positional embedding에 속하는 방법이라 볼 수 있지만, 좀 더 개선된 버전으로 볼 수 있습니다. RoPE라고도 불리는 이 방식은, 벡터의 회전 변환을 통해 상대적인 위치를 나타내는 방식을 사용합니다.
$$R_{\theta_p}\begin{bmatrix}
x_{2i} \\ x_{2i+1}
\end{bmatrix}=\begin{bmatrix}
cos(p\theta) & -sin(p\theta) \\
sin(p\theta) & cos(p\theta) \\
\end{bmatrix}\begin{bmatrix}
x_{2i} \\ x_{2i+1}
\end{bmatrix}$$
어떻게 보면 앞서 봤던 sinusoidal function 방식과 유사하지만, 적용 방식이 다릅니다. Sinusoidal function은 '고정된 위치 값'을 sinusoidal function으로 변환하여 단순히 더해주는 방식이었다면, RoPE는 attention 연산에서 query와 key에 각각 회전 변환(sinusoidal function)을 수행하여 단어 토큰 간의 상대적인 위치 차이를 나타냅니다.
이 방법의 장점은 추가적인 파라미터가 필요 없다는 것입니다. 따라서 앞의 relative positional embedding보다 메모리도 절약하면서 학습 속도도 빠릅니다. 이런 장점들로 최신 LLM들(GPT-3.5, LLama 등)에 사용되고 있습니다.
구현은 아래와 같이 가능합니다.
1. sinusoidal 주기 계산
여기선 sin, cos 함수 대신에 복소수를 통해 회전 변환 행렬을 계산합니다. (회전 변환은 복소수를 통해 나타낼 수 있기 때문)
def precompute_freqs_cis(dim, seqlen):
base = 10000.0
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
t = torch.arange(seqlen)
# torch.outer : 't'와 'freqs'의 외적 계산(outer product)
# freqs : sinusoidal 형태의 주기.
freqs = torch.outer(t, freqs)
'''
torch.polar(abs, angle) : Constructs a complex tensor whose elements are Cartesian coordinates
corresponding to the polar coordinates with absolute value abs and angle angle.
'''
# freqs_cis : freqs를 복소수 형태로 변환.
freqs_cis = torch.polar(abs=torch.ones_like(freqs), angle=freqs)
return freqs_cis
이 'freqs_cis'를 그려보면 아래와 같은 그림을 모습을 확인할 수 있습니다. 앞서 본 sinusoidal 함수와 유사한 형태인 걸 볼 수 있습니다.
2. Query와 Key에 각각 freqs_cis를 곱해 회전 변환을 수행합니다. 그 뒤 attention score 계산을 수행합니다.
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""
Applies rotary positional embeddings to the input tensor.
Args:
x (torch.Tensor): Input tensor with positional embeddings to be applied.
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
Returns:
torch.Tensor: Tensor with rotary embeddings applied.
"""
# 복소수 행렬곱 연산을 위해 입력 벡터도 복소수로 변환
x = x.float().view(*x.shape[:-1], -1, 2)
x = torch.view_as_complex(x)
# 행렬 연산을 위해 shape 변형
freqs_cis = freqs_cis.repeat(16, 1, 1).transpose(1, 2)
# query/key에 회전 변환을 적용.
y = torch.bmm(x, freqs_cis)
# query/key를 다시 실수로 변환 후 반환.
y = torch.view_as_real(y).flatten(2)
return y
Q = apply_rotary_emb(Q, freqs_cis)
K = apply_rotary_emb(K, freqs_cis)
4. 정리
최종적으로 Absolute, relative, rotary positional embedding의 내용을 정리하고 마무리 하겠습니다.
Absolute Positional Embedding
- 위치 정보를 고정된 값을 사용해 표현.
- Sinusoidal, 학습 가능한 벡터 형태로 구현 가능.
- 모델의 입력 벡터에 더해주는 식으로 적용.
- 단점 : 단어 토큰 사이의 상대적인 위치 정보를 고려하지 못함.
Relative Positional Embedding
- 위치 정보를 상대적인 값을 이용해 표현.
- Attention 레이어에 상대적 위치(relative position)값을 더해주는 식으로 적용.
- 단점 : 계산이 오래 걸리고 메모리 소모도 커짐.
Rotary Positional Embedding (RoPE)
- 상대적인 위치 정보를 '회전 변환'을 통해 표현.
- Attention 레이어의 query와 key에 회전 변환을 적용해 위치 정보를 입력할 수 있음.
- 추가적인 파라미터가 필요 없어 메모리 소모를 줄이고 계산 속도를 향상 시킬 수 있음.
- 사전에 회전 변환 행렬을 계산해 둠으로써 계산 속도를 향상시킬 수 있음.
'딥러닝 관련 이것저것' 카테고리의 다른 글
음성 전처리 관련 지식 총정리 (0) | 2023.10.18 |
---|