Commit a838294b by liyinqiao

Bug fixed.

1. Add boundary check for Gather function.
2. Minor error fixed.
parent 1943ce99
......@@ -49,6 +49,8 @@ void _Gather(const XTensor * s, XTensor * t, XTensor * srcIndex, int dim)
return;
}
#endif
ShowNTErrors("TODO!");
return;
}
/*
......@@ -64,27 +66,30 @@ void _Gather(const XTensor * s, XTensor * t, XTensor * srcIndex)
CheckNTErrors(s->devID == t->devID, "the data must be kept on the same device!");
CheckNTErrors((s->unitSize == t->unitSize), "Unmatched tensors!");
if (s->devID >= 0) {
#ifdef USE_CUDA
if (s->devID >= 0 && t->devID >= 0) {
_CudaGather(s, t, srcIndex);
return;
}
#else
ShowNTErrors("Plesae specify USE_CUDA and recompile the code!");
#endif
int stride = 1;
int indexSize = 1;
stride = s->GetDim(-1);
indexSize = srcIndex->unitNum;
DTYPE * sData = (DTYPE*)s->data;
DTYPE * tData = (DTYPE*)t->data;
int * sIndexData = (int*)srcIndex->data;
for (int i = 0; i < indexSize; i++) {
int sIndex = sIndexData[i] * stride;
for (int j = 0; j < stride; j++)
tData[i * stride + j] = sData[sIndex + j];
}
else {
int stride = 1;
int indexSize = 1;
stride = s->GetDim(-1);
indexSize = srcIndex->unitNum;
DTYPE * sData = (DTYPE*)s->data;
DTYPE * tData = (DTYPE*)t->data;
int * sIndexData = (int*)srcIndex->data;
for (int i = 0; i < indexSize; i++) {
int sIndex = sIndexData[i] * stride;
CheckNTErrors(sIndex < s->unitNum, "Wrong index!");
for (int j = 0; j < stride; j++)
tData[i * stride + j] = sData[sIndex + j];
}
}
}
......
......@@ -126,14 +126,31 @@ void _CudaGather(const XTensor * s, XTensor * t, XTensor * srcIndex)
int * sIndex = NULL;
if (srcIndex->devID < 0) {
int * sIndexData = (int*)srcIndex->data;
for (int i = 0; i < indexSize; i++) {
int srcIndexValue = sIndexData[i] * stride;
CheckNTErrors(srcIndexValue < s->unitNum, "Wrong index!");
}
sIndex = mem != NULL ?
(int*)mem->AllocBuf(mem->devID, sizeof(int) * indexSize) :
(int*)XMemAlloc(mem->devID, sizeof(int) * indexSize);
XMemCopy(sIndex, devID, srcIndex, -1, sizeof(int) * indexSize);
}
else
else {
int * sIndexCPU = new int[sizeof(int) * indexSize];
XMemCopy(sIndexCPU, -1, srcIndex, srcIndex->devID, sizeof(int) * indexSize);
int * sIndexData = (int*)sIndexCPU->data;
for (int i = 0; i < indexSize; i++) {
int srcIndexValue = sIndexData[i] * stride;
CheckNTErrors(srcIndexValue < s->unitNum, "Wrong index!");
}
sIndex = (int *)srcIndex->data;
delete[] sIndexCPU;
}
KernelGather<<<blocks, threads >>>(sData, tData, sIndex, indexSize, stride);
if (srcIndex->devID < 0) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论