#!/usr/bin/env python
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Produces the training and dev data for --problem into --data_dir.

generator.py produces sharded and shuffled TFRecord files of tensorflow.Example
protocol buffers for a variety of datasets registered in this file.

All datasets are registered in _SUPPORTED_PROBLEM_GENERATORS. Each entry maps a
string name (selectable on the command-line with --problem) to a function that
takes 2 arguments - input_directory and mode (one of "train" or "dev") - and
yields for each training example a dictionary mapping string feature names to
lists of {string, int, float}. The generator will be run once for each mode.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random
import tempfile

# Dependency imports

import numpy as np

from tensor2tensor.data_generators import algorithmic
from tensor2tensor.data_generators import algorithmic_math
from tensor2tensor.data_generators import all_problems  # pylint: disable=unused-import
from tensor2tensor.data_generators import audio
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import image
from tensor2tensor.data_generators import lm1b
from tensor2tensor.data_generators import ptb
from tensor2tensor.data_generators import snli
from tensor2tensor.data_generators import wiki
from tensor2tensor.data_generators import wmt
from tensor2tensor.data_generators import wsj_parsing
from tensor2tensor.utils import registry

import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("data_dir", "", "Data directory.")
flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen",
                    "Temporary storage directory.")
flags.DEFINE_string("problem", "",
                    "The name of the problem to generate data for.")
flags.DEFINE_integer("num_shards", 10, "How many shards to use.")
flags.DEFINE_integer("max_cases", 0,
                     "Maximum number of cases to generate (unbounded if 0).")
flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")

# Mapping from problems that we can generate data for to their generators.
# pylint: disable=g-long-lambda
_SUPPORTED_PROBLEM_GENERATORS = {
    "algorithmic_shift_decimal40": (
        lambda: algorithmic.shift_generator(20, 10, 40, 100000),
        lambda: algorithmic.shift_generator(20, 10, 80, 10000)),
    "algorithmic_reverse_binary40": (
        lambda: algorithmic.reverse_generator(2, 40, 100000),
        lambda: algorithmic.reverse_generator(2, 400, 10000)),
    "algorithmic_reverse_decimal40": (
        lambda: algorithmic.reverse_generator(10, 40, 100000),
        lambda: algorithmic.reverse_generator(10, 400, 10000)),
    "algorithmic_addition_binary40": (
        lambda: algorithmic.addition_generator(2, 40, 100000),
        lambda: algorithmic.addition_generator(2, 400, 10000)),
    "algorithmic_addition_decimal40": (
        lambda: algorithmic.addition_generator(10, 40, 100000),
        lambda: algorithmic.addition_generator(10, 400, 10000)),
    "algorithmic_multiplication_binary40": (
        lambda: algorithmic.multiplication_generator(2, 40, 100000),
        lambda: algorithmic.multiplication_generator(2, 400, 10000)),
    "algorithmic_multiplication_decimal40": (
        lambda: algorithmic.multiplication_generator(10, 40, 100000),
        lambda: algorithmic.multiplication_generator(10, 400, 10000)),
    "algorithmic_reverse_nlplike_decimal8K": (
        lambda: algorithmic.reverse_generator_nlplike(8000, 70, 100000,
                                                      10, 1.300),
        lambda: algorithmic.reverse_generator_nlplike(8000, 70, 10000,
                                                      10, 1.300)),
    "algorithmic_reverse_nlplike_decimal32K": (
        lambda: algorithmic.reverse_generator_nlplike(32000, 70, 100000,
                                                      10, 1.050),
        lambda: algorithmic.reverse_generator_nlplike(32000, 70, 10000,
                                                      10, 1.050)),
    "algorithmic_algebra_inverse": (
        lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
        lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
    "ice_parsing_tokens": (
        lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
                                                   True, "ice", 2**13, 2**8),
        lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
                                                   False, "ice", 2**13, 2**8)),
    "ice_parsing_characters": (
        lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, True),
        lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, False)),
    "wmt_parsing_tokens_8k": (
        lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, True, 2**13),
        lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, False, 2**13)),
    "wsj_parsing_tokens_16k": (
        lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, True,
                                                    2**14, 2**9),
        lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
                                                    2**14, 2**9)),
    "wmt_enfr_characters": (
        lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, True),
        lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, False)),
    "wmt_enfr_tokens_8k": (
        lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**13),
        lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**13)
    ),
    "wmt_enfr_tokens_32k": (
        lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
        lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
    ),
    "wmt_ende_characters": (
        lambda: wmt.ende_character_generator(FLAGS.tmp_dir, True),
        lambda: wmt.ende_character_generator(FLAGS.tmp_dir, False)),
    "wmt_ende_bpe32k": (
        lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, True),
        lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, False)),
    "wmt_zhen_tokens_32k": (
        lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, True,
                                                   2**15, 2**15),
        lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, False,
                                                   2**15, 2**15)
    ),
    "lm1b_32k": (
        lambda: lm1b.generator(FLAGS.tmp_dir, True),
        lambda: lm1b.generator(FLAGS.tmp_dir, False)
    ),
    "wiki_32k": (
        lambda: wiki.generator(FLAGS.tmp_dir, True),
        1000
    ),
    "image_mnist_tune": (
        lambda: image.mnist_generator(FLAGS.tmp_dir, True, 55000),
        lambda: image.mnist_generator(FLAGS.tmp_dir, True, 5000, 55000)),
    "image_mnist_test": (
        lambda: image.mnist_generator(FLAGS.tmp_dir, True, 60000),
        lambda: image.mnist_generator(FLAGS.tmp_dir, False, 10000)),
    "image_cifar10_tune": (
        lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 48000),
        lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 2000, 48000)),
    "image_cifar10_test": (
        lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 50000),
        lambda: image.cifar10_generator(FLAGS.tmp_dir, False, 10000)),
    "image_mscoco_characters_test": (
        lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 80000),
        lambda: image.mscoco_generator(FLAGS.tmp_dir, False, 40000)),
    "image_mscoco_tokens_8k_test": (
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            True,
            80000,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13),
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            False,
            40000,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13)),
    "image_mscoco_tokens_32k_test": (
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            True,
            80000,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15),
        lambda: image.mscoco_generator(
            FLAGS.tmp_dir,
            False,
            40000,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15)),
    "snli_32k": (
        lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
        lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
    ),
    "audio_timit_characters_tune": (
        lambda: audio.timit_generator(FLAGS.tmp_dir, True, 1374),
        lambda: audio.timit_generator(FLAGS.tmp_dir, True, 344, 1374)),
    "audio_timit_characters_test": (
        lambda: audio.timit_generator(FLAGS.tmp_dir, True, 1718),
        lambda: audio.timit_generator(FLAGS.tmp_dir, False, 626)),
    "audio_timit_tokens_8k_tune": (
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            True,
            1374,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13),
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            True,
            344,
            1374,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13)),
    "audio_timit_tokens_8k_test": (
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            True,
            1718,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13),
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            False,
            626,
            vocab_filename="tokens.vocab.%d" % 2**13,
            vocab_size=2**13)),
    "audio_timit_tokens_32k_tune": (
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            True,
            1374,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15),
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            True,
            344,
            1374,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15)),
    "audio_timit_tokens_32k_test": (
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            True,
            1718,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15),
        lambda: audio.timit_generator(
            FLAGS.tmp_dir,
            False,
            626,
            vocab_filename="tokens.vocab.%d" % 2**15,
            vocab_size=2**15)),
    "lmptb_10k": (
        lambda: ptb.train_generator(
            FLAGS.tmp_dir,
            FLAGS.data_dir,
            False),
        ptb.valid_generator),
}

# pylint: enable=g-long-lambda


def set_random_seed():
  """Set the random seed from flag everywhere."""
  tf.set_random_seed(FLAGS.random_seed)
  random.seed(FLAGS.random_seed)
  np.random.seed(FLAGS.random_seed)


def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  # Calculate the list of problems to generate.
  problems = sorted(
      list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())
  if FLAGS.problem and FLAGS.problem[-1] == "*":
    problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
  elif FLAGS.problem:
    problems = [p for p in problems if p == FLAGS.problem]
  else:
    problems = []

  # Remove TIMIT if paths are not given.
  if not FLAGS.timit_paths:
    problems = [p for p in problems if "timit" not in p]
  # Remove parsing if paths are not given.
  if not FLAGS.parsing_path:
    problems = [p for p in problems if "parsing" not in p]
  # Remove en-de BPE if paths are not given.
  if not FLAGS.ende_bpe_path:
    problems = [p for p in problems if "ende_bpe" not in p]

  if not problems:
    problems_str = "\n  * ".join(
        sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()))
    error_msg = ("You must specify one of the supported problems to "
                 "generate data for:\n  * " + problems_str + "\n")
    error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
                  "--timit_paths, --ende_bpe_path and --parsing_path.")
    raise ValueError(error_msg)

  if not FLAGS.data_dir:
    FLAGS.data_dir = tempfile.gettempdir()
    tf.logging.warning("It is strongly recommended to specify --data_dir. "
                       "Data will be written to default data_dir=%s.",
                       FLAGS.data_dir)

  tf.logging.info("Generating problems:\n  * %s\n" % "\n  * ".join(problems))
  for problem in problems:
    set_random_seed()

    if problem in _SUPPORTED_PROBLEM_GENERATORS:
      generate_data_for_problem(problem)
    else:
      generate_data_for_registered_problem(problem)


def generate_data_for_problem(problem):
  """Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
  training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]

  if isinstance(dev_gen, int):
    # The dev set and test sets are generated as extra shards using the
    # training generator.  The integer specifies the number of training
    # shards.  FLAGS.num_shards is ignored.
    num_training_shards = dev_gen
    tf.logging.info("Generating data for %s.", problem)
    all_output_files = generator_utils.combined_data_filenames(
        problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
        num_training_shards)
    generator_utils.generate_files(training_gen(), all_output_files,
                                   FLAGS.max_cases)
  else:
    # usual case - train data and dev data are generated using separate
    # generators.
    tf.logging.info("Generating training data for %s.", problem)
    train_output_files = generator_utils.train_data_filenames(
        problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
        FLAGS.num_shards)
    generator_utils.generate_files(training_gen(), train_output_files,
                                   FLAGS.max_cases)
    tf.logging.info("Generating development data for %s.", problem)
    dev_shards = 10 if "coco" in problem else 1
    dev_output_files = generator_utils.dev_data_filenames(
        problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards)
    generator_utils.generate_files(dev_gen(), dev_output_files)
    all_output_files = train_output_files + dev_output_files

  tf.logging.info("Shuffling data...")
  generator_utils.shuffle_dataset(all_output_files)


def generate_data_for_registered_problem(problem_name):
  problem = registry.problem(problem_name)
  problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)


if __name__ == "__main__":
  tf.app.run()
