Commit 6358474e by xuchen

fix the yaml file

parent 537d8744
...@@ -35,6 +35,7 @@ class FeedForwardModule(torch.nn.Module): ...@@ -35,6 +35,7 @@ class FeedForwardModule(torch.nn.Module):
dropout2, dropout2,
activation_fn="relu", activation_fn="relu",
bias=True, bias=True,
output_feat=None,
): ):
""" """
Args: Args:
...@@ -47,8 +48,10 @@ class FeedForwardModule(torch.nn.Module): ...@@ -47,8 +48,10 @@ class FeedForwardModule(torch.nn.Module):
""" """
super(FeedForwardModule, self).__init__() super(FeedForwardModule, self).__init__()
if output_feat is None:
output_feat = input_feat
self.w_1 = torch.nn.Linear(input_feat, hidden_units, bias=bias) self.w_1 = torch.nn.Linear(input_feat, hidden_units, bias=bias)
self.w_2 = torch.nn.Linear(hidden_units, input_feat, bias=bias) self.w_2 = torch.nn.Linear(hidden_units, output_feat, bias=bias)
self.dropout1 = torch.nn.Dropout(dropout1) self.dropout1 = torch.nn.Dropout(dropout1)
self.dropout2 = torch.nn.Dropout(dropout2) self.dropout2 = torch.nn.Dropout(dropout2)
self.activation = get_activation_class(activation_fn) self.activation = get_activation_class(activation_fn)
...@@ -132,13 +135,6 @@ class S2TTransformerS2EncoderLayer(nn.Module): ...@@ -132,13 +135,6 @@ class S2TTransformerS2EncoderLayer(nn.Module):
self.conv_module = None self.conv_module = None
self.final_norm = None self.final_norm = None
self.ffn = FeedForwardModule(
embed_dim,
ffn_dim,
dropout,
dropout,
activation
)
self.ffn_norm = LayerNorm(self.embed_dim) self.ffn_norm = LayerNorm(self.embed_dim)
self.s2_norm = LayerNorm(self.embed_dim) self.s2_norm = LayerNorm(self.embed_dim)
...@@ -160,6 +156,15 @@ class S2TTransformerS2EncoderLayer(nn.Module): ...@@ -160,6 +156,15 @@ class S2TTransformerS2EncoderLayer(nn.Module):
self.league_drop_net_prob = args.encoder_league_drop_net_prob self.league_drop_net_prob = args.encoder_league_drop_net_prob
self.league_drop_net_mix = args.encoder_league_drop_net_mix self.league_drop_net_mix = args.encoder_league_drop_net_mix
self.ffn = FeedForwardModule(
embed_dim,
ffn_dim,
dropout,
dropout,
activation,
output_feat=embed_dim * 2 if self.encoder_collaboration_mode == "concat" else None
)
def get_ratio(self): def get_ratio(self):
if self.league_drop_net: if self.league_drop_net:
frand = float(uniform(0, 1)) frand = float(uniform(0, 1))
...@@ -319,7 +324,7 @@ class S2TTransformerS2EncoderLayer(nn.Module): ...@@ -319,7 +324,7 @@ class S2TTransformerS2EncoderLayer(nn.Module):
attn_mask=attn_mask, attn_mask=attn_mask,
) )
x = self.dropout_module(x) x = self.dropout_module(x)
if s2 is None or self.encoder_collaboration_mode != "parallel": if s2 is None or self.encoder_collaboration_mode == "serial":
x = self.residual_connection(x, residual) x = self.residual_connection(x, residual)
if not self.normalize_before: if not self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
...@@ -338,7 +343,7 @@ class S2TTransformerS2EncoderLayer(nn.Module): ...@@ -338,7 +343,7 @@ class S2TTransformerS2EncoderLayer(nn.Module):
) )
x = self.dropout_module(x) x = self.dropout_module(x)
x = self.residual_connection(x, residual) x = self.residual_connection(x, residual)
elif self.encoder_collaboration_mode == "parallel": elif self.encoder_collaboration_mode in ["parallel", "concat"]:
x2, _ = self.s2_attn( x2, _ = self.s2_attn(
query=attn_x, query=attn_x,
key=s2, key=s2,
......
...@@ -65,6 +65,7 @@ class Adapter(nn.Module): ...@@ -65,6 +65,7 @@ class Adapter(nn.Module):
self.adapter_type = adapter_type self.adapter_type = adapter_type
self.cal_linear = False self.cal_linear = False
self.cal_context = False self.cal_context = False
self.shrink = False
if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]: if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
self.cal_linear = True self.cal_linear = True
...@@ -75,7 +76,8 @@ class Adapter(nn.Module): ...@@ -75,7 +76,8 @@ class Adapter(nn.Module):
LayerNorm(dim), LayerNorm(dim),
) )
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]: if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league",
"league_shrink", "inter_league_shrink", "context_shrink"]:
self.cal_context = True self.cal_context = True
self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization
nn.init.normal_(self.embed_adapter.weight, mean=0, std=dim ** -0.5) nn.init.normal_(self.embed_adapter.weight, mean=0, std=dim ** -0.5)
...@@ -92,10 +94,11 @@ class Adapter(nn.Module): ...@@ -92,10 +94,11 @@ class Adapter(nn.Module):
self.gate_linear2 = nn.Linear(dim, dim) self.gate_linear2 = nn.Linear(dim, dim)
# additional strategy # additional strategy
if self.adapter_type == "shrink": if self.adapter_type in ["shrink", "league_shrink", "inter_league_shrink", "context_shrink"]:
assert strategy is not None assert strategy is not None
ctc_compress_strategy = strategy.get("ctc_compress_strategy", "avg") ctc_compress_strategy = strategy.get("ctc_compress_strategy", "avg")
self.ctc_compress = getattr(CTCCompressStrategy, ctc_compress_strategy) self.ctc_compress = getattr(CTCCompressStrategy, ctc_compress_strategy)
self.shrink = True
logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy) logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy)
if self.cal_context: if self.cal_context:
...@@ -125,46 +128,49 @@ class Adapter(nn.Module): ...@@ -125,46 +128,49 @@ class Adapter(nn.Module):
representation, logit = x representation, logit = x
seq_len, bsz, dim = representation.size() seq_len, bsz, dim = representation.size()
linear_out = None distribution = None
soft_out = None if self.cal_context or self.shrink:
if self.cal_linear:
linear_out = self.linear_adapter(representation)
if self.cal_context:
if self.training and self.gumbel: if self.training and self.gumbel:
distribution = F.gumbel_softmax(logit, tau=self.distribution_temperature, hard=self.distribution_hard) distribution = F.gumbel_softmax(logit, tau=self.distribution_temperature, hard=self.distribution_hard)
else: else:
distribution = F.softmax(logit / self.distribution_temperature, dim=-1) distribution = F.softmax(logit / self.distribution_temperature, dim=-1)
linear_out = None
soft_out = None
out = None
if self.cal_linear:
linear_out = self.linear_adapter(representation)
if self.cal_context:
vocab_size = distribution.size(-1) vocab_size = distribution.size(-1)
distribution = distribution.contiguous().view(-1, vocab_size) distribution_2d = distribution.contiguous().view(-1, vocab_size)
org_distribution = distribution
if self.distribution_cutoff is not None: if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), vocab_size - 1) pass
# cutoff = min(int(self.distribution_cutoff), vocab_size - 1)
# threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1] # threshold = distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
# distribution = torch.where( # distribution_2d = torch.where(
# org_distribution > threshold, org_distribution, torch.zeros_like(org_distribution) # distribution > threshold, distribution, torch.zeros_like(distribution)
# ) # )
# threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True) # threshold = distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True)
# distribution = torch.where( # distribution_2d = torch.where(
# threshold > 0.9, org_distribution, torch.zeros_like(org_distribution) # threshold > 0.9, distribution, torch.zeros_like(distribution)
# ) # )
# distribution = distribution.view(-1, vocab_size) # distribution_2d = distribution_2d.view(-1, vocab_size)
distribution[:, 0] = 0 # distribution_2d[:, 0] = 0
distribution = distribution / distribution.sum(-1, keepdim=True) # distribution_2d = distribution_2d / distribution_2d.sum(-1, keepdim=True)
if self.ground_truth_ratio > 0 and oracle is not None: if self.ground_truth_ratio > 0 and oracle is not None:
oracle = oracle.unsqueeze(-1) oracle = oracle.unsqueeze(-1)
oracle_one_hot = (oracle == torch.arange(vocab_size, device=oracle.device).unsqueeze(0)).\ oracle_one_hot = (oracle == torch.arange(vocab_size, device=oracle.device).unsqueeze(0)). \
to(distribution.dtype).transpose(0, 1) to(distribution.dtype).transpose(0, 1)
oracle_mask = oracle_mask.transpose(0, 1).unsqueeze(-1).repeat(1, 1, vocab_size) oracle_mask = oracle_mask.transpose(0, 1).unsqueeze(-1).repeat(1, 1, vocab_size)
modify_dist = oracle_mask * oracle_one_hot + ~oracle_mask * org_distribution modify_dist = oracle_mask * oracle_one_hot + ~oracle_mask * distribution
soft_out = torch.mm(modify_dist.view(-1, vocab_size), self.embed_adapter.weight).view(seq_len, bsz, -1) soft_out = torch.mm(modify_dist.view(-1, vocab_size), self.embed_adapter.weight).view(seq_len, bsz, -1)
else: else:
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1) soft_out = torch.mm(distribution_2d, self.embed_adapter.weight).view(seq_len, bsz, -1)
if self.embed_norm: if self.embed_norm:
soft_out = self.embed_ln(soft_out) soft_out = self.embed_ln(soft_out)
...@@ -175,7 +181,7 @@ class Adapter(nn.Module): ...@@ -175,7 +181,7 @@ class Adapter(nn.Module):
elif self.adapter_type == "context": elif self.adapter_type == "context":
out = soft_out out = soft_out
elif self.adapter_type == "league": elif self.adapter_type in ["league", "inter_league_shrink"]:
if self.training and self.drop_prob > 0 and torch.rand(1).uniform_() < self.drop_prob: if self.training and self.drop_prob > 0 and torch.rand(1).uniform_() < self.drop_prob:
if torch.rand(1).uniform_() < 0.5: if torch.rand(1).uniform_() < 0.5:
out = linear_out out = linear_out
...@@ -188,17 +194,17 @@ class Adapter(nn.Module): ...@@ -188,17 +194,17 @@ class Adapter(nn.Module):
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid() coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league": elif self.adapter_type in ["inter_league", "inter_league_shrink"]:
out = representation + soft_out out = representation + soft_out
elif self.adapter_type == "none": elif self.adapter_type == "none":
out = representation out = representation
elif self.adapter_type == "shrink": elif self.adapter_type in ["shrink", "league_shrink", "inter_league_shrink", "context_shrink"]:
if self.training and self.gumbel: if self.adapter_type in ["league_shrink", "inter_league_shrink"]:
distribution = F.gumbel_softmax(logit, tau=self.distribution_temperature, hard=self.distribution_hard) representation = out
else: elif self.adapter_type in ["context_shrink"]:
distribution = F.softmax(logit / self.distribution_temperature, dim=-1) representation = soft_out
lengths = (~padding).long().sum(-1) lengths = (~padding).long().sum(-1)
with torch.no_grad(): with torch.no_grad():
......
...@@ -79,7 +79,7 @@ class TransformerS2EncoderLayer(nn.Module): ...@@ -79,7 +79,7 @@ class TransformerS2EncoderLayer(nn.Module):
if self.use_se: if self.use_se:
self.se_attn = SEAttention(self.embed_dim, 16) self.se_attn = SEAttention(self.embed_dim, 16)
self.use_s2_attn_norm = args.use_s2_attn_norm self.use_s2_attn_norm = getattr(args, "use_s2_attn_norm", True)
if self.use_s2_attn_norm: if self.use_s2_attn_norm:
self.s2_norm = LayerNorm(self.embed_dim) self.s2_norm = LayerNorm(self.embed_dim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论