Commit a3c033d9 by xuchen

modify the implementation of the relative position encoding again

parent dd402ec2
......@@ -2,29 +2,33 @@
# 简要说明
Fairseq_ST基于原始的Fairseq,仅做了少量修改增加易用性以及对语音翻译任务的适配。
Fairseq_ST基于原始的Fairseq,提高了程序易用性以及对语音到文本任务的适配。
目前支持功能:
- 针对每个数据集创建egs文件夹保存运行脚本
- 针对每个数据集创建egs文件夹保存运行脚本,目前包括LibriSpeech语音识别数据集和MuST-C语音翻译数据集
- 通过读取yaml配置文件进行训练
- 支持ctc多任务学习
后续目标:
- 输入层CNN结构调整
- 相对位置表示
- Conformer模型结构
- 预训练模型加载
- SATE模型结构
此外,语音翻译任务需要对每个任务预先下载好原始数据,除了已经提供的数据集,如LibriSpeech和MuST-C外,其他数据集需要额外编写代码进行处理,参考examples/speech_to_text路径下的处理文件。
# 需求条件
1. Python ≥3.6
2. torch ≥ 1.4, torchaudio ≥ 0.4.0, cuda ≥ 10.1
3. apex
```
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
4. nccl
```
make -j src.build CUDA_HOME=<path to cuda install>
```
5. gcc ≥ 4.9
6. python包 pandas sentencepiece configargparse gpustat tensorboard editdistance
......@@ -54,6 +58,8 @@ st文件夹下包含了数据文件夹data和代码文件夹fairseq,tools文
# 代码结构
此外,语音翻译任务需要对每个任务预先下载好原始数据,除了已经提供的数据集,如LibriSpeech和MuST-C外,其他数据集需要额外编写代码进行处理,参考examples/speech_to_text路径下的处理文件。
运行脚本存放于fairseq根目录下的egs文件夹,针对每个数据集分别建立了不同的文件夹来执行操作,目前包括语音识别数据集LibriSpeech以及语音翻译数据集MuST-C的执行脚本。
以librispeech文件夹举例,其中包含以下文件:
......
......@@ -347,7 +347,6 @@ class S2TTransformerEncoder(FairseqEncoder):
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
if self.attn_type != "rel_selfattn":
x += positions
# x += positions
x = self.dropout_module(x)
for layer in self.transformer_layers:
......
......@@ -853,7 +853,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
if positions is not None and self.attn_type != "rel_selfattn":
x += positions
if self.layernorm_embedding is not None:
......
......@@ -55,7 +55,7 @@ class RelPositionMultiheadAttention(MultiheadAttention):
# linear transformation for positional encoding
self.linear_pos = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
nn.Linear(embed_dim, embed_dim, bias=False), q_noise, qn_block_size
)
# these two learnable bias are used in matrix c and matrix d
......@@ -312,25 +312,36 @@ class RelPositionMultiheadAttention(MultiheadAttention):
def rel_shift(x, zero_triu=False):
"""Compute relative positional encoding.
:param torch.Tensor x: (batch, time, size)
:param bool zero_triu: return the lower triangular part of the matrix
Args:
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
# zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
# x_padded = torch.cat([zero_pad, x], dim=-1)
#
# x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
# x = x_padded[:, :, 1:].view_as(x)[
# :, :, :, : x.size(-1) // 2 + 1
# ] # only keep the positions from 0 to time2
if zero_triu:
ones = torch.ones((x.size(2), x.size(3)))
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
matrix_bd = matrix_bd.contiguous().view(bsz, self.num_heads, matrix_bd.size(-2), matrix_bd.size(-1))
matrix_bd = rel_shift(
matrix_bd,
).contiguous().view(bsz * self.num_heads, matrix_bd.size(-2), matrix_bd.size(-1))
# matrix_bd = matrix_bd.contiguous().view(bsz, self.num_heads, matrix_bd.size(-2), matrix_bd.size(-1))
# matrix_bd = rel_shift(
# matrix_bd,
# ).contiguous().view(bsz * self.num_heads, matrix_bd.size(-2), matrix_bd.size(-1))
attn_weights = (matrix_ac + matrix_bd) * self.scaling
# attn_weights = torch.bmm(q, k.transpose(1, 2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论