Commit 9e958e0c by xuchen

receive multiple input in encoder

parent b4e95869
......@@ -277,7 +277,8 @@ class CtcCriterion(FairseqCriterion):
def forward(self, model, sample, reduce=True):
# net_output = model(**sample["net_input"])
src_tokens, src_lengths, prev_output_tokens = sample["net_input"].values()
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
if self.training and getattr(model.encoder, "pae_ground_truth_ratio", 0) != 0:
ctc_alignment_oracle = self.get_ground_truth_alignment(model, sample)
......@@ -364,9 +365,12 @@ class CtcCriterion(FairseqCriterion):
return oracle, best_aligns_pad, mistake_flag, mistake_ratio
src_tokens, src_lengths, prev_output_tokens = sample["net_input"].values()
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
tgt_lang_idx = sample["net_input"].get("tgt_lang_idx", None)
with torch.no_grad():
encoder_out = model.encoder(src_tokens, src_lengths)
encoder_out = model.encoder(src_tokens, src_lengths,
tgt_lang_idx=tgt_lang_idx)
ctc_logit = None
if "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) != 0:
......
......@@ -44,7 +44,10 @@ class JoinSpeechTextLoss(
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
speech_tokens, speech_lengths, prev_output_tokens = sample["net_input"].values()
speech_tokens = sample["net_input"]["src_tokens"]
speech_lengths = sample["net_input"]["src_lengths"]
prev_output_tokens = sample["net_input"]["prev_output_tokens"]
text_src_tokens = sample["transcript"]["tokens"]
text_src_lengths = sample["transcript"]["lengths"]
......
......@@ -79,7 +79,10 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
src_tokens, src_lengths, prev_output_tokens = sample["net_input"].values()
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
prev_output_tokens = sample["net_input"]["prev_output_tokens"]
tgt_lang_idx = sample["net_input"].get("tgt_lang_idx", None)
train_enc_only = False
if self.training and self.only_train_enc_prob != 0 and self.ctc_criterion.all_ctc_weight > 0:
......@@ -101,9 +104,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
with utils.set_torch_seed(seed):
ctc_alignment_oracle = self.ctc_criterion.get_ground_truth_alignment(model, sample)
encoder_out = model.encoder(src_tokens, src_lengths,
ctc_alignment_oracle=ctc_alignment_oracle)
ctc_alignment_oracle=ctc_alignment_oracle,
tgt_lang_idx=tgt_lang_idx)
else:
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
encoder_out = model.encoder(src_tokens=src_tokens,
src_lengths=src_lengths,
tgt_lang_idx=tgt_lang_idx)
net_output = model.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
......
......@@ -887,6 +887,7 @@ class S2TTransformerEncoder(FairseqEncoder):
def __init__(self, args, task=None, embed_tokens=None):
super().__init__(None)
self.embed_tokens = embed_tokens
dim = args.encoder_embed_dim
self.source_dictionary = task.source_dictionary
self.target_dictionary = task.target_dictionary
......@@ -975,9 +976,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.inter_ctc_drop_prob = args.inter_ctc_drop_prob
self.share_inter_ctc = getattr(args, "share_inter_ctc", False)
self.inter_ctc_layers = []
self.use_inter_ctc = False
if args.inter_ctc_layers is not None:
self.use_inter_ctc = True
self.share_inter_ctc_norm = args.share_inter_ctc_norm
if self.share_inter_ctc_norm:
logger.info(
......@@ -1749,6 +1748,20 @@ class S2TTransformerEncoder(FairseqEncoder):
else:
return False
def forward_torchscript(self, net_input: Dict[str, Tensor], ):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
if torch.jit.is_scripting():
return self.forward(
src_tokens=net_input["src_tokens"],
src_lengths=net_input["src_lengths"],
)
else:
return self.forward_non_torchscript(net_input)
def forward(self, src_tokens, src_lengths=None, **kwargs):
layer_idx = -1
......@@ -1762,12 +1775,27 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.history is not None:
self.history.clean()
tgt_lang_idx = kwargs.get("tgt_lang_idx", None)
has_add_lang_tag = False
# (B, T, D) -> (T, B, D)
x = src_tokens.transpose(0, 1)
input_lengths = src_lengths
org_bsz = x.size(1)
if (
self.mixup
and layer_idx == mixup_layer
):
if tgt_lang_idx is not None:
assert self.embed_tokens is not None
tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0)
if mixup is not None:
pass
x = torch.cat((tgt_lang_embed, x), 0)
input_lengths += 1
has_add_lang_tag = True
if (
(self.training or self.mixup_infer)
and self.mixup
and layer_idx == mixup_layer
......@@ -1787,6 +1815,15 @@ class S2TTransformerEncoder(FairseqEncoder):
x, input_lengths = self.subsample(x, input_lengths)
self.show_debug(x, "x after subsampling")
#if tgt_lang_idx is not None and False:
if tgt_lang_idx is not None and not has_add_lang_tag:
assert self.embed_tokens is not None
tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0)
if mixup is not None:
pass
x = torch.cat((tgt_lang_embed, x), 0)
input_lengths += 1
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if encoder_padding_mask.size(1) < x.size(0):
bsz = encoder_padding_mask.size(0)
......
......@@ -473,8 +473,12 @@ class MultiheadAttention(nn.Module):
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
entropy = Categorical(weights).entropy()
# mean_entropy = entropy.mean([1, 2])
entropy = Categorical(weights).entropy()
# length = torch.log(torch.Tensor([weights.size(2)])).to(weights.device).unsqueeze(0).unsqueeze(0)
# entropy = entropy / length
# mean_entropy = entropy.mean([0, 1]) / length
mean_entropy = entropy.mean()
if self.entropy_num == 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论