Commit 8b50c392 by xuchen

fix the bugs

parent 244e506e
......@@ -44,10 +44,10 @@ lcrm=1
tokenizer=0
use_specific_dict=1
specific_prefix=asr5k_st10k
specific_dir=${root_dir}/data/${dataset}/st_lcrm_asr
src_vocab_prefix=spm_unigram5000_asr
tgt_vocab_prefix=spm_unigram10000_st
specific_prefix=unified
specific_dir=${root_dir}/data/${dataset}/vocab
src_vocab_prefix=spm_en
tgt_vocab_prefix=spm_zh
org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/mt
......
arch: s2t_dual
asr-encoder: pds
mt-encoder-layers: 30
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 1000
lr: 5e-4
#lr: 1e-5
adam_betas: (0.9,0.98)
criterion: join_speech_and_text_loss
label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.15
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 18
#text-encoder-layers: 30
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#acoustic-encoder: pds
#adapter: league
encoder-embed-dim: 512
#ctc-layer: 12
pds-stages: 4
pds-layers: 6_3_3_6
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt
load-pretrained-asr-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/asr/0308_lcrm_unified_pds_base_8_grow_conformer_ctc_baseline_clamp/avg_10_checkpoint.pt
load-pretrained-mt-encoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
arch: s2t_sate
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 1000
lr: 5e-4
#lr: 1e-5
adam_betas: (0.9,0.98)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
encoder-normalize-before: True
decoder-normalize-before: True
subsampling-type: conv1d
subsmapling-layers: 2
subsampling-filter: 2048
subsampling-kernel: 5
subsampling-stride: 2
subsampling-norm: none
subsampling-activation: glu
dropout: 0.15
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 18
text-encoder-layers: 30
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
acoustic-encoder: pds
adapter: league
encoder-embed-dim: 512
#ctc-layer: 12
pds-stages: 4
pds-layers: 6_3_3_6
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_384_384_512
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_4_4_4
pds-attn-heads: 4_6_6_8
#load-pretrained-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/iwslt2022/st/0308_lcrm_unified_sate_big_pds_grow_conformer_ctc_pretrain_con/checkpoint_best.pt
load-pretrained-acoustic-encoder-from: /home/xuchen/st/checkpoints/iwslt2022/asr/0308_lcrm_unified_pds_base_8_grow_conformer_ctc_baseline_clamp/avg_10_checkpoint.pt
load-pretrained-text-encoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
load-pretrained-decoder-from: /home/xuchen/st/checkpoints/wmt20/mt/0304_unified_lcrm_tok_deep_baseline/avg_5_checkpoint.pt
......@@ -45,10 +45,10 @@ tokenizer=0
use_raw_audio=0
use_specific_dict=1
specific_prefix=asr
specific_dir=${root_dir}/data/${dataset}/asr
asr_vocab_prefix=spm_unigram5000_asr
st_vocab_prefix=
specific_prefix=unified
specific_dir=${root_dir}/data/${dataset}/vocab
asr_vocab_prefix=spm_en
st_vocab_prefix=spm_zh
org_data_dir=${root_dir}/data/${dataset}
data_dir=${root_dir}/data/${dataset}/st
......
arch: s2t_dual
#pds-ctc: 0_1_1_0
#intermedia-adapter: league
#intermedia-ctc-weight: 0.1
#encoder-attention-type: reduced
#pds-attn-ds-ratios: 4_2_2_1
#attention-reduced-method: pool
#attention-reduced-q: True
#inter-mixup: True
#inter-mixup-layer: 0
#inter-mixup-beta: 0.5
asr-encoder: sate
mt-encoder-layers: 3
mt-encoder-dim: 256
mt-encoder-layers: 6
mt-encoder: transformer
encoder-drop-net: True
encoder-drop-net-prob: 0.8
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 15
encoder-layers: 6
pds-layers: 2_1_1_2
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 192_256_256_384
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_4
pds-attn-heads: 4_4_4_6
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
......@@ -42,17 +30,22 @@ warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
ctc-weight: 0.3
criterion: join_speech_and_text_loss
ctc-weight: 0.3
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
#load-pretrained-asr-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/0225_st_purectc_pds_base_8_baseline_topctc/avg_10_checkpoint.pt
#load-pretrained-mt-encoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt
\ No newline at end of file
......@@ -9,10 +9,12 @@ lr: 1e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc
ctc-weight: 0.3
intermedia-ctc-layers: 2,4
label_smoothing: 0.1
ctc-weight: 0.2
intermedia-ctc-weight: 0.1
intermedia-ctc-layers: 2,4
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
......
ctc-weight: 0.2
intermedia-ctc-weight: 0.1
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
......
arch: s2t_dual
asr-encoder: pds
mt-encoder-layers: 6
mt-encoder: transformer
encoder-drop-net: True
encoder-drop-net-prob: 0.5
encoder-embed-dim: 256
pds-stages: 4
#ctc-layer: 12
pds-layers: 3_3_3_3
pds-ratios: 2_2_1_2
pds-fusion: True
pds-fusion-method: all_conv
pds-embed-dims: 256_256_256_256
pds-ds-method: conv
pds-embed-norm: True
pds-position-embed: 1_1_1_1
pds-kernel-sizes: 5_5_5_5
pds-ffn-ratios: 8_8_8_8
pds-attn-heads: 4_4_4_4
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 10000
lr: 2e-3
adam_betas: (0.9,0.98)
criterion: join_speech_and_text_loss
label_smoothing: 0.1
dropout: 0.1
activation-fn: relu
encoder-ffn-embed-dim: 2048
encoder-layers: 12
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-asr-encoder-from: /home/xuchen/st/checkpoints/mustc/asr/0225_st_purectc_pds_base_8_baseline_topctc/avg_10_checkpoint.pt
#load-pretrained-mt-encoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt
#load-pretrained-decoder-from: /home/xuchen/st/checkpoints/mustc/mt/0223_st_small_baseline/avg_10_checkpoint.pt
\ No newline at end of file
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
......@@ -170,10 +170,11 @@ class CtcCriterion(FairseqCriterion):
tokens = sample["transcript"]["tokens"]
else:
tokens = sample["target"]
# if "ctc_padding_mask" in net_output:
# non_padding_mask = ~net_output["ctc_padding_mask"][0]
# else:
# non_padding_mask = ~net_output["encoder_padding_mask"][0]
if "ctc_padding_mask" in net_output:
non_padding_mask = ~net_output["ctc_padding_mask"][0]
else:
non_padding_mask = ~net_output["encoder_padding_mask"][0]
# non_padding_mask = ~net_output["encoder_padding_mask"][0]
mixup = False
if "mixup" in net_output and net_output["mixup"] is not None:
......@@ -182,7 +183,6 @@ class CtcCriterion(FairseqCriterion):
mixup_idx1 = net_output["mixup"]["index1"]
mixup_idx2 = net_output["mixup"]["index2"]
non_padding_mask = ~net_output["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (tokens != self.pad_idx) & (
......@@ -349,6 +349,7 @@ class CtcCriterion(FairseqCriterion):
self.ctc_weight * ctc_loss + \
self.intermedia_ctc_weight * intermedia_ctc_loss + \
self.target_ctc_weight * target_ctc_loss + \
self.target_intermedia_ctc_weight * target_intermedia_ctc_loss + \
self.ctc_self_distill_weight * ctc_self_distill_loss + \
self.ctc_entropy * ctc_entropy
......@@ -452,6 +453,9 @@ class CtcCriterion(FairseqCriterion):
target_ctc_loss_sum = utils.item(
sum(log.get("target_ctc_loss", 0) for log in logging_outputs)
)
target_intermedia_ctc_loss_sum = utils.item(
sum(log.get("target_intermedia_ctc_loss", 0) for log in logging_outputs)
)
ctc_self_distill_loss_sum = utils.item(
sum(log.get("ctc_self_distill_loss", 0) for log in logging_outputs)
)
......@@ -513,6 +517,13 @@ class CtcCriterion(FairseqCriterion):
sample_size,
round=3,
)
if target_intermedia_ctc_loss_sum > 0:
metrics.log_scalar(
"target_intermedia_ctc_loss",
target_intermedia_ctc_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if ctc_self_distill_loss_sum > 0:
metrics.log_scalar(
......
......@@ -835,13 +835,13 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", "avg")
adapter = Adapter(embed_dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy)
len(task.source_dictionary), strategy=strategy)
inter_adapter = adapter
else:
adapter = inter_adapter
else:
adapter = Adapter(embed_dim, "none",
task.source_dictionary)
len(task.source_dictionary))
else:
ctc = None
adapter = None
......@@ -860,10 +860,11 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
self.fusion_weight = nn.Parameter(torch.Tensor(fusion_stages_num).fill_(1.0))
self.fusion_weight.data = self.fusion_weight.data / self.fusion_weight.data.sum(0, keepdim=True)
self.use_ctc = "sate" in args.arch or \
(getattr(args, "criterion", "") == "ctc") or \
(("ctc" in getattr(args, "criterion", "")) and
(getattr(args, "ctc_weight", False) > 0))
# self.use_ctc = "sate" in args.arch or \
# (getattr(args, "criterion", "") == "ctc") or \
# (("ctc" in getattr(args, "criterion", "")) and
# (getattr(args, "ctc_weight", False) > 0))
self.use_ctc = "sate" in args.arch or (getattr(args, "ctc_weight", 0) > 0)
if self.use_ctc:
# self.ctc_layer = (args.ctc_layer + self.layers) % self.layers
# self.ctc_layer = self.layers if self.ctc_layer == 0 else self.ctc_layer
......
......@@ -143,7 +143,7 @@ class S2TDualModel(FairseqEncoderDecoderModel):
parser.add_argument(
"--mt-encoder-layers",
default=6,
type=str,
type=int,
help="the layers of the MT encoder",
)
parser.add_argument(
......@@ -175,6 +175,18 @@ class S2TDualModel(FairseqEncoderDecoderModel):
help="mix the two input with any probability",
)
parser.add_argument(
"--load-pretrained-asr-encoder-from",
type=str,
metavar="STR",
help="model to take asr encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-mt-encoder-from",
type=str,
metavar="STR",
help="model to take mt encoder weights from (for initialization)",
)
pass
@classmethod
......@@ -190,7 +202,7 @@ class S2TDualModel(FairseqEncoderDecoderModel):
f"{args.load_pretrained_encoder_from}"
)
if getattr(args, "load_pretrained_asr_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
encoder.asr_encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder.asr_encoder, checkpoint=args.load_pretrained_asr_encoder_from, strict=False
)
logger.info(
......@@ -198,7 +210,7 @@ class S2TDualModel(FairseqEncoderDecoderModel):
f"{args.load_pretrained_asr_encoder_from}"
)
if getattr(args, "load_pretrained_mt_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
encoder.mt_encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder.mt_encoder, checkpoint=args.load_pretrained_mt_encoder_from, strict=False
)
logger.info(
......@@ -314,23 +326,96 @@ class S2TDualEncoder(FairseqEncoder):
else:
logger.error("Unsupported ASR architecture: %s." % asr_encoder_type)
attn_type = args.encoder_attention_type
setattr(args, "encoder_layers", args.mt_encoder_layers)
setattr(args, "encoder_attention_type", "selfattn")
self.mt_encoder = TransformerS2Encoder(args, task.source_dictionary, embed_tokens)
setattr(args, "encoder_attention_type", attn_type)
def forward(self, speech_src_tokens, speech_src_lengths, text_src_tokens, text_src_lengths, **kwargs):
asr_encoder_out = self.asr_encoder(speech_src_tokens, speech_src_lengths)
ctc_logit = asr_encoder_out["ctc_logit"]
encoder_representation = asr_encoder_out["encoder_out"][0]
encoder_padding_mask = asr_encoder_out["encoder_padding_mask"][0]
encoder_out = self.mt_encoder(text_src_tokens, text_src_lengths,
encoder_representation, encoder_padding_mask)
encoder_out["ctc_logit"] = ctc_logit
encoder_out["ctc_logit"] = asr_encoder_out["ctc_logit"]
encoder_out["ctc_padding_mask"] = asr_encoder_out["encoder_padding_mask"]
return encoder_out
def forward_torchscript(self, net_input: Dict[str, Tensor]):
speech_src_tokens = net_input["src_tokens"]
speech_src_lengths = net_input["src_lengths"]
text_src_tokens = net_input["text_src_tokens"]
text_src_lengths = net_input["text_src_lengths"]
encoder_out = self.forward(speech_src_tokens, speech_src_lengths, text_src_tokens, text_src_lengths)
return encoder_out
def reorder_encoder_out(self, encoder_out, new_order):
self.mt_encoder.reorder_encoder_out(encoder_out, new_order)
return
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if len(encoder_out["encoder_out"]) == 0:
new_encoder_out = []
else:
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask"]) == 0:
new_encoder_padding_mask = []
else:
new_encoder_padding_mask = [
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
]
if len(encoder_out["encoder_out_s2"]) == 0:
new_encoder_out_s2 = []
else:
new_encoder_out_s2 = [encoder_out["encoder_out_s2"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask_s2"]) == 0:
new_encoder_padding_mask_s2 = []
else:
new_encoder_padding_mask_s2 = [
encoder_out["encoder_padding_mask_s2"][0].index_select(0, new_order)
]
if len(encoder_out["encoder_embedding"]) == 0:
new_encoder_embedding = []
else:
new_encoder_embedding = [
encoder_out["encoder_embedding"][0].index_select(0, new_order)
]
if len(encoder_out["src_tokens"]) == 0:
src_tokens = []
else:
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
if len(encoder_out["src_lengths"]) == 0:
src_lengths = []
else:
src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_out_s2": new_encoder_out_s2, # T x B x C
"encoder_padding_mask_s2": new_encoder_padding_mask_s2, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": src_tokens, # B x T
"src_lengths": src_lengths, # B x 1
}
@register_model_architecture(model_name="s2t_dual", arch_name="s2t_dual")
......
......@@ -222,7 +222,7 @@ class TextEncoder(FairseqEncoder):
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter,
dictionary, embed_tokens=embed_tokens,
len(dictionary), embed_tokens=embed_tokens,
strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
......@@ -301,7 +301,7 @@ class S2TSATEEncoder(FairseqEncoder):
self.adapter = Adapter(args.encoder_embed_dim,
args.adapter,
task.source_dictionary,
len(task.source_dictionary),
decoder_embed_tokens if task.source_dictionary == task.target_dictionary else None,
strategy=strategy)
......@@ -352,13 +352,14 @@ class S2TSATEEncoder(FairseqEncoder):
self.history.push(x)
x, target_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
x, target_ctc_logit, target_intermedia_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [ctc_logit], # T x B x C
"intermedia_ctc_logits": acoustic_encoder_out.get("intermedia_ctc_logits", []), # B x T x C
"target_ctc_logits": target_ctc_logits, # B x T x C
"target_ctc_logit": target_ctc_logit, # B x T x C
"target_intermedia_ctc_logits": target_intermedia_ctc_logits, # B x T x C
"ctc_padding_mask": [ctc_padding_mask], # B x T
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C
......
......@@ -597,7 +597,7 @@ class S2TTransformerEncoder(FairseqEncoder):
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(dim, args.intermedia_adapter,
task.source_dictionary, strategy=strategy)
len(task.source_dictionary), strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
......
......@@ -295,42 +295,6 @@ class TransformerModel(FairseqEncoderDecoderModel):
action='store_true',
help="use squeeze and excitation method",
)
# CTC
parser.add_argument(
"--ctc-layer",
type=int,
help="ctc layers for target sentence",
)
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
parser.add_argument(
"--intermedia-temperature",
default=1,
type=float,
help="temperature of the intermedia ctc probability",
)
# fmt: on
@classmethod
......@@ -571,47 +535,6 @@ class TransformerEncoder(FairseqEncoder):
else:
self.history = None
# CTC
self.use_ctc = getattr(args, "ctc_weight", 0) > 0
if self.use_ctc:
self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
self.ctc.ctc_projection.weight = embed_tokens.weight
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
layer_idx = int(layer_idx)
if layer_idx <= 0:
layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
if not self.use_ctc:
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout)
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter,
None, embed_tokens=decoder_embed_tokens, strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args)
......@@ -732,9 +655,6 @@ class TransformerEncoder(FairseqEncoder):
self.history.push(x)
# encoder layers
layer_idx = 0
ctc_logit = None
intermedia_ctc_logits = []
for layer in self.layers:
if self.history is not None:
x = self.history.pop()
......@@ -742,29 +662,10 @@ class TransformerEncoder(FairseqEncoder):
x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None
)
layer_idx += 1
if return_all_hiddens:
assert encoder_states is not None
encoder_states.append(x)
# CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone())
# Intermedia CTC
if layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0:
p = torch.rand(1).uniform_()
if p < self.intermedia_drop_prob:
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
if self.history is not None:
self.history.push(x)
......@@ -774,16 +675,12 @@ class TransformerEncoder(FairseqEncoder):
if self.layer_norm is not None:
x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
......@@ -1436,12 +1333,6 @@ def base_architecture(args):
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", args.encoder_layers)
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
@register_model_architecture("transformer", "transformer_relative")
def transformer_rpr(args):
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Any, Dict, List, Optional, Tuple
import logging
import torch
import torch.nn as nn
from fairseq import checkpoint_utils, utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules.speech_to_text import Adapter, CTC
from fairseq.modules import (
AdaptiveSoftmax,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
RelPositionalEncoding,
LegacyRelPositionalEncoding,
TransformerDecoderLayer,
TransformerEncoderLayer,
DynamicLinearCombination
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
@register_model("transformer_ctc")
class TransformerCTCModel(FairseqEncoderDecoderModel):
"""
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_.
Args:
encoder (TransformerEncoder): the encoder
decoder (TransformerDecoder): the decoder
The Transformer model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.transformer_parser
:prog:
"""
@classmethod
def hub_models(cls):
# fmt: off
def moses_subword(path):
return {
'path': path,
'tokenizer': 'moses',
'bpe': 'subword_nmt',
}
def moses_fastbpe(path):
return {
'path': path,
'tokenizer': 'moses',
'bpe': 'fastbpe',
}
def spm(path):
return {
'path': path,
'bpe': 'sentencepiece',
'tokenizer': 'space',
}
return {
'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'),
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
'transformer.wmt18.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'),
'transformer.wmt19.en-de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'),
'transformer.wmt19.en-ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'),
'transformer.wmt19.de-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'),
'transformer.wmt19.ru-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'),
'transformer.wmt19.en-de.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'),
'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'),
'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'),
'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'),
'transformer.wmt20.en-ta': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz'),
'transformer.wmt20.en-iu.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz'),
'transformer.wmt20.en-iu.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz'),
'transformer.wmt20.ta-en': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz'),
'transformer.wmt20.iu-en.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz'),
'transformer.wmt20.iu-en.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz'),
}
# fmt: on
def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder)
self.args = args
self.supports_align_args = True
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
help='decoder output dimension (extra linear layer '
'if different from decoder embed dim')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--layernorm-embedding', action='store_true',
help='add layernorm to embedding')
parser.add_argument('--no-scale-embedding', action='store_true',
help='if True, dont scale embeddings')
parser.add_argument('--checkpoint-activations', action='store_true',
help='checkpoint activations at each layer, which saves GPU '
'memory usage at the cost of some additional compute')
parser.add_argument('--offload-activations', action='store_true',
help='checkpoint activations at each layer, then save to gpu. '
'Sets --checkpoint-activations.')
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
parser.add_argument('--no-cross-attention', default=False, action='store_true',
help='do not perform cross-attention')
parser.add_argument('--cross-self-attention', default=False, action='store_true',
help='perform cross+self-attention')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
parser.add_argument('--encoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
help='iterative PQ quantization noise at training time')
parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time')
parser.add_argument(
"--encoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
"rel_pos",
"rel_pos_legacy"
],
help="transformer encoder self-attention layer type"
)
parser.add_argument(
"--decoder-attention-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"relative",
],
help="transformer decoder self-attention layer type"
)
# DLCL parameters
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument('--init-value', type=str, default='avg', choices=['avg', 'one'],
help='how to init the learned weight matrix')
parser.add_argument('--weight-type', type=str, default='scalar',
help='type of learned weight [scalar, scalar_n(n>1), vector]')
parser.add_argument('--encoder-learnable', type=eval, default='True',
help='enable to learn weights for encoder')
parser.add_argument('--decoder-learnable', type=eval, default='True',
help='enable to learn weights for decoder')
parser.add_argument('--normalize-learned-weight', type=eval, default='False',
help='normalize learned weight by softmax')
parser.add_argument('--normalize-embedding', type=eval, default='False',
help='normalize the input of embedding')
parser.add_argument('--history-dropout', type=float, default=0.0, metavar='D',
help='dropout for history output')
parser.add_argument('--history-window-size', type=int, default='-1',
help='how many past layers are considered. -1 means all')
# relative position representation
parser.add_argument('--max-encoder-relative-length', type=int, default=-1,
help='the max encoder relative length')
parser.add_argument('--max-decoder-relative-length', type=int, default=-1,
help='the max decoder relative length')
parser.add_argument('--k-only', default=False, action='store_true',
help='select the relative mode to map relative position information')
# args for loading pre-trained models
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--encoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the encoder",
)
parser.add_argument(
"--decoder-freeze-module",
type=str,
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument('--interleave-dropout', default=0, type=float, metavar='D',
help='interleaved dropout probability')
parser.add_argument(
"--squeeze-excitation",
default=False,
action='store_true',
help="use squeeze and excitation method",
)
# CTC
parser.add_argument(
"--ctc-layer",
type=int,
help="ctc layers for target sentence",
)
parser.add_argument(
"--intermedia-ctc-layers",
default=None,
type=str,
help="the position of the ctc loss, separated by comma ",
)
parser.add_argument(
"--intermedia-adapter",
default="none",
type=str,
help="type of intermedia adapter",
)
parser.add_argument(
"--intermedia-distribution-cutoff",
default=None,
type=int,
help="cutoff of the distribution",
)
parser.add_argument(
"--intermedia-drop-prob",
default=0,
type=float,
help="probability of dropping the followed layers",
)
parser.add_argument(
"--intermedia-temperature",
default=1,
type=float,
help="temperature of the intermedia ctc probability",
)
# fmt: on
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if getattr(args, "max_source_positions", None) is None:
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
encoder_embed_tokens = cls.build_embedding(
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = cls.build_embedding(
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = cls.build_embedding(
args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens, decoder_embed_tokens)
if getattr(args, "encoder_freeze_module", None):
utils.freeze_parameters(encoder, args.encoder_freeze_module)
logging.info("freeze the encoder module: {}".format(args.encoder_freeze_module))
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
if getattr(args, "decoder_freeze_module", None):
utils.freeze_parameters(decoder, args.decoder_freeze_module)
logging.info("freeze the decoder module: {}".format(args.decoder_freeze_module))
if not args.share_all_embeddings:
encoder = fsdp_wrap(encoder, min_num_params=1e8)
decoder = fsdp_wrap(decoder, min_num_params=1e8)
return cls(args, encoder, decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens, decoder_embed_tokens=None):
encoder = TransformerCTCEncoder(args, src_dict, embed_tokens, decoder_embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
return encoder
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
decoder = TransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, "no_cross_attention", False),
)
if getattr(args, "load_pretrained_decoder_from", None):
logger.info(
f"loaded pretrained decoder from: "
f"{args.load_pretrained_decoder_from}"
)
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from, strict=False
)
return decoder
# TorchScript doesn't support optional arguments with variable length (**kwargs).
# Current workaround is to add union of all arguments in child classes.
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
return_all_hiddens: bool = True,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Run the forward pass for an encoder-decoder model.
Copied from the base class, but without ``**kwargs``,
which are not supported by TorchScript.
"""
encoder_out = self.encoder(
src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens
)
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens,
)
return decoder_out
# Since get_normalized_probs is in the Fairseq Model which is not scriptable,
# I rewrite the get_normalized_probs from Base Class to call the
# helper function in the Base Class.
@torch.jit.export
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
class TransformerCTCEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, args, dictionary, embed_tokens, decoder_embed_tokens=None):
self.args = args
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.encoder_layerdrop = args.encoder_layerdrop
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
self.attn_type = getattr(args, "encoder_attention_type", "selfattn")
if self.attn_type == "rel_pos":
self.embed_positions = RelPositionalEncoding(
args.max_source_positions, args.encoder_embed_dim
)
elif self.attn_type in ["rel_selfattn", "rel_pos_legacy"]:
self.embed_positions = LegacyRelPositionalEncoding(
args.encoder_embed_dim, args.dropout, args.max_source_positions
)
elif self.attn_type == "rope":
self.embed_positions = None
else: # Use absolute positional embedding
self.embed_positions = (
PositionalEmbedding(
args.max_source_positions,
embed_dim,
self.padding_idx,
learned=args.encoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
if not args.adaptive_input and args.quant_noise_pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
args.quant_noise_pq,
args.quant_noise_pq_block_size,
)
else:
self.quant_noise = None
if self.encoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
else:
self.layers = nn.ModuleList([])
self.layers.extend(
[self.build_encoder_layer(args) for i in range(args.encoder_layers)]
)
self.num_layers = len(self.layers)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
if args.use_enc_dlcl:
self.history = DynamicLinearCombination(args, is_encoder=True)
else:
self.history = None
# CTC
self.use_ctc = getattr(args, "ctc_weight", 0) > 0
if self.use_ctc:
self.ctc_layer = args.ctc_layer
self.inter_ctc = True if self.ctc_layer != args.encoder_layers else False
if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout,
need_layernorm=True if self.inter_ctc else False)
self.ctc.ctc_projection.weight = embed_tokens.weight
self.intermedia_ctc_layers = []
if args.intermedia_ctc_layers is not None:
intermedia_ctc_layers = args.intermedia_ctc_layers.split(",")
for layer_idx in intermedia_ctc_layers:
layer_idx = int(layer_idx)
if layer_idx <= 0:
layer_idx += args.encoder_layers
self.intermedia_ctc_layers.append(layer_idx)
logger.info("Intermedia CTC loss in layer %d" % layer_idx)
if not self.use_ctc:
self.ctc = CTC(embed_dim,
dictionary_size=decoder_embed_tokens.num_embeddings,
dropout=args.dropout)
self.ctc.ctc_projection.weight = embed_tokens.weight
strategy = None
if args.intermedia_adapter == "shrink":
strategy = getattr(args, "ctc_compress_strategy", None)
elif args.intermedia_adapter == "league":
strategy = getattr(args, "intermedia_distribution_cutoff", None)
self.adapter = Adapter(embed_dim, args.intermedia_adapter,
decoder_embed_tokens.num_embeddings, embed_tokens=decoder_embed_tokens, strategy=strategy)
self.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
self.intermedia_temperature = getattr(args, "intermedia_temperature", 1)
def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args)
if getattr(args, "checkpoint_activations", False):
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
layer = fsdp_wrap(layer, min_num_params=1e8)
return layer
def forward_embedding(
self, src_tokens, token_embedding: Optional[torch.Tensor] = None
):
# embed tokens and positions
if token_embedding is None:
token_embedding = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
if self.quant_noise is not None:
x = self.quant_noise(x)
return x, embed
def forward(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
return self.forward_scriptable(src_tokens,
src_lengths,
return_all_hiddens,
token_embeddings)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def forward_scriptable(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any())
if self.history is not None:
self.history.clean()
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
# account for padding while computing the representation
if encoder_padding_mask is not None:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_states = []
if return_all_hiddens:
encoder_states.append(x)
# add emb into history
if self.history is not None:
self.history.push(x)
# encoder layers
layer_idx = 0
ctc_logit = None
intermedia_ctc_logits = []
for layer in self.layers:
if self.history is not None:
x = self.history.pop()
x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None
)
layer_idx += 1
if return_all_hiddens:
assert encoder_states is not None
encoder_states.append(x)
# CTC
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(x.clone())
# Intermedia CTC
if layer_idx in self.intermedia_ctc_layers:
if self.intermedia_drop_prob > 0:
p = torch.rand(1).uniform_()
if p < self.intermedia_drop_prob:
break
norm_x = self.layer_norm(x)
logit = self.ctc(norm_x)
intermedia_ctc_logits.append(logit)
prob = utils.softmax(logit / self.intermedia_temperature, dim=-1)
x, encoder_padding_mask = self.adapter([x, prob], encoder_padding_mask)
if self.history is not None:
self.history.push(x)
if self.history is not None:
x = self.history.pop()
if self.layer_norm is not None:
x = self.layer_norm(x)
if self.use_ctc and ctc_logit is None:
ctc_logit = self.ctc(x)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
return {
"encoder_out": [x], # T x B x C
"ctc_logit": [] if ctc_logit is None else [ctc_logit], # T x B x C
"intermedia_ctc_logits": intermedia_ctc_logits, # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
}
@torch.jit.export
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if len(encoder_out["encoder_out"]) == 0:
new_encoder_out = []
else:
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask"]) == 0:
new_encoder_padding_mask = []
else:
new_encoder_padding_mask = [
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
]
if len(encoder_out["encoder_embedding"]) == 0:
new_encoder_embedding = []
else:
new_encoder_embedding = [
encoder_out["encoder_embedding"][0].index_select(0, new_order)
]
if len(encoder_out["src_tokens"]) == 0:
src_tokens = []
else:
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
if len(encoder_out["src_lengths"]) == 0:
src_lengths = []
else:
src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": src_tokens, # B x T
"src_lengths": src_lengths, # B x 1
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions)
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
print("deleting {0}".format(weights_key))
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
for i in range(self.num_layers):
# update layer norms
self.layers[i].upgrade_state_dict_named(
state_dict, "{}.layers.{}".format(name, i)
)
version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
class TransformerDecoder(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
self.args = args
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
self._future_mask = torch.empty(0)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.decoder_layerdrop = args.decoder_layerdrop
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
self.embed_dim = embed_dim
self.output_embed_dim = args.decoder_output_dim
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
if not args.adaptive_input and args.quant_noise_pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
args.quant_noise_pq,
args.quant_noise_pq_block_size,
)
else:
self.quant_noise = None
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
self.max_target_positions,
embed_dim,
self.padding_idx,
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
self.cross_self_attention = getattr(args, "cross_self_attention", False)
if self.decoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
else:
self.layers = nn.ModuleList([])
self.layers.extend(
[
self.build_decoder_layer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
]
)
self.num_layers = len(self.layers)
self.attn_type = getattr(args, "decoder_attention_type", "selfattn")
if args.decoder_normalize_before and not getattr(
args, "no_decoder_final_norm", False
):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
if args.use_dec_dlcl:
self.history = DynamicLinearCombination(args, is_encoder=False)
else:
self.history = None
self.project_out_dim = (
Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
else None
)
self.adaptive_softmax = None
self.output_projection = None
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif self.share_input_output_embed:
self.output_projection = nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
self.output_projection.weight = self.embed_tokens.weight
else:
self.output_projection = nn.Linear(
self.output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
)
self.gather_attn_weight = getattr(args, "gather_attn_weight", False)
#self.gather_attn_weight = True
self.attn_weights = dict()
def build_decoder_layer(self, args, no_encoder_attn=False):
layer = TransformerDecoderLayer(args, no_encoder_attn)
if getattr(args, "checkpoint_activations", False):
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
layer = fsdp_wrap(layer, min_num_params=1e8)
return layer
def forward(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
full_context_alignment=full_context_alignment,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
return self.extract_features_scriptable(
prev_output_tokens,
encoder_out,
incremental_state,
full_context_alignment,
alignment_layer,
alignment_heads,
)
"""
A scriptable subclass of this class has an extract_features method and calls
super().extract_features, but super() is not supported in torchscript. A copy of
this function is made to be used in the subclass instead.
"""
def extract_features_scriptable(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
if self.history is not None:
self.history.clean()
if alignment_layer is None:
alignment_layer = self.num_layers - 1
# embed positions
positions = None
if self.embed_positions is not None:
positions = self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None and self.attn_type != "rel_selfattn":
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# add emb into history
if self.history is not None:
self.history.push(x)
self_attn_padding_mask: Optional[Tensor] = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
mixup = None
if "mixup" in encoder_out and encoder_out["mixup"] is not None:
mixup = encoder_out["mixup"]
coef = mixup["coef"]
idx1 = mixup["index1"]
idx2 = mixup["index2"]
x1 = x[:, idx1]
x2 = x[:, idx2]
x = coef * x1 + (1 - coef) * x2
if self_attn_padding_mask is not None:
pad1 = self_attn_padding_mask[idx1]
pad2 = self_attn_padding_mask[idx2]
self_attn_padding_mask = pad1 + pad2
# decoder layers
avg_attn = None
attn: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x]
for idx, layer in enumerate(self.layers):
if self.history is not None:
x = self.history.pop()
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn, _ = layer(
x,
encoder_out["encoder_out"][0]
if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
else None,
encoder_out["encoder_padding_mask"][0]
if (
encoder_out is not None
and len(encoder_out["encoder_padding_mask"]) > 0
)
else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer) or self.gather_attn_weight),
need_head_weights=bool((idx == alignment_layer) or self.gather_attn_weight),
pos_emb=positions
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if self.history is not None:
self.history.push(x)
if self.gather_attn_weight:
if avg_attn is None:
avg_attn = layer_attn
else:
avg_attn += layer_attn
if self.gather_attn_weight:
avg_attn = avg_attn / len(self.layers)
attn = avg_attn.mean(0).sum(-2)
attn = torch.reshape(attn, [attn.numel()])
attn = attn // 0.001
attn = attn.int().cpu()
if len(encoder_out["encoder_padding_mask"]) > 0:
mask = encoder_out["encoder_padding_mask"][0]
mask = torch.reshape(mask, [mask.numel()])
else:
mask = None
i = -1
for item in attn:
i += 1
if mask[i]:
continue
idx = int(item) * 0.001
if idx not in self.attn_weights:
self.attn_weights[idx] = 0
self.attn_weights[idx] += 1
elif attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.history is not None:
x = self.history.pop()
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {"attn": [attn], "inner_states": inner_states, "mixup": mixup}
def output_layer(self, features):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
return self.output_projection(features)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions)
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
if (
self._future_mask.size(0) == 0
or (not self._future_mask.device == tensor.device)
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
)
self._future_mask = self._future_mask.to(tensor)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
if f"{name}.output_projection.weight" not in state_dict:
if self.share_input_output_embed:
embed_out_key = f"{name}.embed_tokens.weight"
else:
embed_out_key = f"{name}.embed_out"
if embed_out_key in state_dict:
state_dict[f"{name}.output_projection.weight"] = state_dict[
embed_out_key
]
if not self.share_input_output_embed:
del state_dict[embed_out_key]
for i in range(self.num_layers):
# update layer norms
layer_norm_map = {
"0": "self_attn_layer_norm",
"1": "encoder_attn_layer_norm",
"2": "final_layer_norm",
}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
if k in state_dict:
state_dict[
"{}.layers.{}.{}.{}".format(name, i, new, m)
] = state_dict[k]
del state_dict[k]
version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
@register_model_architecture("transformer_ctc", "transformer_ctc_tiny")
def tiny_architecture(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 64)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 64)
args.encoder_layers = getattr(args, "encoder_layers", 2)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2)
args.decoder_layers = getattr(args, "decoder_layers", 2)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2)
return base_architecture(args)
@register_model_architecture("transformer_ctc", "transformer_ctc")
def base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.no_cross_attention = getattr(args, "no_cross_attention", False)
args.cross_self_attention = getattr(args, "cross_self_attention", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
args.encoder_attention_type = getattr(args, "encoder_attention_type", "selfattn")
args.decoder_attention_type = getattr(args, "decoder_attention_type", "selfattn")
# settings for DLCL
args.use_enc_dlcl = getattr(args, "use_enc_dlcl", False)
args.use_dec_dlcl = getattr(args, "use_dec_dlcl", False)
args.init_value = getattr(args, 'init_value', 'avg')
args.weight_type = getattr(args, 'weight_type', 'scalar')
args.encoder_learnable = getattr(args, 'encoder_learnable', True)
args.decoder_learnable = getattr(args, 'decoder_learnable', True)
args.normalize_embed = getattr(args, 'normalize_embed', False)
args.history_dropout = getattr(args, 'history_dropout', 0.0)
args.history_window_size = getattr(args, 'history_window_size', -1)
# settings for RPR
args.max_encoder_relative_length = getattr(args, 'max_encoder_relative_length', -1)
args.max_decoder_relative_length = getattr(args, 'max_decoder_relative_length', -1)
args.k_only = getattr(args, 'k_only', True)
# CTC
args.ctc_layer = getattr(args, "ctc_layer", args.encoder_layers)
args.intermedia_ctc_layers = getattr(args, "intermedia_ctc_layers", None)
args.intermedia_adapter = getattr(args, "intermedia_adapter", None)
args.intermedia_drop_prob = getattr(args, "intermedia_drop_prob", 0)
@register_model_architecture("transformer_ctc", "transformer_ctc_relative")
def transformer_ctc_rpr(args):
args.max_encoder_relative_length = 20
args.max_decoder_relative_length = 20
args.k_only = True
base_architecture(args)
@register_model_architecture("transformer_ctc", "transformer_ctc_iwslt_de_en")
def transformer_ctc_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.decoder_layers = getattr(args, "decoder_layers", 6)
base_architecture(args)
@register_model_architecture("transformer_ctc", "transformer_ctc_wmt_en_de")
def transformer_ctc_wmt_en_de(args):
base_architecture(args)
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture("transformer_ctc", "transformer_ctc_vaswani_wmt_en_de_big")
def transformer_ctc_vaswani_wmt_en_de_big(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.3)
base_architecture(args)
@register_model_architecture("transformer_ctc", "transformer_ctc_vaswani_wmt_en_fr_big")
def transformer_ctc_vaswani_wmt_en_fr_big(args):
args.dropout = getattr(args, "dropout", 0.1)
transformer_ctc_vaswani_wmt_en_de_big(args)
@register_model_architecture("transformer_ctc", "transformer_ctc_wmt_en_de_big")
def transformer_ctc_wmt_en_de_big(args):
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
transformer_ctc_vaswani_wmt_en_de_big(args)
# default parameters used in tensor2tensor implementation
@register_model_architecture("transformer_ctc", "transformer_ctc_wmt_en_de_big_t2t")
def transformer_ctc_wmt_en_de_big_t2t(args):
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_dropout = getattr(args, "activation_dropout", 0.1)
transformer_ctc_vaswani_wmt_en_de_big(args)
......@@ -57,7 +57,7 @@ class CTCCompressStrategy:
class Adapter(nn.Module):
def __init__(self, dim, adapter_type, dictionary, embed_tokens=None, strategy=None):
def __init__(self, dim, adapter_type, dictionary_size, embed_tokens=None, strategy=None):
super().__init__()
dim = dim
......@@ -71,11 +71,14 @@ class Adapter(nn.Module):
)
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
if embed_tokens is None:
num_embeddings = len(dictionary)
self.embed_adapter = nn.Linear(num_embeddings, dim) # Embedding(num_embeddings, dim, dictionary.pad())
else:
self.embed_adapter = embed_tokens
self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization
if embed_tokens is not None:
self.embed_adapter.weight = embed_tokens.weight
# if embed_tokens is None:
# num_embeddings = len(dictionary)
# self.embed_adapter = nn.Linear(num_embeddings, dim) # Embedding(num_embeddings, dim, dictionary.pad())
# else:
# self.embed_adapter = embed_tokens
if self.adapter_type == "gated_league":
self.gate_linear = nn.Linear(2 * dim, dim)
......@@ -95,45 +98,40 @@ class Adapter(nn.Module):
def forward(self, x, padding):
representation, distribution = x
distribution = distribution.type_as(representation)
seq_len, bsz, dim = representation.size()
org_distribution = distribution
distribution = distribution.view(-1, distribution.size(-1))
lengths = (~padding).long().sum(-1)
if self.adapter_type == "linear":
out = self.linear_adapter(representation)
elif self.adapter_type == "context":
distribution = distribution.view(-1, distribution.size(-1))
out = torch.mm(
distribution, self.embed_adapter.weight.float()
).view(seq_len, bsz, -1).type_as(representation)
out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
elif self.adapter_type == "league":
linear_out = self.linear_adapter(representation)
if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), distribution.size(-1) - 1)
threshold = distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
distribution = torch.where(distribution > threshold, distribution, torch.zeros_like(distribution))
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(
distribution, self.embed_adapter.weight.float()
).view(seq_len, bsz, -1).type_as(representation)
cutoff = min(int(self.distribution_cutoff), org_distribution.size(-1) - 1)
threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
distribution = torch.where(
org_distribution > threshold, org_distribution, torch.zeros_like(org_distribution)
)
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
out = linear_out + soft_out
elif self.adapter_type == "gated_league":
linear_out = self.linear_adapter(representation)
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(
distribution, self.embed_adapter.weight.float()
).view(seq_len, bsz, -1).type_as(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league":
distribution = distribution.view(-1, distribution.size(-1))
soft_out = torch.mm(
distribution, self.embed_adapter.weight.float()
).view(seq_len, bsz, -1).type_as(representation)
soft_out = torch.mm(distribution, self.embed_adapter.weight.t()).view(seq_len, bsz, -1)
out = representation + soft_out
elif self.adapter_type == "none":
......
......@@ -197,6 +197,12 @@ class SequenceGenerator(nn.Module):
)
net_input = sample["net_input"]
if "transcript" in sample:
text_src_tokens = sample["transcript"]["tokens"]
text_src_lengths = sample["transcript"]["lengths"]
net_input["text_src_tokens"] = text_src_tokens
net_input["text_src_lengths"] = text_src_lengths
if "src_tokens" in net_input:
src_tokens = net_input["src_tokens"]
# length of the source text being the character length except EndOfSentence and pad
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论