Commit 0187e5d6 by xuchen

fix the bug of the ctc in pyramid transformer

parent c0e06600
......@@ -13,8 +13,10 @@ pyramid-position-embed: 1_1_1_1
pyramid-kernel-sizes: 5_5_5_5
pyramid-ffn-ratios: 8_8_8_4
pyramid-heads: 2_2_4_8
#ctc-layer: 8
train-subset: train-clean-100,train-clean-360,train-other-500
#train-subset: train-clean-100,train-clean-360,train-other-500
train-subset: train-clean-100
valid-subset: dev-clean
max-epoch: 100
......@@ -39,7 +41,8 @@ warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy
criterion: label_smoothed_cross_entropy_with_ctc
ctc-weight: 0.3
label_smoothing: 0.1
conv-channels: 1024
......
......@@ -41,7 +41,8 @@ warmup-updates: 10000
lr: 2e-3
#adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy
criterion: label_smoothed_cross_entropy_with_ctc
ctc-weight: 0
label_smoothing: 0.1
conv-channels: 1024
......
......@@ -33,7 +33,7 @@ if [[ -n ${data_dir} ]]; then
fi
if [[ ${#test_subset[@]} -ne 0 ]]; then
subsets=$(echo ${test_subset[*]} | sed 's/ /,/g')
cmd="$cmd --test_subset ${test_subset}"
cmd="$cmd --test_subset ${subsets}"
fi
echo $cmd
......
gpu_num=8
gpu_num=4
cmd="sh train.sh"
while :
......
......@@ -94,16 +94,15 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
ctc_loss = self.compute_ctc_loss(model, sample, encoder_out)
logging_output["ctc_loss"] = utils.item(ctc_loss.data)
loss = (1 - self.ctc_weight) * loss + self.ctc_weight * ctc_loss
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
else:
loss = (1 - self.ctc_weight) * loss
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
def compute_ctc_loss(self, model, sample, encoder_out):
transcript = sample["transcript"]
if "ctc_logit" in encoder_out:
ctc_logit = encoder_out["ctc_logit"][0]
else:
ctc_logit = model.encoder.compute_ctc_logit(encoder_out)
ctc_logit = model.encoder.compute_ctc_logit(encoder_out)
lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
......@@ -189,8 +188,7 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
return loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
def reduce_metrics(self, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
trans_loss_sum = utils.item(
......@@ -199,9 +197,10 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs)
)
ctc_loss_sum = utils.item(
sum(log.get("ctc_loss", 0) for log in logging_outputs)
)
if self.ctc_weight > 0:
ctc_loss_sum = utils.item(
sum(log.get("ctc_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
......@@ -216,12 +215,13 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_scalar(
"ctc_loss",
ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if self.ctc_weight > 0:
metrics.log_scalar(
"ctc_loss",
ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
......
......@@ -358,7 +358,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
(getattr(args, "ctc_weight", False) > 0))
if self.use_ctc:
self.ctc_layer = (args.encoder_layers + args.ctc_layer) % args.encoder_layers
self.inter_ctc = True if self.ctc_layer != args.encoder_layers - 1 else False
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if task.source_dictionary == task.target_dictionary and getattr(args, "share_all_embeddings", False):
self.ctc_projection = nn.Linear(
......@@ -376,6 +376,7 @@ class PyS2TTransformerEncoder(FairseqEncoder):
ctc_layer -= self.pyramid_layers[i]
if ctc_layer <= 0:
embed_dim = self.pyramid_embed_dims[i]
break
self.ctc_layer_norm = LayerNorm(embed_dim)
self.ctc_projection = nn.Linear(embed_dim, len(task.source_dictionary), bias=False)
......@@ -431,15 +432,15 @@ class PyS2TTransformerEncoder(FairseqEncoder):
x = layer(x, encoder_padding_mask, pos_emb=positions)
layer_idx += 1
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x)
prev_state.append(x)
prev_padding.append(encoder_padding_mask)
if block_attn is not None:
x = block_attn(x, prev_state[-1], prev_padding[-1])
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc_layer_norm(x)
if self.use_ppm:
pool_state = [x]
seq_len, bsz, dim = x.size()
......@@ -473,8 +474,8 @@ class PyS2TTransformerEncoder(FairseqEncoder):
def compute_ctc_logit(self, encoder_out):
assert self.use_ctc, "CTC is not available!"
if isinstance(encoder_out, dict) and "encoder_out" in encoder_out:
encoder_state = encoder_out["encoder_out"][0]
if isinstance(encoder_out, dict) and "ctc_logit" in encoder_out:
encoder_state = encoder_out["ctc_logit"][0]
else:
encoder_state = encoder_out
ctc_logit = self.ctc_projection(self.ctc_dropout_module(encoder_state))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论