Commit d2c7e39a by linye

1. XTensorBLAS bug fixed 2. gather backward float16 supported

parent 01747e0d
...@@ -161,11 +161,11 @@ void _CudaBLASMatrixMULBatched(cublasHandle_t * handle, ...@@ -161,11 +161,11 @@ void _CudaBLASMatrixMULBatched(cublasHandle_t * handle,
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS) if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS) else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS) else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_TRANS) else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_16F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH); cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
} }
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT) { else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT) {
...@@ -173,11 +173,11 @@ void _CudaBLASMatrixMULBatched(cublasHandle_t * handle, ...@@ -173,11 +173,11 @@ void _CudaBLASMatrixMULBatched(cublasHandle_t * handle,
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS) if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS) else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS) else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_TRANS) else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, (void*)&alpha, b, CUDA_R_16F, mb, a, CUDA_R_16F, ma, (void*)&beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH); cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
} }
else if (dataTypeA == X_INT8 && dataTypeB == X_INT8 && dataTypeC == X_FLOAT) { else if (dataTypeA == X_INT8 && dataTypeB == X_INT8 && dataTypeC == X_FLOAT) {
...@@ -193,11 +193,11 @@ void _CudaBLASMatrixMULBatched(cublasHandle_t * handle, ...@@ -193,11 +193,11 @@ void _CudaBLASMatrixMULBatched(cublasHandle_t * handle,
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS) if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS) else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS) else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_TRANS) else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha, b, CUDA_R_8I, mb, a, CUDA_R_8I, ma, &beta, c, CUDA_R_32F, mc, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH); cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
} }
else { else {
...@@ -246,11 +246,11 @@ void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle, ...@@ -246,11 +246,11 @@ void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS) if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS) else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS) else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_TRANS) else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, (void*)&alpha, b, CUDA_R_16F, mb, strideB, a, CUDA_R_16F, ma, strideA, (void*)&beta, c, CUDA_R_16F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH); cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
} }
else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT) { else if (dataTypeA == X_FLOAT16 && dataTypeB == X_FLOAT16 && dataTypeC == X_FLOAT) {
...@@ -278,11 +278,11 @@ void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle, ...@@ -278,11 +278,11 @@ void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS) if (transposedA == X_NOTRANS && transposedB == X_NOTRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_N, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_NOTRANS) else if (transposedA == X_TRANS && transposedB == X_NOTRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_N, CUBLAS_OP_T, mc, nc, na, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_NOTRANS && transposedB == X_TRANS) else if (transposedA == X_NOTRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_N, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
else if (transposedA == X_TRANS && transposedB == X_TRANS) else if (transposedA == X_TRANS && transposedB == X_TRANS)
cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, ma, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); cublasGemmStridedBatchedEx(*handle, CUBLAS_OP_T, CUBLAS_OP_T, mc, nc, na, &alpha, b, CUDA_R_8I, mb, strideB, a, CUDA_R_8I, ma, strideA, &beta, c, CUDA_R_32F, mc, strideC, count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH); cublasSetMathMode(*handle, CUBLAS_DEFAULT_MATH);
} }
else { else {
......
...@@ -330,12 +330,13 @@ Care of the operator "+=" instead of "=". ...@@ -330,12 +330,13 @@ Care of the operator "+=" instead of "=".
>> indexSize - the number of index >> indexSize - the number of index
>> stride - stride of a data block >> stride - stride of a data block
*/ */
template <class T, TENSOR_DATA_TYPE datatype>
__global__ __global__
void KernelSpreadForGather(DTYPE * sData, DTYPE * cData, int * srcIndex, void KernelSpreadForGather(T * sData, T * cData, int * srcIndex,
int indexSize, int stride) int indexSize, int stride)
{ {
__shared__ DTYPE * sp[MAX_CUDA_THREAD_NUM_PER_BLOCK]; __shared__ T * sp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE * cp[MAX_CUDA_THREAD_NUM_PER_BLOCK]; __shared__ T * cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
/* block id */ /* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.x * blockIdx.x + threadIdx.x;
...@@ -353,13 +354,18 @@ void KernelSpreadForGather(DTYPE * sData, DTYPE * cData, int * srcIndex, ...@@ -353,13 +354,18 @@ void KernelSpreadForGather(DTYPE * sData, DTYPE * cData, int * srcIndex,
__syncthreads(); __syncthreads();
DTYPE * s = sp[threadIdx.x]; T * s = sp[threadIdx.x];
DTYPE * c = cp[threadIdx.x]; T * c = cp[threadIdx.x];
//DTYPE * s = sData + srcIndex[i] * stride; //DTYPE * s = sData + srcIndex[i] * stride;
//DTYPE * c = cData + i * stride; //DTYPE * c = cData + i * stride;
atomicAdd(s + offset, c[offset]); if (datatype == X_FLOAT) {
atomicAdd(((DTYPE*)s + offset), *((DTYPE*)c + offset));
}
else if (datatype == X_FLOAT16) {
atomicAdd(((__half2*)s + offset), *((__half2*)c + offset));
}
} }
/* /*
...@@ -372,6 +378,10 @@ And this is a special spread function for backward computation of gather functio ...@@ -372,6 +378,10 @@ And this is a special spread function for backward computation of gather functio
*/ */
void _CudaSpreadForGather(XTensor * source, XTensor * collection, XTensor * srcIndex) void _CudaSpreadForGather(XTensor * source, XTensor * collection, XTensor * srcIndex)
{ {
CheckNTErrors((source->dataType == X_FLOAT) ||
(source->dataType == X_FLOAT16),
"Unmatched tensors in gather!");
int devID = source->devID; int devID = source->devID;
XMem * mem = source->mem; XMem * mem = source->mem;
...@@ -384,8 +394,6 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, XTensor * srcI ...@@ -384,8 +394,6 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, XTensor * srcI
int devIDBackup; int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup); ProtectCudaDev(source->devID, devIDBackup);
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
int * sIndex = NULL; int * sIndex = NULL;
GDevs.GetCudaThread2D(devID, indexSize, stride, MAX_INT, cudaGrids, cudaBlocks); GDevs.GetCudaThread2D(devID, indexSize, stride, MAX_INT, cudaGrids, cudaBlocks);
...@@ -402,7 +410,19 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, XTensor * srcI ...@@ -402,7 +410,19 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, XTensor * srcI
else else
sIndex = (int *)srcIndex->data; sIndex = (int *)srcIndex->data;
KernelSpreadForGather<<<blocks, threads >>>(sData, cData, sIndex, indexSize, stride); if (source->dataType == DEFAULT_DTYPE && collection->dataType == DEFAULT_DTYPE)
{
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
KernelSpreadForGather<DTYPE, X_FLOAT> << <blocks, threads >> >(sData, cData, sIndex, indexSize, stride);
}
else if (source->dataType == X_FLOAT16 && collection->dataType == X_FLOAT16)
{
__half2 * sData = (__half2*)source->data;
__half2 * cData = (__half2*)collection->data;
KernelSpreadForGather<__half2, X_FLOAT16> << <blocks, threads >> >(sData, cData, sIndex, indexSize, stride);
}
if (srcIndex->devID < 0) { if (srcIndex->devID < 0) {
if(mem != NULL) if(mem != NULL)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论