Commit acc044b2 by 张裕浩

修复softmax函数BUG,添加Topk优化函数(暂未测试)

parent ece0dc78
...@@ -363,6 +363,139 @@ void KernelTopK2(T * input, int stride, int strideNum, int blockNum, int k, T mi ...@@ -363,6 +363,139 @@ void KernelTopK2(T * input, int stride, int strideNum, int blockNum, int k, T mi
} }
/* /*
get the top-k items
>> input - the input data array
>> stride - number of items we go over when we move to the next item along a given dimension
>> strideNum - size of the given dimension
>> blockNum - number of data blocks
>> k - as it is
>> minValue - min value of an item
>> output - the output data array
>> index - the output index array
*/
template<class T> __global__
void KernelTopK3(T * input, int stride, int strideNum, int blockNum, int k, T minValue, T * output, int * index)
{
__shared__ CudaHeapNode<T> heapData[(SHARED_MEMORY_SIZE - 1024 * sizeof(T)) / sizeof(CudaHeapNode<T>)];
__shared__ T eachHeapMaxValue[1024];
/*optimization k size the parameter must more than half of k*/
int parameter = 0;
/* worker index */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* index of the data arry along the given dimension */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if (i >= strideNum || i >= blockDim.x || j >= stride * blockNum)
return;
int blockIndex = j / stride;
int offsetInBlock = j % stride;
T * d = input + stride * strideNum * blockIndex + offsetInBlock;
CudaXHeap<MIN_HEAP, T> heap(k - parameter, heapData + k * (threadIdx.y * blockDim.x + threadIdx.x));
__syncthreads();
/* go over the data array and build the heap */
int indexOffset = blockDim.x;
int dataOffset = stride * blockDim.x;
if (i + (heap.size - 1) * indexOffset < strideNum) {
int p = i;
int q = i * stride;
for (int m = 0; m < heap.size; m++) {
heap.Push(p, d[q]);
p += indexOffset;
q += dataOffset;
}
for (; p < strideNum; p += indexOffset, q += dataOffset) {
T v = d[q];
if (v > heap.topValue) {
heap.ReplaceTop(p, v);
}
}
}
else {
for (int p = i, q = i * stride; p < strideNum; p += indexOffset, q += dataOffset) {
heap.Push(p, d[q]);
}
}
/* fill the heap if no enough items are processed */
while (heap.count < heap.size) {
heap.Push(-1, minValue);
}
__syncthreads();
/*to merge the heap use another way*/
T minData = minValue;
int heapLimit = heap.count / 2;
if (heapLimit % 2 == 0 && heapLimit != 0) heapLimit -= 1;
for (int counter = heap.count - 1; counter >= heapLimit; --counter)
{
if (minData < heap.items[counter].value)
minData = heap.items[counter].value;
}
eachHeapMaxValue[threadIdx.y * blockDim.x + threadIdx.x] = minData;
//need more optimation
if (i == 0)
{
int threadLimit = (threadIdx.y + 1) * blockDim.x;
CudaXHeap<MIN_HEAP, T> chooseHeap(k, heapData + k * ((blockDim.x * blockDim.y) + threadIdx.y));
int counter = threadIdx.y * blockDim.x;
for (; counter < threadIdx.y * blockDim.x + k; ++counter)
{
chooseHeap.Push(counter, eachHeapMaxValue[counter]);
}
for (; counter < threadLimit; ++counter)
{
if (eachHeapMaxValue[counter]>chooseHeap.items[0].value)
{
chooseHeap.ReplaceTop(counter, eachHeapMaxValue[counter]);
}
}
CudaXHeap<MIN_HEAP, T> ansHeapData(k, k - parameter, heapData + k * chooseHeap.items[0].index);
int miss = parameter;
for (counter = 1; counter < k; ++counter)
{
//printf("%f %d\n",chooseHeap.items[0].value,chooseHeap.items[0].index);
chooseHeap.items[0] = chooseHeap.items[chooseHeap.count - 1];
chooseHeap.count--;
chooseHeap.Down(0);
CudaHeapNode<T> * cmpHeapData = heapData + k * (chooseHeap.items[0].index);
int cmpHeapLimit = 0;
if (counter + heapLimit <= k - parameter)
{
cmpHeapLimit = heapLimit;
}
//take the max data from the minHeap,so start search from the leaf node
for (int iterator = k - 1 - parameter; iterator >= cmpHeapLimit; --iterator)
{
if (miss > 0)
{
ansHeapData.Push(cmpHeapData[iterator].index, cmpHeapData[iterator].value);
miss--;
}
else if (ansHeapData.items[0].value < cmpHeapData[iterator].value)
{
ansHeapData.ReplaceTop(cmpHeapData[iterator].index, cmpHeapData[iterator].value);
}
}
}
int offset = stride * k * blockIndex + offsetInBlock;
T * dOutput = output + offset;
int * indexOutput = index + offset;
for (int q = 0; q < k; ++q)
{
dOutput[stride * q] = ansHeapData.items[q].value;
indexOutput[stride * q] = ansHeapData.items[q].index;
}
}
}
/*
get the top-k items along a given dimension get the top-k items along a given dimension
>> a - input tensor >> a - input tensor
>> b - output tensor (top-k result) >> b - output tensor (top-k result)
...@@ -389,7 +522,12 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -389,7 +522,12 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
blockNum *= a->dimSizeRDI[i]; blockNum *= a->dimSizeRDI[i];
int workerNum = blockNum < 16 ? 64 : 32; // should be tuned for better performance int workerNum = blockNum < 16 ? 64 : 32; // should be tuned for better performance
/*adjust the thread num according size of k for fitting the share memory size*/
if (k< 6) workerNum = 512;
else if (k < 11) workerNum = 256;
else if (k < 22) workerNum = 128;
else if (k < 44) workerNum = 64;
else workerNum = 32;
int cudaGrids[3]; int cudaGrids[3];
int cudaBlocks[3]; int cudaBlocks[3];
...@@ -397,7 +535,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -397,7 +535,7 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
workerNum, stride * blockNum, MAX_INT, workerNum, stride * blockNum, MAX_INT,
cudaGrids, cudaBlocks); cudaGrids, cudaBlocks);
for (int i = 0; i < 2; i++) { /*for (int i = 0; i < 2; i++) {
if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) >= SHARED_MEMORY_SIZE) { if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) >= SHARED_MEMORY_SIZE) {
if (cudaBlocks[1] >= 2 && cudaBlocks[1] % 2 == 0) { if (cudaBlocks[1] >= 2 && cudaBlocks[1] % 2 == 0) {
cudaBlocks[1] /= 2; cudaBlocks[1] /= 2;
...@@ -411,12 +549,14 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k) ...@@ -411,12 +549,14 @@ void _CudaTopK(const XTensor * a, XTensor * b, XTensor * index, int dim, int k)
cudaGrids[0] *= 2; cudaGrids[0] *= 2;
} }
} }
} }*/
int devIDBackup = 0; int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup); ProtectCudaDev(a->devID, devIDBackup);
/* we run the kernel if the heaps can fit into the shared memory */ /* we run the kernel if the heaps can fit into the shared memory */
cudaGrids[1] *= cudaBlocks[1];
cudaBlocks[1] = 1;
if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) < SHARED_MEMORY_SIZE) { if ((cudaBlocks[0] * cudaBlocks[1] + 1) * k * (a->unitSize + sizeof(int)) < SHARED_MEMORY_SIZE) {
if (a->dataType == DEFAULT_DTYPE) { if (a->dataType == DEFAULT_DTYPE) {
KernelTopK2<DTYPE> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> > KernelTopK2<DTYPE> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
......
...@@ -223,24 +223,32 @@ void _CudaSoftmaxSumMax(const XTensor * x, XTensor * y, int leadDim, XTensor * s ...@@ -223,24 +223,32 @@ void _CudaSoftmaxSumMax(const XTensor * x, XTensor * y, int leadDim, XTensor * s
int cudaGridSize[3]; int cudaGridSize[3];
int cudaBlockSize[3]; int cudaBlockSize[3];
//allocate thread num for old function
//GDevs.GetCudaThread2D(x->devID, stride * blockNum, dimensionSize, MAX_INT, cudaGridSize, cudaBlockSize); //GDevs.GetCudaThread2D(x->devID, stride * blockNum, dimensionSize, MAX_INT, cudaGridSize, cudaBlockSize);
//allocate thread num for new function
GDevs.GetCudaThread2D(x->devID, dimensionSize, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread2D(x->devID, dimensionSize, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
if (cudaBlockSize[0] % 32 != 0) if (cudaBlockSize[0] < 32)
cudaBlockSize[0] += (32 - cudaBlockSize[0] % 32); {
/**/ cudaBlockSize[0] = 32;//use at least a warp
if (cudaBlockSize[1] > 32)
{
cudaGridSize[1] = int(ceil(float(stride * blockNum) / 32));
cudaBlockSize[1] = 32;
}
}
int devIDBackup; int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup); ProtectCudaDev(x->devID, devIDBackup);
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){ if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
printf("run here\n");
/*KernelSoftmaxComputeTensor<<<dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1])>>> /*KernelSoftmaxComputeTensor<<<dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1])>>>
((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data, ((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data,
stride, dimensionSize, stride * dimensionSize, blockNum, stride * blockNum);*/ stride, dimensionSize, stride * dimensionSize, blockNum, stride * blockNum);
*/
KernelSoftmaxComputeTensorUseBroadcast << <dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >> > KernelSoftmaxComputeTensorUseBroadcast << <dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >> >
((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data, ((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data,
stride, dimensionSize, blockNum); stride, dimensionSize, blockNum);
//printf("%d %d %d %d %d %d\n", stride, dimensionSize, stride * dimensionSize, blockNum, stride * blockNum); printf("%d %d %d %d\n", cudaGridSize[0], cudaGridSize[1], cudaBlockSize[0], cudaBlockSize[1]);
} }
else if(x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16){ else if(x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16){
KernelSoftmaxComputeTensor<<<dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1])>>> KernelSoftmaxComputeTensor<<<dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1])>>>
......
...@@ -214,7 +214,7 @@ bool TestSoftmax3Gpu() ...@@ -214,7 +214,7 @@ bool TestSoftmax3Gpu()
int order = 2; int order = 2;
int * dimSize = new int[order]; int * dimSize = new int[order];
dimSize[0] = 32; dimSize[0] = 8;
dimSize[1] = 1000; dimSize[1] = 1000;
int unitNum = 1; int unitNum = 1;
...@@ -228,7 +228,7 @@ bool TestSoftmax3Gpu() ...@@ -228,7 +228,7 @@ bool TestSoftmax3Gpu()
/* initialize variables */ /* initialize variables */
FILE *dataFile; FILE *dataFile;
char dataString[32]; char dataString[32];
const int dataSize = 32 * 1000; const int dataSize = 8 * 1000;
DTYPE xData[dataSize]; DTYPE xData[dataSize];
if ((dataFile = fopen("D:\\Work\\TensorFlowLearn\\testdata.in", "r")) == NULL) if ((dataFile = fopen("D:\\Work\\TensorFlowLearn\\testdata.in", "r")) == NULL)
{ {
...@@ -253,7 +253,7 @@ bool TestSoftmax3Gpu() ...@@ -253,7 +253,7 @@ bool TestSoftmax3Gpu()
yGPU->SetZeroAll(); yGPU->SetZeroAll();
/* call Softmax function */ /* call Softmax function */
_Softmax(xGPU, yGPU, 0); _Softmax(xGPU, yGPU, 1);
/* check result */ /* check result */
...@@ -261,11 +261,10 @@ bool TestSoftmax3Gpu() ...@@ -261,11 +261,10 @@ bool TestSoftmax3Gpu()
DTYPE check = 0; DTYPE check = 0;
DTYPE TensorData[dataSize]; DTYPE TensorData[dataSize];
cudaMemcpy(TensorData, yGPU->data, sizeof(DTYPE)* unitNum, cudaMemcpyDeviceToHost); cudaMemcpy(TensorData, yGPU->data, sizeof(DTYPE)* unitNum, cudaMemcpyDeviceToHost);
//float check = 0; for (int i = 0; i < 1000; ++i)
for (int i = 0; i < 32; ++i)
{ {
check += TensorData[i]; check += TensorData[i];
printf("%f ", TensorData[i]); //printf("%f ", TensorData[i]);
} }
printf("\n%f \n", check); printf("\n%f \n", check);
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
*/ */
#include "TTopK.h" #include "TTopK.h"
#include "TSort.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* /*
...@@ -97,12 +97,21 @@ bool TestTopK1() ...@@ -97,12 +97,21 @@ bool TestTopK1()
int dim = 0; int dim = 0;
int k = sDimSize[dim]; int k = sDimSize[dim];
_TopK(s, t1, index1, dim, k); _TopK(s, t1, index1, dim, k);
_SortMe(t1, index1, dim);
TopK(sUser, tUser1, indexUser1, dim, k); TopK(sUser, tUser1, indexUser1, dim, k);
_SortMe(&tUser1, &indexUser1, dim);
t1->Dump(stderr);
tUser1.Dump(stderr);
index1->Dump(stderr);
dim = 1; dim = 1;
k = sDimSize[dim]; k = sDimSize[dim];
_TopK(s, t2, index2, dim, k); _TopK(s, t2, index2, dim, k);
_SortMe(t2, index2, dim);
TopK(sUser, tUser2, indexUser2, dim, k); TopK(sUser, tUser2, indexUser2, dim, k);
_SortMe(&tUser2, &indexUser2, dim);
/* check results */ /* check results */
cpuTest = t1->CheckData(tAnswer1, tUnitNum) && tUser1.CheckData(tAnswer1, tUnitNum) cpuTest = t1->CheckData(tAnswer1, tUnitNum) && tUser1.CheckData(tAnswer1, tUnitNum)
......
...@@ -60,8 +60,8 @@ bool Test() ...@@ -60,8 +60,8 @@ bool Test()
wrong = !TestSplit() || wrong; wrong = !TestSplit() || wrong;
wrong = !TestSum() || wrong; wrong = !TestSum() || wrong;
wrong = !TestSumByColumnTV() || wrong; wrong = !TestSumByColumnTV() || wrong;
wrong = !TestSumByColumnVT() || wrong; wrong = !TestSumByColumnVT() || wrong;*/
wrong = !TestTopK() || wrong; /*wrong = !TestTopK() || wrong;
wrong = !TestUnsqueeze() || wrong; wrong = !TestUnsqueeze() || wrong;
wrong = !TestXMem() || wrong; wrong = !TestXMem() || wrong;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论