Commit 6358474e by xuchen

fix the yaml file

parent 537d8744
......@@ -35,6 +35,7 @@ class FeedForwardModule(torch.nn.Module):
dropout2,
activation_fn="relu",
bias=True,
output_feat=None,
):
"""
Args:
......@@ -47,8 +48,10 @@ class FeedForwardModule(torch.nn.Module):
"""
super(FeedForwardModule, self).__init__()
if output_feat is None:
output_feat = input_feat
self.w_1 = torch.nn.Linear(input_feat, hidden_units, bias=bias)
self.w_2 = torch.nn.Linear(hidden_units, input_feat, bias=bias)
self.w_2 = torch.nn.Linear(hidden_units, output_feat, bias=bias)
self.dropout1 = torch.nn.Dropout(dropout1)
self.dropout2 = torch.nn.Dropout(dropout2)
self.activation = get_activation_class(activation_fn)
......@@ -132,13 +135,6 @@ class S2TTransformerS2EncoderLayer(nn.Module):
self.conv_module = None
self.final_norm = None
self.ffn = FeedForwardModule(
embed_dim,
ffn_dim,
dropout,
dropout,
activation
)
self.ffn_norm = LayerNorm(self.embed_dim)
self.s2_norm = LayerNorm(self.embed_dim)
......@@ -160,6 +156,15 @@ class S2TTransformerS2EncoderLayer(nn.Module):
self.league_drop_net_prob = args.encoder_league_drop_net_prob
self.league_drop_net_mix = args.encoder_league_drop_net_mix
self.ffn = FeedForwardModule(
embed_dim,
ffn_dim,
dropout,
dropout,
activation,
output_feat=embed_dim * 2 if self.encoder_collaboration_mode == "concat" else None
)
def get_ratio(self):
if self.league_drop_net:
frand = float(uniform(0, 1))
......@@ -319,7 +324,7 @@ class S2TTransformerS2EncoderLayer(nn.Module):
attn_mask=attn_mask,
)
x = self.dropout_module(x)
if s2 is None or self.encoder_collaboration_mode != "parallel":
if s2 is None or self.encoder_collaboration_mode == "serial":
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
......@@ -338,7 +343,7 @@ class S2TTransformerS2EncoderLayer(nn.Module):
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
elif self.encoder_collaboration_mode == "parallel":
elif self.encoder_collaboration_mode in ["parallel", "concat"]:
x2, _ = self.s2_attn(
query=attn_x,
key=s2,
......
......@@ -65,6 +65,7 @@ class Adapter(nn.Module):
self.adapter_type = adapter_type
self.cal_linear = False
self.cal_context = False
self.shrink = False
if self.adapter_type in ["linear", "league", "gated_league", "gated_league2"]:
self.cal_linear = True
......@@ -75,9 +76,10 @@ class Adapter(nn.Module):
LayerNorm(dim),
)
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league"]:
if self.adapter_type in ["context", "league", "gated_league", "gated_league2", "inter_league",
"league_shrink", "inter_league_shrink", "context_shrink"]:
self.cal_context = True
self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization
self.embed_adapter = nn.Linear(dim, dictionary_size, bias=False) # reverse for initialization
nn.init.normal_(self.embed_adapter.weight, mean=0, std=dim ** -0.5)
self.embed_norm = strategy.get("embed_norm", False)
if self.embed_norm:
......@@ -92,10 +94,11 @@ class Adapter(nn.Module):
self.gate_linear2 = nn.Linear(dim, dim)
# additional strategy
if self.adapter_type == "shrink":
if self.adapter_type in ["shrink", "league_shrink", "inter_league_shrink", "context_shrink"]:
assert strategy is not None
ctc_compress_strategy = strategy.get("ctc_compress_strategy", "avg")
self.ctc_compress = getattr(CTCCompressStrategy, ctc_compress_strategy)
self.shrink = True
logger.info("CTC Compress Strategy: %s" % ctc_compress_strategy)
if self.cal_context:
......@@ -125,46 +128,49 @@ class Adapter(nn.Module):
representation, logit = x
seq_len, bsz, dim = representation.size()
linear_out = None
soft_out = None
if self.cal_linear:
linear_out = self.linear_adapter(representation)
if self.cal_context:
distribution = None
if self.cal_context or self.shrink:
if self.training and self.gumbel:
distribution = F.gumbel_softmax(logit, tau=self.distribution_temperature, hard=self.distribution_hard)
else:
distribution = F.softmax(logit / self.distribution_temperature, dim=-1)
linear_out = None
soft_out = None
out = None
if self.cal_linear:
linear_out = self.linear_adapter(representation)
if self.cal_context:
vocab_size = distribution.size(-1)
distribution = distribution.contiguous().view(-1, vocab_size)
org_distribution = distribution
distribution_2d = distribution.contiguous().view(-1, vocab_size)
if self.distribution_cutoff is not None:
cutoff = min(int(self.distribution_cutoff), vocab_size - 1)
pass
# cutoff = min(int(self.distribution_cutoff), vocab_size - 1)
# threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
# distribution = torch.where(
# org_distribution > threshold, org_distribution, torch.zeros_like(org_distribution)
# threshold = distribution.sort(dim=-1, descending=True)[0][:, :, cutoff:cutoff+1]
# distribution_2d = torch.where(
# distribution > threshold, distribution, torch.zeros_like(distribution)
# )
# threshold = org_distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True)
# distribution = torch.where(
# threshold > 0.9, org_distribution, torch.zeros_like(org_distribution)
# threshold = distribution.sort(dim=-1, descending=True)[0][:, :, :cutoff].sum(-1, keepdim=True)
# distribution_2d = torch.where(
# threshold > 0.9, distribution, torch.zeros_like(distribution)
# )
# distribution = distribution.view(-1, vocab_size)
# distribution_2d = distribution_2d.view(-1, vocab_size)
distribution[:, 0] = 0
distribution = distribution / distribution.sum(-1, keepdim=True)
# distribution_2d[:, 0] = 0
# distribution_2d = distribution_2d / distribution_2d.sum(-1, keepdim=True)
if self.ground_truth_ratio > 0 and oracle is not None:
oracle = oracle.unsqueeze(-1)
oracle_one_hot = (oracle == torch.arange(vocab_size, device=oracle.device).unsqueeze(0)).\
oracle_one_hot = (oracle == torch.arange(vocab_size, device=oracle.device).unsqueeze(0)). \
to(distribution.dtype).transpose(0, 1)
oracle_mask = oracle_mask.transpose(0, 1).unsqueeze(-1).repeat(1, 1, vocab_size)
modify_dist = oracle_mask * oracle_one_hot + ~oracle_mask * org_distribution
modify_dist = oracle_mask * oracle_one_hot + ~oracle_mask * distribution
soft_out = torch.mm(modify_dist.view(-1, vocab_size), self.embed_adapter.weight).view(seq_len, bsz, -1)
else:
soft_out = torch.mm(distribution, self.embed_adapter.weight).view(seq_len, bsz, -1)
soft_out = torch.mm(distribution_2d, self.embed_adapter.weight).view(seq_len, bsz, -1)
if self.embed_norm:
soft_out = self.embed_ln(soft_out)
......@@ -175,7 +181,7 @@ class Adapter(nn.Module):
elif self.adapter_type == "context":
out = soft_out
elif self.adapter_type == "league":
elif self.adapter_type in ["league", "inter_league_shrink"]:
if self.training and self.drop_prob > 0 and torch.rand(1).uniform_() < self.drop_prob:
if torch.rand(1).uniform_() < 0.5:
out = linear_out
......@@ -188,17 +194,17 @@ class Adapter(nn.Module):
coef = (self.gate_linear(torch.cat([linear_out, soft_out], dim=-1))).sigmoid()
out = coef * linear_out + (1 - coef) * soft_out
elif self.adapter_type == "inter_league":
elif self.adapter_type in ["inter_league", "inter_league_shrink"]:
out = representation + soft_out
elif self.adapter_type == "none":
out = representation
elif self.adapter_type == "shrink":
if self.training and self.gumbel:
distribution = F.gumbel_softmax(logit, tau=self.distribution_temperature, hard=self.distribution_hard)
else:
distribution = F.softmax(logit / self.distribution_temperature, dim=-1)
elif self.adapter_type in ["shrink", "league_shrink", "inter_league_shrink", "context_shrink"]:
if self.adapter_type in ["league_shrink", "inter_league_shrink"]:
representation = out
elif self.adapter_type in ["context_shrink"]:
representation = soft_out
lengths = (~padding).long().sum(-1)
with torch.no_grad():
......
......@@ -79,7 +79,7 @@ class TransformerS2EncoderLayer(nn.Module):
if self.use_se:
self.se_attn = SEAttention(self.embed_dim, 16)
self.use_s2_attn_norm = args.use_s2_attn_norm
self.use_s2_attn_norm = getattr(args, "use_s2_attn_norm", True)
if self.use_s2_attn_norm:
self.s2_norm = LayerNorm(self.embed_dim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论