Commit 561a9e58 by xiaotong

update cuda code for Sum

parent f5161a0d
...@@ -37,7 +37,7 @@ return a pointer ...@@ -37,7 +37,7 @@ return a pointer
*/ */
void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta) void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{ {
CheckNTErrors(a && b && c, "Empty tensors in addition!"); CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == b->unitNum && a->unitNum == c->unitNum, CheckNTErrors(a->unitNum == b->unitNum && a->unitNum == c->unitNum,
"Unmatched tensors in addition!"); "Unmatched tensors in addition!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType, CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
......
...@@ -51,11 +51,9 @@ tensor summation c = a + b * \beta (cuda version) ...@@ -51,11 +51,9 @@ tensor summation c = a + b * \beta (cuda version)
>> c - where we put a+b*\beta. we save it in a if c is NULL >> c - where we put a+b*\beta. we save it in a if c is NULL
>> beta - the scaling factor >> beta - the scaling factor
*/ */
void _CudaSum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta) void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{ {
if (c == NULL) CheckNTErrors(a && b && c, "Empty tensor input!");
c = a;
CheckNTErrors((a->unitNum == b->unitNum && a->unitNum == c->unitNum), CheckNTErrors((a->unitNum == b->unitNum && a->unitNum == c->unitNum),
"Unmatched tensors in addition!"); "Unmatched tensors in addition!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType), CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
......
...@@ -34,7 +34,7 @@ void KernelADD(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1. ...@@ -34,7 +34,7 @@ void KernelADD(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1.
/* tensor summation c = a + b * \beta (cuda version) */ /* tensor summation c = a + b * \beta (cuda version) */
extern "C" extern "C"
void _CudaSum(XTensor * a, XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0); void _CudaSum(const XTensor * a, const XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
/* tensor summation c = a + b * \beta (cuda version) with an input handle */ /* tensor summation c = a + b * \beta (cuda version) with an input handle */
extern "C" extern "C"
......
...@@ -35,7 +35,7 @@ copy a range of elements from a source vector to a target vector ...@@ -35,7 +35,7 @@ copy a range of elements from a source vector to a target vector
>> stream - the stream for creating the job pipeline >> stream - the stream for creating the job pipeline
<< return - succeed or not << return - succeed or not
*/ */
bool CudaCopyValues(XTensor * s, XTensor * t, XStream * stream) bool CudaCopyValues(const XTensor * s, XTensor * t, XStream * stream)
{ {
if (s == NULL || t == NULL) if (s == NULL || t == NULL)
return false; return false;
...@@ -70,7 +70,7 @@ bool CudaCopyValues(XTensor * s, XTensor * t, XStream * stream) ...@@ -70,7 +70,7 @@ bool CudaCopyValues(XTensor * s, XTensor * t, XStream * stream)
s->dataType == DEFAULT_DTYPE && s->dataType == DEFAULT_DTYPE &&
t->dataType == DEFAULT_DTYPE) t->dataType == DEFAULT_DTYPE)
{ {
int num = s->GetNonzeroSize(); int num = s->unitNumNonZero;
int size = sizeof(int) + num * (s->unitSize + sizeof(int)); int size = sizeof(int) + num * (s->unitSize + sizeof(int));
if (stream == NULL) if (stream == NULL)
......
...@@ -30,7 +30,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -30,7 +30,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* copy all elements from a source matrix to a target matrix */ /* copy all elements from a source matrix to a target matrix */
extern "C" extern "C"
bool CudaCopyValues(XTensor * s, XTensor * t, XStream * stream = NULL); bool CudaCopyValues(const XTensor * s, XTensor * t, XStream * stream = NULL);
#endif // USE_CUDA #endif // USE_CUDA
......
...@@ -73,7 +73,7 @@ bool TestSum1() ...@@ -73,7 +73,7 @@ bool TestSum1()
bGPU->SetData(bData, unitNum); bGPU->SetData(bData, unitNum);
/* call sum function */ /* call sum function */
_Sum(aGPU, bGPU); _Sum(aGPU, bGPU, aGPU);
/* check results */ /* check results */
gpuTest = aGPU->CheckData(answer, unitNum); gpuTest = aGPU->CheckData(answer, unitNum);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论