# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


def gen_forward():

    kernels = [3, 5, 7, 15, 31, 63, 127, 255]
    seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]]

    head = """
/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include "lightconv_cuda.cuh"

std::vector<at::Tensor> lightconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l) {

    at::DeviceGuard g(input.device());
    const auto minibatch = input.size(0);
    const auto numFeatures = input.size(1);
    const auto sequenceLength = input.size(2);

    const auto numHeads = filters.size(0);
    const auto filterSize = filters.size(1);

    const auto numFiltersInBlock = numFeatures / numHeads;

    const dim3 blocks(minibatch, numFeatures);

    auto output = at::zeros_like(input);
    auto stream = at::cuda::getCurrentCUDAStream();
"""

    sequence_if = """
    if (sequenceLength <= {seq}) {{
        switch(filterSize) {{
"""

    case_k = """
            case {k}:
"""

    main_block = """
                if (padding_l == {pad}) {{
                    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_forward", ([&] {{
                        lightconv_forward_kernel<{k}, {b_size}, {pad}, scalar_t>
                        <<<blocks, {b_size}, 0, stream>>>(
                                input.data<scalar_t>(),
                                filters.data<scalar_t>(),
                                minibatch,
                                sequenceLength,
                                numFeatures,
                                numFiltersInBlock,
                                output.data<scalar_t>());
                    }}));
                }} else
"""

    bad_padding = """
                {
                    std::cout << "WARNING: Unsupported padding size - skipping forward pass" << std::endl;
                }
                break;
"""

    bad_filter = """
            default:
                std::cout << "WARNING: Unsupported filter length passed - skipping forward pass" << std::endl;
        }
"""

    con_else = """
    } else
"""

    final_else = """
    {
        switch(filterSize) {
"""

    final_return = """
    }

    return {output};
}
"""

    with open("lightconv_cuda_forward.cu", "w") as forward:
        forward.write(head)
        for seq in seqs:
            forward.write(sequence_if.format(seq=seq))
            for k in kernels:
                forward.write(case_k.format(k=k))
                for pad in [k // 2, k - 1]:
                    forward.write(main_block.format(k=k, b_size=seq, pad=pad))
                forward.write(bad_padding)
            forward.write(bad_filter)
            forward.write(con_else)

        forward.write(final_else)
        for k in kernels:
            forward.write(case_k.format(k=k))
            for pad in [k // 2, k - 1]:
                forward.write(main_block.format(k=k, b_size=seq, pad=pad))
            forward.write(bad_padding)
        forward.write(bad_filter)
        forward.write(final_return)


def gen_backward():

    head = """
/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include "lightconv_cuda.cuh"

std::vector<at::Tensor> lightconv_cuda_backward(
        at::Tensor gradOutput,
        int padding_l,
        at::Tensor input,
        at::Tensor filters) {

    // gradWrtInput
    const int minibatch = input.size(0);
    const int numFeatures = input.size(1);
    const int sequenceLength = input.size(2);

    const int numHeads = filters.size(0);
    const int filterSize = filters.size(1);

    const dim3 gradBlocks(minibatch, numFeatures);
    const dim3 weightGradFirstpassShortBlocks(minibatch, numHeads);
    const dim3 weightGradSecondpassBlocks(numHeads, filterSize);

    const int numFiltersInBlock = numFeatures / numHeads;

    auto gradInput = at::zeros_like(input);
    auto gradFilters = at::zeros_like(filters);

    at::DeviceGuard g(input.device());
    auto stream = at::cuda::getCurrentCUDAStream();

    switch(filterSize) {
"""

    sequence_if = """
            if (sequenceLength <= {seq}) {{
"""

    case_k = """
        case {k}:
"""

    main_block = """
                if (padding_l == {p}) {{
                    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_backward", ([&] {{
                        lightconv_grad_wrt_input_kernel<{k}, {b_size}, {p}, scalar_t>
                        <<<gradBlocks, {b_size}, 0, stream>>>(
                                gradOutput.data<scalar_t>(),
                                filters.data<scalar_t>(),
                                minibatch,
                                sequenceLength,
                                numFeatures,
                                numFiltersInBlock,
                                gradInput.data<scalar_t>());

"""

    weight_grad_short = """
                        at::Tensor tempSumGradFilters = at::zeros({{minibatch, numHeads, filterSize}}, input.options().dtype(at::kFloat));
                        lightconv_grad_wrt_weights_firstpass_short_kernel<{k}, {b_size}, {p}, scalar_t>
                        <<<weightGradFirstpassShortBlocks, {b_size}, 0, stream>>>(
                                input.data<scalar_t>(),
                                gradOutput.data<scalar_t>(),
                                minibatch,
                                sequenceLength,
                                numFeatures,
                                numFiltersInBlock,
                                numHeads,
                                tempSumGradFilters.data<float>()
                        );

                        lightconv_grad_wrt_weights_secondpass_short_kernel<{k}, {b_size}, scalar_t>
                        <<<weightGradSecondpassBlocks, {b_size}, 0, stream>>>(
                                tempSumGradFilters.data<float>(),
                                minibatch,
                                numFiltersInBlock,
                                gradFilters.data<scalar_t>()
                        );
                    }}));
                }} else
"""

    weight_grad = """
                        at::Tensor tempSumGradFilters = at::zeros({{minibatch, numFeatures, filterSize}}, input.options().dtype(at::kFloat));
                        lightconv_grad_wrt_weights_firstpass_kernel<{k}, {b_size}, {p}, scalar_t>
                        <<<gradBlocks, {b_size}, 0, stream>>>(
                                input.data<scalar_t>(),
                                gradOutput.data<scalar_t>(),
                                minibatch,
                                sequenceLength,
                                numFeatures,
                                numFiltersInBlock,
                                tempSumGradFilters.data<float>()
                        );

                        lightconv_grad_wrt_weights_secondpass_kernel<{k}, {b_size}, scalar_t>
                        <<<weightGradSecondpassBlocks, {b_size}, 0, stream>>>(
                                tempSumGradFilters.data<float>(),
                                minibatch,
                                numFiltersInBlock,
                                gradFilters.data<scalar_t>()
                        );
                    }}));
                }} else
"""

    bad_padding = """
                {
                    std::cout << "WARNING: Unsupported padding size - skipping backward pass" << std::endl;
                }
"""

    breakout = """
                break;
"""

    bad_filter = """
        default:
            std::cout << "WARNING: Unsupported filter length passed - skipping backward pass" << std::endl;
"""

    con_else = """
            } else
"""

    final_else = """
    {
        switch(filterSize) {
"""

    last_return = """
    }
    return {gradInput, gradFilters};
}
"""

    kernels = [3, 5, 7, 15, 31, 63, 127, 255]
    seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]]
    thresh = [32, 32, 64, 128, 256, -1, -1, -1]
    max_mem = [-1, -1, -1, -1, -1, 192, 96, 64]

    with open("lightconv_cuda_backward.cu", "w") as backward:
        backward.write(head)
        for (k, t, mem) in zip(kernels, thresh, max_mem):
            backward.write(case_k.format(k=k))
            for seq in seqs:
                if (t == -1 or seq <= t) and (mem == -1 or seq < mem):
                    backward.write(sequence_if.format(seq=seq))
                    for p in [k // 2, k - 1]:
                        backward.write(main_block.format(k=k, b_size=seq, p=p))
                        backward.write(weight_grad_short.format(k=k, b_size=seq, p=p))
                    backward.write(bad_padding)
                else:
                    for p in [k // 2, k - 1]:
                        backward.write(main_block.format(k=k, b_size=32, p=p))
                        backward.write(weight_grad.format(k=k, b_size=32, p=p))
                    backward.write(bad_padding)
                    backward.write(breakout)
                    break
                backward.write(con_else)
        backward.write(bad_filter)
        backward.write(last_return)


if __name__ == "__main__":
    gen_forward()
    gen_backward()