Commit db29e01d by xuchen

modify the implementation of lowercase and removing the punctuations

parent 8b44456c
......@@ -9,5 +9,6 @@ with open(in_file, "r", encoding="utf-8") as f:
line = line.strip().lower()
for w in string.punctuation:
line = line.replace(w, "")
line = line.replace(" ", "")
print(line)
......@@ -134,6 +134,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
mkdir -p ${data_dir}/data
for split in ${train_subset} ${valid_subset} ${test_subset}; do
{
cmd="cat ${org_data_dir}/${lang}/data/${split}.${src_lang}"
if [[ ${lc_rm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${org_data_dir}/${lang}/data/${split}.${src_lang}"
......@@ -154,7 +155,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo -e "\033[34mRun command: \n${cmd} \033[0m"
[[ $eval -eq 1 ]] && eval ${cmd}
}&
done
wait
cmd="python ${root_dir}/fairseq_cli/preprocess.py
--source-lang ${src_lang} --target-lang ${tgt_lang}
......
import sys
import string
in_file = sys.argv[1]
with open(in_file, "r", encoding="utf-8") as f:
for line in f.readlines():
line = line.strip().lower()
for w in string.punctuation:
line = line.replace(w, "")
line = line.replace(" ", "")
print(line)
......@@ -37,6 +37,7 @@ task=translation
vocab_type=unigram
vocab_size=32000
share_dict=1
lc_rm=1
use_specific_dict=0
specific_prefix=st_share10k
......@@ -62,7 +63,7 @@ train_config=train.yaml
# training setting
fp16=1
max_tokens=4096
step_valid=0
step_valid=1
bleu_valid=0
# decoding setting
......@@ -116,6 +117,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--splits ${train_subset},${valid_subset},${test_subset}
--src-lang ${src_lang}
--tgt-lang ${tgt_lang}
--lowercase-src
--rm-punc-src
--vocab-type ${vocab_type}
--vocab-size ${vocab_size}"
if [[ $share_dict -eq 1 ]]; then
......@@ -133,10 +136,13 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
mkdir -p ${data_dir}/data
for split in ${train_subset} ${valid_subset} ${test_subset}; do
{
cmd="spm_encode
--model ${data_dir}/${src_vocab_prefix}.model
cmd="cat ${org_data_dir}/${lang}/data/${split}.${src_lang}"
if [[ ${lc_rm} -eq 1 ]]; then
cmd="python local/lower_rm.py ${org_data_dir}/${lang}/data/${split}.${src_lang}"
fi
cmd="${cmd}
| spm_encode --model ${data_dir}/${src_vocab_prefix}.model
--output_format=piece
< ${org_data_dir}/${lang}/data/${split}.${src_lang}
> ${data_dir}/data/${split}.${src_lang}"
echo -e "\033[34mRun command: \n${cmd} \033[0m"
......
......@@ -10,11 +10,10 @@ import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Tuple
import string
import pandas as pd
from examples.speech_to_text.data_utils import (
gen_vocab,
save_df_to_tsv,
)
from torch.utils.data import Dataset
from tqdm import tqdm
......@@ -75,6 +74,13 @@ def process(args):
dataset = MTData(args.data_root, src_lang, tgt_lang, split)
for src_text, tgt_text in tqdm(dataset):
if args.lowercase_src:
src_text = src_text.lower()
if args.rm_punc_src:
for w in string.punctuation:
src_text = src_text.replace(w, "")
src_text = src_text.replace(" ", "")
manifest["src_text"].append(src_text)
manifest["tgt_text"].append(tgt_text)
......@@ -154,6 +160,8 @@ def main():
parser.add_argument("--vocab-size", default=10000, type=int)
parser.add_argument("--size", default=-1, type=int)
parser.add_argument("--splits", default="train,dev,test", type=str)
parser.add_argument("--lowercase-src", action="store_true", help="lowercase the source text")
parser.add_argument("--rm-punc-src", action="store_true", help="remove the punctuation of the source text")
parser.add_argument("--src-lang", required=True, type=str)
parser.add_argument("--tgt-lang", required=True, type=str)
parser.add_argument("--share", action="store_true", help="share the source and target vocabulary")
......
......@@ -246,6 +246,7 @@ def process(args):
if args.rm_punc_src:
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
src_utt = src_utt.replace(" ", "")
manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt)
if args.task == "st" and args.add_src:
manifest["src_text"].append(src_utt)
......@@ -285,7 +286,9 @@ def process(args):
if args.lowercase_src:
src_utt = src_utt.lower()
if args.rm_punc_src:
src_utt = src_utt.translate(None, string.punctuation)
for w in string.punctuation:
src_utt = src_utt.replace(w, "")
src_utt = src_utt.replace(" ", "")
train_text.append(src_utt)
train_text.append(tgt_utt)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论