Commit 090afc4d by xuchen

implement the dlcl for s2t task

parent db29e01d
......@@ -55,8 +55,10 @@ class DLCLTransformerModel(TransformerModel):
# dense layer parameters
parser.add_argument('--encoder-history-type',
default="learnable_dense",
help='encoder layer history type')
parser.add_argument('--decoder-history-type',
default="learnable_dense",
help='decoder layer history type')
parser.add_argument('--encoder-integration-type', choices=['avg', 'sum'],
help='encoder layer integration type')
......
......@@ -99,6 +99,9 @@ class S2TConformerEncoder(S2TTransformerEncoder):
)
def forward(self, src_tokens, src_lengths):
if self.history is not None:
self.history.clean()
x, input_lengths = self.subsample(src_tokens, src_lengths)
x = self.embed_scale * x
......@@ -109,8 +112,19 @@ class S2TConformerEncoder(S2TTransformerEncoder):
x = self.dropout_module(x)
positions = self.dropout_module(positions)
# add emb into history
if self.history is not None:
self.history.add(x)
for layer in self.layers:
if self.history is not None:
x = self.history.pop()
x = layer(x, encoder_padding_mask, pos_emb=positions)
if self.history is not None:
self.history.add(x)
if self.history is not None:
x = self.history.pop()
if self.layer_norm is not None:
x = self.layer_norm(x)
......
......@@ -16,13 +16,15 @@ from fairseq.models.speech_to_text import (
S2TTransformerModel,
S2TTransformerEncoder,
S2TConformerEncoder,
S2TConformerModel)
S2TConformerModel
)
from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler
from fairseq.modules import (
FairseqDropout,
LayerNorm,
PositionalEmbedding,
TransformerEncoderLayer,
LearnableDenseLayerHistory
)
logger = logging.getLogger(__name__)
......@@ -208,10 +210,17 @@ class TextEncoder(FairseqEncoder):
else:
self.layer_norm = None
def forward(self, x, encoder_padding_mask=None, positions=None):
def forward(self, x, encoder_padding_mask=None, positions=None, history=None):
for layer in self.layers:
if history is not None:
x = history.pop()
x = layer(x, encoder_padding_mask, pos_emb=positions)
if history is not None:
history.add(x)
if history is not None:
x = history.pop()
if self.layer_norm is not None:
x = self.layer_norm(x)
......@@ -241,7 +250,16 @@ class S2TSATEEncoder(FairseqEncoder):
# text encoder
self.text_encoder = TextEncoder(args, embed_tokens)
if getattr(args, "use_enc_dlcl", False):
normalize_before = args.encoder_normalize_before
layer_num = args.encoder_layers + args.text_encoder_layers + 1
self.history = LearnableDenseLayerHistory(normalize_before, layer_num, args.encoder_embed_dim, True)
else:
self.history = None
def forward(self, src_tokens, src_lengths):
if self.history is not None:
self.history.clean()
acoustic_encoder_out = self.acoustic_encoder(src_tokens, src_lengths)
......@@ -254,7 +272,18 @@ class S2TSATEEncoder(FairseqEncoder):
x, positions = self.adapter(x, encoder_padding_mask)
x = self.text_encoder(x, encoder_padding_mask, positions)
if self.history is not None:
acoustic_history = self.acoustic_encoder.history
layer_num = acoustic_history.layer_num
idx = torch.arange(layer_num).unsqueeze(0).T.repeat(1, layer_num).to(x.device)
self.history.weight.scatter(0, idx, acoustic_history.weight)
self.history.layers.extend(acoustic_history.layers)
self.history.count = acoustic_history.count
self.history.sum = acoustic_history.sum
self.history.add(x)
x = self.text_encoder(x, encoder_padding_mask, positions, self.history)
return {
"ctc_logit": [ctc_logit], # T x B x C
......
......@@ -19,6 +19,7 @@ from fairseq.modules import (
LayerNorm,
PositionalEmbedding,
TransformerEncoderLayer,
CreateLayerHistory,
)
from torch import Tensor
......@@ -247,6 +248,28 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
metavar="STR",
help="freeze the module of the decoder",
)
parser.add_argument(
"--use-enc-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
"--use-dec-dlcl",
default=False,
action='store_true',
help="use dlcl encoder",
)
parser.add_argument(
'--encoder-history-type',
default="learnable_dense",
help='encoder layer history type'
)
parser.add_argument(
'--decoder-history-type',
default="learnable_dense",
help='decoder layer history type'
)
pass
@classmethod
......@@ -362,6 +385,11 @@ class S2TTransformerEncoder(FairseqEncoder):
else:
self.layer_norm = None
if getattr(args, "use_enc_dlcl", False):
self.history = CreateLayerHistory(args, is_encoder=True)
else:
self.history = None
self.use_ctc = "sate" in args.arch or \
(("ctc" in getattr(args, "criterion", False)) and \
(getattr(args, "ctc_weight", False) > 0))
......@@ -384,6 +412,10 @@ class S2TTransformerEncoder(FairseqEncoder):
self.softmax = nn.Softmax(dim=-1)
def forward(self, src_tokens, src_lengths):
if self.history is not None:
self.history.clean()
x, input_lengths = self.subsample(src_tokens, src_lengths)
x = self.embed_scale * x
......@@ -394,8 +426,19 @@ class S2TTransformerEncoder(FairseqEncoder):
x = self.dropout_module(x)
positions = self.dropout_module(positions)
# add emb into history
if self.history is not None:
self.history.add(x)
for layer in self.layers:
if self.history is not None:
x = self.history.pop()
x = layer(x, encoder_padding_mask, pos_emb=positions)
if self.history is not None:
self.history.add(x)
if self.history is not None:
x = self.history.pop()
if self.layer_norm is not None:
x = self.layer_norm(x)
......
......@@ -27,6 +27,7 @@ from fairseq.modules import (
SinusoidalPositionalEmbedding,
TransformerDecoderLayer,
TransformerEncoderLayer,
CreateLayerHistory
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
......@@ -778,6 +779,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
else:
self.layer_norm = None
if getattr(args, "use_dec_dlcl", False):
self.history = CreateLayerHistory(args, is_encoder=False)
else:
self.history = None
self.project_out_dim = (
Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
......@@ -913,6 +919,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
if self.history is not None:
self.history.clean()
if alignment_layer is None:
alignment_layer = self.num_layers - 1
......@@ -948,6 +957,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# add emb into history
if self.history is not None:
self.history.add(x)
self_attn_padding_mask: Optional[Tensor] = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
......@@ -956,6 +969,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
attn: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x]
for idx, layer in enumerate(self.layers):
if self.history is not None:
x = self.history.pop()
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
......@@ -982,6 +998,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if self.history is not None:
self.history.add(x)
if attn is not None:
if alignment_heads is not None:
......@@ -990,6 +1008,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# average probabilities over heads
attn = attn.mean(dim=0)
if self.history is not None:
x = self.history.pop()
if self.layer_norm is not None:
x = self.layer_norm(x)
......
......@@ -21,7 +21,7 @@ from .grad_multiply import GradMultiply
from .gumbel_vector_quantizer import GumbelVectorQuantizer
from .kmeans_vector_quantizer import KmeansVectorQuantizer
from .layer_drop import LayerDropModuleList
from .layer_history import CreateLayerHistory
from .layer_history import CreateLayerHistory, LearnableDenseLayerHistory
from .layer_norm import Fp32LayerNorm, LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
......@@ -65,6 +65,7 @@ __all__ = [
"KmeansVectorQuantizer",
"LayerDropModuleList",
"LayerNorm",
"LearnableDenseLayerHistory",
"LearnedPositionalEmbedding",
"LightweightConv1dTBC",
"LightweightConv",
......
......@@ -7,35 +7,42 @@ import numpy as np
def CreateLayerHistory(args, is_encoder):
history_type = args.encoder_history_type if is_encoder else args.decoder_history_type
normalize_before = args.encoder_normalize_before if is_encoder else args.decoder_normalize_before
layer_num = args.encoder_layers if is_encoder else args.decoder_layers
dim = args.encoder_embed_dim if is_encoder else args.decoder_embed_dim
if history_type is None:
return None
elif history_type == "residual":
return ResidualLayerHistory(args, is_encoder)
return ResidualLayerHistory(normalize_before, layer_num, dim, is_encoder)
elif history_type == "dense":
return DenseLayerHistory(args, is_encoder)
integration_type = getattr(args, 'encoder_integration_type', 'avg') if is_encoder else \
getattr(args, 'decoder_integration_type', 'avg')
windows_size = getattr(args, 'encoder_windows_size', -1) if is_encoder else \
getattr(args, 'decoder_windows_size', -1)
return DenseLayerHistory(normalize_before, layer_num, dim, is_encoder, integration_type, windows_size)
elif history_type == "learnable_dense":
return LearnableDenseLayerHistory(args, is_encoder)
return LearnableDenseLayerHistory(normalize_before, layer_num, dim, is_encoder)
elif history_type == "learnable_dense_mask":
return LearnableDenseMaskLayerHistory(args, is_encoder)
return LearnableDenseMaskLayerHistory(normalize_before, layer_num, dim, is_encoder)
elif history_type == "learnable_dense_nonorm":
return LearnableDenseNoNormLayerHistory(args, is_encoder)
return LearnableDenseNoNormLayerHistory(normalize_before, layer_num, dim, is_encoder)
elif history_type == "gru":
return GruLayerHistory(args, is_encoder)
return GruLayerHistory(normalize_before, layer_num, dim, is_encoder)
else:
raise ValueError
class BaseLayerHistory(nn.Module):
def __init__(self, args, is_encoder):
def __init__(self, normalize_before, layer_num, dim, is_encoder):
super(BaseLayerHistory, self).__init__()
self.is_encoder = is_encoder
self.normalize_before = args.encoder_normalize_before if is_encoder else args.decoder_normalize_before
self.normalize_before = normalize_before
# the first layer (aka. embedding layer) does not have layer normalization
layers = args.encoder_layers if is_encoder else args.decoder_layers
dim = args.encoder_embed_dim if is_encoder else args.decoder_embed_dim
self.layer_norms = nn.ModuleList(LayerNorm(dim) for _ in range(layers))
self.layer_norms = nn.ModuleList(LayerNorm(dim) for _ in range(layer_num))
def add(self, layer):
raise NotImplemented
......@@ -52,8 +59,8 @@ class ResidualLayerHistory(BaseLayerHistory):
x_n = x_{n-1} + y_{n-1}
"""
def __init__(self, args, is_encoder):
super(ResidualLayerHistory, self).__init__(args, is_encoder)
def __init__(self, normalize_before, layer_num, dim, is_encoder):
super(ResidualLayerHistory, self).__init__(normalize_before, layer_num, dim, is_encoder)
self.count = 0
self.x = None
self.y = None
......@@ -90,19 +97,17 @@ class DenseLayerHistory(BaseLayerHistory):
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(DenseLayerHistory, self).__init__(args, is_encoder)
def __init__(self, normalize_before, layer_num, dim, is_encoder, integration_type, windows_size):
super(DenseLayerHistory, self).__init__(normalize_before, layer_num, dim, is_encoder)
self.sum = None
self.count = 0
self.individuals = None # store past individual value, used for windows_size > 0
self.integration_type = getattr(args, 'encoder_integration_type', 'avg') if is_encoder else \
getattr(args, 'decoder_integration_type', 'avg')
self.integration_type = integration_type
# windows = 1 means not use residual connection
self.windows_size = getattr(args, 'encoder_windows_size', -1) if is_encoder else \
getattr(args, 'decoder_windows_size', -1)
self.windows_size = windows_size
if self.windows_size > 0:
assert self.windows_size <= (args.encoder_layers + 1) if is_encoder else (args.decoder_layers + 1)
assert self.windows_size <= 1 + layer_num
self.individuals = queue.Queue(self.windows_size)
def add(self, layer):
......@@ -151,13 +156,14 @@ class LearnableDenseLayerHistory(BaseLayerHistory):
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(LearnableDenseLayerHistory, self).__init__(args, is_encoder)
def __init__(self, normalize_before, layer_num, dim, is_encoder):
super(LearnableDenseLayerHistory, self).__init__(normalize_before, layer_num, dim, is_encoder)
self.sum = None
self.count = 0
self.layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers)
self.layer_num = 1 + layer_num
self.weight = nn.Parameter(torch.Tensor(self.layer_num, self.layer_num).fill_(1.0).tril())
self.weight.data = self.weight.data / self.weight.data.sum(1, keepdim=True)
self.layers = []
def extra_repr(self):
return 'n_layers={layer_num}, '.format(**self.__dict__)
......@@ -198,11 +204,11 @@ class LearnableDenseMaskLayerHistory(BaseLayerHistory):
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(LearnableDenseMaskLayerHistory, self).__init__(args, is_encoder)
def __init__(self, normalize_before, layer_num, dim, is_encoder):
super(LearnableDenseMaskLayerHistory, self).__init__(normalize_before, layer_num, dim, is_encoder)
self.sum = None
self.count = 0
self.layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers)
self.layer_num = 1 + layer_num
if is_encoder:
self.weight_mask = np.loadtxt("encoder_mask.txt", dtype=float, delimiter=' ')
else:
......@@ -246,11 +252,11 @@ class LearnableDenseNoNormLayerHistory(BaseLayerHistory):
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(LearnableDenseNoNormLayerHistory, self).__init__(args, is_encoder)
def __init__(self, normalize_before, layer_num, dim, is_encoder):
super(LearnableDenseNoNormLayerHistory, self).__init__(normalize_before, layer_num, dim, is_encoder)
self.sum = None
self.count = 0
self.layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers)
self.layer_num = 1 + layer_num
self.weight = nn.Parameter(torch.Tensor(self.layer_num, self.layer_num).fill_(1.0).tril())
self.weight.data = self.weight.data / self.weight.data.sum(1, keepdim=True)
self.layers = []
......@@ -286,13 +292,13 @@ class GruLayerHistory(BaseLayerHistory):
x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n
"""
def __init__(self, args, is_encoder):
super(GruLayerHistory, self).__init__(args, is_encoder)
def __init__(self, normalize_before, layer_num, dim, is_encoder):
super(GruLayerHistory, self).__init__(normalize_before, layer_num, dim, is_encoder)
self.count = 0
self.gru = nn.GRUCell(args.encoder_embed_dim, args.encoder_embed_dim)
self.gru = nn.GRUCell(dim)
self.gru_cells = []
self.layer_norms = nn.ModuleList(LayerNorm(args.encoder_embed_dim) for _ in range(args.decoder_layers + 1))
self.decoder_layers = args.decoder_layers
self.layer_norms = nn.ModuleList(LayerNorm(dim) for _ in range(layer_num + 1))
self.decoder_layers = layer_num
def compute_gru(self, layer_output):
if len(self.gru_cells) == 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论