Commit a91f6809 by xiaotong

tensor connections for ReduceMax and ReduceSum

parent ea8c6aea
...@@ -29,6 +29,8 @@ ...@@ -29,6 +29,8 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_MATMUL "M_MATMUL" #define MATH_MATMUL "M_MATMUL"
#define MATH_REDUCEMAX "M_REDUCEMAX"
#define MATH_REDUCESUM "M_REDUCESUM"
#define MATH_SELECTRANGE "M_SELECTRANGE" #define MATH_SELECTRANGE "M_SELECTRANGE"
#define MATH_SORT "M_SORT" #define MATH_SORT "M_SORT"
#define MATH_SUM "M_SUM" #define MATH_SUM "M_SUM"
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
*/ */
#include "../XTensor.h" #include "../XTensor.h"
#include "../XName.h"
#include "ReduceMax.h" #include "ReduceMax.h"
#include "ReduceMax.cuh" #include "ReduceMax.cuh"
...@@ -52,6 +53,10 @@ void ReduceMax(XTensor * input, XTensor * output, int dim) ...@@ -52,6 +53,10 @@ void ReduceMax(XTensor * input, XTensor * output, int dim)
} }
} }
/* make tensor connections */
XLink::MakeLink(input, NULL, output, MATH_REDUCEMAX);
XLink::AddParamToHeadInt(output, dim);
if(input->devID >= 0){ if(input->devID >= 0){
#ifdef USE_CUDA #ifdef USE_CUDA
CudaReduceMax(input, output, dim); CudaReduceMax(input, output, dim);
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <math.h> #include <math.h>
#include "ReduceSum.h" #include "ReduceSum.h"
#include "ReduceSum.cuh" #include "ReduceSum.cuh"
#include "../XName.h"
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
...@@ -58,6 +59,12 @@ void ReduceSum(XTensor * input, XTensor * output, int dim, XTensor * shift, DTYP ...@@ -58,6 +59,12 @@ void ReduceSum(XTensor * input, XTensor * output, int dim, XTensor * shift, DTYP
} }
} }
/* make tensor connections */
XLink::MakeLink(input, shift, output, MATH_REDUCESUM);
XLink::AddParamToHeadInt(output, dim);
XLink::AddParamToHead(output, power);
XLink::AddParamToHeadInt(output, isExp);
if(input->devID >= 0){ if(input->devID >= 0){
#ifdef USE_CUDA #ifdef USE_CUDA
CudaReduceSum(input, output, dim, shift, power, isExp); CudaReduceSum(input, output, dim, shift, power, isExp);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论