Commit 408e2b95 by xuchen

fix the bugs during prepare and CTC decoding

parent 8f084189
......@@ -185,7 +185,7 @@ class AudioDataset(Dataset):
if need_waveform:
offset = item.get('offset', False)
if offset:
if offset is not False:
waveform, sample_rate = torchaudio.load(audio,
frame_offset=offset,
num_frames=item["n_frames"])
......@@ -331,7 +331,7 @@ def process(args):
audio_path = item["audio"]
# add offset and frames info
if item.get("offset", False):
if item.get("offset", False) is not False:
audio_path = f"{audio_path}:{item['offset']}:{n_frames}"
manifest["audio"].append(audio_path)
else:
......
......@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
@dataclass
class CtcCriterionConfig(FairseqDataclass):
zero_infinity: bool = field(
default=False,
default=True,
metadata={"help": "zero inf loss when source length <= target length"},
)
sentence_avg: bool = II("optimization.sentence_avg")
......
......@@ -882,7 +882,8 @@ class PDSS2TTransformerEncoder(FairseqEncoder):
if self.inter_ctc:
logger.info("Intermedia CTC loss in layer %d" % self.ctc_layer)
embed_dim = self.pds_embed_dims[-1]
# embed_dim = self.pds_embed_dims[-1]
embed_dim = self.embed_dim
if self.inter_ctc:
ctc_layer = self.ctc_layer
for i in range(self.pds_stages):
......
......@@ -638,8 +638,9 @@ class CTCDecoder(object):
src_lengths=src_lengths)
ctc_logit = encoder_outs["ctc_logit"][0].transpose(0, 1)
logit_length = (~encoder_outs["encoder_padding_mask"][0]).long().sum(-1)
beam_results, beam_scores, time_steps, out_lens = self.ctc_decoder.decode(
utils.softmax(ctc_logit, -1), src_lengths
utils.softmax(ctc_logit, -1), logit_length
)
finalized = []
......
......@@ -10,6 +10,7 @@ import math
import torch
from torch import nn
import torch.nn.functional as F
import logging
from fairseq.modules.rotary_positional_embedding import (
RotaryPositionalEmbedding,
apply_rotary_pos_emb,
......@@ -76,12 +77,17 @@ class ESPNETMultiHeadedAttention(nn.Module):
-1e8 if scores.dtype == torch.float32 else -1e4
# float("-inf"), # (batch, head, time1, time2)
)
# self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
scores = scores.clamp(min=-1e8 if scores.dtype == torch.float32 else -1e4,
max=1e8 if scores.dtype == torch.float32 else 1e4)
self.attn = F.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores) # (batch, head, time1, time2)
if torch.isnan(self.attn).any():
import logging
logging.error("Tensor attention scores has nan.")
logging.warning("Tensor attention scores has nan.")
# torch.save(scores, "scores.pt")
# torch.save(self.attn, "attn.pt")
# exit()
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
......
......@@ -350,6 +350,9 @@ class MultiheadAttention(nn.Module):
if before_softmax:
return attn_weights, v
attn_weights = attn_weights.clamp(min=-1e8 if attn_weights.dtype == torch.float32 else -1e4,
max=1e8 if attn_weights.dtype == torch.float32 else 1e4)
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论