Commit b970c7df by xuchen

up-sampling the representation for ctc calculation

parent 8b50c392
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
arch: transformer arch: transformer_ctc
share-all-embeddings: True share-all-embeddings: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
......
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 10,20
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
...@@ -107,6 +107,7 @@ class AudioDataset(Dataset): ...@@ -107,6 +107,7 @@ class AudioDataset(Dataset):
for idx, u in enumerate(utterances): for idx, u in enumerate(utterances):
segments[idx][_lang] = u segments[idx][_lang] = u
# split = split.replace("_gen", "")
# Gather info # Gather info
self.data = dict() self.data = dict()
if self.mode == "easy": if self.mode == "easy":
......
...@@ -156,9 +156,11 @@ class TextEncoder(FairseqEncoder): ...@@ -156,9 +156,11 @@ class TextEncoder(FairseqEncoder):
super().__init__(None) super().__init__(None)
self.register_buffer("version", torch.Tensor([3])) # for consistent
embed_dim = args.encoder_embed_dim embed_dim = args.encoder_embed_dim
layer_num = args.text_encoder_layers layer_num = args.text_encoder_layers
self.layer_num = layer_num self.layer_num = layer_num
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim)
if args.no_scale_embedding: if args.no_scale_embedding:
......
...@@ -672,6 +672,14 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -672,6 +672,14 @@ class TransformerCTCEncoder(FairseqEncoder):
return_all_hiddens, return_all_hiddens,
token_embeddings) token_embeddings)
def upsample(self, x, ratio=2):
if ratio <= 1:
return x
seq_len, bsz, dim = x.size()
x = x.unsqueeze(0).expand(ratio, -1, -1, -1).reshape(-1, bsz, dim)
return x
# TorchScript doesn't support super() method so that the scriptable Subclass # TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript. # can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and # Current workaround is to add a helper function with different name and
...@@ -749,7 +757,7 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -749,7 +757,7 @@ class TransformerCTCEncoder(FairseqEncoder):
# CTC # CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx: if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone()) ctc_logit = self.ctc(self.upsample(x.clone()))
# Intermedia CTC # Intermedia CTC
if layer_idx in self.intermedia_ctc_layers: if layer_idx in self.intermedia_ctc_layers:
...@@ -759,11 +767,15 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -759,11 +767,15 @@ class TransformerCTCEncoder(FairseqEncoder):
break break
norm_x = self.layer_norm(x) norm_x = self.layer_norm(x)
logit = self.ctc(norm_x) up_x = self.upsample(norm_x)
up_logit = self.ctc(up_x)
intermedia_ctc_logits.append(logit) intermedia_ctc_logits.append(up_logit)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1) up_prob = utils.softmax(up_logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask) up_prob = up_prob.permute(1, 2, 0)
prob = nn.functional.max_pool1d(up_prob, kernel_size=2, stride=2)
prob = prob.permute(2, 0, 1)
x, _ = self.adapter([x, prob])
if self.history is not None: if self.history is not None:
self.history.push(x) self.history.push(x)
...@@ -775,7 +787,12 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -775,7 +787,12 @@ class TransformerCTCEncoder(FairseqEncoder):
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None: if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x) ctc_logit = self.ctc(self.upsample(x))
ctc_padding_mask = encoder_padding_mask
if ctc_logit is not None or len(intermedia_ctc_logits) != 0:
bsz = encoder_padding_mask.size(0)
ctc_padding_mask = encoder_padding_mask.unsqueeze(-1).expand(-1, -1, 2).reshape(bsz, -1)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead. # `forward` so we use a dictionary instead.
...@@ -784,6 +801,7 @@ class TransformerCTCEncoder(FairseqEncoder): ...@@ -784,6 +801,7 @@ class TransformerCTCEncoder(FairseqEncoder):
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C "ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"ctc_padding_mask": [ctc_padding_mask],
"intermedia_ctc_logits": intermedia_ctc_logits, # T x B x C "intermedia_ctc_logits": intermedia_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C "encoder_embedding": [encoder_embedding], # B x T x C
......
...@@ -95,14 +95,13 @@ class Adapter(nn.Module): ...@@ -95,14 +95,13 @@ class Adapter(nn.Module):
if self.distribution_cutoff is not None: if self.distribution_cutoff is not None:
logger.info("Distribution cutoff: %d" % int(strategy)) logger.info("Distribution cutoff: %d" % int(strategy))
def forward(self, x, padding): def forward(self, x, padding=None):
representation, distribution = x representation, distribution = x
distribution = distribution.type_as(representation) distribution = distribution.type_as(representation)
seq_len, bsz, dim = representation.size() seq_len, bsz, dim = representation.size()
org_distribution = distribution org_distribution = distribution
distribution = distribution.view(-1, distribution.size(-1)) distribution = distribution.contiguous().view(-1, distribution.size(-1))
lengths = (~padding).long().sum(-1)
if self.adapter_type == "linear": if self.adapter_type == "linear":
out = self.linear_adapter(representation) out = self.linear_adapter(representation)
...@@ -140,6 +139,7 @@ class Adapter(nn.Module): ...@@ -140,6 +139,7 @@ class Adapter(nn.Module):
elif self.adapter_type == "shrink": elif self.adapter_type == "shrink":
from itertools import groupby from itertools import groupby
lengths = (~padding).long().sum(-1)
with torch.no_grad(): with torch.no_grad():
batch_predicted = [] batch_predicted = []
prob_ctc = org_distribution.transpose(0, 1) # T x B x D -> B x T x D prob_ctc = org_distribution.transpose(0, 1) # T x B x D -> B x T x D
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论