/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 * 
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */


template <typename U, typename V>	
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {	
  return (a + b - 1) / b;	
}


template<int FS, int SB, int padding_l, typename scalar_t>
__inline__ __device__
void zeroSharedMem(scalar_t* data) {
  /*
    Given an array of length FS + SB, zero out the first padding_l and last
    (FS - padding_l) values in the array
  */

  int tid = threadIdx.x;

  if (FS < SB) {

    // zero all if we have enough threads in a block to do all of them
    if (tid < padding_l || tid > SB - FS + padding_l - 1) {
      data[tid] = scalar_t(0.0);
    }
  } else {

    // otherwise zero out one block at a time
    const int numIterations = divUp<int, int>(FS, SB);
    for (int i = 0; i < numIterations; i++) {
      int offset = i * SB;
      if (tid + offset < padding_l) {
        data[tid + offset] = scalar_t(0.0);
      } else if (tid + offset < FS) {
        data[SB + tid + offset] = scalar_t(0.0);
      }
    }
  }
}

template<typename scalar_t>
__inline__ __device__
scalar_t warpReduce(scalar_t data) {
  /*
    Reduce an array within each warp. After processing all values in warp will
    caontain the sum of all original values in that warp.

    data - pointer to data to reduce
  */
  data += __shfl_xor_sync(SHFL_MASK, data, 16);
  data += __shfl_xor_sync(SHFL_MASK, data, 8);
  data += __shfl_xor_sync(SHFL_MASK, data, 4);
  data += __shfl_xor_sync(SHFL_MASK, data, 2);
  data += __shfl_xor_sync(SHFL_MASK, data, 1);
  return data;
}

template<typename scalar_t>
__inline__ __device__
scalar_t blockReduce(scalar_t data) {
  /*
     Reduce an entire array on the block level. After processing, the
     first value in the array will contain the reduced sum.

     data - pointer to data to reduce
  */

  static __shared__ scalar_t warpSum[32];
  const int tid = threadIdx.x;
  int wid = tid / 32;
  int lane = tid % 32;

  __syncthreads();

  // reduce each warp then write to shared memory
  scalar_t sum = warpReduce(data);
  if (lane == 0) {
    warpSum[wid] = sum;
  }
  
  __syncthreads();

  scalar_t v;
  // perform final sum of partial warp sums
  if (tid < blockDim.x / 32) {
    v = warpSum[lane];
  } else {
    v = scalar_t(0.0);
  }

  if (wid == 0) {
    v = warpReduce(v);
  }
  __syncthreads();

  return v;
}

void checkCudaStatus(cudaError_t status, int lineNumber = -1) {

  if (status != cudaSuccess) {
    std::cout << cudaGetErrorString(status)
              << " at line " << lineNumber << std::endl;
    std::cout << "Exiting" << std::endl;
    exit(1);
  }
}

template<int FS, int SB, int padding_l, typename scalar_t>
__device__
void load_input_to_shared(const scalar_t* input, // global memory
                          int inputOffset, int sequenceLength,
                          int iteration, int numIterations,
                          bool no_prev, scalar_t* output /* shared memory */) {
  /*
    Load a block size of input into shared memory with
    right and left overhang of total size FS. If previously
    loaded memory, overlap will be shifted over to reduce
    global memory access

    input - pointer to start of channel sequence
    inputOffset - how far in the sequence to start loading
    sequenceLength - total length of sequence
    iteration - which block of sequence we are loading
    numIterations - total number of blocks to load
    no_prev - whether to load the whole block if the previous block
              wasn't loaded
    output - shared memory to write input to
  */

  const int tid = threadIdx.x;

  // Load the left "overhang" of input
  if (iteration > 0) {
    if (padding_l < SB) {

      // load all at once
      if (tid < padding_l) {
        output[tid] = (no_prev) ? input[inputOffset - padding_l + tid] : output[tid + SB];
      }
    } else {

      // load in chunks of size SB
      int numIterations = divUp<int, int>(padding_l, SB);
      for (int i = 0; i < numIterations; i++) {
        int offset = i * SB;
        if ((tid + offset) < padding_l) {
          output[tid + offset] = (no_prev) ? input[inputOffset - padding_l + tid + offset] : output[tid + offset + SB];
        }
      }
    }
  }

  // Load the right "overhang" of input
  if (iteration < (numIterations - 1)) {
    const int elementsLeft = sequenceLength - (iteration+1) * SB;

    if ((FS - padding_l) < SB) {

      // load all at once
      if (tid < (FS - padding_l)) {
          output[padding_l + SB + tid] = (tid < elementsLeft) ? input[inputOffset + SB + tid] : scalar_t(0.0);
      }
    } else {

      // load in chunks of size SB
      int numIterations = divUp<int, int>(FS - padding_l, SB);
      for (int i = 0; i < numIterations; i++) {
        int offset = i * SB;
        if ((tid + offset) < (FS - padding_l)) {
          output[padding_l + SB + tid + offset] = ((tid + offset) < elementsLeft) ? input[inputOffset + SB + tid + offset] : scalar_t(0.0);
        }
      }
    }
  }

  // We should also clear out the right "overhang"
  if (iteration == (numIterations - 1)) {
    if ((FS - padding_l) < SB) {

      // clear out all at once
      if (tid < (FS - padding_l)) {
          output[padding_l + SB + tid] = scalar_t(0.0);
      }
    } else {

      // clear in chunks of size SB
      int numIterations = divUp<int, int>(FS - padding_l, SB);
      for (int i = 0; i < numIterations; i++) {
        int offset = i * SB;
        if ((tid + offset) < (FS - padding_l)) {
          output[padding_l + SB + tid + offset] = scalar_t(0.0);
        }
      }
    }
  }
  output[tid + padding_l] = ((inputOffset + tid) < sequenceLength) ? input[inputOffset + tid] : scalar_t(0.0);
}