Commit 80da434c by liyinqiao

Bug fixed.

Fix the __half bugs in SetData.
parent 55dd6a78
...@@ -51,7 +51,7 @@ void KernelSetDataFixed(T * d, T v, int size) ...@@ -51,7 +51,7 @@ void KernelSetDataFixed(T * d, T v, int size)
template __global__ void KernelSetDataFixed<int>(int *, int, int); template __global__ void KernelSetDataFixed<int>(int *, int, int);
template __global__ void KernelSetDataFixed<float>(float *, float, int); template __global__ void KernelSetDataFixed<float>(float *, float, int);
template __global__ void KernelSetDataFixed<double>(double *, double, int); template __global__ void KernelSetDataFixed<double>(double *, double, int);
template __global__ void KernelSetDataFixed<__half>(__half*, __half, int); //template __global__ void KernelSetDataFixed<__half>(__half*, __half, int);
/* /*
generate data items with a fixed value generate data items with a fixed value
...@@ -80,8 +80,8 @@ void _CudaSetDataFixed(XTensor * tensor, T value) ...@@ -80,8 +80,8 @@ void _CudaSetDataFixed(XTensor * tensor, T value)
KernelSetDataFixed << <blocks, threads >> > ((float*)tensor->data, (float)value, tensor->unitNum); KernelSetDataFixed << <blocks, threads >> > ((float*)tensor->data, (float)value, tensor->unitNum);
else if (tensor->dataType == X_DOUBLE) else if (tensor->dataType == X_DOUBLE)
KernelSetDataFixed << <blocks, threads >> > ((double*)tensor->data, (double)value, tensor->unitNum); KernelSetDataFixed << <blocks, threads >> > ((double*)tensor->data, (double)value, tensor->unitNum);
else if (tensor->dataType == X_FLOAT16) //else if (tensor->dataType == X_FLOAT16)
KernelSetDataFixed << <blocks, threads >> > ((__half*)tensor->data, (__half)value, tensor->unitNum); // KernelSetDataFixed << <blocks, threads >> > ((__half*)tensor->data, (__half)value, tensor->unitNum);
else else
ShowNTErrors("TODO! Unsupported datatype!") ShowNTErrors("TODO! Unsupported datatype!")
...@@ -111,7 +111,7 @@ void KernelSetDataFixedCond(T * d, T * c, T value, int size) ...@@ -111,7 +111,7 @@ void KernelSetDataFixedCond(T * d, T * c, T value, int size)
template __global__ void KernelSetDataFixedCond<int>(int*, int*, int, int); template __global__ void KernelSetDataFixedCond<int>(int*, int*, int, int);
template __global__ void KernelSetDataFixedCond<float>(float*, float*, float, int); template __global__ void KernelSetDataFixedCond<float>(float*, float*, float, int);
template __global__ void KernelSetDataFixedCond<double>(double*, double*, double, int); template __global__ void KernelSetDataFixedCond<double>(double*, double*, double, int);
template __global__ void KernelSetDataFixedCond<__half>(__half*, __half*, __half, int); //template __global__ void KernelSetDataFixedCond<__half>(__half*, __half*, __half, int);
/* /*
generate data items with a fixed value p generate data items with a fixed value p
...@@ -146,9 +146,9 @@ void _CudaSetDataFixedCond(XTensor* tensor, XTensor* condition, T value) ...@@ -146,9 +146,9 @@ void _CudaSetDataFixedCond(XTensor* tensor, XTensor* condition, T value)
else if (tensor->dataType == X_DOUBLE) else if (tensor->dataType == X_DOUBLE)
KernelSetDataFixedCond <<< blocks, threads >>> ((double*)tensor->data, (double*)condition->data, KernelSetDataFixedCond <<< blocks, threads >>> ((double*)tensor->data, (double*)condition->data,
(double)value, tensor->unitNum); (double)value, tensor->unitNum);
else if (tensor->dataType == X_FLOAT16) //else if (tensor->dataType == X_FLOAT16)
KernelSetDataFixedCond <<< blocks, threads >>> ((__half*)tensor->data, (__half*)condition->data, // KernelSetDataFixedCond <<< blocks, threads >>> ((__half*)tensor->data, (__half*)condition->data,
(__half)value, tensor->unitNum); // (__half)value, tensor->unitNum);
else else
ShowNTErrors("TODO! Unsupported datatype!") ShowNTErrors("TODO! Unsupported datatype!")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论