Commit a838294b by liyinqiao

Bug fixed.

1. Add boundary check for Gather function.
2. Minor error fixed.
parent 1943ce99
...@@ -49,6 +49,8 @@ void _Gather(const XTensor * s, XTensor * t, XTensor * srcIndex, int dim) ...@@ -49,6 +49,8 @@ void _Gather(const XTensor * s, XTensor * t, XTensor * srcIndex, int dim)
return; return;
} }
#endif #endif
ShowNTErrors("TODO!");
return;
} }
/* /*
...@@ -64,27 +66,30 @@ void _Gather(const XTensor * s, XTensor * t, XTensor * srcIndex) ...@@ -64,27 +66,30 @@ void _Gather(const XTensor * s, XTensor * t, XTensor * srcIndex)
CheckNTErrors(s->devID == t->devID, "the data must be kept on the same device!"); CheckNTErrors(s->devID == t->devID, "the data must be kept on the same device!");
CheckNTErrors((s->unitSize == t->unitSize), "Unmatched tensors!"); CheckNTErrors((s->unitSize == t->unitSize), "Unmatched tensors!");
if (s->devID >= 0) {
#ifdef USE_CUDA #ifdef USE_CUDA
if (s->devID >= 0 && t->devID >= 0) {
_CudaGather(s, t, srcIndex); _CudaGather(s, t, srcIndex);
return; #else
} ShowNTErrors("Plesae specify USE_CUDA and recompile the code!");
#endif #endif
}
int stride = 1; else {
int indexSize = 1; int stride = 1;
int indexSize = 1;
stride = s->GetDim(-1);
indexSize = srcIndex->unitNum; stride = s->GetDim(-1);
indexSize = srcIndex->unitNum;
DTYPE * sData = (DTYPE*)s->data;
DTYPE * tData = (DTYPE*)t->data; DTYPE * sData = (DTYPE*)s->data;
int * sIndexData = (int*)srcIndex->data; DTYPE * tData = (DTYPE*)t->data;
int * sIndexData = (int*)srcIndex->data;
for (int i = 0; i < indexSize; i++) {
int sIndex = sIndexData[i] * stride; for (int i = 0; i < indexSize; i++) {
for (int j = 0; j < stride; j++) int sIndex = sIndexData[i] * stride;
tData[i * stride + j] = sData[sIndex + j]; CheckNTErrors(sIndex < s->unitNum, "Wrong index!");
for (int j = 0; j < stride; j++)
tData[i * stride + j] = sData[sIndex + j];
}
} }
} }
......
...@@ -126,14 +126,31 @@ void _CudaGather(const XTensor * s, XTensor * t, XTensor * srcIndex) ...@@ -126,14 +126,31 @@ void _CudaGather(const XTensor * s, XTensor * t, XTensor * srcIndex)
int * sIndex = NULL; int * sIndex = NULL;
if (srcIndex->devID < 0) { if (srcIndex->devID < 0) {
int * sIndexData = (int*)srcIndex->data;
for (int i = 0; i < indexSize; i++) {
int srcIndexValue = sIndexData[i] * stride;
CheckNTErrors(srcIndexValue < s->unitNum, "Wrong index!");
}
sIndex = mem != NULL ? sIndex = mem != NULL ?
(int*)mem->AllocBuf(mem->devID, sizeof(int) * indexSize) : (int*)mem->AllocBuf(mem->devID, sizeof(int) * indexSize) :
(int*)XMemAlloc(mem->devID, sizeof(int) * indexSize); (int*)XMemAlloc(mem->devID, sizeof(int) * indexSize);
XMemCopy(sIndex, devID, srcIndex, -1, sizeof(int) * indexSize); XMemCopy(sIndex, devID, srcIndex, -1, sizeof(int) * indexSize);
} }
else else {
int * sIndexCPU = new int[sizeof(int) * indexSize];
XMemCopy(sIndexCPU, -1, srcIndex, srcIndex->devID, sizeof(int) * indexSize);
int * sIndexData = (int*)sIndexCPU->data;
for (int i = 0; i < indexSize; i++) {
int srcIndexValue = sIndexData[i] * stride;
CheckNTErrors(srcIndexValue < s->unitNum, "Wrong index!");
}
sIndex = (int *)srcIndex->data; sIndex = (int *)srcIndex->data;
delete[] sIndexCPU;
}
KernelGather<<<blocks, threads >>>(sData, tData, sIndex, indexSize, stride); KernelGather<<<blocks, threads >>>(sData, tData, sIndex, indexSize, stride);
if (srcIndex->devID < 0) { if (srcIndex->devID < 0) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论