#下面这行计算的是Q*KT,并且加上绝对位置编码
score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head)
score = score + position_bias
# 下面这行代码是加上了上三角为负无穷大的attentionmask,也就是单向注意力的由来
score = torch.masked_fill(
attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
score = self.softmax(score)
score = torch.masked_fill(
attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
# 以下这行是计算softmax(q*kT)*v的结果,socre=softmax(q*kT
# (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head)
score = torch.matmul(score, value)