Commit 9e958e0c by xuchen

receive multiple input in encoder

parent b4e95869
...@@ -277,7 +277,8 @@ class CtcCriterion(FairseqCriterion): ...@@ -277,7 +277,8 @@ class CtcCriterion(FairseqCriterion):
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
# net_output = model(**sample["net_input"]) # net_output = model(**sample["net_input"])
src_tokens, src_lengths, prev_output_tokens = sample["net_input"].values() src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
if self.training and getattr(model.encoder, "pae_ground_truth_ratio", 0) != 0: if self.training and getattr(model.encoder, "pae_ground_truth_ratio", 0) != 0:
ctc_alignment_oracle = self.get_ground_truth_alignment(model, sample) ctc_alignment_oracle = self.get_ground_truth_alignment(model, sample)
...@@ -364,9 +365,12 @@ class CtcCriterion(FairseqCriterion): ...@@ -364,9 +365,12 @@ class CtcCriterion(FairseqCriterion):
return oracle, best_aligns_pad, mistake_flag, mistake_ratio return oracle, best_aligns_pad, mistake_flag, mistake_ratio
src_tokens, src_lengths, prev_output_tokens = sample["net_input"].values() src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
tgt_lang_idx = sample["net_input"].get("tgt_lang_idx", None)
with torch.no_grad(): with torch.no_grad():
encoder_out = model.encoder(src_tokens, src_lengths) encoder_out = model.encoder(src_tokens, src_lengths,
tgt_lang_idx=tgt_lang_idx)
ctc_logit = None ctc_logit = None
if "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) != 0: if "ctc_logit" in encoder_out and len(encoder_out["ctc_logit"]) != 0:
......
...@@ -44,7 +44,10 @@ class JoinSpeechTextLoss( ...@@ -44,7 +44,10 @@ class JoinSpeechTextLoss(
2) the sample size, which is used as the denominator for the gradient 2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training 3) logging outputs to display while training
""" """
speech_tokens, speech_lengths, prev_output_tokens = sample["net_input"].values() speech_tokens = sample["net_input"]["src_tokens"]
speech_lengths = sample["net_input"]["src_lengths"]
prev_output_tokens = sample["net_input"]["prev_output_tokens"]
text_src_tokens = sample["transcript"]["tokens"] text_src_tokens = sample["transcript"]["tokens"]
text_src_lengths = sample["transcript"]["lengths"] text_src_lengths = sample["transcript"]["lengths"]
......
...@@ -79,7 +79,10 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -79,7 +79,10 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
2) the sample size, which is used as the denominator for the gradient 2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training 3) logging outputs to display while training
""" """
src_tokens, src_lengths, prev_output_tokens = sample["net_input"].values() src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
prev_output_tokens = sample["net_input"]["prev_output_tokens"]
tgt_lang_idx = sample["net_input"].get("tgt_lang_idx", None)
train_enc_only = False train_enc_only = False
if self.training and self.only_train_enc_prob != 0 and self.ctc_criterion.all_ctc_weight > 0: if self.training and self.only_train_enc_prob != 0 and self.ctc_criterion.all_ctc_weight > 0:
...@@ -101,9 +104,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -101,9 +104,12 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
with utils.set_torch_seed(seed): with utils.set_torch_seed(seed):
ctc_alignment_oracle = self.ctc_criterion.get_ground_truth_alignment(model, sample) ctc_alignment_oracle = self.ctc_criterion.get_ground_truth_alignment(model, sample)
encoder_out = model.encoder(src_tokens, src_lengths, encoder_out = model.encoder(src_tokens, src_lengths,
ctc_alignment_oracle=ctc_alignment_oracle) ctc_alignment_oracle=ctc_alignment_oracle,
tgt_lang_idx=tgt_lang_idx)
else: else:
encoder_out = model.encoder(src_tokens=src_tokens, src_lengths=src_lengths) encoder_out = model.encoder(src_tokens=src_tokens,
src_lengths=src_lengths,
tgt_lang_idx=tgt_lang_idx)
net_output = model.decoder( net_output = model.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
......
...@@ -887,6 +887,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -887,6 +887,7 @@ class S2TTransformerEncoder(FairseqEncoder):
def __init__(self, args, task=None, embed_tokens=None): def __init__(self, args, task=None, embed_tokens=None):
super().__init__(None) super().__init__(None)
self.embed_tokens = embed_tokens
dim = args.encoder_embed_dim dim = args.encoder_embed_dim
self.source_dictionary = task.source_dictionary self.source_dictionary = task.source_dictionary
self.target_dictionary = task.target_dictionary self.target_dictionary = task.target_dictionary
...@@ -975,9 +976,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -975,9 +976,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.inter_ctc_drop_prob = args.inter_ctc_drop_prob self.inter_ctc_drop_prob = args.inter_ctc_drop_prob
self.share_inter_ctc = getattr(args, "share_inter_ctc", False) self.share_inter_ctc = getattr(args, "share_inter_ctc", False)
self.inter_ctc_layers = [] self.inter_ctc_layers = []
self.use_inter_ctc = False
if args.inter_ctc_layers is not None: if args.inter_ctc_layers is not None:
self.use_inter_ctc = True
self.share_inter_ctc_norm = args.share_inter_ctc_norm self.share_inter_ctc_norm = args.share_inter_ctc_norm
if self.share_inter_ctc_norm: if self.share_inter_ctc_norm:
logger.info( logger.info(
...@@ -1749,6 +1748,20 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1749,6 +1748,20 @@ class S2TTransformerEncoder(FairseqEncoder):
else: else:
return False return False
def forward_torchscript(self, net_input: Dict[str, Tensor], ):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
if torch.jit.is_scripting():
return self.forward(
src_tokens=net_input["src_tokens"],
src_lengths=net_input["src_lengths"],
)
else:
return self.forward_non_torchscript(net_input)
def forward(self, src_tokens, src_lengths=None, **kwargs): def forward(self, src_tokens, src_lengths=None, **kwargs):
layer_idx = -1 layer_idx = -1
...@@ -1762,12 +1775,27 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1762,12 +1775,27 @@ class S2TTransformerEncoder(FairseqEncoder):
if self.history is not None: if self.history is not None:
self.history.clean() self.history.clean()
tgt_lang_idx = kwargs.get("tgt_lang_idx", None)
has_add_lang_tag = False
# (B, T, D) -> (T, B, D) # (B, T, D) -> (T, B, D)
x = src_tokens.transpose(0, 1) x = src_tokens.transpose(0, 1)
input_lengths = src_lengths input_lengths = src_lengths
org_bsz = x.size(1) org_bsz = x.size(1)
if ( if (
self.mixup
and layer_idx == mixup_layer
):
if tgt_lang_idx is not None:
assert self.embed_tokens is not None
tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0)
if mixup is not None:
pass
x = torch.cat((tgt_lang_embed, x), 0)
input_lengths += 1
has_add_lang_tag = True
if (
(self.training or self.mixup_infer) (self.training or self.mixup_infer)
and self.mixup and self.mixup
and layer_idx == mixup_layer and layer_idx == mixup_layer
...@@ -1787,6 +1815,15 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -1787,6 +1815,15 @@ class S2TTransformerEncoder(FairseqEncoder):
x, input_lengths = self.subsample(x, input_lengths) x, input_lengths = self.subsample(x, input_lengths)
self.show_debug(x, "x after subsampling") self.show_debug(x, "x after subsampling")
#if tgt_lang_idx is not None and False:
if tgt_lang_idx is not None and not has_add_lang_tag:
assert self.embed_tokens is not None
tgt_lang_embed = self.embed_tokens(tgt_lang_idx).unsqueeze(0)
if mixup is not None:
pass
x = torch.cat((tgt_lang_embed, x), 0)
input_lengths += 1
encoder_padding_mask = lengths_to_padding_mask(input_lengths) encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if encoder_padding_mask.size(1) < x.size(0): if encoder_padding_mask.size(1) < x.size(0):
bsz = encoder_padding_mask.size(0) bsz = encoder_padding_mask.size(0)
......
...@@ -474,7 +474,11 @@ class MultiheadAttention(nn.Module): ...@@ -474,7 +474,11 @@ class MultiheadAttention(nn.Module):
).transpose(1, 0) ).transpose(1, 0)
entropy = Categorical(weights).entropy() entropy = Categorical(weights).entropy()
# mean_entropy = entropy.mean([1, 2])
# length = torch.log(torch.Tensor([weights.size(2)])).to(weights.device).unsqueeze(0).unsqueeze(0)
# entropy = entropy / length
# mean_entropy = entropy.mean([0, 1]) / length
mean_entropy = entropy.mean() mean_entropy = entropy.mean()
if self.entropy_num == 0: if self.entropy_num == 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论