Commit 779351ac by liyinqiao

Bug fixed for Gather functions.

parent a838294b
......@@ -138,9 +138,8 @@ void _CudaGather(const XTensor * s, XTensor * t, XTensor * srcIndex)
XMemCopy(sIndex, devID, srcIndex, -1, sizeof(int) * indexSize);
}
else {
int * sIndexCPU = new int[sizeof(int) * indexSize];
XMemCopy(sIndexCPU, -1, srcIndex, srcIndex->devID, sizeof(int) * indexSize);
int * sIndexData = (int*)sIndexCPU->data;
int * sIndexData = new int[sizeof(int) * indexSize];
XMemCopy(sIndexData, -1, srcIndex, srcIndex->devID, sizeof(int) * indexSize);
for (int i = 0; i < indexSize; i++) {
int srcIndexValue = sIndexData[i] * stride;
CheckNTErrors(srcIndexValue < s->unitNum, "Wrong index!");
......@@ -148,7 +147,7 @@ void _CudaGather(const XTensor * s, XTensor * t, XTensor * srcIndex)
sIndex = (int *)srcIndex->data;
delete[] sIndexCPU;
delete[] sIndexData;
}
KernelGather<<<blocks, threads >>>(sData, tData, sIndex, indexSize, stride);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论