Commit 4cffbd98 by xuchen

the fix version of sate

parent 8645e75b
......@@ -100,6 +100,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
def compute_ctc_loss(self, model, sample, encoder_out):
transcript = sample["transcript"]
if "ctc_logit" in encoder_out:
ctc_logit = encoder_out["ctc_logit"][0]
else:
ctc_logit = model.encoder.compute_ctc_logit(encoder_out)
lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True
......
......@@ -7,3 +7,4 @@ from .berard import * # noqa
from .convtransformer import * # noqa
from .s2t_transformer import * # noqa
from .s2t_conformer import * # noqa
from .s2t_sate import * # noqa
......@@ -92,12 +92,12 @@ class S2TConformerEncoder(S2TTransformerEncoder):
def __init__(self, args, task=None, embed_tokens=None):
super().__init__(args, task, embed_tokens)
self.conformer_layers = nn.ModuleList(
del self.layers
self.layers = nn.ModuleList(
[ConformerEncoderLayer(args) for _ in range(args.encoder_layers)]
)
del self.transformer_layers
def forward(self, src_tokens, src_lengths):
x, input_lengths = self.subsample(src_tokens, src_lengths)
x = self.embed_scale * x
......@@ -109,7 +109,7 @@ class S2TConformerEncoder(S2TTransformerEncoder):
x = self.dropout_module(x)
positions = self.dropout_module(positions)
for layer in self.conformer_layers:
for layer in self.layers:
x = layer(x, encoder_padding_mask, pos_emb=positions)
if self.layer_norm is not None:
......
......@@ -247,26 +247,28 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TTransformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info(
f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}"
)
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
return encoder
@classmethod
def build_decoder(cls, args, task, embed_tokens):
decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None):
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info(
f"loaded pretrained decoder from: "
f"{args.load_pretrained_decoder_from}"
)
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from, strict=False
)
return decoder
@classmethod
......@@ -346,7 +348,7 @@ class S2TTransformerEncoder(FairseqEncoder):
args.max_source_positions, args.encoder_embed_dim, self.padding_idx, pos_emb_type=self.attn_type
)
self.transformer_layers = nn.ModuleList(
self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
)
if args.encoder_normalize_before:
......@@ -372,6 +374,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.ctc_dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.softmax = nn.Softmax(dim=-1)
def forward(self, src_tokens, src_lengths):
x, input_lengths = self.subsample(src_tokens, src_lengths)
......@@ -384,7 +387,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x = self.dropout_module(x)
positions = self.dropout_module(positions)
for layer in self.transformer_layers:
for layer in self.layers:
x = layer(x, encoder_padding_mask, pos_emb=positions)
if self.layer_norm is not None:
......@@ -404,17 +407,20 @@ class S2TTransformerEncoder(FairseqEncoder):
def compute_ctc_logit(self, encoder_out):
assert self.use_ctc, "CTC is not available!"
if isinstance(encoder_out, dict) and "encoder_out" in encoder_out:
encoder_state = encoder_out["encoder_out"][0]
else:
encoder_state = encoder_out
ctc_logit = self.ctc_projection(self.ctc_dropout_module(encoder_state))
return ctc_logit
def compute_ctc_prob(self, encoder_out):
def compute_ctc_prob(self, encoder_out, temperature=1.0):
assert self.use_ctc, "CTC is not available!"
ctc_logit = self.compute_ctc_logit(encoder_out)
ctc_logit = self.compute_ctc_logit(encoder_out) / temperature
return ctc_logit.Softmax(dim=-1)
return self.softmax(ctc_logit)
def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论