import numpy
import argparse
def write_vocab():
  print('write vocab file ...')
  vocab_file = open('online.vocab', 'w')

  def _write_to_file(vocab, online_vocab):
    with open(vocab, 'r', encoding='utf-8') as f:
      line_num = 0
      for line in f:
        line_num += 1
        if line_num == 1:
          # print(unicode(u'{} {}'.format(line_num-1, u'<PAD>')) ,file=online_vocab)
          print(u'{} {}'.format(line_num - 1, u'<PAD>'), file=online_vocab)
          continue
        if line_num == 2:
          # print(unicode(u'{} {}'.format(line_num-1, u'<EOS>')), file=online_vocab)
          print(u'{} {}'.format(line_num - 1, u'<EOS>'), file=online_vocab)
          continue
        token = line.strip()[1:-1]
        # print(unicode(u'{} {}'.format(line_num-1, token)), file=online_vocab)
        print(u'{} {}'.format(line_num - 1, token), file=online_vocab)

  _write_to_file(src_vocab_name, vocab_file)
  print(u"================================", file=vocab_file)
  _write_to_file(tgt_vocab_name, vocab_file)
  vocab_file.close()


def convert_dict(vocab_input, vocab_output):
  specials = ['<PAD>', '<UNK>', '<EOS>', '<LUA>']
  with open(vocab_input, 'r') as infile, \
    open(vocab_output + '.src', 'w') as src_ofile, \
    open(vocab_output + '.tgt', 'w') as tgt_ofile:
    is_src = True
    for line in infile:
      # explicitly strip '\n', by default, strip() will delete more characters, which may be the token
      line = line.strip('\n')
      if line.startswith('====='):
        assert is_src
        is_src = False
        continue

      # explicitly split by ' ', or it may be wrong to split due to some special symbols
      items = line.split(' ')
      assert len(items) == 2
      idx, token = items
      if token in specials:
        continue
      print('%s %d' % (token, 1), file=src_ofile if is_src else tgt_ofile)

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('-src_vocab', required=True, help="source vocabulary path")
  parser.add_argument('-tgt_vocab', required=True, help="target vocabulary path")
  parser.add_argument('-vocab_output', required=True, help='fairseq dict path')
  args = parser.parse_args()

  src_vocab_name = args.src_vocab
  tgt_vocab_name = args.tgt_vocab
  write_vocab()
  convert_dict('online.vocab', args.vocab_output)