Commit d4a68f26 by xuchen

support the multiple config files

parent 41ef5b4e
encoder-attention-type: relative
decoder-attention-type: relative
max-encoder-relative-length: 100
max-decoder-relative-length: 20
\ No newline at end of file
......@@ -242,10 +242,10 @@ class RelativeMultiheadAttention(MultiheadAttention):
)
if self.k_only:
relation_keys = F.embedding(relative_positions_matrix.long().cuda(), self.relative_position_keys)
relation_keys = F.embedding(relative_positions_matrix.long().to(k.device), self.relative_position_keys)
else:
relation_keys = F.embedding(relative_positions_matrix.long().cuda(), self.relative_position_keys)
relation_values = F.embedding(relative_positions_matrix.long().cuda(), self.relative_position_values)
relation_keys = F.embedding(relative_positions_matrix.long().to(k.device), self.relative_position_keys)
relation_values = F.embedding(relative_positions_matrix.long().to(k.device), self.relative_position_values)
attn_weights = self._relative_attention_inner(q, k, relation_keys, transpose=True)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
......
......@@ -96,6 +96,16 @@ def parse_args_and_arch(
is_config_file=True,
help="Configuration YAML filename (for training)",
)
parser.add_argument(
"--train-config1",
is_config_file=True,
help="Configuration YAML filename (for training)",
)
parser.add_argument(
"--train-config2",
is_config_file=True,
help="Configuration YAML filename (for training)",
)
if suppress_defaults:
# Parse args without any default values. This requires us to parse
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论