Commit e9f4a75b by xiaotong

bug fixes of binary math functions

parent 97338baf
......@@ -129,6 +129,7 @@ void T2TSearch::Generate(T2TStateBundle * beam)
int sizePredict = score.GetDim(-1);
/* pre id !!! */
Descale(preID, sizePredict);
/* mod !!! */
......
......@@ -54,11 +54,10 @@ void _funcName(const XTensor * a, XTensor * b, int num) \
/* run it on GPUs */ \
if (a->devID >= 0) { \
_cudaFuncName(a, b, num); \
b->Dump(stderr, "zxc"); \
return; \
} \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \
"Input tensors should have the same data type!"); \
CheckNTErrors((a->dataType == X_INT), "TODO!"); \
int * d = (int*)a->data; \
int * db = (int*)b->data; \
......@@ -66,63 +65,74 @@ void _funcName(const XTensor * a, XTensor * b, int num) \
db[i] = (int)origFunc(d[i], num); \
}
#define SIMPLE_BINARY_FUNCTION_ME(funcName, _funcName) \
void funcName(XTensor &a, int num) \
{ \
_funcName(&a, &a, num); \
}
#define SIMPLE_BINARY_FUNCTION(funcName, _funcName) \
XTensor funcName(const XTensor &a, int num) \
void funcName(const XTensor &a, int num) \
{ \
XTensor b(&a); \
b.SetTMPFlag(); \
_funcName(&a, &b, num); \
b.Dump(stderr, "asd"); \
return b; \
}
_SIMPLE_BINARY_FUNCTION(_Scale, _CudaScale, scale)
SIMPLE_BINARY_FUNCTION_ME(Scale, _Scale)
SIMPLE_BINARY_FUNCTION(Scale, _Scale)
_SIMPLE_BINARY_FUNCTION(_DeScale, _CudaDeScale, descale)
SIMPLE_BINARY_FUNCTION(DeScale, _DeScale)
_SIMPLE_BINARY_FUNCTION(_Descale, _CudaDescale, descale)
SIMPLE_BINARY_FUNCTION_ME(Descale, _Descale)
SIMPLE_BINARY_FUNCTION(Descale, _Descale)
_SIMPLE_BINARY_FUNCTION(_Shift, _CudaShift, shift)
SIMPLE_BINARY_FUNCTION_ME(Shift, _Shift)
SIMPLE_BINARY_FUNCTION(Shift, _Shift)
_SIMPLE_BINARY_FUNCTION(_Mod, _CudaMod, mod)
SIMPLE_BINARY_FUNCTION_ME(Mod, _Mod)
SIMPLE_BINARY_FUNCTION(Mod, _Mod)
#else
/* define three marco separately, specify the respective function names (CPU mode) */
#define _SIMPLE_BINARY_FUNCTION(_funcName, _cudaFuncName, origFunc) \
#define _SIMPLE_BINARY_FUNCTION(_funcName, origFunc) \
void _funcName(const XTensor * a, XTensor * b, int num) \
{ \
/* run it on GPUs */ \
if (a->devID >= 0) { \
_cudaFuncName(a, b, num); \
return; \
} \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!"); \
"Input tensors should have the same data type!"); \
CheckNTErrors((a->dataType == X_INT), "TODO!"); \
int * d = (int*)a->data; \
int * db = (int*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (int)origFunc(d[i], num); \
}
#define SIMPLE_BINARY_FUNCTION_ME(funcName, _funcName) \
void funcName(XTensor & a, int num) \
{ \
_funcName(&a, &a, num); \
}
#define SIMPLE_BINARY_FUNCTION(funcName, _funcName) \
void funcName(const XTensor & a, XTensor &b, int num) \
{ \
_funcName(&a, &b, num); \
}
_SIMPLE_BINARY_FUNCTION(_Scale, _CudaScale, scale)
_SIMPLE_BINARY_FUNCTION(_Scale, scale)
SIMPLE_BINARY_FUNCTION_ME(Scale, _Scale)
SIMPLE_BINARY_FUNCTION(Scale, _Scale)
_SIMPLE_BINARY_FUNCTION(_DeScale, _CudaDeScale, descale)
SIMPLE_BINARY_FUNCTION(DeScale, _DeScale)
_SIMPLE_BINARY_FUNCTION(_Descale, descale)
SIMPLE_BINARY_FUNCTION_ME(Descale, _Descale)
SIMPLE_BINARY_FUNCTION(Descale, _Descale)
_SIMPLE_BINARY_FUNCTION(_Shift, _CudaShift, shift)
_SIMPLE_BINARY_FUNCTION(_Shift, shift)
SIMPLE_BINARY_FUNCTION_ME(Shift, _Shift)
SIMPLE_BINARY_FUNCTION(Shift, _Shift)
_SIMPLE_BINARY_FUNCTION(_Mod, _CudaMod, mod)
_SIMPLE_BINARY_FUNCTION(_Mod, mod)
SIMPLE_BINARY_FUNCTION_ME(Mod, _Mod)
SIMPLE_BINARY_FUNCTION(Mod, _Mod)
#endif
......
......@@ -93,7 +93,7 @@ void _Cuda##funcName(const XTensor * a, XTensor * b, int num) \
} \
SIMPLE_BINARY_FUNCTION_GPU(Scale, cudascale)
SIMPLE_BINARY_FUNCTION_GPU(DeScale, cudadescale)
SIMPLE_BINARY_FUNCTION_GPU(Descale, cudadescale)
SIMPLE_BINARY_FUNCTION_GPU(Shift, cudashift)
SIMPLE_BINARY_FUNCTION_GPU(Mod, cudamod)
......
......@@ -37,9 +37,9 @@ void _CudaScale(const XTensor * a, XTensor * b, int num);
/* descale each entry (CUDA Kernel) */
__global__
void KernelDeScale(int * a, int * b, int size, int num);
void KernelDescale(int * a, int * b, int size, int num);
/* descale each entry */
void _CudaDeScale(const XTensor * a, XTensor * b, int num);
void _CudaDescale(const XTensor * a, XTensor * b, int num);
/* shift each entry (CUDA Kernel) */
__global__
......
......@@ -30,14 +30,19 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
scale all tensor entires
b = a * scale
*/
void _Scale(const XTensor * a, XTensor * b, int num);
void _Scale(const XTensor * a, XTensor * b, int scale);
/*
scale tensor entires (on site)
b = a * scale
*/
void Scale(XTensor & a, int scale);
/*
scale tensor entires
make a new tensor to keep the result and return it
b = a * scale
*/
XTensor Scale(const XTensor & a, int num);
void Scale(const XTensor & a, XTensor &b, int scale);
//void Scale(const XTensor & a, XTensor & b, int num);
......@@ -45,40 +50,55 @@ XTensor Scale(const XTensor & a, int num);
descale tensor entires
b = a / scale
*/
void _DeScale(const XTensor * a, XTensor * b, int num);
void _Descale(const XTensor * a, XTensor * b, int scale);
/*
descale tensor entires (on site)
b = a / scale
*/
void Descale(XTensor & a, int scale);
/*
descale tensor entires
make a new tensor to keep the result and return it
b = a / scale
*/
XTensor DeScale(const XTensor & a, int num);
void Descale(const XTensor & a, XTensor & b, int scale);
/*
shift tensor entires
b = a + shift
*/
void _Shift(const XTensor * a, XTensor * b, int num);
void _Shift(const XTensor * a, XTensor * b, int shift);
/*
shift tensor entires (on site)
b = a + shift
*/
void Shift(XTensor & a, int shift);
/*
shift tensor entires
make a new tensor to keep the result and return it
b = a + shift
*/
XTensor Shift(const XTensor & a, int num);
void Shift(const XTensor & a, XTensor & b, int shift);
/*
mod tensor entires
b = a % mod
*/
void _Mod(const XTensor * a, XTensor * b, int num);
void _Mod(const XTensor * a, XTensor * b, int base);
/*
mod tensor entires (on site)
b = a % mod
*/
void Mod(XTensor & a, int base);
/*
mod tensor entires
make a new tensor to keep the result and return it
b = a % mod
*/
XTensor Mod(const XTensor & a, int num);
void Mod(const XTensor & a, XTensor & b, int base);
} // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论