Commit 9f14dc72 by linye

bug fixed

parent 3ad0e638
......@@ -187,8 +187,7 @@ void _CudaDiv(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, in
int cudaGridSize[3];
int cudaBlockSize[3];
unsigned short temp = FloatToFloat16(alpha);
half alpha1 = *((half *)&temp);
half alpha1 = __float2half(alpha);
if (a->unitNum == c->unitNum && b->unitNum == c->unitNum) {
GDevs.GetCudaThread(a->devID, c->unitNum, cudaGridSize, cudaBlockSize);
......
......@@ -170,8 +170,7 @@ void _CudaDivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE
}
}
else if (a->dataType == X_FLOAT16) {
unsigned short temp = FloatToFloat16(alpha);
half alpha1 = *((half *)&temp);
half alpha1 = __float2half(alpha);
if (stride > 1){
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
if (alpha == (DTYPE)0.0F)
......
......@@ -170,8 +170,7 @@ void _CudaMultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n,
}
}
else if (a->dataType == X_FLOAT16) {
unsigned short temp = FloatToFloat16(alpha);
half alpha1 = *((half *)&temp);
half alpha1 = __float2half(alpha);
if (stride > 1) {
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
if (alpha == 0.0F)
......
......@@ -128,8 +128,8 @@ void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
unsigned short temp = FloatToFloat16(beta);
half beta1 = *((half *)&temp);
half beta1 = __float2half(beta);
KernelADD << <blocks, threads >> >((__half *)a->data, (__half *)b->data, (__half *)c->data, a->unitNum, beta1);
}
else if (a->dataType == X_INT &&
......
......@@ -245,7 +245,6 @@ void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim >= 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
......@@ -298,7 +297,6 @@ void _SetDataIndexed(XTensor * source, XTensor * modify, int dim, int index)
int order = source->order;
int size = source->GetDim(dim);
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
CheckNTErrors(index >= 0 && index < size, "Illegal index!");
......
......@@ -79,11 +79,8 @@ void _CudaClip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper)
KernelClip << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, lower, upper, a->unitNum);
}
else if (a->dataType == X_FLOAT16) {
unsigned short temp1 = FloatToFloat16(lower);
unsigned short temp2 = FloatToFloat16(upper);
half lower1 = *((half *)&temp1);
half upper1 = *((half *)&temp2);
half lower1 = __float2half(lower);
half upper1 = __float2half(upper);
KernelClip << <blocks, threads >> >((__half*)a->data, (__half*)b->data, lower1, upper1, a->unitNum);
}
......
......@@ -29,12 +29,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* set each entry to its clip value (CUDA Kernel) */
template <class T> __global__
void KernelClip(T * a, T * b, T lower, T upper, int size);
/* set each entry to its clip value (CUDA Kernel) with float16 data type*/
template <class T>
__global__
void KernelClip(__half * a, __half * b, DTYPE lower, DTYPE upper, int size);
void KernelClip(T * a, T * b, T lower, T upper, int size);
/* set each entry to its clip value */
void _CudaClip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper);
......
......@@ -96,19 +96,17 @@ void _CudaScaleAndShift(const XTensor * a, XTensor * b, DTYPE scale, DTYPE shift
KernelScaleAndShift<DTYPE, false, false> << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum, scale, shift);
}
else if(a->dataType == X_FLOAT16){
unsigned short scale2 = FloatToFloat16(scale);
unsigned short shift2 = FloatToFloat16(shift);
__half * scaleft16p = (__half*)&scale2;
__half * shiftft16p = (__half*)&shift2;
half scale1 = __float2half(scale);
half shift1 = __float2half(shift);
if (scale == 1.0F && shift == 0)
KernelScaleAndShift<__half, true, true><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p);
KernelScaleAndShift<__half, true, true><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, scale1, shift1);
else if (scale == 1.0F && shift != 0)
KernelScaleAndShift<__half, true, false><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p);
KernelScaleAndShift<__half, true, false><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, scale1, shift1);
else if (scale != 1.0F && shift == 0)
KernelScaleAndShift<__half, false, true><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p);
KernelScaleAndShift<__half, false, true><<<blocks, threads>>>((__half*)a->data, (__half*)b->data, a->unitNum, scale1, shift1);
else
KernelScaleAndShift<__half, false, false> << <blocks, threads >> >((__half*)a->data, (__half*)b->data, a->unitNum, *scaleft16p, *shiftft16p);
KernelScaleAndShift<__half, false, false> << <blocks, threads >> >((__half*)a->data, (__half*)b->data, a->unitNum, scale1, shift1);
}
else if (a->dataType == X_INT) {
int scale2 = int(scale);
......
......@@ -30,7 +30,7 @@ bool Test()
XPRINT(0, stdout, "Testing the XTensor utilites ... \n\n");
//wrong = !TestAbsolute() || wrong;
//wrong = !TestClip() || wrong;
wrong = !TestClip() || wrong;
//wrong = !TestCompare() || wrong;
//wrong = !TestConcatenate() || wrong;
//wrong = !TestConcatenateSolely() || wrong;
......@@ -38,18 +38,18 @@ bool Test()
//wrong = !TestConvertDataType() || wrong;
//wrong = !TestCopyIndexed() || wrong;
//wrong = !TestCopyValues() || wrong;
//wrong = !TestDiv() || wrong;
//wrong = !TestDivDim() || wrong;
wrong = !TestDiv() || wrong;
wrong = !TestDivDim() || wrong;
//wrong = !TestExp() || wrong;
//wrong = !TestGather() || wrong;
//wrong = !TestLog() || wrong;
wrong = !TestMatrixMul() || wrong;
//wrong = !TestMatrixMul() || wrong;
//wrong = !TestMatrixMul2D() || wrong;
//wrong = !TestMatrixMul2DParallel() || wrong;
//wrong = !TestMatrixMulBatched() || wrong;
//wrong = !TestMerge() || wrong;
//wrong = !TestMultiply() || wrong;
//wrong = !TestMultiplyDim() || wrong;
wrong = !TestMultiplyDim() || wrong;
//wrong = !TestNegate() || wrong;
//wrong = !TestNormalize() || wrong;
//wrong = !TestPower() || wrong;
......@@ -60,7 +60,7 @@ bool Test()
//wrong = !TestReduceSumSquared() || wrong;
//wrong = !TestReduceVariance() || wrong;
//wrong = !TestRound() || wrong;
//wrong = !TestScaleAndShift() || wrong;
wrong = !TestScaleAndShift() || wrong;
//wrong = !TestSelect() || wrong;
//wrong = !TestSetAscendingOrder() || wrong;
//wrong = !TestSetData() || wrong;
......@@ -70,7 +70,7 @@ bool Test()
//wrong = !TestSplit() || wrong;
//wrong = !TestSpread() || wrong;
//wrong = !TestSub() || wrong;
//wrong = !TestSum() || wrong;
wrong = !TestSum() || wrong;
//wrong = !TestSumByColumnTV() || wrong;
//wrong = !TestSumByColumnVT() || wrong;
//wrong = !TestSumDim() || wrong;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论