Commit a91f6809 by xiaotong

tensor connections for ReduceMax and ReduceSum

parent ea8c6aea
...@@ -29,6 +29,8 @@ ...@@ -29,6 +29,8 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_MATMUL "M_MATMUL" #define MATH_MATMUL "M_MATMUL"
#define MATH_REDUCEMAX "M_REDUCEMAX"
#define MATH_REDUCESUM "M_REDUCESUM"
#define MATH_SELECTRANGE "M_SELECTRANGE" #define MATH_SELECTRANGE "M_SELECTRANGE"
#define MATH_SORT "M_SORT" #define MATH_SORT "M_SORT"
#define MATH_SUM "M_SUM" #define MATH_SUM "M_SUM"
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
*/ */
#include "../XTensor.h" #include "../XTensor.h"
#include "../XName.h"
#include "ReduceMax.h" #include "ReduceMax.h"
#include "ReduceMax.cuh" #include "ReduceMax.cuh"
...@@ -34,7 +35,7 @@ get the max value of the items along a dimension of the tensor. ...@@ -34,7 +35,7 @@ get the max value of the items along a dimension of the tensor.
void ReduceMax(XTensor * input, XTensor * output, int dim) void ReduceMax(XTensor * input, XTensor * output, int dim)
{ {
CheckNTErrors((input->devID == output->devID || (input->devID < 0 && output->devID < 0)), CheckNTErrors((input->devID == output->devID || (input->devID < 0 && output->devID < 0)),
"This code must be run on the same device!"); "This code must be run on the same device!");
CheckNTErrors((input && output), "Empty input or output tensors!"); CheckNTErrors((input && output), "Empty input or output tensors!");
CheckNTErrors((input->order == output->order + 1), "Incorrect tensor sizes!"); CheckNTErrors((input->order == output->order + 1), "Incorrect tensor sizes!");
CheckNTErrors((input->order > dim && dim >=0), "Illegal dimension to reduce!"); CheckNTErrors((input->order > dim && dim >=0), "Illegal dimension to reduce!");
...@@ -44,14 +45,18 @@ void ReduceMax(XTensor * input, XTensor * output, int dim) ...@@ -44,14 +45,18 @@ void ReduceMax(XTensor * input, XTensor * output, int dim)
for(int i = 0; i < input->order; i++){ for(int i = 0; i < input->order; i++){
if(i < dimRDI){ if(i < dimRDI){
CheckNTErrors((input->dimSizeRDI[i] == output->dimSizeRDI[i]), CheckNTErrors((input->dimSizeRDI[i] == output->dimSizeRDI[i]),
"Unmatched tensors!"); "Unmatched tensors!");
} }
else if(i > dimRDI){ else if(i > dimRDI){
CheckNTErrors((input->dimSizeRDI[i] == output->dimSizeRDI[i - 1]), CheckNTErrors((input->dimSizeRDI[i] == output->dimSizeRDI[i - 1]),
"Unmatched tensors!"); "Unmatched tensors!");
} }
} }
/* make tensor connections */
XLink::MakeLink(input, NULL, output, MATH_REDUCEMAX);
XLink::AddParamToHeadInt(output, dim);
if(input->devID >= 0){ if(input->devID >= 0){
#ifdef USE_CUDA #ifdef USE_CUDA
CudaReduceMax(input, output, dim); CudaReduceMax(input, output, dim);
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <math.h> #include <math.h>
#include "ReduceSum.h" #include "ReduceSum.h"
#include "ReduceSum.cuh" #include "ReduceSum.cuh"
#include "../XName.h"
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
...@@ -58,6 +59,12 @@ void ReduceSum(XTensor * input, XTensor * output, int dim, XTensor * shift, DTYP ...@@ -58,6 +59,12 @@ void ReduceSum(XTensor * input, XTensor * output, int dim, XTensor * shift, DTYP
} }
} }
/* make tensor connections */
XLink::MakeLink(input, shift, output, MATH_REDUCESUM);
XLink::AddParamToHeadInt(output, dim);
XLink::AddParamToHead(output, power);
XLink::AddParamToHeadInt(output, isExp);
if(input->devID >= 0){ if(input->devID >= 0){
#ifdef USE_CUDA #ifdef USE_CUDA
CudaReduceSum(input, output, dim, shift, power, isExp); CudaReduceSum(input, output, dim, shift, power, isExp);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论