Commit 779351ac by liyinqiao

Bug fixed for Gather functions.

parent a838294b
...@@ -138,9 +138,8 @@ void _CudaGather(const XTensor * s, XTensor * t, XTensor * srcIndex) ...@@ -138,9 +138,8 @@ void _CudaGather(const XTensor * s, XTensor * t, XTensor * srcIndex)
XMemCopy(sIndex, devID, srcIndex, -1, sizeof(int) * indexSize); XMemCopy(sIndex, devID, srcIndex, -1, sizeof(int) * indexSize);
} }
else { else {
int * sIndexCPU = new int[sizeof(int) * indexSize]; int * sIndexData = new int[sizeof(int) * indexSize];
XMemCopy(sIndexCPU, -1, srcIndex, srcIndex->devID, sizeof(int) * indexSize); XMemCopy(sIndexData, -1, srcIndex, srcIndex->devID, sizeof(int) * indexSize);
int * sIndexData = (int*)sIndexCPU->data;
for (int i = 0; i < indexSize; i++) { for (int i = 0; i < indexSize; i++) {
int srcIndexValue = sIndexData[i] * stride; int srcIndexValue = sIndexData[i] * stride;
CheckNTErrors(srcIndexValue < s->unitNum, "Wrong index!"); CheckNTErrors(srcIndexValue < s->unitNum, "Wrong index!");
...@@ -148,7 +147,7 @@ void _CudaGather(const XTensor * s, XTensor * t, XTensor * srcIndex) ...@@ -148,7 +147,7 @@ void _CudaGather(const XTensor * s, XTensor * t, XTensor * srcIndex)
sIndex = (int *)srcIndex->data; sIndex = (int *)srcIndex->data;
delete[] sIndexCPU; delete[] sIndexData;
} }
KernelGather<<<blocks, threads >>>(sData, tData, sIndex, indexSize, stride); KernelGather<<<blocks, threads >>>(sData, tData, sIndex, indexSize, stride);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论