Commit 03076942 by xuchen

fix the bugs

parent 1d60b3a6
...@@ -6,6 +6,7 @@ max-update: 100000 ...@@ -6,6 +6,7 @@ max-update: 100000
patience: 20 patience: 20
best_checkpoint_metric: loss best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False maximize_best_checkpoint_metric: False
post-process: sentencepiece
no-epoch-checkpoints: True no-epoch-checkpoints: True
#keep-last-epochs: 10 #keep-last-epochs: 10
......
...@@ -31,6 +31,8 @@ decoder-embed-dim: 512 ...@@ -31,6 +31,8 @@ decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048 decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8 decoder-attention-heads: 8
cnn-module-norm: layer_norm
load-pretrained-encoder-from: /home/xuchen/after.pt load-pretrained-encoder-from: /home/xuchen/after.pt
load-pretrained-decoder-from: /home/xuchen/after.pt load-pretrained-decoder-from: /home/xuchen/after.pt
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
ctc-weight: 0.3 ctc-weight: 0.3
post-process: sentencepiece share-ctc-and-embed: True
\ No newline at end of file
ctc-weight: 0.2
interleaved-ctc-weight: 0.1
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
sae-adapter: league
sae-drop-prob: 0.2
sae-distribution-cutoff: 10
share-ctc-and-sae: False
ctc-self-distill-weight: 0
inter_mixup: True
inter_mixup_layer: -1
inter_mixup_prob: 1.0
inter_mixup_ratio: 0.2
\ No newline at end of file
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
# training the model # training the model
gpu_num=8 gpu_num=2
update_freq=1 update_freq=1
max_tokens=40000 max_tokens=160000
extra_tag= extra_tag=
extra_parameter= extra_parameter=
...@@ -13,12 +13,12 @@ extra_parameter= ...@@ -13,12 +13,12 @@ extra_parameter=
exp_tag= exp_tag=
#config_list=(base ctc) config_list=(base ctc)
config_list=(base ctc conformer) #config_list=(base ctc conformer)
config_list=(big ctc conformer) #config_list=(big ctc conformer)
#config_list=(pds_base_16) #config_list=(pds_base_16)
config_list=(pds_base_16 conformer) #config_list=(pds_base_16 conformer)
# exp full name # exp full name
exp_name= exp_name=
......
arch: transformer_ctc arch: transformer
share-all-embeddings: True share-all-embeddings: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -8,7 +8,7 @@ warmup-updates: 8000 ...@@ -8,7 +8,7 @@ warmup-updates: 8000
lr: 1e-3 lr: 1e-3
adam_betas: (0.9,0.997) adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc criterion: label_smoothed_cross_entropy
label_smoothing: 0.1 label_smoothing: 0.1
dropout: 0.3 dropout: 0.3
......
arch: transformer_ctc
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 1e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.3
attention-dropout: 0.0
activation-dropout: 0.0
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 1024
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 512
decoder-ffn-embed-dim: 1024
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
...@@ -6,6 +6,7 @@ max-update: 50000 ...@@ -6,6 +6,7 @@ max-update: 50000
patience: 20 patience: 20
best_checkpoint_metric: loss best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False maximize_best_checkpoint_metric: False
post-process: sentencepiece
no-epoch-checkpoints: True no-epoch-checkpoints: True
#keep-last-epochs: 10 #keep-last-epochs: 10
......
#ctc-weight: 0.2 #ctc-layer:
intermedia-ctc-weight: 0.3 ctc-weight: 0.2
intermedia-ctc-layers: 2,4 interleaved-ctc-weight: 0.1
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
interleaved_ctc_upsampling_ratio: 2
#target-ctc-weight: 0.3 sae-adapter: league
#target-ctc-layer: 6 sae-drop-prob: 0.0
#target-intermedia-ctc-weight: 0.1 #sae-distribution-cutoff: 10
#target-intermedia-ctc-layers: 2,4 share-ctc-and-sae: False
intermedia-adapter: league ctc-self-distill-weight: 0
#intermedia-drop-prob: 0.2 \ No newline at end of file
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
arch: s2t_sate arch: s2t_transformer
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
...@@ -37,9 +37,9 @@ activation-dropout: 0.1 ...@@ -37,9 +37,9 @@ activation-dropout: 0.1
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
#inter_mixup: True inter_mixup: True
#inter_mixup_layer: -1 inter_mixup_layer: -1
#inter_mixup_ratio: 0.2 inter_mixup_ratio: 0.2
ctc-weight: 0.2 ctc-weight: 0.2
interleaved-ctc-weight: 0.1 interleaved-ctc-weight: 0.1
...@@ -48,8 +48,8 @@ interleaved-temperature: 2 ...@@ -48,8 +48,8 @@ interleaved-temperature: 2
#target-ctc-weight: 0.3 #target-ctc-weight: 0.3
#target-ctc-layer: 6 #target-ctc-layer: 6
target-interleaved-ctc-weight: 0.1 #target-interleaved-ctc-weight: 0.1
target-interleaved-ctc-layers: 2,4 #target-interleaved-ctc-layers: 2,4
sae-adapter: league sae-adapter: league
share-ctc-and-sae: False share-ctc-and-sae: False
......
...@@ -6,6 +6,7 @@ max-update: 100000 ...@@ -6,6 +6,7 @@ max-update: 100000
patience: 20 patience: 20
best_checkpoint_metric: loss best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False maximize_best_checkpoint_metric: False
post-process: sentencepiece
no-epoch-checkpoints: True no-epoch-checkpoints: True
#keep-last-epochs: 10 #keep-last-epochs: 10
...@@ -17,4 +18,3 @@ log-interval: 100 ...@@ -17,4 +18,3 @@ log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
skip-invalid-size-inputs-valid-test: True skip-invalid-size-inputs-valid-test: True
post-process: sentencepiece
\ No newline at end of file
ctc-weight: 0.3 ctc-weight: 0.3
post-process: sentencepiece share-ctc-and-embed: True
\ No newline at end of file
arch: transformer_ctc
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 1e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 512
encoder-ffn-embed-dim: 2048
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 8
decoder-embed-dim: 512
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
...@@ -6,6 +6,7 @@ max-update: 100000 ...@@ -6,6 +6,7 @@ max-update: 100000
patience: 20 patience: 20
best_checkpoint_metric: loss best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False maximize_best_checkpoint_metric: False
post-process: sentencepiece
no-epoch-checkpoints: True no-epoch-checkpoints: True
#keep-last-epochs: 10 #keep-last-epochs: 10
......
...@@ -33,4 +33,19 @@ decoder-ffn-embed-dim: 2048 ...@@ -33,4 +33,19 @@ decoder-ffn-embed-dim: 2048
decoder-attention-heads: 8 decoder-attention-heads: 8
#load-pretrained-encoder-from: #load-pretrained-encoder-from:
#load-pretrained-decoder-from: #load-pretrained-decoder-from:
\ No newline at end of file
#ctc-layer:
#ctc-weight: 0.2
interleaved-ctc-weight: 0.3
interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
interleaved_ctc_upsampling_ratio: 2
sae-adapter: league
sae-drop-prob: 0.0
#sae-distribution-cutoff: 10
share-ctc-and-sae: True
ctc-self-distill-weight: 0
\ No newline at end of file
#ctc-layer:
#ctc-weight: 0.2 #ctc-weight: 0.2
intermedia-ctc-weight: 0.3 interleaved-ctc-weight: 0.3
intermedia-ctc-layers: 2,4 interleaved-ctc-layers: 6,9
interleaved-ctc-temperature: 1.0
interleaved-ctc-drop-prob: 0
interleaved_ctc_upsampling_ratio: 2
#target-ctc-weight: 0.3 sae-adapter: league
#target-ctc-layer: 6 sae-drop-prob: 0.0
#target-intermedia-ctc-weight: 0.1 #sae-distribution-cutoff: 10
#target-intermedia-ctc-layers: 2,4 share-ctc-and-sae: False
intermedia-adapter: league ctc-self-distill-weight: 0
#intermedia-drop-prob: 0.2 \ No newline at end of file
#intermedia-temperature: 5
post-process: sentencepiece
\ No newline at end of file
arch: transformer_ctc
share-all-embeddings: True
optimizer: adam
clip-norm: 10.0
lr-scheduler: inverse_sqrt
warmup-init-lr: 1e-7
warmup-updates: 8000
lr: 1e-3
adam_betas: (0.9,0.997)
criterion: label_smoothed_cross_entropy_with_ctc
label_smoothing: 0.1
dropout: 0.1
attention-dropout: 0.1
activation-dropout: 0.1
activation-fn: relu
encoder-normalize-before: True
decoder-normalize-before: True
encoder-embed-dim: 256
encoder-ffn-embed-dim: 2048
encoder-layers: 6
decoder-layers: 6
encoder-attention-heads: 4
decoder-embed-dim: 256
decoder-ffn-embed-dim: 2048
decoder-attention-heads: 4
#load-pretrained-encoder-from:
#load-pretrained-decoder-from:
\ No newline at end of file
arch: s2t_transformer_s arch: s2t_transformer_s
share-decoder-input-output-embed: True share-decoder-input-output-embed: True
share-ctc-and-embed: True
optimizer: adam optimizer: adam
clip-norm: 10.0 clip-norm: 10.0
lr-scheduler: inverse_sqrt lr-scheduler: inverse_sqrt
......
...@@ -6,6 +6,7 @@ max-update: 100000 ...@@ -6,6 +6,7 @@ max-update: 100000
patience: 20 patience: 20
best_checkpoint_metric: loss best_checkpoint_metric: loss
maximize_best_checkpoint_metric: False maximize_best_checkpoint_metric: False
post-process: sentencepiece
no-epoch-checkpoints: True no-epoch-checkpoints: True
#keep-last-epochs: 10 #keep-last-epochs: 10
...@@ -16,5 +17,4 @@ no-progress-bar: True ...@@ -16,5 +17,4 @@ no-progress-bar: True
log-interval: 100 log-interval: 100
seed: 1 seed: 1
report-accuracy: True report-accuracy: True
skip-invalid-size-inputs-valid-test: True skip-invalid-size-inputs-valid-test: True
post-process: sentencepiece \ No newline at end of file
\ No newline at end of file
ctc-weight: 0.3 ctc-weight: 0.3
post-process: sentencepiece share-ctc-and-embed: True
\ No newline at end of file \ No newline at end of file
...@@ -263,7 +263,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -263,7 +263,7 @@ class CtcCriterion(FairseqCriterion):
target_interleaved_ctc_loss = 0 target_interleaved_ctc_loss = 0
# calculate the target CTC loss # calculate the target CTC loss
if self.target_ctc_weight > 0 or self.target_interleaved_ctc_weight: if self.target_ctc_weight > 0 or self.target_interleaved_ctc_weight > 0:
target = sample["target"] target = sample["target"]
pad_mask = (target != self.pad_idx) & (target != self.eos_idx) pad_mask = (target != self.pad_idx) & (target != self.eos_idx)
...@@ -297,27 +297,28 @@ class CtcCriterion(FairseqCriterion): ...@@ -297,27 +297,28 @@ class CtcCriterion(FairseqCriterion):
target_interleaved_ctc_num = 0 target_interleaved_ctc_num = 0
if "target_interleaved_ctc_logits" in net_output: if "target_interleaved_ctc_logits" in net_output:
target_interleaved_ctc_num = len(net_output["target_interleaved_ctc_logits"]) target_interleaved_ctc_num = len(net_output["target_interleaved_ctc_logits"])
if target_interleaved_ctc_num != 0 and self.target_interleaved_ctc_weight > 0:
for i in range(target_interleaved_ctc_num):
out = net_output["target_interleaved_ctc_logits"][i]
if type(out) == list:
inter_ctc_logit = out[0]
padding = ~out[1]
tgt_input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
tgt_input_lengths = input_lengths
for i in range(target_interleaved_ctc_num): tgt_inter_lprobs = model.get_normalized_probs(
out = net_output["target_interleaved_ctc_logits"][i] [inter_ctc_logit], log_probs=True
if type(out) == list: ).contiguous() # (T, B, C) from the encoder
inter_ctc_logit = out[0] tgt_inter_lprobs.batch_first = False
padding = ~out[1]
tgt_input_lengths = padding.long().sum(-1)
else:
inter_ctc_logit = out
tgt_input_lengths = input_lengths
tgt_inter_lprobs = model.get_normalized_probs(
[inter_ctc_logit], log_probs=True
).contiguous() # (T, B, C) from the encoder
tgt_inter_lprobs.batch_first = False
for flat, lengths, coef in zip(target_flat, target_length, loss_coef): for flat, lengths, coef in zip(target_flat, target_length, loss_coef):
target_interleaved_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths, lengths) * coef target_interleaved_ctc_loss += self.get_loss(tgt_inter_lprobs, flat, tgt_input_lengths,
lengths) * coef
target_interleaved_ctc_loss /= target_interleaved_ctc_num target_interleaved_ctc_loss /= target_interleaved_ctc_num
logging_output["target_interleaved_ctc_loss"] = utils.item(target_interleaved_ctc_loss.data) logging_output["target_interleaved_ctc_loss"] = utils.item(target_interleaved_ctc_loss.data)
# calculate the self distillation CTC loss # calculate the self distillation CTC loss
ctc_self_distill_loss = 0 ctc_self_distill_loss = 0
...@@ -358,7 +359,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -358,7 +359,7 @@ class CtcCriterion(FairseqCriterion):
logging_output["all_ctc_loss"] = utils.item(loss.data) logging_output["all_ctc_loss"] = utils.item(loss.data)
if torch.isnan(loss) or torch.isinf(loss) or utils.item(loss.data) < 0: if torch.isnan(loss) or torch.isinf(loss) or utils.item(loss.data) < 0:
logger.warning("Illegal loss %f!" % loss) # logger.warning("Illegal loss %f!" % loss)
if self.ctc_weight != 0: if self.ctc_weight != 0:
logger.warning("CTC loss %f!" % ctc_loss) logger.warning("CTC loss %f!" % ctc_loss)
if self.interleaved_ctc_weight != 0: if self.interleaved_ctc_weight != 0:
...@@ -366,7 +367,7 @@ class CtcCriterion(FairseqCriterion): ...@@ -366,7 +367,7 @@ class CtcCriterion(FairseqCriterion):
if self.target_ctc_weight != 0: if self.target_ctc_weight != 0:
logger.warning("Target CTC loss %f!" % target_ctc_loss) logger.warning("Target CTC loss %f!" % target_ctc_loss)
if not model.training and self.ctc_weight > 0: if not model.training and self.ctc_weight + self.interleaved_ctc_weight > 0:
import editdistance import editdistance
with torch.no_grad(): with torch.no_grad():
......
...@@ -55,12 +55,12 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T ...@@ -55,12 +55,12 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T
) )
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__( def __init__(
self, self,
task, task,
sentence_avg, sentence_avg,
label_smoothing, label_smoothing,
ignore_prefix_size=0, ignore_prefix_size=0,
report_accuracy=False, report_accuracy=False,
): ):
super().__init__(task) super().__init__(task)
self.sentence_avg = sentence_avg self.sentence_avg = sentence_avg
...@@ -99,11 +99,11 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -99,11 +99,11 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
target = model.get_targets(sample, net_output) target = model.get_targets(sample, net_output)
if self.ignore_prefix_size > 0: if self.ignore_prefix_size > 0:
if getattr(lprobs, "batch_first", False): if getattr(lprobs, "batch_first", False):
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous() target = target[:, self.ignore_prefix_size:].contiguous()
else: else:
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous() lprobs = lprobs[self.ignore_prefix_size:, :, :].contiguous()
target = target[self.ignore_prefix_size :, :].contiguous() target = target[self.ignore_prefix_size:, :].contiguous()
if "mixup" in net_output[1] and net_output[1]["mixup"] is not None: if "mixup" in net_output[1] and net_output[1]["mixup"] is not None:
mixup = net_output[1]["mixup"] mixup = net_output[1]["mixup"]
idx1 = mixup["index1"] idx1 = mixup["index1"]
......
...@@ -69,9 +69,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -69,9 +69,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
n_tokens = sample["ntokens"] n_tokens = sample["ntokens"]
n_sentences = sample["target"].size(0) n_sentences = sample["target"].size(0)
if use_mixup: if use_mixup:
sample_size //= 2 sample_size //= net_output[0].size(0) if self.sentence_avg else encoder_out["mixup"]["ratio"]
n_tokens //= 2 n_tokens //= encoder_out["mixup"]["ratio"]
n_sentences //= 2 n_sentences //= net_output[0].size(0)
logging_output = { logging_output = {
"trans_loss": utils.item(loss.data) if reduce else loss.data, "trans_loss": utils.item(loss.data) if reduce else loss.data,
...@@ -88,7 +88,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -88,7 +88,8 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
if self.ctc_criterion.all_ctc_weight > 0: if self.ctc_criterion.all_ctc_weight > 0:
ctc_loss, logging_output = self.ctc_criterion.compute_ctc_loss(model, sample, encoder_out, logging_output) ctc_loss, logging_output = self.ctc_criterion.compute_ctc_loss(model, sample, encoder_out, logging_output)
loss = (1 - self.ctc_weight) * loss + ctc_loss # loss = (1 - self.ctc_weight) * loss + ctc_loss
loss = loss + ctc_loss
# if hasattr(model.encoder, "get_loss"): # if hasattr(model.encoder, "get_loss"):
# encoder_loss = model.encoder.get_loss() # encoder_loss = model.encoder.get_loss()
......
...@@ -259,11 +259,11 @@ class TextEncoder(FairseqEncoder): ...@@ -259,11 +259,11 @@ class TextEncoder(FairseqEncoder):
"drop_prob": getattr(args, "sae_drop_prob", 0), "drop_prob": getattr(args, "sae_drop_prob", 0),
} }
self.sae_adapter = Adapter(embed_dim, args.sae_adapter, self.sae = Adapter(embed_dim, args.sae_adapter,
len(dictionary), len(dictionary),
strategy=strategy) strategy=strategy)
if args.share_target_ctc_and_sae and hasattr(self.sae_adapter, "embed_adapter"): if args.share_target_ctc_and_sae and hasattr(self.sae, "embed_adapter"):
self.ctc.ctc_projection.weight = self.sae_adapter.embed_adapter.weight self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob self.interleaved_ctc_drop_prob = args.interleaved_ctc_drop_prob
...@@ -297,7 +297,7 @@ class TextEncoder(FairseqEncoder): ...@@ -297,7 +297,7 @@ class TextEncoder(FairseqEncoder):
target_interleaved_ctc_logits.append(logit) target_interleaved_ctc_logits.append(logit)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1) prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae_adapter([x, prob], encoder_padding_mask) x, encoder_padding_mask = self.sae([x, prob], encoder_padding_mask)
if history is not None: if history is not None:
history.push(x) history.push(x)
...@@ -376,8 +376,8 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -376,8 +376,8 @@ class S2TSATEEncoder(FairseqEncoder):
encoder_out = acoustic_encoder_out["encoder_out"][0] encoder_out = acoustic_encoder_out["encoder_out"][0]
encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0] encoder_padding_mask = acoustic_encoder_out["encoder_padding_mask"][0]
ctc_padding_mask = encoder_padding_mask ctc_padding_mask = encoder_padding_mask
if "mixup" in encoder_out: if "mixup" in acoustic_encoder_out:
mixup = encoder_out["mixup"] mixup = acoustic_encoder_out["mixup"]
else: else:
mixup = None mixup = None
...@@ -406,7 +406,8 @@ class S2TSATEEncoder(FairseqEncoder): ...@@ -406,7 +406,8 @@ class S2TSATEEncoder(FairseqEncoder):
x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask, x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask,
self.history) self.history)
else: else:
x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask, self.history) x, target_ctc_logit, target_interleaved_ctc_logits = self.text_encoder(x, encoder_padding_mask,
self.history)
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
......
...@@ -657,12 +657,12 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -657,12 +657,12 @@ class S2TTransformerEncoder(FairseqEncoder):
"drop_prob": getattr(args, "sae_drop_prob", 0), "drop_prob": getattr(args, "sae_drop_prob", 0),
} }
self.sae_adapter = Adapter(dim, args.sae_adapter, self.sae = Adapter(dim, args.sae_adapter,
len(task.source_dictionary), len(task.source_dictionary),
strategy=strategy, strategy=strategy,
) )
if args.share_ctc_and_sae and hasattr(self.sae_adapter, "embed_adapter"): if args.share_ctc_and_sae and hasattr(self.sae, "embed_adapter"):
self.ctc.ctc_projection.weight = self.sae_adapter.embed_adapter.weight self.sae.embed_adapter.weight = self.ctc.ctc_projection.weight
# mixup # mixup
self.mixup = getattr(args, "inter_mixup", False) self.mixup = getattr(args, "inter_mixup", False)
...@@ -734,6 +734,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -734,6 +734,7 @@ class S2TTransformerEncoder(FairseqEncoder):
input_lengths = (~encoder_padding_mask).sum(-1) input_lengths = (~encoder_padding_mask).sum(-1)
mixup = { mixup = {
"ratio": self.mixup_ratio,
"coef": coef, "coef": coef,
"index1": idx1, "index1": idx1,
"index2": idx2, "index2": idx2,
...@@ -766,12 +767,12 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -766,12 +767,12 @@ class S2TTransformerEncoder(FairseqEncoder):
# down-sampling # down-sampling
x, input_lengths = self.subsample(x, input_lengths) x, input_lengths = self.subsample(x, input_lengths)
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
# embedding scaling # embedding scaling
x = self.embed_scale * x x = self.embed_scale * x
# padding and position embedding # position embedding
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]: if self.attn_type in ["rel_pos", "rel_pos_legacy", "rel_selfattn"]:
positions = self.embed_positions(x) positions = self.embed_positions(x)
...@@ -836,7 +837,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -836,7 +837,7 @@ class S2TTransformerEncoder(FairseqEncoder):
max=1e8 if logit.dtype == torch.float32 else 1e4) max=1e8 if logit.dtype == torch.float32 else 1e4)
prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1) prob = utils.softmax(logit / self.interleaved_ctc_temperature, dim=-1)
x, encoder_padding_mask = self.sae_adapter([x, prob], encoder_padding_mask) x, encoder_padding_mask = self.sae([x, prob], encoder_padding_mask)
# gather cosine similarity # gather cosine similarity
if self.gather_cos_sim: if self.gather_cos_sim:
......
...@@ -58,7 +58,7 @@ class ConvolutionModule(nn.Module): ...@@ -58,7 +58,7 @@ class ConvolutionModule(nn.Module):
elif norm_type == "layer_norm": elif norm_type == "layer_norm":
self.norm = LayerNorm(expand_embed_dim) self.norm = LayerNorm(expand_embed_dim)
else: else:
assert False, "Unsupported normalization type in convolution module" assert False, "Unsupported normalization type %s in convolution module" % norm_type
self.activation = get_activation_class(activation_fn) self.activation = get_activation_class(activation_fn)
self.pointwise_conv2 = torch.nn.Conv1d( self.pointwise_conv2 = torch.nn.Conv1d(
expand_embed_dim, expand_embed_dim,
......
...@@ -77,6 +77,7 @@ class Adapter(nn.Module): ...@@ -77,6 +77,7 @@ class Adapter(nn.Module):
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]: if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
self.cal_context = True self.cal_context = True
self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization
nn.init.normal_(self.embed_adapter.weight, mean=0, std=dim ** -0.5)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_adapter.weight = embed_tokens.weight self.embed_adapter.weight = embed_tokens.weight
......
...@@ -20,7 +20,8 @@ class CTC(nn.Module): ...@@ -20,7 +20,8 @@ class CTC(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.ctc_projection = nn.Linear(embed_dim, dictionary_size) self.ctc_projection = nn.Linear(embed_dim, dictionary_size)
# nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5) nn.init.normal_(self.ctc_projection.weight, mean=0, std=embed_dim ** -0.5)
nn.init.constant_(self.ctc_projection.bias, 0.0)
self.ctc_dropout_module = FairseqDropout( self.ctc_dropout_module = FairseqDropout(
p=dropout, module_name=self.__class__.__name__ p=dropout, module_name=self.__class__.__name__
......
...@@ -198,7 +198,11 @@ class Conv2dSubsampling(nn.Module): ...@@ -198,7 +198,11 @@ class Conv2dSubsampling(nn.Module):
transpose=True if norm == "layer" else False), transpose=True if norm == "layer" else False),
get_activation_class(act, dim=1) get_activation_class(act, dim=1)
) for layer_id in range(num_layers)]) ) for layer_id in range(num_layers)])
self.linear = nn.Linear(filters[-1] * in_dim // 2 ** num_layers, filters[-1])
dim = in_dim
for _ in range(num_layers):
dim = (dim - 1) // 2
self.linear = nn.Linear(dim*filters[-1], filters[-1])
def forward(self, x, x_len): def forward(self, x, x_len):
...@@ -211,11 +215,12 @@ class Conv2dSubsampling(nn.Module): ...@@ -211,11 +215,12 @@ class Conv2dSubsampling(nn.Module):
# Update Sequence Lengths # Update Sequence Lengths
if x_len is not None: if x_len is not None:
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1 x_len = torch.div(x_len - 1, 2, rounding_mode='floor')
# (B, C, D // S, T // S) -> (B, C * D // S, T // S) # (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size() batch_size, channels, subsampled_dim, subsampled_length = x.size()
assert subsampled_length == max(x_len), "The lengths are mismatched." assert subsampled_length == max(x_len), \
("The lengths are mismatched: %d and %d." % (subsampled_length, max(x_len)))
x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).permute(2, 0, 1) x = x.reshape(batch_size, channels * subsampled_dim, subsampled_length).permute(2, 0, 1)
x = self.linear(x) x = self.linear(x)
......
...@@ -197,11 +197,11 @@ class SequenceGenerator(nn.Module): ...@@ -197,11 +197,11 @@ class SequenceGenerator(nn.Module):
) )
net_input = sample["net_input"] net_input = sample["net_input"]
if "transcript" in sample: # if "transcript" in sample:
text_src_tokens = sample["transcript"]["tokens"] # text_src_tokens = sample["transcript"]["tokens"]
text_src_lengths = sample["transcript"]["lengths"] # text_src_lengths = sample["transcript"]["lengths"]
net_input["text_src_tokens"] = text_src_tokens # net_input["text_src_tokens"] = text_src_tokens
net_input["text_src_lengths"] = text_src_lengths # net_input["text_src_lengths"] = text_src_lengths
if "src_tokens" in net_input: if "src_tokens" in net_input:
src_tokens = net_input["src_tokens"] src_tokens = net_input["src_tokens"]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论