Commit e083e9f2 by xiaotong

add SumFilled

parent e5a709dc
...@@ -89,6 +89,7 @@ int main( int argc, const char ** argv ) ...@@ -89,6 +89,7 @@ int main( int argc, const char ** argv )
void TransposeTest() void TransposeTest()
{ {
#ifdef USE_CUDA
XMem mem0(0, UNI_FREE, MILLION * 64, 1024, MILLION * 64); XMem mem0(0, UNI_FREE, MILLION * 64, 1024, MILLION * 64);
//XMem mem1(1, UNI_FREE, MILLION * 64, 1024, MILLION * 64); //XMem mem1(1, UNI_FREE, MILLION * 64, 1024, MILLION * 64);
XTensor x; XTensor x;
...@@ -143,4 +144,5 @@ void TransposeTest() ...@@ -143,4 +144,5 @@ void TransposeTest()
fprintf(stderr, "split:%f merge:%f\n", time1 - time0, time3 - time2); fprintf(stderr, "split:%f merge:%f\n", time1 - time0, time3 - time2);
fprintf(stderr, "split:%f merge:%f\n", elapsedSplit, elapsedMerge); fprintf(stderr, "split:%f merge:%f\n", elapsedSplit, elapsedMerge);
#endif
} }
...@@ -64,10 +64,11 @@ ...@@ -64,10 +64,11 @@
#include "arithmetic/Sum.h" #include "arithmetic/Sum.h"
#include "arithmetic/SumByColumnTV.h" #include "arithmetic/SumByColumnTV.h"
#include "arithmetic/SumByColumnVT.h" #include "arithmetic/SumByColumnVT.h"
#include "arithmetic/SumFilled.h"
#include "sort/TopK.h" #include "sort/TopK.h"
#include "shape/Transpose.h" #include "shape/Transpose.h"
#include "shape/Unsqueeze.h" #include "shape/Unsqueeze.h"
#include "utilities/XMatrixSegment.h" #include "utilities/XMatrixSegment.h"
#include "arithmetic/XTensorBLAS.h" #include "arithmetic/XTensorBLAS.h"
#endif // __CHEADER_H__ #endif // __CHEADER_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-28
*/
#include "SumFilled.h"
#include "SumFilled.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-28
*/
#include "SumFilled.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-28
*/
#ifndef __SUMFILLED_CUH__
#define __SUMFILLED_CUH__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* tensor summation c = a + b * \beta where each dimension of b is equal to that of a or has
a value of 1, i.e., a is summed with b by broadcasting */
void _CudaSumFilled(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __SUMFILLED_CUH__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-28
*/
#ifndef __SUMFILLED_H__
#define __SUMFILLED_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* tensor summation c = a + b * \beta where each dimension of b is equal to that of a or has
a value of 1, i.e., a is summed with b by broadcasting */
void _SumFilled(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
/* tensor summation c = a + b * \beta where each dimension of b is equal to that of a or has
a value of 1, i.e., a is summed with b by broadcasting
keep the result in the input tensor a and return nothing */
void _SumFilledMe(XTensor * a, const XTensor * b, DTYPE beta = (DTYPE)1.0);
/* tensor summation c = a + b * \beta where each dimension of b is equal to that of a or has
a value of 1, i.e., a is summed with b by broadcasting
make a new tensor c to keep the result and return it */
XTensor SumFilled(const XTensor &a, const XTensor &b, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __SUMFILLED_H__
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论