/* Copyright (c) Microsoft Corporation. Licensed under the MIT License. */ /* Kernel implementation for blocking repeated n-grams. */ #include <cuda.h> #include <cuda_runtime.h> #include <math.h> #include <torch/extension.h> #include <vector> // Ban repeated ngrams of length = 'no_repeat_ngram_size' __global__ void banRepeatedTokens(long* __restrict__ tokens, float* __restrict__ lprobs, int max_predict_len, int vocab_size, int no_repeat_ngram_size) { auto row = blockIdx.x; auto col = threadIdx.x; auto start = row * (max_predict_len) + col; // Each thread compares ngram starting from // thread index with final ngram starting from // step - no_repeat_ngram_size +2 auto check_start_pos = blockDim.x; auto lprob_start = row * vocab_size; bool is_banned = true; extern __shared__ long tokens_shm[]; tokens_shm[col] = tokens[start]; if (col == blockDim.x - 1) { for (int i=1; i<no_repeat_ngram_size; i++){ if (col+i < max_predict_len){ tokens_shm[col + i] = tokens[start + i]; } } } __syncthreads(); for (int k = 0; k < no_repeat_ngram_size - 1; k++) { if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) { is_banned = false; } } if (is_banned == true) { auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1]; lprobs[lprob_start + token_to_be_banned] = -INFINITY; } } // Allocate blocks and threads based on // batch size and sequence length and launch // kernel torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor tokens, torch::Tensor lprobs, int bsz, int step, int beam_size, int no_repeat_ngram_size) { int threads = step - no_repeat_ngram_size + 2; if (threads <= 0) return lprobs; int max_predict_len = tokens.size(1); int vocab_size = lprobs.size(1); auto token_ptr = tokens.data_ptr<long>(); auto lprob_ptr = lprobs.data_ptr<float>(); int blocks = bsz * beam_size; int shared_mem_size = (step + 1) * sizeof(long); // Launching N blocks where N is number of samples in a batch (beams*bsz) // Launching T threads where T is number of previous ngrams in a sample // Allocating shared mem per block for fastser access of input tokens since // each token will be accessed N times to compare with current Ngram where // N is Ngram size. banRepeatedTokens<<<blocks, threads, shared_mem_size>>>( token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size); return lprobs; }