dedup_data.py 3.06 KB
Newer Older
xuchen committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
import argparse
from collections import namedtuple
import os

DATADIR = "/path/to/train_data"
DEDUP_FROM_DIR = "/path/to/eval/data"
OUTPUT_DIR = "/path/to/output/data"


def main(args):
    languages = set()
    for language_directory in os.listdir(DATADIR):
        if "_" in language_directory:
            src, tgt = language_directory.split("_")
            languages.add(LanguagePair(src=src, tgt=tgt))

    data = existing_data()
    train_languages = sorted(languages)
    for language_pair in train_languages[args.start_index:args.start_index + args.size]:
        print(language_pair)
        dedup(language_pair, data)


LanguagePair = namedtuple("LanguagePair", ["src", "tgt"])


def existing_data():
    data = set()
    for file in os.listdir(DEDUP_FROM_DIR):
        with open(os.path.join(DEDUP_FROM_DIR, file)) as f:
            data |= set(f.readlines())
    return data
 
def dedup(language_pair, data, verbose=True, output=True):
    train_filenames = LanguagePair(
            src=f"{DATADIR}/{language_pair.src}_{language_pair.tgt}/train.{language_pair.src}",
            tgt=f"{DATADIR}/{language_pair.src}_{language_pair.tgt}/train.{language_pair.tgt}",
        )

    output_filenames = LanguagePair(
        src=f"{OUTPUT_DIR}/train.dedup.{language_pair.src}-{language_pair.tgt}.{language_pair.src}",
        tgt=f"{OUTPUT_DIR}/train.dedup.{language_pair.src}-{language_pair.tgt}.{language_pair.tgt}"
    )

    # If output exists, skip this pair. It has already been done.
    if (os.path.exists(output_filenames.src) and
        os.path.exists(output_filenames.tgt)):
        if verbose:
            print(f"{language_pair.src}-{language_pair.tgt} already done.")
        return

    if verbose:
        print(f"{language_pair.src}-{language_pair.tgt} ready, will check dups.")

    # If there is no output, no need to actually do the loop.
    if not output:
        return

    if os.path.exists(train_filenames.src) and os.path.exists(train_filenames.tgt):
        with open(train_filenames.src) as f:
            train_source = f.readlines()

        with open(train_filenames.tgt) as f:
            train_target = f.readlines()

        # do dedup
        new_train_source = []
        new_train_target = []
        for i, train_line in enumerate(train_source):
            if train_line not in data and train_target[i] not in data:
                new_train_source.append(train_line)
                new_train_target.append(train_target[i])

        assert len(train_source) == len(train_target)
        assert len(new_train_source) == len(new_train_target)
        assert len(new_train_source) <= len(train_source)

        with open(output_filenames.src, "w") as o:
            for line in new_train_source:
                o.write(line)

        with open(output_filenames.tgt, "w") as o:
            for line in new_train_target:
                o.write(line)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-s", "--start-index", required=True, type=int)
    parser.add_argument("-n", "--size", required=True, type=int)
    main(parser.parse_args())