Commit 03076942 by xuchen

fix the bugs

parent 1d60b3a6
......@@ -6,6 +6,7 @@ max-update: 100000
patience: 20
best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
post-process: sentencepiece
no-epoch-checkpoints: True
#keep-last-epochs: 10
......
......@@ -31,6 +31,8 @@ decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
cnn-module-norm: layer_norm
load-pretrained-encoder-from: /home/xuchen/after.pt
load-pretrained-decoder-from: /home/xuchen/after.pt
#load-pretrained-decoder-from:
ctc-weight: 0.3
post-process: sentencepiece
share-ctc-and-embed: True
\ No newline at end of file
ctc-weight: 0.2
interleaved-ctc-weight: 0.1
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
sae-adapter: league
sae-drop-prob: 0.2
sae-distribution-cutoff: 10
share-ctc-and-sae: False
ctc-self-distill-weight: 0
inter_mixup: True
inter_mixup_layer: -1
inter_mixup_prob: 1.0
inter_mixup_ratio: 0.2
\ No newline at end of file
......@@ -2,9 +2,9 @@
# training the model
gpu_num=8
gpu_num=2
update_freq=1
max_tokens=40000
max_tokens=160000
extra_tag=
extra_parameter=
......@@ -13,12 +13,12 @@ extra_parameter=
exp_tag=
#config_list=(base ctc)
config_list=(base ctc conformer)
config_list=(big ctc conformer)
config_list=(base ctc)
#config_list=(base ctc conformer)
#config_list=(big ctc conformer)
#config_list=(pds_base_16)
config_list=(pds_base_16 conformer)
#config_list=(pds_base_16 conformer)
# exp full name
exp_name=
......
arch: transformer_ctc
arch: transformer
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
......@@ -8,7 +8,7 @@ warmup-updates: 8000
lr: 1e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc
criterion: label_smoothed_cross_entropy
label_smoothing: 0.1
dropout: 0.3
......
arch: transformer_ctc
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 1e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.3
attention-dropout: 0.0
activation-dropout: 0.0
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 1024
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 512
decoder-ffn-embed-dim: 1024
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
......@@ -6,6 +6,7 @@ max-update: 50000
patience: 20
best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
post-process: sentencepiece
no-epoch-checkpoints: True
#keep-last-epochs: 10
......
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
#ctc-layer:
ctc-weight: 0.2
interleaved-ctc-weight: 0.1
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
interleaved_ctc_upsampling_ratio: 2
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
sae-adapter: league
sae-drop-prob: 0.0
#sae-distribution-cutoff: 10
share-ctc-and-sae: False
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
ctc-self-distill-weight: 0
\ No newline at end of file
arch: s2t_sate
arch: s2t_transformer
share-decoder-input-output-embed: True
optimizer: adam
clip-norm: 10.0
......@@ -37,9 +37,9 @@ activation-dropout: 0.1
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
#inter_mixup: True
#inter_mixup_layer: -1
#inter_mixup_ratio: 0.2
inter_mixup: True
inter_mixup_layer: -1
inter_mixup_ratio: 0.2
ctc-weight: 0.2
interleaved-ctc-weight: 0.1
......@@ -48,8 +48,8 @@ interleaved-temperature: 2
#target-ctc-weight: 0.3
#target-ctc-layer: 6
target-interleaved-ctc-weight: 0.1
target-interleaved-ctc-layers: 2,4
#target-interleaved-ctc-weight: 0.1
#target-interleaved-ctc-layers: 2,4
sae-adapter: league
share-ctc-and-sae: False
......
......@@ -6,6 +6,7 @@ max-update: 100000
patience: 20
best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
post-process: sentencepiece
no-epoch-checkpoints: True
#keep-last-epochs: 10
......@@ -17,4 +18,3 @@ log-interval: 100
seed: 1
report-accuracy: True
skip-invalid-size-inputs-valid-test: True
post-process: sentencepiece
\ No newline at end of file
ctc-weight: 0.3
post-process: sentencepiece
share-ctc-and-embed: True
\ No newline at end of file
arch: transformer_ctc
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 1e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
......@@ -6,6 +6,7 @@ max-update: 100000
patience: 20
best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
post-process: sentencepiece
no-epoch-checkpoints: True
#keep-last-epochs: 10
......
......@@ -34,3 +34,18 @@ decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
#ctc-layer:
#ctc-weight: 0.2
interleaved-ctc-weight: 0.3
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
interleaved_ctc_upsampling_ratio: 2
sae-adapter: league
sae-drop-prob: 0.0
#sae-distribution-cutoff: 10
share-ctc-and-sae: True
ctc-self-distill-weight: 0
\ No newline at end of file
#ctc-layer:
#ctc-weight: 0.2
intermedia-ctc-weight: 0.3
intermedia-ctc-layers: 2,4
interleaved-ctc-weight: 0.3
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
interleaved_ctc_upsampling_ratio: 2
#target-ctc-weight: 0.3
#target-ctc-layer: 6
#target-intermedia-ctc-weight: 0.1
#target-intermedia-ctc-layers: 2,4
sae-adapter: league
sae-drop-prob: 0.0
#sae-distribution-cutoff: 10
share-ctc-and-sae: False
intermedia-adapter: league
#intermedia-drop-prob: 0.2
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
ctc-self-distill-weight: 0
\ No newline at end of file
arch: transformer_ctc
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 1e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: s2t_transformer_s
share-decoder-input-output-embed: True
share-ctc-and-embed: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
......
......@@ -6,6 +6,7 @@ max-update: 100000
patience: 20
best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False
post-process: sentencepiece
no-epoch-checkpoints: True
#keep-last-epochs: 10
......@@ -17,4 +18,3 @@ log-interval: 100
seed: 1
report-accuracy: True
skip-invalid-size-inputs-valid-test: True
\ No newline at end of file
post-process: sentencepiece
\ No newline at end of file
ctc-weight: 0.3
post-process: sentencepiece
\ No newline at end of file
share-ctc-and-embed: True
\ No newline at end of file
......@@ -263,7 +263,7 @@ class CtcCriterion(FairseqCriterion):
target_interleaved_ctc_loss = 0
# calculate the target CTC loss
if self.target_ctc_weight > 0 or self.target_interleaved_ctc_weight:
if self.target_ctc_weight > 0 or self.target_interleaved_ctc_weight > 0:
target = sample["target"]
pad_mask = (target != self.pad_idx) & (target != self.eos_idx)
......@@ -297,7 +297,7 @@ class CtcCriterion(FairseqCriterion):
target_interleaved_ctc_num = 0
if "target_interleaved_ctc_logits" in net_output:
target_interleaved_ctc_num = len(net_output["target_interleaved_ctc_logits"])
if target_interleaved_ctc_num != 0 and self.target_interleaved_ctc_weight > 0:
for i in range(target_interleaved_ctc_num):
out = net_output["target_interleaved_ctc_logits"][i]
if type(out) == list:
......@@ -314,7 +314,8 @@ class CtcCriterion(FairseqCriterion):
tgt_inter_lprobs.batch_first = False
for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_interleaved_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths, lengths) * coef
target_interleaved_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths,
lengths) * coef
target_interleaved_ctc_loss /= target_interleaved_ctc_num
logging_output["target_interleaved_ctc_loss"] = utils.item(target_interleaved_ctc_loss.data)
......@@ -358,7 +359,7 @@ class CtcCriterion(FairseqCriterion):
logging_output["all_ctc_loss"] = utils.item(loss.data)
if torch.isnan(loss) or torch.isinf(loss) or utils.item(loss.data) < 0:
logger.warning("Illegal loss %f!" % loss)
# logger.warning("Illegal loss %f!" % loss)
if self.ctc_weight != 0:
logger.warning("CTC loss %f!" % ctc_loss)
if self.interleaved_ctc_weight != 0:
......@@ -366,7 +367,7 @@ class CtcCriterion(FairseqCriterion):
if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss)
if not model.training and self.ctc_weight > 0:
if not model.training and self.ctc_weight + self.interleaved_ctc_weight > 0:
import editdistance
with torch.no_grad():
......
......@@ -99,11 +99,11 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
target = model.get_targets(sample, net_output)
if self.ignore_prefix_size > 0:
if getattr(lprobs, "batch_first", False):
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
target = target[:, self.ignore_prefix_size:].contiguous()
else:
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
target = target[self.ignore_prefix_size :, :].contiguous()
lprobs = lprobs[self.ignore_prefix_size:, :, :].contiguous()
target = target[self.ignore_prefix_size:, :].contiguous()
if "mixup" in net_output[1] and net_output[1]["mixup"] is not None:
mixup = net_output[1]["mixup"]
idx1 = mixup["index1"]
......
......@@ -69,9 +69,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
n_tokens = sample["ntokens"]
n_sentences = sample["target"].size(0)
if use_mixup:
sample_size //= 2
n_tokens //= 2
n_sentences //= 2
sample_size //= net_output[0].size(0) if self.sentence_avg else encoder_out["mixup"]["ratio"]
n_tokens //= encoder_out["mixup"]["ratio"]
n_sentences //= net_output[0].size(0)
logging_output = {
"trans_loss": utils.item(loss.data) if reduce else loss.data,
......@@ -88,7 +88,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
if self.ctc_criterion.all_ctc_weight > 0:
ctc_loss, logging_output = self.ctc_criterion.compute_ctc_loss(model, sample, encoder_out, logging_output)
loss = (1 - self.ctc_weight) * loss + ctc_loss
# loss = (1 - self.ctc_weight) * loss + ctc_loss
loss = loss + ctc_loss
# if hasattr(model.encoder, "get_loss"):
# encoder_loss = model.encoder.get_loss()
......
......@@ -259,11 +259,11 @@ class TextEncoder(FairseqEncoder):
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
self.sae_adapter = Adapter(embed_dim, args.sae_adapter,
self.sae = Adapter(embed_dim, args.sae_adapter,
len(dictionary),
strategy=strategy)
if args.share_target_ctc_and_sae and hasattr(self.sae_adapter, "embed_adapter"):
self.ctc.ctc_projection.weight = self.sae_adapter.embed_adapter.weight
if args.share_target_ctc_and_sae and hasattr(self.sae, "embed_adapter"):
self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
......@@ -297,7 +297,7 @@ class TextEncoder(FairseqEncoder):
target_interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae_adapter([x, prob], encoder_padding_mask)
x, encoder_padding_mask = self.sae([x, prob], encoder_padding_mask)
if history is not None:
history.push(x)
......@@ -376,8 +376,8 @@ class S2TSATEEncoder(FairseqEncoder):
encoder_out = acoustic_encoder_out["encoder_out"][0]
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
ctc_padding_mask = encoder_padding_mask
if "mixup" in encoder_out:
mixup = encoder_out["mixup"]
if "mixup" in acoustic_encoder_out:
mixup = acoustic_encoder_out["mixup"]
else:
mixup = None
......@@ -406,7 +406,8 @@ class S2TSATEEncoder(FairseqEncoder):
x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask,
self.history)
else:
x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history)
x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask,
self.history)
return {
"encoder_out": [x], # T x B x C
......
......@@ -657,12 +657,12 @@ class S2TTransformerEncoder(FairseqEncoder):
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
self.sae_adapter = Adapter(dim, args.sae_adapter,
self.sae = Adapter(dim, args.sae_adapter,
len(task.source_dictionary),
strategy=strategy,
)
if args.share_ctc_and_sae and hasattr(self.sae_adapter, "embed_adapter"):
self.ctc.ctc_projection.weight = self.sae_adapter.embed_adapter.weight
if args.share_ctc_and_sae and hasattr(self.sae, "embed_adapter"):
self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
# mixup
self.mixup = getattr(args, "inter_mixup", False)
......@@ -734,6 +734,7 @@ class S2TTransformerEncoder(FairseqEncoder):
input_lengths = (~encoder_padding_mask).sum(-1)
mixup = {
"ratio": self.mixup_ratio,
"coef": coef,
"index1": idx1,
"index2": idx2,
......@@ -766,12 +767,12 @@ class S2TTransformerEncoder(FairseqEncoder):
# down-sampling
x, input_lengths = self.subsample(x, input_lengths)
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
# embedding scaling
x = self.embed_scale * x
# padding and position embedding
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
# position embedding
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
positions = self.embed_positions(x)
......@@ -836,7 +837,7 @@ class S2TTransformerEncoder(FairseqEncoder):
max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae_adapter([x, prob], encoder_padding_mask)
x, encoder_padding_mask = self.sae([x, prob], encoder_padding_mask)
# gather cosine similarity
if self.gather_cos_sim:
......
......@@ -36,7 +36,6 @@ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
......@@ -87,23 +86,37 @@ class TransformerCTCModel(FairseqEncoderDecoderModel):
}
return {
'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'),
'transformer.wmt14.en-fr': moses_subword(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'),
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
'transformer.wmt18.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'),
'transformer.wmt19.en-de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'),
'transformer.wmt19.en-ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'),
'transformer.wmt19.de-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'),
'transformer.wmt19.ru-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'),
'transformer.wmt19.en-de.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'),
'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'),
'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'),
'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'),
'transformer.wmt18.en-de': moses_subword(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'),
'transformer.wmt19.en-de': moses_fastbpe(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'),
'transformer.wmt19.en-ru': moses_fastbpe(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'),
'transformer.wmt19.de-en': moses_fastbpe(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'),
'transformer.wmt19.ru-en': moses_fastbpe(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'),
'transformer.wmt19.en-de.single_model': moses_fastbpe(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'),
'transformer.wmt19.en-ru.single_model': moses_fastbpe(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'),
'transformer.wmt19.de-en.single_model': moses_fastbpe(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'),
'transformer.wmt19.ru-en.single_model': moses_fastbpe(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'),
'transformer.wmt20.en-ta': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz'),
'transformer.wmt20.en-iu.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz'),
'transformer.wmt20.en-iu.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz'),
'transformer.wmt20.en-iu.news': spm(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz'),
'transformer.wmt20.en-iu.nh': spm(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz'),
'transformer.wmt20.ta-en': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz'),
'transformer.wmt20.iu-en.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz'),
'transformer.wmt20.iu-en.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz'),
'transformer.wmt20.iu-en.news': spm(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz'),
'transformer.wmt20.iu-en.nh': spm(
'https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz'),
}
# fmt: on
......@@ -599,7 +612,7 @@ class TransformerCTCEncoder(FairseqEncoder):
self.interleaved_ctc_temperature = args.interleaved_ctc_temperature
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
self.interleaved_ctc_upsampling_ratio = args.interleaved_ctc_upsampling_ratio
self.interleaved_ctc_upsampling_ratio = int(args.interleaved_ctc_upsampling_ratio)
self.interleaved_ctc_layers = []
if args.interleaved_ctc_layers is not None:
interleaved_ctc_layers = args.interleaved_ctc_layers.split(",")
......@@ -624,12 +637,15 @@ class TransformerCTCEncoder(FairseqEncoder):
"drop_prob": getattr(args, "sae_drop_prob", 0),
}
self.sae_adapter = Adapter(embed_dim, args.sae_adapter,
self.sae = Adapter(embed_dim, args.sae_adapter,
decoder_embed_tokens.num_embeddings,
strategy=strategy
)
if args.share_ctc_and_sae and hasattr(self.sae_adapter, "embed_adapter"):
self.ctc.ctc_projection.weight = self.sae_adapter.embed_adapter.weight
if args.share_ctc_and_sae and hasattr(self.sae, "embed_adapter"):
self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
if hasattr(self, "ctc"):
self.pool = nn.MaxPool1d(kernel_size=self.interleaved_ctc_upsampling_ratio,
stride=self.interleaved_ctc_upsampling_ratio)
def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args)
......@@ -778,7 +794,7 @@ class TransformerCTCEncoder(FairseqEncoder):
if self.use_ctc and self.inter_ctc and self.ctc_layer == layer_idx:
ctc_logit = self.ctc(self.upsampling(x.clone()))
# Intermedia CTC
# Interleaved CTC
if layer_idx in self.interleaved_ctc_layers:
if self.interleaved_ctc_drop_prob > 0:
p = torch.rand(1).uniform_()
......@@ -793,12 +809,10 @@ class TransformerCTCEncoder(FairseqEncoder):
up_prob = utils.softmax(up_logit / self.interleaved_ctc_temperature, dim=-1)
up_prob = up_prob.permute(1, 2, 0)
prob = nn.functional.max_pool1d(up_prob,
kernel_size=self.interleaved_ctc_upsampling_ratio,
stride=self.interleaved_ctc_upsampling_ratio)
prob = self.pool(up_prob)
prob = prob.permute(2, 0, 1)
x, _ = self.adapter([x, prob])
x, _ = self.sae([x, prob])
if self.history is not None:
self.history.push(x)
......@@ -815,7 +829,7 @@ class TransformerCTCEncoder(FairseqEncoder):
ctc_padding_mask = encoder_padding_mask
if ctc_logit is not None or len(interleaved_ctc_logits) != 0:
bsz = encoder_padding_mask.size(0)
ctc_padding_mask = encoder_padding_mask.unsqueeze(-1).\
ctc_padding_mask = encoder_padding_mask.unsqueeze(-1). \
expand(-1, -1, self.interleaved_ctc_upsampling_ratio).reshape(bsz, -1)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
......@@ -1046,7 +1060,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
)
self.gather_attn_weight = getattr(args, "gather_attn_weight", False)
#self.gather_attn_weight = True
# self.gather_attn_weight = True
self.attn_weights = dict()
def build_decoder_layer(self, args, no_encoder_attn=False):
......@@ -1486,8 +1500,9 @@ def base_architecture(args):
args.interleaved_ctc_layers = getattr(args, "interleaved_ctc_layers", None)
args.interleaved_ctc_temperature = getattr(args, "interleaved_ctc_temperature", 1)
args.interleaved_ctc_drop_prob = getattr(args, "interleaved_ctc_drop_prob", 0)
args.interleaved_ctc_upsampling_ratio = getattr(args, "interleaved_ctc_upsampling_ratio", 2)
# Semantics-augmented Encoding (sae)
# Semantics-augmented Encoding (SAE)
args.sae_adapter = getattr(args, "sae_adapter", "none")
args.share_ctc_and_sae = getattr(args, "share_ctc_and_sae", False)
args.sae_drop_prob = getattr(args, "sae_drop_prob", 0)
......
......@@ -58,7 +58,7 @@ class ConvolutionModule(nn.Module):
elif norm_type == "layer_norm":
self.norm = LayerNorm(expand_embed_dim)
else:
assert False, "Unsupported normalization type in convolution module"
assert False, "Unsupported normalization type %s in convolution module" % norm_type
self.activation = get_activation_class(activation_fn)
self.pointwise_conv2 = torch.nn.Conv1d(
expand_embed_dim,
......
......@@ -77,6 +77,7 @@ class Adapter(nn.Module):
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
self.cal_context = True
self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization
nn.init.normal_(self.embed_adapter.weight, mean=0, std=dim ** -0.5)
if embed_tokens is not None:
self.embed_adapter.weight = embed_tokens.weight
......
......@@ -20,7 +20,8 @@ class CTC(nn.Module):
self.embed_dim = embed_dim
self.ctc_projection = nn.Linear(embed_dim, dictionary_size)
# nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5)
nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5)
nn.init.constant_(self.ctc_projection.bias, 0.0)
self.ctc_dropout_module = FairseqDropout(
p=dropout, module_name=self.__class__.__name__
......
......@@ -198,7 +198,11 @@ class Conv2dSubsampling(nn.Module):
transpose=True if norm == "layer" else False),
get_activation_class(act, dim=1)
) for layer_id in range(num_layers)])
self.linear = nn.Linear(filters[-1] * in_dim // 2 ** num_layers, filters[-1])
dim = in_dim
for _ in range(num_layers):
dim = (dim - 1) // 2
self.linear = nn.Linear(dim*filters[-1], filters[-1])
def forward(self, x, x_len):
......@@ -211,11 +215,12 @@ class Conv2dSubsampling(nn.Module):
# Update Sequence Lengths
if x_len is not None:
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
x_len = torch.div(x_len - 1, 2, rounding_mode='floor')
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size()
assert subsampled_length == max(x_len), "The lengths are mismatched."
assert subsampled_length == max(x_len), \
("The lengths are mismatched: %d and %d." % (subsampled_length, max(x_len)))
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).permute(2, 0, 1)
x = self.linear(x)
......
......@@ -197,11 +197,11 @@ class SequenceGenerator(nn.Module):
)
net_input = sample["net_input"]
if "transcript" in sample:
text_src_tokens = sample["transcript"]["tokens"]
text_src_lengths = sample["transcript"]["lengths"]
net_input["text_src_tokens"] = text_src_tokens
net_input["text_src_lengths"] = text_src_lengths
# if "transcript" in sample:
# text_src_tokens = sample["transcript"]["tokens"]
# text_src_lengths = sample["transcript"]["lengths"]
# net_input["text_src_tokens"] = text_src_tokens
# net_input["text_src_lengths"] = text_src_lengths
if "src_tokens" in net_input:
src_tokens = net_input["src_tokens"]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论