Commit 8b50c392 by xuchen

fix the bugs

parent 244e506e
...@@ -44,10 +44,10 @@ lcrm=1 ...@@ -44,10 +44,10 @@ lcrm=1
tokenizer=0 tokenizer=0
use_specific_dict=1 use_specific_dict=1
specific_prefix=asr5k_st10k specific_prefix=unified
specific_dir=${root_dir}/data/${dataset}/st_lcrm_asr specific_dir=${root_dir}/data/${dataset}/vocab
src_vocab_prefix=spm_unigram5000_asr src_vocab_prefix=spm_en
tgt_vocab_prefix=spm_unigram10000_st tgt_vocab_prefix=spm_zh
org_data_dir=${root_dir}/data/${dataset} org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/mt data_dir=${root_dir}/data/${dataset}/mt
......
arch: s2t_dual
asr-encoder: pds
mt-encoder-layers: 30
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 1000
lr: 5e-4
#lr: 1e-5
adam_betas: (0.9,0.98)
criterion: join_speech_and_text_loss
label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.15
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 18
#text-encoder-layers: 30
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#acoustic-encoder: pds
#adapter: league
encoder-embed-dim: 512
#ctc-layer: 12
pds-stages: 4
pds-layers: 6_3_3_6
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt
load-pretrained-asr-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/asr/0308_lcrm_unified_pds_base_8_grow_conformer_ctc_baseline_clamp/avg_10_checkpoint.pt
load-pretrained-mt-encoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
arch: s2t_sate
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 1000
lr: 5e-4
#lr: 1e-5
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.15
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 18
text-encoder-layers: 30
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
acoustic-encoder: pds
adapter: league
encoder-embed-dim: 512
#ctc-layer: 12
pds-stages: 4
pds-layers: 6_3_3_6
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt
load-pretrained-acoustic-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/asr/0308_lcrm_unified_pds_base_8_grow_conformer_ctc_baseline_clamp/avg_10_checkpoint.pt
load-pretrained-text-encoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
...@@ -45,10 +45,10 @@ tokenizer=0 ...@@ -45,10 +45,10 @@ tokenizer=0
use_raw_audio=0 use_raw_audio=0
use_specific_dict=1 use_specific_dict=1
specific_prefix=asr specific_prefix=unified
specific_dir=${root_dir}/data/${dataset}/asr specific_dir=${root_dir}/data/${dataset}/vocab
asr_vocab_prefix=spm_unigram5000_asr asr_vocab_prefix=spm_en
st_vocab_prefix= st_vocab_prefix=spm_zh
org_data_dir=${root_dir}/data/${dataset} org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/st data_dir=${root_dir}/data/${dataset}/st
......
arch: s2t_dual arch: s2t_dual
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
#inter-mixup: True
#inter-mixup-layer: 0
#inter-mixup-beta: 0.5
asr-encoder: sate asr-encoder: sate
mt-encoder-layers: 3 mt-encoder-layers: 6
mt-encoder-dim: 256 mt-encoder: transformer
encoder-drop-net: True
encoder-drop-net-prob: 0.8
encoder-embed-dim: 256 encoder-embed-dim: 256
pds-stages: 4 pds-stages: 4
#ctc-layer: 15 #ctc-layer: 12
encoder-layers: 6 pds-layers: 3_3_3_3
pds-layers: 2_1_1_2
pds-ratios: 2_2_1_2 pds-ratios: 2_2_1_2
pds-fusion: True pds-fusion: True
pds-fusion-method: all_conv pds-fusion-method: all_conv
pds-embed-dims: 192_256_256_384 pds-embed-dims: 256_256_256_256
pds-ds-method: conv pds-ds-method: conv
pds-embed-norm: True pds-embed-norm: True
pds-position-embed: 1_1_1_1 pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5 pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_4 pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_6 pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
...@@ -42,17 +30,22 @@ warmup-updates: 10000 ...@@ -42,17 +30,22 @@ warmup-updates: 10000
lr: 2e-3 lr: 2e-3
adam_betas: (0.9,0.98) adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: join_speech_and_text_loss criterion: join_speech_and_text_loss
ctc-weight: 0.3
label_smoothing: 0.1 label_smoothing: 0.1
dropout: 0.1 dropout: 0.1
activation-fn: relu activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6 decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256 decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4 decoder-attention-heads: 4
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-asr-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/0225_st_purectc_pds_base_8_baseline_topctc/avg_10_checkpoint.pt
#load-pretrained-mt-encoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt
\ No newline at end of file
...@@ -9,10 +9,12 @@ lr: 1e-3 ...@@ -9,10 +9,12 @@ lr: 1e-3
adam_betas: (0.9,0.997) adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy_with_ctc
ctc-weight: 0.3
intermedia-ctc-layers: 2,4
label_smoothing: 0.1 label_smoothing: 0.1
ctc-weight: 0.2
intermedia-ctc-weight: 0.1
intermedia-ctc-layers: 2,4
dropout: 0.1 dropout: 0.1
attention-dropout: 0.1 attention-dropout: 0.1
activation-dropout: 0.1 activation-dropout: 0.1
......
ctc-weight: 0.2 #ctc-weight: 0.2
intermedia-ctc-weight: 0.1 intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4 intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3 #target-ctc-weight: 0.3
......
arch: s2t_dual
asr-encoder: pds
mt-encoder-layers: 6
mt-encoder: transformer
encoder-drop-net: True
encoder-drop-net-prob: 0.5
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: join_speech_and_text_loss
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-asr-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/0225_st_purectc_pds_base_8_baseline_topctc/avg_10_checkpoint.pt
#load-pretrained-mt-encoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt
\ No newline at end of file
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
...@@ -170,9 +170,10 @@ class CtcCriterion(FairseqCriterion): ...@@ -170,9 +170,10 @@ class CtcCriterion(FairseqCriterion):
tokens = sample["transcript"]["tokens"] tokens = sample["transcript"]["tokens"]
else: else:
tokens = sample["target"] tokens = sample["target"]
# if "ctc_padding_mask" in net_output: if "ctc_padding_mask" in net_output:
# non_padding_mask = ~net_output["ctc_padding_mask"][0] non_padding_mask = ~net_output["ctc_padding_mask"][0]
# else: else:
non_padding_mask = ~net_output["encoder_padding_mask"][0]
# non_padding_mask = ~net_output["encoder_padding_mask"][0] # non_padding_mask = ~net_output["encoder_padding_mask"][0]
mixup = False mixup = False
...@@ -182,7 +183,6 @@ class CtcCriterion(FairseqCriterion): ...@@ -182,7 +183,6 @@ class CtcCriterion(FairseqCriterion):
mixup_idx1 = net_output["mixup"]["index1"] mixup_idx1 = net_output["mixup"]["index1"]
mixup_idx2 = net_output["mixup"]["index2"] mixup_idx2 = net_output["mixup"]["index2"]
non_padding_mask = ~net_output["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1) input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (tokens != self.pad_idx) & ( pad_mask = (tokens != self.pad_idx) & (
...@@ -349,6 +349,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -349,6 +349,7 @@ class CtcCriterion(FairseqCriterion):
self.ctc_weight * ctc_loss + \ self.ctc_weight * ctc_loss + \
self.intermedia_ctc_weight * intermedia_ctc_loss + \ self.intermedia_ctc_weight * intermedia_ctc_loss + \
self.target_ctc_weight * target_ctc_loss + \ self.target_ctc_weight * target_ctc_loss + \
self.target_intermedia_ctc_weight * target_intermedia_ctc_loss + \
self.ctc_self_distill_weight * ctc_self_distill_loss + \ self.ctc_self_distill_weight * ctc_self_distill_loss + \
self.ctc_entropy * ctc_entropy self.ctc_entropy * ctc_entropy
...@@ -452,6 +453,9 @@ class CtcCriterion(FairseqCriterion): ...@@ -452,6 +453,9 @@ class CtcCriterion(FairseqCriterion):
target_ctc_loss_sum = utils.item( target_ctc_loss_sum = utils.item(
sum(log.get("target_ctc_loss", 0) for log in logging_outputs) sum(log.get("target_ctc_loss", 0) for log in logging_outputs)
) )
target_intermedia_ctc_loss_sum = utils.item(
sum(log.get("target_intermedia_ctc_loss", 0) for log in logging_outputs)
)
ctc_self_distill_loss_sum = utils.item( ctc_self_distill_loss_sum = utils.item(
sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs) sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs)
) )
...@@ -513,6 +517,13 @@ class CtcCriterion(FairseqCriterion): ...@@ -513,6 +517,13 @@ class CtcCriterion(FairseqCriterion):
sample_size, sample_size,
round=3, round=3,
) )
if target_intermedia_ctc_loss_sum > 0:
metrics.log_scalar(
"target_intermedia_ctc_loss",
target_intermedia_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if ctc_self_distill_loss_sum > 0: if ctc_self_distill_loss_sum > 0:
metrics.log_scalar( metrics.log_scalar(
......
...@@ -835,13 +835,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -835,13 +835,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if args.intermedia_adapter == "shrink": if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", "avg") strategy = getattr(args, "ctc_compress_strategy", "avg")
adapter = Adapter(embed_dim, args.intermedia_adapter, adapter = Adapter(embed_dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy) len(task.source_dictionary), strategy=strategy)
inter_adapter = adapter inter_adapter = adapter
else: else:
adapter = inter_adapter adapter = inter_adapter
else: else:
adapter = Adapter(embed_dim, "none", adapter = Adapter(embed_dim, "none",
task.source_dictionary) len(task.source_dictionary))
else: else:
ctc = None ctc = None
adapter = None adapter = None
...@@ -860,10 +860,11 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -860,10 +860,11 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.fusion_weight = nn.Parameter(torch.Tensor(fusion_stages_num).fill_(1.0)) self.fusion_weight = nn.Parameter(torch.Tensor(fusion_stages_num).fill_(1.0))
self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True) self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True)
self.use_ctc = "sate" in args.arch or \ # self.use_ctc = "sate" in args.arch or \
(getattr(args, "criterion", "") == "ctc") or \ # (getattr(args, "criterion", "") == "ctc") or \
(("ctc" in getattr(args, "criterion", "")) and # (("ctc" in getattr(args, "criterion", "")) and
(getattr(args, "ctc_weight", False) > 0)) # (getattr(args, "ctc_weight", False) > 0))
self.use_ctc = "sate" in args.arch or (getattr(args, "ctc_weight", 0) > 0)
if self.use_ctc: if self.use_ctc:
# self.ctc_layer = (args.ctc_layer + self.layers) % self.layers # self.ctc_layer = (args.ctc_layer + self.layers) % self.layers
# self.ctc_layer = self.layers if self.ctc_layer == 0 else self.ctc_layer # self.ctc_layer = self.layers if self.ctc_layer == 0 else self.ctc_layer
......
...@@ -143,7 +143,7 @@ class S2TDualModel(FairseqEncoderDecoderModel): ...@@ -143,7 +143,7 @@ class S2TDualModel(FairseqEncoderDecoderModel):
parser.add_argument( parser.add_argument(
"--mt-encoder-layers", "--mt-encoder-layers",
default=6, default=6,
type=str, type=int,
help="the layers of the MT encoder", help="the layers of the MT encoder",
) )
parser.add_argument( parser.add_argument(
...@@ -175,6 +175,18 @@ class S2TDualModel(FairseqEncoderDecoderModel): ...@@ -175,6 +175,18 @@ class S2TDualModel(FairseqEncoderDecoderModel):
help="mix the two input with any probability", help="mix the two input with any probability",
) )
parser.add_argument(
"--load-pretrained-asr-encoder-from",
type=str,
metavar="STR",
help="model to take asr encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-mt-encoder-from",
type=str,
metavar="STR",
help="model to take mt encoder weights from (for initialization)",
)
pass pass
@classmethod @classmethod
...@@ -190,7 +202,7 @@ class S2TDualModel(FairseqEncoderDecoderModel): ...@@ -190,7 +202,7 @@ class S2TDualModel(FairseqEncoderDecoderModel):
f"{args.load_pretrained_encoder_from}" f"{args.load_pretrained_encoder_from}"
) )
if getattr(args, "load_pretrained_asr_encoder_from", None): if getattr(args, "load_pretrained_asr_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model( encoder.asr_encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder.asr_encoder, checkpoint=args.load_pretrained_asr_encoder_from, strict=False component=encoder.asr_encoder, checkpoint=args.load_pretrained_asr_encoder_from, strict=False
) )
logger.info( logger.info(
...@@ -198,7 +210,7 @@ class S2TDualModel(FairseqEncoderDecoderModel): ...@@ -198,7 +210,7 @@ class S2TDualModel(FairseqEncoderDecoderModel):
f"{args.load_pretrained_asr_encoder_from}" f"{args.load_pretrained_asr_encoder_from}"
) )
if getattr(args, "load_pretrained_mt_encoder_from", None): if getattr(args, "load_pretrained_mt_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model( encoder.mt_encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder.mt_encoder, checkpoint=args.load_pretrained_mt_encoder_from, strict=False component=encoder.mt_encoder, checkpoint=args.load_pretrained_mt_encoder_from, strict=False
) )
logger.info( logger.info(
...@@ -314,23 +326,96 @@ class S2TDualEncoder(FairseqEncoder): ...@@ -314,23 +326,96 @@ class S2TDualEncoder(FairseqEncoder):
else: else:
logger.error("Unsupported ASR architecture: %s." % asr_encoder_type) logger.error("Unsupported ASR architecture: %s." % asr_encoder_type)
attn_type = args.encoder_attention_type
setattr(args, "encoder_layers", args.mt_encoder_layers)
setattr(args, "encoder_attention_type", "selfattn")
self.mt_encoder = TransformerS2Encoder(args, task.source_dictionary, embed_tokens) self.mt_encoder = TransformerS2Encoder(args, task.source_dictionary, embed_tokens)
setattr(args, "encoder_attention_type", attn_type)
def forward(self, speech_src_tokens, speech_src_lengths, text_src_tokens, text_src_lengths, **kwargs): def forward(self, speech_src_tokens, speech_src_lengths, text_src_tokens, text_src_lengths, **kwargs):
asr_encoder_out = self.asr_encoder(speech_src_tokens, speech_src_lengths) asr_encoder_out = self.asr_encoder(speech_src_tokens, speech_src_lengths)
ctc_logit = asr_encoder_out["ctc_logit"]
encoder_representation = asr_encoder_out["encoder_out"][0] encoder_representation = asr_encoder_out["encoder_out"][0]
encoder_padding_mask = asr_encoder_out["encoder_padding_mask"][0] encoder_padding_mask = asr_encoder_out["encoder_padding_mask"][0]
encoder_out = self.mt_encoder(text_src_tokens, text_src_lengths, encoder_out = self.mt_encoder(text_src_tokens, text_src_lengths,
encoder_representation, encoder_padding_mask) encoder_representation, encoder_padding_mask)
encoder_out["ctc_logit"] = ctc_logit encoder_out["ctc_logit"] = asr_encoder_out["ctc_logit"]
encoder_out["ctc_padding_mask"] = asr_encoder_out["encoder_padding_mask"]
return encoder_out
def forward_torchscript(self, net_input: Dict[str, Tensor]):
speech_src_tokens = net_input["src_tokens"]
speech_src_lengths = net_input["src_lengths"]
text_src_tokens = net_input["text_src_tokens"]
text_src_lengths = net_input["text_src_lengths"]
encoder_out = self.forward(speech_src_tokens, speech_src_lengths, text_src_tokens, text_src_lengths)
return encoder_out return encoder_out
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
self.mt_encoder.reorder_encoder_out(encoder_out, new_order) """
return Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if len(encoder_out["encoder_out"]) == 0:
new_encoder_out = []
else:
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask"]) == 0:
new_encoder_padding_mask = []
else:
new_encoder_padding_mask = [
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
]
if len(encoder_out["encoder_out_s2"]) == 0:
new_encoder_out_s2 = []
else:
new_encoder_out_s2 = [encoder_out["encoder_out_s2"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask_s2"]) == 0:
new_encoder_padding_mask_s2 = []
else:
new_encoder_padding_mask_s2 = [
encoder_out["encoder_padding_mask_s2"][0].index_select(0, new_order)
]
if len(encoder_out["encoder_embedding"]) == 0:
new_encoder_embedding = []
else:
new_encoder_embedding = [
encoder_out["encoder_embedding"][0].index_select(0, new_order)
]
if len(encoder_out["src_tokens"]) == 0:
src_tokens = []
else:
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
if len(encoder_out["src_lengths"]) == 0:
src_lengths = []
else:
src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_out_s2": new_encoder_out_s2, # T x B x C
"encoder_padding_mask_s2": new_encoder_padding_mask_s2, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": src_tokens, # B x T
"src_lengths": src_lengths, # B x 1
}
@register_model_architecture(model_name="s2t_dual", arch_name="s2t_dual") @register_model_architecture(model_name="s2t_dual", arch_name="s2t_dual")
......
...@@ -222,7 +222,7 @@ class TextEncoder(FairseqEncoder): ...@@ -222,7 +222,7 @@ class TextEncoder(FairseqEncoder):
elif args.intermedia_adapter == "league": elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None) strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter, self.adapter = Adapter(embed_dim, args.intermedia_adapter,
dictionary, embed_tokens=embed_tokens, len(dictionary), embed_tokens=embed_tokens,
strategy=strategy) strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1) self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
...@@ -301,7 +301,7 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -301,7 +301,7 @@ class S2TSATEEncoder(FairseqEncoder):
self.adapter = Adapter(args.encoder_embed_dim, self.adapter = Adapter(args.encoder_embed_dim,
args.adapter, args.adapter,
task.source_dictionary, len(task.source_dictionary),
decoder_embed_tokens if task.source_dictionary == task.target_dictionary else None, decoder_embed_tokens if task.source_dictionary == task.target_dictionary else None,
strategy=strategy) strategy=strategy)
...@@ -352,13 +352,14 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -352,13 +352,14 @@ class S2TSATEEncoder(FairseqEncoder):
self.history.push(x) self.history.push(x)
x, target_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history) x, target_ctc_logit, target_intermedia_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C "ctc_logit": [ctc_logit], # T x B x C
"intermedia_ctc_logits": acoustic_encoder_out.get("intermedia_ctc_logits", []), # B x T x C "intermedia_ctc_logits": acoustic_encoder_out.get("intermedia_ctc_logits", []), # B x T x C
"target_ctc_logits": target_ctc_logits, # B x T x C "target_ctc_logit": target_ctc_logit, # B x T x C
"target_intermedia_ctc_logits": target_intermedia_ctc_logits, # B x T x C
"ctc_padding_mask": [ctc_padding_mask], # B x T "ctc_padding_mask": [ctc_padding_mask], # B x T
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
......
...@@ -597,7 +597,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -597,7 +597,7 @@ class S2TTransformerEncoder(FairseqEncoder):
elif args.intermedia_adapter == "league": elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None) strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(dim, args.intermedia_adapter, self.adapter = Adapter(dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy) len(task.source_dictionary), strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0) self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1) self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
......
...@@ -295,42 +295,6 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -295,42 +295,6 @@ class TransformerModel(FairseqEncoderDecoderModel):
action='store_true', action='store_true',
help="use squeeze and excitation method", help="use squeeze and excitation method",
) )
# CTC
parser.add_argument(
"--ctc-layer",
type=int,
help="ctc layers for target sentence",
)
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
parser.add_argument(
"--intermedia-temperature",
default=1,
type=float,
help="temperature of the intermedia ctc probability",
)
# fmt: on # fmt: on
@classmethod @classmethod
...@@ -571,47 +535,6 @@ class TransformerEncoder(FairseqEncoder): ...@@ -571,47 +535,6 @@ class TransformerEncoder(FairseqEncoder):
else: else:
self.history = None self.history = None
# CTC
self.use_ctc = getattr(args, "ctc_weight", 0) > 0
if self.use_ctc:
self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
self.ctc.ctc_projection.weight = embed_tokens.weight
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
layer_idx = int(layer_idx)
if layer_idx <= 0:
layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
if not self.use_ctc:
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout)
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter,
None, embed_tokens=decoder_embed_tokens, strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
def build_encoder_layer(self, args): def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args) layer = TransformerEncoderLayer(args)
...@@ -732,9 +655,6 @@ class TransformerEncoder(FairseqEncoder): ...@@ -732,9 +655,6 @@ class TransformerEncoder(FairseqEncoder):
self.history.push(x) self.history.push(x)
# encoder layers # encoder layers
layer_idx = 0
ctc_logit = None
intermedia_ctc_logits = []
for layer in self.layers: for layer in self.layers:
if self.history is not None: if self.history is not None:
x = self.history.pop() x = self.history.pop()
...@@ -742,29 +662,10 @@ class TransformerEncoder(FairseqEncoder): ...@@ -742,29 +662,10 @@ class TransformerEncoder(FairseqEncoder):
x = layer( x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None x, encoder_padding_mask=encoder_padding_mask if has_pads else None
) )
layer_idx += 1
if return_all_hiddens: if return_all_hiddens:
assert encoder_states is not None assert encoder_states is not None
encoder_states.append(x) encoder_states.append(x)
# CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone())
# Intermedia CTC
if layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0:
p = torch.rand(1).uniform_()
if p < self.intermedia_drop_prob:
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
if self.history is not None: if self.history is not None:
self.history.push(x) self.history.push(x)
...@@ -774,16 +675,12 @@ class TransformerEncoder(FairseqEncoder): ...@@ -774,16 +675,12 @@ class TransformerEncoder(FairseqEncoder):
if self.layer_norm is not None: if self.layer_norm is not None:
x = self.layer_norm(x) x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead. # `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists. # TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None. # The empty list is equivalent to None.
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C "encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C] "encoder_states": encoder_states, # List[T x B x C]
...@@ -1436,12 +1333,6 @@ def base_architecture(args): ...@@ -1436,12 +1333,6 @@ def base_architecture(args):
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1) args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True) args.k_only = getattr(args, 'k_only', True)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", args.encoder_layers)
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
@register_model_architecture("transformer", "transformer_relative") @register_model_architecture("transformer", "transformer_relative")
def transformer_rpr(args): def transformer_rpr(args):
......
...@@ -57,7 +57,7 @@ class CTCCompressStrategy: ...@@ -57,7 +57,7 @@ class CTCCompressStrategy:
class Adapter(nn.Module): class Adapter(nn.Module):
def __init__(self, dim, adapter_type, dictionary, embed_tokens=None, strategy=None): def __init__(self, dim, adapter_type, dictionary_size, embed_tokens=None, strategy=None):
super().__init__() super().__init__()
dim = dim dim = dim
...@@ -71,11 +71,14 @@ class Adapter(nn.Module): ...@@ -71,11 +71,14 @@ class Adapter(nn.Module):
) )
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]: if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
if embed_tokens is None: self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization
num_embeddings = len(dictionary) if embed_tokens is not None:
self.embed_adapter = nn.Linear(num_embeddings, dim) # Embedding(num_embeddings, dim, dictionary.pad()) self.embed_adapter.weight = embed_tokens.weight
else: # if embed_tokens is None:
self.embed_adapter = embed_tokens # num_embeddings = len(dictionary)
# self.embed_adapter = nn.Linear(num_embeddings, dim) # Embedding(num_embeddings, dim, dictionary.pad())
# else:
# self.embed_adapter = embed_tokens
if self.adapter_type == "gated_league": if self.adapter_type == "gated_league":
self.gate_linear = nn.Linear(2 * dim, dim) self.gate_linear = nn.Linear(2 * dim, dim)
...@@ -95,45 +98,40 @@ class Adapter(nn.Module): ...@@ -95,45 +98,40 @@ class Adapter(nn.Module):
def forward(self, x, padding): def forward(self, x, padding):
representation, distribution = x representation, distribution = x
distribution = distribution.type_as(representation)
seq_len, bsz, dim = representation.size() seq_len, bsz, dim = representation.size()
org_distribution = distribution org_distribution = distribution
distribution = distribution.view(-1, distribution.size(-1))
lengths = (~padding).long().sum(-1) lengths = (~padding).long().sum(-1)
if self.adapter_type == "linear": if self.adapter_type == "linear":
out = self.linear_adapter(representation) out = self.linear_adapter(representation)
elif self.adapter_type == "context": elif self.adapter_type == "context":
distribution = distribution.view(-1, distribution.size(-1)) out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
out = torch.mm(
distribution, self.embed_adapter.weight.float()
).view(seq_len, bsz, -1).type_as(representation)
elif self.adapter_type == "league": elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
if self.distribution_cutoff is not None: if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), distribution.size(-1) - 1) cutoff = min(int(self.distribution_cutoff), org_distribution.size(-1) - 1)
threshold = distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1] threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
distribution = torch.where(distribution > threshold, distribution, torch.zeros_like(distribution)) distribution = torch.where(
org_distribution > threshold, org_distribution, torch.zeros_like(org_distribution)
)
distribution = distribution.view(-1, distribution.size(-1)) distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(
distribution, self.embed_adapter.weight.float() soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
).view(seq_len, bsz, -1).type_as(representation)
out = linear_out + soft_out out = linear_out + soft_out
elif self.adapter_type == "gated_league": elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation) linear_out = self.linear_adapter(representation)
distribution = distribution.view(-1, distribution.size(-1)) soft_out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
soft_out = torch.mm(
distribution, self.embed_adapter.weight.float()
).view(seq_len, bsz, -1).type_as(representation)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid() coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league": elif self.adapter_type == "inter_league":
distribution = distribution.view(-1, distribution.size(-1)) soft_out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
soft_out = torch.mm(
distribution, self.embed_adapter.weight.float()
).view(seq_len, bsz, -1).type_as(representation)
out = representation + soft_out out = representation + soft_out
elif self.adapter_type == "none": elif self.adapter_type == "none":
......
...@@ -197,6 +197,12 @@ class SequenceGenerator(nn.Module): ...@@ -197,6 +197,12 @@ class SequenceGenerator(nn.Module):
) )
net_input = sample["net_input"] net_input = sample["net_input"]
if "transcript" in sample:
text_src_tokens = sample["transcript"]["tokens"]
text_src_lengths = sample["transcript"]["lengths"]
net_input["text_src_tokens"] = text_src_tokens
net_input["text_src_lengths"] = text_src_lengths
if "src_tokens" in net_input: if "src_tokens" in net_input:
src_tokens = net_input["src_tokens"] src_tokens = net_input["src_tokens"]
# length of the source text being the character length except EndOfSentence and pad # length of the source text being the character length except EndOfSentence and pad
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论