Commit a91f6809 by xiaotong

tensor connections for ReduceMax and ReduceSum

parent ea8c6aea
......@@ -29,6 +29,8 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_MATMUL "M_MATMUL"
#define MATH_REDUCEMAX "M_REDUCEMAX"
#define MATH_REDUCESUM "M_REDUCESUM"
#define MATH_SELECTRANGE "M_SELECTRANGE"
#define MATH_SORT "M_SORT"
#define MATH_SUM "M_SUM"
......
......@@ -20,6 +20,7 @@
*/
#include "../XTensor.h"
#include "../XName.h"
#include "ReduceMax.h"
#include "ReduceMax.cuh"
......@@ -52,6 +53,10 @@ void ReduceMax(XTensor * input, XTensor * output, int dim)
}
}
/* make tensor connections */
XLink::MakeLink(input, NULL, output, MATH_REDUCEMAX);
XLink::AddParamToHeadInt(output, dim);
if(input->devID >= 0){
#ifdef USE_CUDA
CudaReduceMax(input, output, dim);
......
......@@ -22,6 +22,7 @@
#include <math.h>
#include "ReduceSum.h"
#include "ReduceSum.cuh"
#include "../XName.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -58,6 +59,12 @@ void ReduceSum(XTensor * input, XTensor * output, int dim, XTensor * shift, DTYP
}
}
/* make tensor connections */
XLink::MakeLink(input, shift, output, MATH_REDUCESUM);
XLink::AddParamToHeadInt(output, dim);
XLink::AddParamToHead(output, power);
XLink::AddParamToHeadInt(output, isExp);
if(input->devID >= 0){
#ifdef USE_CUDA
CudaReduceSum(input, output, dim, shift, power, isExp);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论