Commit 4cffbd98 by xuchen

the fix version of sate

parent 8645e75b
...@@ -100,6 +100,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC( ...@@ -100,6 +100,9 @@ class LabelSmoothedCrossEntropyCriterionWithCTC(
def compute_ctc_loss(self, model, sample, encoder_out): def compute_ctc_loss(self, model, sample, encoder_out):
transcript = sample["transcript"] transcript = sample["transcript"]
if "ctc_logit" in encoder_out:
ctc_logit = encoder_out["ctc_logit"][0]
else:
ctc_logit = model.encoder.compute_ctc_logit(encoder_out) ctc_logit = model.encoder.compute_ctc_logit(encoder_out)
lprobs = model.get_normalized_probs( lprobs = model.get_normalized_probs(
[ctc_logit], log_probs=True [ctc_logit], log_probs=True
......
...@@ -7,3 +7,4 @@ from .berard import * # noqa ...@@ -7,3 +7,4 @@ from .berard import * # noqa
from .convtransformer import * # noqa from .convtransformer import * # noqa
from .s2t_transformer import * # noqa from .s2t_transformer import * # noqa
from .s2t_conformer import * # noqa from .s2t_conformer import * # noqa
from .s2t_sate import * # noqa
...@@ -92,12 +92,12 @@ class S2TConformerEncoder(S2TTransformerEncoder): ...@@ -92,12 +92,12 @@ class S2TConformerEncoder(S2TTransformerEncoder):
def __init__(self, args, task=None, embed_tokens=None): def __init__(self, args, task=None, embed_tokens=None):
super().__init__(args, task, embed_tokens) super().__init__(args, task, embed_tokens)
self.conformer_layers = nn.ModuleList( del self.layers
self.layers = nn.ModuleList(
[ConformerEncoderLayer(args) for _ in range(args.encoder_layers)] [ConformerEncoderLayer(args) for _ in range(args.encoder_layers)]
) )
del self.transformer_layers
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
x, input_lengths = self.subsample(src_tokens, src_lengths) x, input_lengths = self.subsample(src_tokens, src_lengths)
x = self.embed_scale * x x = self.embed_scale * x
...@@ -109,7 +109,7 @@ class S2TConformerEncoder(S2TTransformerEncoder): ...@@ -109,7 +109,7 @@ class S2TConformerEncoder(S2TTransformerEncoder):
x = self.dropout_module(x) x = self.dropout_module(x)
positions = self.dropout_module(positions) positions = self.dropout_module(positions)
for layer in self.conformer_layers: for layer in self.layers:
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
if self.layer_norm is not None: if self.layer_norm is not None:
......
...@@ -247,26 +247,28 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): ...@@ -247,26 +247,28 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
def build_encoder(cls, args, task=None, embed_tokens=None): def build_encoder(cls, args, task=None, embed_tokens=None):
encoder = S2TTransformerEncoder(args, task, embed_tokens) encoder = S2TTransformerEncoder(args, task, embed_tokens)
if getattr(args, "load_pretrained_encoder_from", None): if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info( logger.info(
f"loaded pretrained encoder from: " f"loaded pretrained encoder from: "
f"{args.load_pretrained_encoder_from}" f"{args.load_pretrained_encoder_from}"
) )
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
return encoder return encoder
@classmethod @classmethod
def build_decoder(cls, args, task, embed_tokens): def build_decoder(cls, args, task, embed_tokens):
decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens) decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None): if getattr(args, "load_pretrained_decoder_from", None):
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_encoder_from, strict=False
)
logger.info( logger.info(
f"loaded pretrained decoder from: " f"loaded pretrained decoder from: "
f"{args.load_pretrained_decoder_from}" f"{args.load_pretrained_decoder_from}"
) )
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from, strict=False
)
return decoder return decoder
@classmethod @classmethod
...@@ -346,7 +348,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -346,7 +348,7 @@ class S2TTransformerEncoder(FairseqEncoder):
args.max_source_positions, args.encoder_embed_dim, self.padding_idx, pos_emb_type=self.attn_type args.max_source_positions, args.encoder_embed_dim, self.padding_idx, pos_emb_type=self.attn_type
) )
self.transformer_layers = nn.ModuleList( self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
) )
if args.encoder_normalize_before: if args.encoder_normalize_before:
...@@ -372,6 +374,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -372,6 +374,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.ctc_dropout_module = FairseqDropout( self.ctc_dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__ p=args.dropout, module_name=self.__class__.__name__
) )
self.softmax = nn.Softmax(dim=-1)
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
x, input_lengths = self.subsample(src_tokens, src_lengths) x, input_lengths = self.subsample(src_tokens, src_lengths)
...@@ -384,7 +387,7 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -384,7 +387,7 @@ class S2TTransformerEncoder(FairseqEncoder):
x = self.dropout_module(x) x = self.dropout_module(x)
positions = self.dropout_module(positions) positions = self.dropout_module(positions)
for layer in self.transformer_layers: for layer in self.layers:
x = layer(x, encoder_padding_mask, pos_emb=positions) x = layer(x, encoder_padding_mask, pos_emb=positions)
if self.layer_norm is not None: if self.layer_norm is not None:
...@@ -404,17 +407,20 @@ class S2TTransformerEncoder(FairseqEncoder): ...@@ -404,17 +407,20 @@ class S2TTransformerEncoder(FairseqEncoder):
def compute_ctc_logit(self, encoder_out): def compute_ctc_logit(self, encoder_out):
assert self.use_ctc, "CTC is not available!" assert self.use_ctc, "CTC is not available!"
if isinstance(encoder_out, dict) and "encoder_out" in encoder_out:
encoder_state = encoder_out["encoder_out"][0] encoder_state = encoder_out["encoder_out"][0]
else:
encoder_state = encoder_out
ctc_logit = self.ctc_projection(self.ctc_dropout_module(encoder_state)) ctc_logit = self.ctc_projection(self.ctc_dropout_module(encoder_state))
return ctc_logit return ctc_logit
def compute_ctc_prob(self, encoder_out): def compute_ctc_prob(self, encoder_out, temperature=1.0):
assert self.use_ctc, "CTC is not available!" assert self.use_ctc, "CTC is not available!"
ctc_logit = self.compute_ctc_logit(encoder_out) ctc_logit = self.compute_ctc_logit(encoder_out) / temperature
return ctc_logit.Softmax(dim=-1) return self.softmax(ctc_logit)
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = ( new_encoder_out = (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论