Commit a91f6809 by xiaotong

tensor connections for ReduceMax and ReduceSum

parent ea8c6aea
......@@ -29,6 +29,8 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_MATMUL "M_MATMUL"
#define MATH_REDUCEMAX "M_REDUCEMAX"
#define MATH_REDUCESUM "M_REDUCESUM"
#define MATH_SELECTRANGE "M_SELECTRANGE"
#define MATH_SORT "M_SORT"
#define MATH_SUM "M_SUM"
......
......@@ -20,6 +20,7 @@
*/
#include "../XTensor.h"
#include "../XName.h"
#include "ReduceMax.h"
#include "ReduceMax.cuh"
......@@ -34,7 +35,7 @@ get the max value of the items along a dimension of the tensor.
void ReduceMax(XTensor * input, XTensor * output, int dim)
{
CheckNTErrors((input->devID == output->devID || (input->devID < 0 && output->devID < 0)),
"This code must be run on the same device!");
"This code must be run on the same device!");
CheckNTErrors((input && output), "Empty input or output tensors!");
CheckNTErrors((input->order == output->order + 1), "Incorrect tensor sizes!");
CheckNTErrors((input->order > dim && dim >=0), "Illegal dimension to reduce!");
......@@ -44,14 +45,18 @@ void ReduceMax(XTensor * input, XTensor * output, int dim)
for(int i = 0; i < input->order; i++){
if(i < dimRDI){
CheckNTErrors((input->dimSizeRDI[i] == output->dimSizeRDI[i]),
"Unmatched tensors!");
"Unmatched tensors!");
}
else if(i > dimRDI){
CheckNTErrors((input->dimSizeRDI[i] == output->dimSizeRDI[i - 1]),
"Unmatched tensors!");
"Unmatched tensors!");
}
}
/* make tensor connections */
XLink::MakeLink(input, NULL, output, MATH_REDUCEMAX);
XLink::AddParamToHeadInt(output, dim);
if(input->devID >= 0){
#ifdef USE_CUDA
CudaReduceMax(input, output, dim);
......
......@@ -22,6 +22,7 @@
#include <math.h>
#include "ReduceSum.h"
#include "ReduceSum.cuh"
#include "../XName.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -58,6 +59,12 @@ void ReduceSum(XTensor * input, XTensor * output, int dim, XTensor * shift, DTYP
}
}
/* make tensor connections */
XLink::MakeLink(input, shift, output, MATH_REDUCESUM);
XLink::AddParamToHeadInt(output, dim);
XLink::AddParamToHead(output, power);
XLink::AddParamToHeadInt(output, isExp);
if(input->devID >= 0){
#ifdef USE_CUDA
CudaReduceSum(input, output, dim, shift, power, isExp);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论