Commit 408e2b95 by xuchen

fix the bugs during prepare and CTC decoding

parent 8f084189
...@@ -185,7 +185,7 @@ class AudioDataset(Dataset): ...@@ -185,7 +185,7 @@ class AudioDataset(Dataset):
if need_waveform: if need_waveform:
offset = item.get('offset', False) offset = item.get('offset', False)
if offset: if offset is not False:
waveform, sample_rate = torchaudio.load(audio, waveform, sample_rate = torchaudio.load(audio,
frame_offset=offset, frame_offset=offset,
num_frames=item["n_frames"]) num_frames=item["n_frames"])
...@@ -331,7 +331,7 @@ def process(args): ...@@ -331,7 +331,7 @@ def process(args):
audio_path = item["audio"] audio_path = item["audio"]
# add offset and frames info # add offset and frames info
if item.get("offset", False): if item.get("offset", False) is not False:
audio_path = f"{audio_path}:{item['offset']}:{n_frames}" audio_path = f"{audio_path}:{item['offset']}:{n_frames}"
manifest["audio"].append(audio_path) manifest["audio"].append(audio_path)
else: else:
......
...@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) ...@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class CtcCriterionConfig(FairseqDataclass): class CtcCriterionConfig(FairseqDataclass):
zero_infinity: bool = field( zero_infinity: bool = field(
default=False, default=True,
metadata={"help": "zero inf loss when source length <= target length"}, metadata={"help": "zero inf loss when source length <= target length"},
) )
sentence_avg: bool = II("optimization.sentence_avg") sentence_avg: bool = II("optimization.sentence_avg")
......
...@@ -882,7 +882,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder): ...@@ -882,7 +882,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if self.inter_ctc: if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer) logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
embed_dim = self.pds_embed_dims[-1] # embed_dim = self.pds_embed_dims[-1]
embed_dim = self.embed_dim
if self.inter_ctc: if self.inter_ctc:
ctc_layer = self.ctc_layer ctc_layer = self.ctc_layer
for i in range(self.pds_stages): for i in range(self.pds_stages):
......
...@@ -638,8 +638,9 @@ class CTCDecoder(object): ...@@ -638,8 +638,9 @@ class CTCDecoder(object):
src_lengths=src_lengths) src_lengths=src_lengths)
ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1) ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1)
logit_length = (~encoder_outs["encoder_padding_mask"][0]).long().sum(-1)
beam_results, beam_scores, time_steps, out_lens = self.ctc_decoder.decode( beam_results, beam_scores, time_steps, out_lens = self.ctc_decoder.decode(
utils.softmax(ctc_logit, -1), src_lengths utils.softmax(ctc_logit, -1), logit_length
) )
finalized = [] finalized = []
......
...@@ -10,6 +10,7 @@ import math ...@@ -10,6 +10,7 @@ import math
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import logging
from fairseq.modules.rotary_positional_embedding import ( from fairseq.modules.rotary_positional_embedding import (
RotaryPositionalEmbedding, RotaryPositionalEmbedding,
apply_rotary_pos_emb, apply_rotary_pos_emb,
...@@ -76,12 +77,17 @@ class ESPNETMultiHeadedAttention(nn.Module): ...@@ -76,12 +77,17 @@ class ESPNETMultiHeadedAttention(nn.Module):
-1e8 if scores.dtype == torch.float32 else -1e4 -1e8 if scores.dtype == torch.float32 else -1e4
# float("-inf"), # (batch, head, time1, time2) # float("-inf"), # (batch, head, time1, time2)
) )
# self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
scores = scores.clamp(min=-1e8 if scores.dtype == torch.float32 else -1e4, scores = scores.clamp(min=-1e8 if scores.dtype == torch.float32 else -1e4,
max=1e8 if scores.dtype == torch.float32 else 1e4) max=1e8 if scores.dtype == torch.float32 else 1e4)
self.attn = F.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores) # (batch, head, time1, time2) self.attn = F.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores) # (batch, head, time1, time2)
if torch.isnan(self.attn).any(): if torch.isnan(self.attn).any():
import logging logging.warning("Tensor attention scores has nan.")
logging.error("Tensor attention scores has nan.") # torch.save(scores, "scores.pt")
# torch.save(self.attn, "attn.pt")
# exit()
p_attn = self.dropout(self.attn) p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
......
...@@ -350,6 +350,9 @@ class MultiheadAttention(nn.Module): ...@@ -350,6 +350,9 @@ class MultiheadAttention(nn.Module):
if before_softmax: if before_softmax:
return attn_weights, v return attn_weights, v
attn_weights = attn_weights.clamp(min=-1e8 if attn_weights.dtype == torch.float32 else -1e4,
max=1e8 if attn_weights.dtype == torch.float32 else 1e4)
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
attn_weights = attn_weights_float.type_as(attn_weights) attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights) attn_probs = self.dropout_module(attn_weights)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论