Commit 0187e5d6 by xuchen

fix the bug of the ctc in pyramid transformer

parent c0e06600
...@@ -13,8 +13,10 @@ pyramid-position-embed: 1_1_1_1 ...@@ -13,8 +13,10 @@ pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5 pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_4 pyramid-ffn-ratios: 8_8_8_4
pyramid-heads: 2_2_4_8 pyramid-heads: 2_2_4_8
#ctc-layer: 8
train-subset: train-clean-100,train-clean-360,train-other-500 #train-subset: train-clean-100,train-clean-360,train-other-500
train-subset: train-clean-100
valid-subset: dev-clean valid-subset: dev-clean
max-epoch: 100 max-epoch: 100
...@@ -39,7 +41,8 @@ warmup-updates: 10000 ...@@ -39,7 +41,8 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy criterion: label_smoothed_cross_entropy_with_ctc
ctc-weight: 0.3
label_smoothing: 0.1 label_smoothing: 0.1
conv-channels: 1024 conv-channels: 1024
......
...@@ -41,7 +41,8 @@ warmup-updates: 10000 ...@@ -41,7 +41,8 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
#adam_betas: (0.9,0.98) #adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy criterion: label_smoothed_cross_entropy_with_ctc
ctc-weight: 0
label_smoothing: 0.1 label_smoothing: 0.1
conv-channels: 1024 conv-channels: 1024
......
...@@ -33,7 +33,7 @@ if [[ -n ${data_dir} ]]; then ...@@ -33,7 +33,7 @@ if [[ -n ${data_dir} ]]; then
fi fi
if [[ ${#test_subset[@]} -ne 0 ]]; then if [[ ${#test_subset[@]} -ne 0 ]]; then
subsets=$(echo ${test_subset[*]} | sed 's/ /,/g') subsets=$(echo ${test_subset[*]} | sed 's/ /,/g')
cmd="$cmd --test_subset ${test_subset}" cmd="$cmd --test_subset ${subsets}"
fi fi
echo $cmd echo $cmd
......
gpu_num=8 gpu_num=4
cmd="sh train.sh" cmd="sh train.sh"
while : while :
......
...@@ -94,16 +94,15 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -94,16 +94,15 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
ctc_loss = self.compute_ctc_loss(model, sample, encoder_out) ctc_loss = self.compute_ctc_loss(model, sample, encoder_out)
logging_output["ctc_loss"] = utils.item(ctc_loss.data) logging_output["ctc_loss"] = utils.item(ctc_loss.data)
loss = (1 - self.ctc_weight) * loss + self.ctc_weight * ctc_loss loss = (1 - self.ctc_weight) * loss + self.ctc_weight * ctc_loss
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data else:
loss = (1 - self.ctc_weight) * loss
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output return loss, sample_size, logging_output
def compute_ctc_loss(self, model, sample, encoder_out): def compute_ctc_loss(self, model, sample, encoder_out):
transcript = sample["transcript"] transcript = sample["transcript"]
if "ctc_logit" in encoder_out: ctc_logit = model.encoder.compute_ctc_logit(encoder_out)
ctc_logit = encoder_out["ctc_logit"][0]
else:
ctc_logit = model.encoder.compute_ctc_logit(encoder_out)
lprobs = model.get_normalized_probs( lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True [ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder ).contiguous() # (T, B, C) from the encoder
...@@ -189,8 +188,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -189,8 +188,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
return loss return loss
@staticmethod def reduce_metrics(self, logging_outputs) -> None:
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
trans_loss_sum = utils.item( trans_loss_sum = utils.item(
...@@ -199,9 +197,10 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -199,9 +197,10 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
nll_loss_sum = utils.item( nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs) sum(log.get("nll_loss", 0) for log in logging_outputs)
) )
ctc_loss_sum = utils.item( if self.ctc_weight > 0:
sum(log.get("ctc_loss", 0) for log in logging_outputs) ctc_loss_sum = utils.item(
) sum(log.get("ctc_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item( sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs) sum(log.get("sample_size", 0) for log in logging_outputs)
...@@ -216,12 +215,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -216,12 +215,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
metrics.log_scalar( metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
) )
metrics.log_scalar( if self.ctc_weight > 0:
"ctc_loss", metrics.log_scalar(
ctc_loss_sum / sample_size / math.log(2), "ctc_loss",
sample_size, ctc_loss_sum / sample_size / math.log(2),
round=3, sample_size,
) round=3,
)
metrics.log_derived( metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
) )
......
...@@ -358,7 +358,7 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -358,7 +358,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
(getattr(args, "ctc_weight", False) > 0)) (getattr(args, "ctc_weight", False) > 0))
if self.use_ctc: if self.use_ctc:
self.ctc_layer = (args.encoder_layers + args.ctc_layer) % args.encoder_layers self.ctc_layer = (args.encoder_layers + args.ctc_layer) % args.encoder_layers
self.inter_ctc = True if self.ctc_layer != args.encoder_layers - 1 else False self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if task.source_dictionary == task.target_dictionary and getattr(args, "share_all_embeddings", False): if task.source_dictionary == task.target_dictionary and getattr(args, "share_all_embeddings", False):
self.ctc_projection = nn.Linear( self.ctc_projection = nn.Linear(
...@@ -376,6 +376,7 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -376,6 +376,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
ctc_layer -= self.pyramid_layers[i] ctc_layer -= self.pyramid_layers[i]
if ctc_layer <= 0: if ctc_layer <= 0:
embed_dim = self.pyramid_embed_dims[i] embed_dim = self.pyramid_embed_dims[i]
break
self.ctc_layer_norm = LayerNorm(embed_dim) self.ctc_layer_norm = LayerNorm(embed_dim)
self.ctc_projection = nn.Linear(embed_dim, len(task.source_dictionary), bias=False) self.ctc_projection = nn.Linear(embed_dim, len(task.source_dictionary), bias=False)
...@@ -431,15 +432,15 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -431,15 +432,15 @@ class PyS2TTransformerEncoder(FairseqEncoder):
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1 layer_idx += 1
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x)
prev_state.append(x) prev_state.append(x)
prev_padding.append(encoder_padding_mask) prev_padding.append(encoder_padding_mask)
if block_attn is not None: if block_attn is not None:
x = block_attn(x, prev_state[-1], prev_padding[-1]) x = block_attn(x, prev_state[-1], prev_padding[-1])
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x)
if self.use_ppm: if self.use_ppm:
pool_state = [x] pool_state = [x]
seq_len, bsz, dim = x.size() seq_len, bsz, dim = x.size()
...@@ -473,8 +474,8 @@ class PyS2TTransformerEncoder(FairseqEncoder): ...@@ -473,8 +474,8 @@ class PyS2TTransformerEncoder(FairseqEncoder):
def compute_ctc_logit(self, encoder_out): def compute_ctc_logit(self, encoder_out):
assert self.use_ctc, "CTC is not available!" assert self.use_ctc, "CTC is not available!"
if isinstance(encoder_out, dict) and "encoder_out" in encoder_out: if isinstance(encoder_out, dict) and "ctc_logit" in encoder_out:
encoder_state = encoder_out["encoder_out"][0] encoder_state = encoder_out["ctc_logit"][0]
else: else:
encoder_state = encoder_out encoder_state = encoder_out
ctc_logit = self.ctc_projection(self.ctc_dropout_module(encoder_state)) ctc_logit = self.ctc_projection(self.ctc_dropout_module(encoder_state))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论