Commit b970c7df by xuchen

up-sampling the representation for ctc calculation

parent 8b50c392
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
arch: transformer
arch: transformer_ctc
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
......
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 10,20
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
......@@ -107,6 +107,7 @@ class AudioDataset(Dataset):
for idx, u in enumerate(utterances):
segments[idx][_lang] = u
# split = split.replace("_gen", "")
# Gather info
self.data = dict()
if self.mode == "easy":
......
......@@ -156,9 +156,11 @@ class TextEncoder(FairseqEncoder):
super().__init__(None)
self.register_buffer("version", torch.Tensor([3])) # for consistent
embed_dim = args.encoder_embed_dim
layer_num = args.text_encoder_layers
self.layer_num = layer_num
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
if args.no_scale_embedding:
......
......@@ -672,6 +672,14 @@ class TransformerCTCEncoder(FairseqEncoder):
return_all_hiddens,
token_embeddings)
def upsample(self, x, ratio=2):
if ratio <= 1:
return x
seq_len, bsz, dim = x.size()
x = x.unsqueeze(0).expand(ratio, -1, -1, -1).reshape(-1, bsz, dim)
return x
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
......@@ -749,7 +757,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone())
ctc_logit = self.ctc(self.upsample(x.clone()))
# Intermedia CTC
if layer_idx in self.intermedia_ctc_layers:
......@@ -759,11 +767,15 @@ class TransformerCTCEncoder(FairseqEncoder):
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
up_x = self.upsample(norm_x)
up_logit = self.ctc(up_x)
intermedia_ctc_logits.append(logit)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
intermedia_ctc_logits.append(up_logit)
up_prob = utils.softmax(up_logit / self.intermedia_temperature, dim=-1)
up_prob = up_prob.permute(1, 2, 0)
prob = nn.functional.max_pool1d(up_prob, kernel_size=2, stride=2)
prob = prob.permute(2, 0, 1)
x, _ = self.adapter([x, prob])
if self.history is not None:
self.history.push(x)
......@@ -775,7 +787,12 @@ class TransformerCTCEncoder(FairseqEncoder):
x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x)
ctc_logit = self.ctc(self.upsample(x))
ctc_padding_mask = encoder_padding_mask
if ctc_logit is not None or len(intermedia_ctc_logits) != 0:
bsz = encoder_padding_mask.size(0)
ctc_padding_mask = encoder_padding_mask.unsqueeze(-1).expand(-1, -1, 2).reshape(bsz, -1)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
......@@ -784,6 +801,7 @@ class TransformerCTCEncoder(FairseqEncoder):
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"ctc_padding_mask": [ctc_padding_mask],
"intermedia_ctc_logits": intermedia_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
......
......@@ -95,14 +95,13 @@ class Adapter(nn.Module):
if self.distribution_cutoff is not None:
logger.info("Distribution cutoff: %d" % int(strategy))
def forward(self, x, padding):
def forward(self, x, padding=None):
representation, distribution = x
distribution = distribution.type_as(representation)
seq_len, bsz, dim = representation.size()
org_distribution = distribution
distribution = distribution.view(-1, distribution.size(-1))
lengths = (~padding).long().sum(-1)
distribution = distribution.contiguous().view(-1, distribution.size(-1))
if self.adapter_type == "linear":
out = self.linear_adapter(representation)
......@@ -140,6 +139,7 @@ class Adapter(nn.Module):
elif self.adapter_type == "shrink":
from itertools import groupby
lengths = (~padding).long().sum(-1)
with torch.no_grad():
batch_predicted = []
prob_ctc = org_distribution.transpose(0, 1) # T x B x D -> B x T x D
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论