Commit b97a0cfd by libei

add vocab convert script to convert tensor2tensor vocab to fairseq vocab

parent 35cd255b
import numpy
import argparse
def write_vocab():
print('write vocab file ...')
vocab_file = open('online.vocab', 'w')
def _write_to_file(vocab, online_vocab):
with open(vocab, 'r', encoding='utf-8') as f:
line_num = 0
for line in f:
line_num += 1
if line_num == 1:
# print(unicode(u'{} {}'.format(line_num-1, u'<PAD>')) ,file=online_vocab)
print(u'{} {}'.format(line_num - 1, u'<PAD>'), file=online_vocab)
continue
if line_num == 2:
# print(unicode(u'{} {}'.format(line_num-1, u'<EOS>')), file=online_vocab)
print(u'{} {}'.format(line_num - 1, u'<EOS>'), file=online_vocab)
continue
token = line.strip()[1:-1]
# print(unicode(u'{} {}'.format(line_num-1, token)), file=online_vocab)
print(u'{} {}'.format(line_num - 1, token), file=online_vocab)
_write_to_file(src_vocab_name, vocab_file)
print(u"================================", file=vocab_file)
_write_to_file(tgt_vocab_name, vocab_file)
vocab_file.close()
def convert_dict(vocab_input, vocab_output):
specials = ['<PAD>', '<UNK>', '<EOS>', '<LUA>']
with open(vocab_input, 'r') as infile, \
open(vocab_output + '.src', 'w') as src_ofile, \
open(vocab_output + '.tgt', 'w') as tgt_ofile:
is_src = True
for line in infile:
# explicitly strip '\n', by default, strip() will delete more characters, which may be the token
line = line.strip('\n')
if line.startswith('====='):
assert is_src
is_src = False
continue
# explicitly split by ' ', or it may be wrong to split due to some special symbols
items = line.split(' ')
assert len(items) == 2
idx, token = items
if token in specials:
continue
print('%s %d' % (token, 1), file=src_ofile if is_src else tgt_ofile)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-src_vocab', required=True, help="source vocabulary path")
parser.add_argument('-tgt_vocab', required=True, help="target vocabulary path")
parser.add_argument('-vocab_output', required=True, help='fairseq dict path')
args = parser.parse_args()
src_vocab_name = args.src_vocab
tgt_vocab_name = args.tgt_vocab
write_vocab()
convert_dict('online.vocab', args.vocab_output)
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论