Commit 9f14dc72 by linye

bug fixed

parent 3ad0e638
...@@ -187,8 +187,7 @@ void _CudaDiv(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, in ...@@ -187,8 +187,7 @@ void _CudaDiv(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, in
int cudaGridSize[3]; int cudaGridSize[3];
int cudaBlockSize[3]; int cudaBlockSize[3];
unsigned short temp = FloatToFloat16(alpha); half alpha1 = __float2half(alpha);
half alpha1 = *((half *)&temp);
if (a->unitNum == c->unitNum && b->unitNum == c->unitNum) { if (a->unitNum == c->unitNum && b->unitNum == c->unitNum) {
GDevs.GetCudaThread(a->devID, c->unitNum, cudaGridSize, cudaBlockSize); GDevs.GetCudaThread(a->devID, c->unitNum, cudaGridSize, cudaBlockSize);
......
...@@ -170,8 +170,7 @@ void _CudaDivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE ...@@ -170,8 +170,7 @@ void _CudaDivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE
} }
} }
else if (a->dataType == X_FLOAT16) { else if (a->dataType == X_FLOAT16) {
unsigned short temp = FloatToFloat16(alpha); half alpha1 = __float2half(alpha);
half alpha1 = *((half *)&temp);
if (stride > 1){ if (stride > 1){
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks); GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
if (alpha == (DTYPE)0.0F) if (alpha == (DTYPE)0.0F)
......
...@@ -170,8 +170,7 @@ void _CudaMultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, ...@@ -170,8 +170,7 @@ void _CudaMultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n,
} }
} }
else if (a->dataType == X_FLOAT16) { else if (a->dataType == X_FLOAT16) {
unsigned short temp = FloatToFloat16(alpha); half alpha1 = __float2half(alpha);
half alpha1 = *((half *)&temp);
if (stride > 1) { if (stride > 1) {
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks); GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
if (alpha == 0.0F) if (alpha == 0.0F)
......
...@@ -128,8 +128,8 @@ void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta) ...@@ -128,8 +128,8 @@ void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
dim3 blocks(gridSize[0]); dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]); dim3 threads(blockSize[0]);
unsigned short temp = FloatToFloat16(beta); half beta1 = __float2half(beta);
half beta1 = *((half *)&temp);
KernelADD << <blocks, threads >> >((__half *)a->data, (__half *)b->data, (__half *)c->data, a->unitNum, beta1); KernelADD << <blocks, threads >> >((__half *)a->data, (__half *)b->data, (__half *)c->data, a->unitNum, beta1);
} }
else if (a->dataType == X_INT && else if (a->dataType == X_INT &&
......
...@@ -245,7 +245,6 @@ void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p) ...@@ -245,7 +245,6 @@ void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
{ {
int n = tensor->order; int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim >= 0, "Illegal dimension!"); CheckNTErrors(dim < n && dim >= 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!"); CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!"); CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
...@@ -298,7 +297,6 @@ void _SetDataIndexed(XTensor * source, XTensor * modify, int dim, int index) ...@@ -298,7 +297,6 @@ void _SetDataIndexed(XTensor * source, XTensor * modify, int dim, int index)
int order = source->order; int order = source->order;
int size = source->GetDim(dim); int size = source->GetDim(dim);
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!"); CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
CheckNTErrors(index >= 0 && index < size, "Illegal index!"); CheckNTErrors(index >= 0 && index < size, "Illegal index!");
......
...@@ -79,11 +79,8 @@ void _CudaClip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper) ...@@ -79,11 +79,8 @@ void _CudaClip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper)
KernelClip << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, lower, upper, a->unitNum); KernelClip << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, lower, upper, a->unitNum);
} }
else if (a->dataType == X_FLOAT16) { else if (a->dataType == X_FLOAT16) {
unsigned short temp1 = FloatToFloat16(lower); half lower1 = __float2half(lower);
unsigned short temp2 = FloatToFloat16(upper); half upper1 = __float2half(upper);
half lower1 = *((half *)&temp1);
half upper1 = *((half *)&temp2);
KernelClip << <blocks, threads >> >((__half*)a->data, (__half*)b->data, lower1, upper1, a->unitNum); KernelClip << <blocks, threads >> >((__half*)a->data, (__half*)b->data, lower1, upper1, a->unitNum);
} }
......
...@@ -29,12 +29,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -29,12 +29,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA #ifdef USE_CUDA
/* set each entry to its clip value (CUDA Kernel) */ /* set each entry to its clip value (CUDA Kernel) */
template <class T> __global__ template <class T>
void KernelClip(T * a, T * b, T lower, T upper, int size);
/* set each entry to its clip value (CUDA Kernel) with float16 data type*/
__global__ __global__
void KernelClip(__half * a, __half * b, DTYPE lower, DTYPE upper, int size); void KernelClip(T * a, T * b, T lower, T upper, int size);
/* set each entry to its clip value */ /* set each entry to its clip value */
void _CudaClip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper); void _CudaClip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper);
......
...@@ -96,19 +96,17 @@ void _CudaScaleAndShift(const XTensor * a, XTensor * b, DTYPE scale, DTYPE shift ...@@ -96,19 +96,17 @@ void _CudaScaleAndShift(const XTensor * a, XTensor * b, DTYPE scale, DTYPE shift
KernelScaleAndShift<DTYPE, false, false> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift); KernelScaleAndShift<DTYPE, false, false> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
} }
else if(a->dataType == X_FLOAT16){ else if(a->dataType == X_FLOAT16){
unsigned short scale2 = FloatToFloat16(scale); half scale1 = __float2half(scale);
unsigned short shift2 = FloatToFloat16(shift); half shift1 = __float2half(shift);
__half * scaleft16p = (__half*)&scale2;
__half * shiftft16p = (__half*)&shift2;
if (scale == 1.0F && shift == 0) if (scale == 1.0F && shift == 0)
KernelScaleAndShift<__half, true, true><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p); KernelScaleAndShift<__half, true, true><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, scale1, shift1);
else if (scale == 1.0F && shift != 0) else if (scale == 1.0F && shift != 0)
KernelScaleAndShift<__half, true, false><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p); KernelScaleAndShift<__half, true, false><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, scale1, shift1);
else if (scale != 1.0F && shift == 0) else if (scale != 1.0F && shift == 0)
KernelScaleAndShift<__half, false, true><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p); KernelScaleAndShift<__half, false, true><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, scale1, shift1);
else else
KernelScaleAndShift<__half, false, false> << <blocks, threads >> >((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p); KernelScaleAndShift<__half, false, false> << <blocks, threads >> >((__half*)a->data, (__half*)b->data, a->unitNum, scale1, shift1);
} }
else if (a->dataType == X_INT) { else if (a->dataType == X_INT) {
int scale2 = int(scale); int scale2 = int(scale);
......
...@@ -30,7 +30,7 @@ bool Test() ...@@ -30,7 +30,7 @@ bool Test()
XPRINT(0, stdout, "Testing the XTensor utilites ... \n\n"); XPRINT(0, stdout, "Testing the XTensor utilites ... \n\n");
//wrong = !TestAbsolute() || wrong; //wrong = !TestAbsolute() || wrong;
//wrong = !TestClip() || wrong; wrong = !TestClip() || wrong;
//wrong = !TestCompare() || wrong; //wrong = !TestCompare() || wrong;
//wrong = !TestConcatenate() || wrong; //wrong = !TestConcatenate() || wrong;
//wrong = !TestConcatenateSolely() || wrong; //wrong = !TestConcatenateSolely() || wrong;
...@@ -38,18 +38,18 @@ bool Test() ...@@ -38,18 +38,18 @@ bool Test()
//wrong = !TestConvertDataType() || wrong; //wrong = !TestConvertDataType() || wrong;
//wrong = !TestCopyIndexed() || wrong; //wrong = !TestCopyIndexed() || wrong;
//wrong = !TestCopyValues() || wrong; //wrong = !TestCopyValues() || wrong;
//wrong = !TestDiv() || wrong; wrong = !TestDiv() || wrong;
//wrong = !TestDivDim() || wrong; wrong = !TestDivDim() || wrong;
//wrong = !TestExp() || wrong; //wrong = !TestExp() || wrong;
//wrong = !TestGather() || wrong; //wrong = !TestGather() || wrong;
//wrong = !TestLog() || wrong; //wrong = !TestLog() || wrong;
wrong = !TestMatrixMul() || wrong; //wrong = !TestMatrixMul() || wrong;
//wrong = !TestMatrixMul2D() || wrong; //wrong = !TestMatrixMul2D() || wrong;
//wrong = !TestMatrixMul2DParallel() || wrong; //wrong = !TestMatrixMul2DParallel() || wrong;
//wrong = !TestMatrixMulBatched() || wrong; //wrong = !TestMatrixMulBatched() || wrong;
//wrong = !TestMerge() || wrong; //wrong = !TestMerge() || wrong;
//wrong = !TestMultiply() || wrong; //wrong = !TestMultiply() || wrong;
//wrong = !TestMultiplyDim() || wrong; wrong = !TestMultiplyDim() || wrong;
//wrong = !TestNegate() || wrong; //wrong = !TestNegate() || wrong;
//wrong = !TestNormalize() || wrong; //wrong = !TestNormalize() || wrong;
//wrong = !TestPower() || wrong; //wrong = !TestPower() || wrong;
...@@ -60,7 +60,7 @@ bool Test() ...@@ -60,7 +60,7 @@ bool Test()
//wrong = !TestReduceSumSquared() || wrong; //wrong = !TestReduceSumSquared() || wrong;
//wrong = !TestReduceVariance() || wrong; //wrong = !TestReduceVariance() || wrong;
//wrong = !TestRound() || wrong; //wrong = !TestRound() || wrong;
//wrong = !TestScaleAndShift() || wrong; wrong = !TestScaleAndShift() || wrong;
//wrong = !TestSelect() || wrong; //wrong = !TestSelect() || wrong;
//wrong = !TestSetAscendingOrder() || wrong; //wrong = !TestSetAscendingOrder() || wrong;
//wrong = !TestSetData() || wrong; //wrong = !TestSetData() || wrong;
...@@ -70,7 +70,7 @@ bool Test() ...@@ -70,7 +70,7 @@ bool Test()
//wrong = !TestSplit() || wrong; //wrong = !TestSplit() || wrong;
//wrong = !TestSpread() || wrong; //wrong = !TestSpread() || wrong;
//wrong = !TestSub() || wrong; //wrong = !TestSub() || wrong;
//wrong = !TestSum() || wrong; wrong = !TestSum() || wrong;
//wrong = !TestSumByColumnTV() || wrong; //wrong = !TestSumByColumnTV() || wrong;
//wrong = !TestSumByColumnVT() || wrong; //wrong = !TestSumByColumnVT() || wrong;
//wrong = !TestSumDim() || wrong; //wrong = !TestSumDim() || wrong;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论