Commit 976237ec by xuchen

dump

parent 2e6223d1
......@@ -438,40 +438,33 @@ class MultiheadAttention(nn.Module):
localness = 0
window = int(src_len * self.localness_window)
# print(src_len)
# print(window)
# for i in range(window, src_len - window):
# item_localness = 0
# for j in range(-window, window + 1):
# # if j == 0:
# # continue
# item_localness += weights[:, :, i, i + j]
# localness += item_localness
for i in range(bsz):
sum_num = 0
for i in range(window, src_len - window):
item_localness = 0
# print(weights[i, :, :])
for j in range(window, src_len - window):
if key_padding_mask is not None and key_padding_mask[i, j] == True:
continue
unit_localness = 0
for k in range(-window, window + 1):
unit_localness += weights[i, j, j + k]
# print(j)
# print(unit_localness)
item_localness += unit_localness
sum_num += 1
# exit()
if sum_num > 0:
localness += item_localness / sum_num
for j in range(-window, window + 1):
item_localness += weights[:, i, i + j]
localness += item_localness
localness = localness / bsz
# for i in range(bsz):
# sum_num = 0
# item_localness = 0
# for j in range(window, src_len - window):
# if key_padding_mask is not None and key_padding_mask[i, j] == True:
# continue
# unit_localness = 0
# for k in range(-window, window + 1):
# unit_localness += weights[i, j, j + k]
# item_localness += unit_localness
# sum_num += 1
# if sum_num > 0:
# localness += item_localness / sum_num
# localness = localness / bsz
if self.localness_num == 0:
self.localness = localness.mean()
else:
self.localness = (self.localness * self.localness_num + localness.mean()) / (self.localness_num + 1)
# print(self.localness)
self.localness_num += 1
def cal_entropy_func(self, attn_weights_float, bsz, src_len, tgt_len):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论