Commit e9611d9c by xuchen

Merge branch 'xuchen' into xiaotong-working

parents 30c3a629 f7f33b29
......@@ -42,7 +42,7 @@ NiuTrans.Tensor是小牛开源项目所开发的一个工具包,提供了完
## 开发团队
NiuTrans.Tensor张量计算库由东北大学自然语言处理实验室、小牛翻译、小牛雅智合作开发,致力于为深度学习相关研究及工业系统的开发提供完整的张量定义及计算功能。
NiuTrans.Tensor张量计算库由小牛团队开发,成员来自东北大学自然语言处理实验室、小牛翻译、小牛雅智,致力于为深度学习相关研究及工业系统的开发提供完整的张量定义及计算功能。
## 更新版本
......
......@@ -2,7 +2,7 @@
## 注意事项
CUDA最新版本9.2尚且不支持VS2017最新版本,因此建议使用CUDA版本为9.0或9.1,建议使用VS版本为VS2015,或使用VS2017时安装v140工具集。
CUDA最新版本9.2尚且不支持VS2017最新版本,因此建议使用CUDA版本为9.0或9.1,建议使用VS版本为VS2015,或使用VS2017时安装v140工具集,解决方案平台设置为×64
## CUDA配置
......@@ -29,7 +29,7 @@ CUDA最新版本9.2尚且不支持VS2017最新版本,因此建议使用CUDA版
**C/C++->预处理器->预处理器定义** 中,添加
>USE_CUDA;USE_BLAS;WIN32;MKL;DEBUG;CRT_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS_
>USE_CUDA;USE_BLAS;WIN32;MKL;_DEBUG;_CRT_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS_
CONSOLE;
**链接器->系统->子系统**,设置为控制台。
......
......@@ -24,6 +24,7 @@
#include "../tensor/XUtility.h"
#include "../tensor/function/FHeader.h"
#include "../tensor/core/CHeader.h"
#include "../tensor/test/Test.h"
#include "../sample/fnnlm/FNNLM.h"
#include "../sample/transformer/Transformer.h"
......@@ -31,18 +32,24 @@
//#include <stdlib.h>
//#include <crtdbg.h>
void BackwardTest();
void TransposeTest();
void SumDimTest();
using namespace nts;
using namespace fnnlm;
using namespace transformer;
using namespace GAN;
int main( int argc, const char ** argv )
{
//_CrtSetBreakAlloc(896);
//BackwardTest();
//return 0;
if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
if(argc > 1 && !strcmp(argv[1], "-test"))
Test();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
FNNLMMain(argc - 1, argv + 1);
else if(argc > 1 && !strcmp(argv[1], "-t2t"))
TransformerMain(argc - 1, argv + 1);
......@@ -58,6 +65,41 @@ int main( int argc, const char ** argv )
return 0;
}
void BackwardTest()
{
XNet net;
XTensor a;
XTensor b;
XTensor c;
XTensor mean;
XTensor origin;
InitTensor2D(&a, 2, 3);
InitTensor1D(&b, 2);
a.SetZeroAll();
b.SetZeroAll();
a.Set2D(1.0F, 0, 0);
a.Set2D(2.0F, 0, 1);
a.Set2D(3.0F, 0, 2);
a.Set2D(4.0F, 1, 0);
a.Set2D(5.0F, 1, 1);
a.Set2D(6.0F, 1, 2);
b.Set1D(2.0F, 0);
b.Set1D(1.0F, 1);
c = DivDim(a, b, 0);
c.Dump(stderr, "c:");
XLink::ShowNetwork(stderr, &c);
net.Backward(c);
net.Dump(stderr);
}
void TransposeTest()
{
#ifdef USE_CUDA
......
......@@ -22,7 +22,11 @@
#include "XBackwardLoss.h"
#include "../tensor/XName.h"
#include "../tensor/function/HardTanH.h"
#include "../tensor/function/Identity.h"
#include "../tensor/function/LogSoftmax.h"
#include "../tensor/function/Rectify.h"
#include "../tensor/function/Sigmoid.h"
#include "../tensor/function/Softmax.h"
namespace nts{
......@@ -49,10 +53,22 @@ void XLossGrad::Compute(XTensor * gold, XTensor * y, XTensor * x,
if(funcID == FUNC_HARDTANH){
_HardTanHBackward(gold, y, x, dedy, dedx, lossName);
}
else if(funcID == FUNC_IDENTITY){
_IdentityBackward(gold, y, x, dedy, dedx, lossName);
}
else if(funcID == FUNC_LOGSOFTMAX){
int leadDim = *(int*)params;
_LogSoftmaxBackward(gold, y, x, dedy, dedx, leadDim, lossName);
}
else if(funcID == FUNC_RECTIFY){
_RectifyBackward(gold, y, x, dedy, dedx, lossName);
}
else if(funcID == FUNC_SIGMOID){
_SigmoidBackward(gold, y, x, dedy, dedx, lossName);
}else if(funcID == FUNC_SOFTMAX){
int leadDim = *(int*)params;
_SoftmaxBackward(gold, y, x, dedy, dedx, leadDim, lossName);
}
else{
ShowNTErrors("wrong function found when call the backward process!");
}
......
......@@ -40,18 +40,50 @@ public:
bool IsMathOP(XTensor * node);
private:
/* gradient for sum: c = a + b * \beta */
/* gradient for absolute */
static
void GradSum(XTensor * node);
void GradAbsolute(XTensor * node);
/* gradient for cos */
static
void GradCos(XTensor * node);
/* gradient for exp */
static
void GradExp(XTensor * node);
/* gradient for sum with one dimension: c = a + b * \beta
where the size of b is equal to that of one dimension of a */
/* gradient for log: c = log(a) */
static
void GradSumDim(XTensor * node);
void GradLog(XTensor * node);
/* gradient for round */
static
void GradRound(XTensor * node);
/* gradient for sign */
static
void GradSign(XTensor * node);
/* gradient for multiply (dot production): c = a * b * \alpha */
/* gradient for sin */
static
void GradMultiply(XTensor * node);
void GradSin(XTensor * node);
/* gradient for tan */
static
void GradTan(XTensor * node);
/* gradient for clip */
static
void GradClip(XTensor * node);
/* gradient for Divide */
static
void GradDiv(XTensor * node);
/* gradient for DivideDim */
static
void GradDivDim(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static
......@@ -68,17 +100,26 @@ private:
static
void GradMatrixMulBatched(XTensor * node);
/* gradient for log: c = log(a) */
/* gradient for multiply (dot production): c = a * b * \alpha */
static
void GradLog(XTensor * node);
void GradMultiply(XTensor * node);
/* gradient for power */
/* gradient for multiply one dimension: c = a * b * \alpha
where the size of b is equal to that of one dimension of a */
static
void GradPower(XTensor * node);
void GradMultiplyDim(XTensor * node);
/* gradient for negate */
static
void GradNegate(XTensor * node);
/* gradient for normalize */
static
void GradNormalize(XTensor * node);
/* gradient for power */
static
void GradPower(XTensor * node);
/* gradient for ScaleAndShift */
static
......@@ -87,10 +128,20 @@ private:
/* gradient for Minus */
static
void GradSub(XTensor * node);
/* gradient for sub with one dimension: c = a - b * \beta
where the size of b is equal to that of one dimension of a */
static
void GradSubDim(XTensor * node);
/* gradient for Divide */
/* gradient for sum: c = a + b * \beta */
static
void GradDiv(XTensor * node);
void GradSum(XTensor * node);
/* gradient for sum with one dimension: c = a + b * \beta
where the size of b is equal to that of one dimension of a */
static
void GradSumDim(XTensor * node);
/* gradient for reduceMean */
static
......@@ -107,42 +158,6 @@ private:
/* gradient for reduceVariance */
static
void GradReduceVariance(XTensor * node);
/* gradient for sin */
static
void GradSin(XTensor * node);
/* gradient for cos */
static
void GradCos(XTensor * node);
/* gradient for tan */
static
void GradTan(XTensor * node);
/* gradient for exp */
static
void GradExp(XTensor * node);
/* gradient for normalize */
static
void GradNormalize(XTensor * node);
/* gradient for absolute */
static
void GradAbsolute(XTensor * node);
/* gradient for sign */
static
void GradSign(XTensor * node);
/* gradient for clip */
static
void GradClip(XTensor * node);
/* gradient for round */
static
void GradRound(XTensor * node);
};
}
......
......@@ -99,7 +99,7 @@ arguments:
(how many words)
-shuffle: shuffle the training data
-devid D: the id of the device used
-1: GPU, >=0: GPUs
-1: CPU, >=0: GPUs
-mempool: use memory pools for memory management
-autodiff: use automatic differentiation for training
......
......@@ -45,7 +45,7 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(123);
/* a tiny test */
SmallTest();
//SmallTest();
//_CrtDumpMemoryLeaks();
//return 0;
......
......@@ -43,7 +43,7 @@
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
#define _XINLINE_ inline
#define _XINLINE_
//#define DOUBELPRICSION
......
......@@ -45,12 +45,16 @@ const char * GetOPName(int type)
return "M_CLIP";
else if (type == MATH_DIV)
return "M_DIV";
else if (type == MATH_DIVDIM)
return "M_DIVDIM";
else if (type == MATH_MATRIXMUL)
return "M_MATRIXMUL";
else if (type == MATH_MATRIXMULBATCHED)
return "M_MATRIXMULBATCHED";
else if (type == MATH_MULTIPLY)
return "M_MULTIPLY";
else if (type == MATH_MULTIPLYDIM)
return "M_MULTIPLYDIM";
else if (type == MATH_NEGATE)
return "M_NEGATE";
else if (type == MATH_NORMALIZE)
......@@ -61,10 +65,12 @@ const char * GetOPName(int type)
return "M_SCALEANDSHIFT";
else if (type == MATH_SIGN)
return "M_SIGN";
else if (type == MATH_SUM)
return "M_SUM";
else if (type == MATH_SUB)
return "M_SUB";
else if (type == MATH_SUBDIM)
return "M_SUBDIM";
else if (type == MATH_SUM)
return "M_SUM";
else if (type == MATH_SUMDIM)
return "M_SUMDIM";
else if (type == REDUCE_REDUCEMAX)
......
......@@ -41,17 +41,20 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_CLIP MATH_ROUND + 1
#define MATH_DIV MATH_CLIP + 1
#define MATH_MATRIXMUL MATH_DIV + 1
#define MATH_DIVDIM MATH_DIV + 1
#define MATH_MATRIXMUL MATH_DIVDIM + 1
#define MATH_MATRIXMULBATCHED MATH_MATRIXMUL + 1
#define MATH_MULTIPLY MATH_MATRIXMULBATCHED + 1
#define MATH_NEGATE MATH_MULTIPLY + 1
#define MATH_MULTIPLYDIM MATH_MULTIPLY + 1
#define MATH_NEGATE MATH_MULTIPLYDIM + 1
#define MATH_NORMALIZE MATH_NEGATE + 1
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
#define MATH_SIGN MATH_SCALEANDSHIFT + 1
#define MATH_SUM MATH_SIGN + 1
#define MATH_SUB MATH_SUM + 1
#define MATH_SUMDIM MATH_SUB + 1
#define MATH_SUB MATH_SIGN + 1
#define MATH_SUBDIM MATH_SUB + 1
#define MATH_SUM MATH_SUBDIM + 1
#define MATH_SUMDIM MATH_SUM + 1
#define REDUCE MATH_SUMDIM + 1
#define REDUCE_REDUCEMAX REDUCE + 1
......
......@@ -27,15 +27,18 @@
#include "../XTensor.h"
#include "arithmetic/Div.h"
#include "arithmetic/DivDim.h"
#include "arithmetic/MatrixMul.h"
#include "arithmetic/MatrixMul2D.h"
#include "arithmetic/MatrixMul2DMultiTheading.h"
#include "arithmetic/MatrixMul2DParallel.h"
#include "arithmetic/MatrixMulBatched.h"
#include "arithmetic/Multiply.h"
#include "arithmetic/MultiplyDim.h"
#include "arithmetic/Negate.h"
#include "arithmetic/Sign.h"
#include "arithmetic/Sub.h"
#include "arithmetic/SubDim.h"
#include "arithmetic/Sum.h"
#include "arithmetic/SumByColumnTV.h"
#include "arithmetic/SumByColumnVT.h"
......
......@@ -23,6 +23,7 @@
#include "../../XName.h"
#include "Div.h"
#include "Div.cuh"
#include "DivDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -137,6 +138,33 @@ void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim)
_Div(a, b, a, alpha, leadingDim);
}
/*
return a dimension if the division is performed as DivDim (in more details in DivDim.h)
>> a - a tensor
>> b - another tensor for division
*/
int GetDivDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/*
element-wise division of two tensors (return a XTensor structure)
make a new tensor c to keep the result and return it
......@@ -146,23 +174,41 @@ where i is the index of the item
>> a - tensor a
>> b - tensor b
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors
*/
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim)
XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
{
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XTensor c(&a);
c.SetTMP();
int n = GetDivDimIndex(a, b);
if(n == -1){
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* call _Div function */
_Div(&a, &b, &c, alpha, leadingDim);
/* call _Multiply function */
_Div(&a, &b, &c, 0, leadingDim);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHeadInt(&c, leadingDim);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
else if(n >= 0 && n < a.order){
/* call _DivDim function */
_DivDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadInt(&c, alpha);
}
else{
ShowNTErrors("Something is wrong!");
}
return c;
}
......
......@@ -31,7 +31,7 @@ element-wise division of two tensors:
c(i) = a(i)/b(i) + \alpha * c(i)
where i is the index of the element
*/
void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0.0, int leadingDim = 0);
/*
element-wise division of two tensors (do it on site)
......@@ -39,7 +39,7 @@ keep the result in the input tensor a and return nothing
a(i) = a(i)/b(i) + \alpha * a(i)
where i is the index of the element
*/
void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha = 0, int leadingDim = 0);
void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha = 0.0, int leadingDim = 0);
/*
element-wise division of two tensors (return a XTensor structure)
......@@ -47,7 +47,7 @@ make a new tensor to keep the result and return it
c(i) = a(i)/b(i)
where i is the index of the element
*/
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim = 0);
XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha = 0.0, int leadingDim = 0);
} // namespace nts(NiuTrans.Tensor)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-15
*/
#include "Div.h"
#include "DivDim.h"
#include "DivDim.cuh"
#include "../../XName.h"
#include "../movement/CopyValues.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor division
c = a / b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put result. we save it in a if c is NULL
>> n - the dimension index
>> alpha - the scaling factor
*/
void _DivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in division!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in addition!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in division!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
if(XTensor::IsSameShaped(a, b)){
_Div(a, b, c, alpha);
return;
}
if(a->devID >= 0 || b->devID >= 0 || c->devID >= 0){
#ifdef USE_CUDA
_CudaDivDim(a, b, c, n, alpha);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
}
else{
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for(int i = a->order - 1; i >= 0; i--){
if(i > n)
stride *= a->dimSize[i];
else if(i < n)
blockNum *= a->dimSize[i];
}
if (a->dataType == DEFAULT_DTYPE){
int num = a->unitNum;
if(stride > 1){
for(int i = 0, j = 0; i < num; i += stride, j++){
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE bv = *((DTYPE*)b->data + j % blockSize);
DTYPE * cp = (DTYPE*)c->data + i;
for(int k = 0; k < stride; k++){
if(alpha == 0.0F)
cp[k] = ap[k] / bv;
else
cp[k] = ap[k] / bv + alpha * cp[k];
}
}
}
else if(stride == 1){
DTYPE * bp = (DTYPE*)b->data;
for(int i = 0; i < num; i += blockSize){
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE * cp = (DTYPE*)c->data + i;
if(alpha == 0.0F){
for(int j = 0; j < blockSize; j++)
cp[j] = ap[j] / bp[j];
}
else{
for(int j = 0; j < blockSize; j++)
cp[j] = ap[j] / bp[j] + alpha * cp[j];
}
}
}
else{
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
}
}
/*
tensor division of two tensors (do it on site)
keep the result in the input tensor and return nothing
a = a/b + \alpha * a
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> alpha - the scaling factor
*/
void _DivDim(XTensor * a, const XTensor * b, int n, DTYPE alpha)
{
_DivDim(a, b, a, n, alpha);
}
/*
tensor division of two tensors (return a XTensor structure and make tensor connections)
make a new tensor to keep the result and return it
c = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> alpha - the scaling factor
<< return - the result tensor by tensor division
*/
XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
{
XTensor c(&a);
c.SetTMP();
/* call _Div function */
_DivDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
return c;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-15
*/
#include "DivDim.cuh"
#include "../../XDevice.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
tensor division of a tensor and a row vector
c = a / b + alpha * c
where a is a tensor and b is a row vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c
>> colNum - number of columns of a and c (i.e., the size of b)
>> alpha - the scaling factor
*/
template <class T, bool alphaFired>
__global__
void KernelDivWithRow(T * a, T * b, T * c, int rowNum, int colNum, T alpha)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int col = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
if(col >= colNum || row >= rowNum)
return;
if(threadIdx.y == 0)
bv[threadIdx.x] = b[col];
__syncthreads();
int offset = colNum * row + col;
if(alphaFired)
c[offset] = a[offset] / bv[threadIdx.x] + c[offset] * alpha;
else
c[offset] = a[offset] / bv[threadIdx.x];
}
/*
tensor division of a tensor and a colum vector
c = a / b + alpha * c
where a is a tensor and b is a colum vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c (i.e., the size of b)
>> colNum - number of columns of a and c
>> blockNum - size of a block (matrix), i.e., rowNum * colNum
>> blockNum - number of matrics
>> alpha - the scaling factor
*/
template <class T, bool alphaFired>
__global__
void KernelDivWithCol(T * a, T * b, T * c, int rowNum, int colNum, int blockSize, int blockNum, T alpha)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int colIndex = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
int col = colIndex % colNum;
int block = colIndex / colNum;
if(row >= rowNum || block >= blockNum)
return;
if(threadIdx.x == 0)
bv[threadIdx.y] = b[row];
__syncthreads();
int offset = block * blockSize + row * colNum + col;
if(alphaFired)
c[offset] = a[offset] / bv[threadIdx.y] + c[offset] * alpha;
else
c[offset] = a[offset] / bv[threadIdx.y];
}
/*
tensor division
c = a / b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a / b + \alpha * c. we save it in a if c is NULL
>> n - the dimension index
>> alpha - the scaling factor
*/
void _CudaDivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in division!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in division!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in division!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for(int i = a->order - 1; i >= 0; i--){
if(i > n)
stride *= a->dimSize[i];
else if(i < n)
blockNum *= a->dimSize[i];
}
int cudaGrids[3];
int cudaBlocks[3];
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE){
if(stride > 1){
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
if(alpha == (DTYPE)0.0F)
KernelDivWithCol<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, alpha);
else
KernelDivWithCol<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, alpha);
}
else if(stride == 1){
GDevs.GetCudaThread2D(a->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
if(alpha == (DTYPE)0.0F)
KernelDivWithRow<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, alpha);
else
KernelDivWithRow<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, alpha);
}
else{
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-15
*/
#ifndef __DIVDIM_CUH__
#define __DIVDIM_CUH__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
tensor division
c(i) = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting (cuda version)
*/
void _CudaDivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha = (DTYPE)0.0);
#endif
} // namespace nts(NiuTrans.Tensor)
#endif // __DIVDIM_CUH__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-15
*/
#ifndef __DIVDIM_H__
#define __DIVDIM_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor division of two tensors:
c(i) = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
*/
void _DivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha = (DTYPE)0.0);
/*
tensor division of two tensors:
c(i) = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
we keep the result in the input tensor a and return nothing
*/
void _DivDim(XTensor * a, const XTensor * b, int n, DTYPE alpha = (DTYPE)0.0);
/*
tensor division of two tensors:
c(i) = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
we make a new tensor c to keep the result and return it
*/
XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha = (DTYPE)0.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __DIVDIM_H__
......@@ -23,6 +23,7 @@
#include "../../XName.h"
#include "Multiply.h"
#include "Multiply.cuh"
#include "MultiplyDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -138,6 +139,33 @@ void _MultiplyMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim)
_Multiply(a, b, a, alpha, leadingDim);
}
/*
return a dimension if the multiplication is performed as MultiplyDim (in more details in MultiplyDim.h)
>> a - a tensor
>> b - another tensor for multiplication
*/
int GetMultiplyDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/*
element-wise product of two tensors (return a XTensor structure)
make a new tensor c to keep the result and return it
......@@ -150,20 +178,38 @@ where i is the index of the item
>> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors
*/
XTensor Multiply(const XTensor &a, const XTensor &b, int leadingDim)
XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
{
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XTensor c(&a);
c.SetTMP();
/* call _Multiply function */
_Multiply(&a, &b, &c, 0, leadingDim);
int n = GetMultiplyDimIndex(a, b);
if(n == -1){
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHeadInt(&c, leadingDim);
/* call _Multiply function */
_Multiply(&a, &b, &c, 0, leadingDim);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
else if(n >= 0 && n < a.order){
/* call _MultiplyDim function */
_MultiplyDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadInt(&c, alpha);
}
else{
ShowNTErrors("Something is wrong!");
}
return c;
}
......
......@@ -31,7 +31,7 @@ element-wise product of two tensors:
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the element
*/
void _Multiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
void _Multiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0.0, int leadingDim = 0);
/*
element-wise product of two tensors (do it on site)
......@@ -39,7 +39,7 @@ keep the result in the input tensor a and return nothing
a(i) = a(i)*b(i) + \alpha * a(i)
where i is the index of the element
*/
void _MultiplyMe(XTensor * a, const XTensor * b, DTYPE alpha = 0, int leadingDim = 0);
void _MultiplyMe(XTensor * a, const XTensor * b, DTYPE alpha = 0.0, int leadingDim = 0);
/*
element-wise product of two tensors (return a XTensor structure)
......@@ -47,7 +47,7 @@ make a new tensor to keep the result and return it
c(i) = a(i)*b(i)
where i is the index of the element
*/
XTensor Multiply(const XTensor &a, const XTensor &b, int leadingDim = 0);
XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha = 0.0, int leadingDim = 0);
} // namespace nts(NiuTrans.Tensor)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2018-08-14
*/
#include "Multiply.h"
#include "MultiplyDim.h"
#include "MultiplyDim.cuh"
#include "../../XName.h"
#include "../movement/CopyValues.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor multiplication
c = a * b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a * b + \alpha * c. we save it in a if c is NULL
>> n - the dimension index
>> alpha - the scaling factor
*/
void _MultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha) {
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in multiplication!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in multiplication!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in multiplication!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
if(XTensor::IsSameShaped(a, b)){
_Multiply(a, b, c, alpha);
return;
}
if(a->devID >= 0 || b->devID >= 0 || c->devID >= 0){
#ifdef USE_CUDA
_CudaMultiplyDim(a, b, c, n, alpha);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
}
else{
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for(int i = a->order - 1; i >= 0; i--){
if(i > n)
stride *= a->dimSize[i];
else if(i < n)
blockNum *= a->dimSize[i];
}
if(a->dataType == DEFAULT_DTYPE){
int num = a->unitNum;
if(stride > 1){
for(int i = 0, j = 0; i < num; i += stride, j++){
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE bv = *((DTYPE*)b->data + j % blockSize);
DTYPE * cp = (DTYPE*)c->data + i;
for(int k = 0; k < stride; k++)
if(alpha == 0.0F)
cp[k] = ap[k] * bv;
else
cp[k] = ap[k] * bv + alpha * cp[k];
}
}
else if(stride == 1){
DTYPE * bp = (DTYPE*)b->data;
for(int i = 0; i < num; i += blockSize){
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE * cp = (DTYPE*)c->data + i;
if(alpha == 0.0F){
for(int j = 0; j < blockSize; j++)
cp[j] = ap[j] * bp[j];
}
else{
for(int j = 0; j < blockSize; j++)
cp[j] = ap[j] * bp[j] + alpha * cp[j];
}
}
}
else{
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
}
}
/*
tensor multiplication(do it on site)
make a new tensor to keep the result and return it
c = a * b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> alpha - the scaling factor
*/
void _MultiplyDimMe(XTensor * a, const XTensor * b, int n, DTYPE alpha)
{
_MultiplyDim(a, b, a, n, alpha);
}
/*
tensor multiplication (return a XTensor structure and make tensor connections)
make a new tensor to keep the result and return it
c = a * b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> alpha - the scaling factor
<< return - the result tensor by tensor multiplication
*/
XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
{
XTensor c(&a);
c.SetTMP();
/* call _Multiply function */
_MultiplyDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
return c;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2018-08-14
*/
#include "../../XDevice.h"
#include "../../XUtility.h"
#include "MultiplyDim.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
tensor multiplication of a tensor and a row vector
c = a * b + \alpha * c
where a is a tensor and b is a row vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c
>> colNum - number of columns of a and c (i.e., the size of b)
>> alpha - the scaling factor
*/
template <class T, bool alphaFired>
__global__
void KernelMultiplyWithRow(T * a, T * b, T * c, int rowNum, int colNum, T alpha)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int col = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
if (col >= colNum || row >= rowNum)
return;
if (threadIdx.y == 0)
bv[threadIdx.x] = b[col];
__syncthreads();
int offset = colNum * row + col;
if (alphaFired)
c[offset] = a[offset] * bv[threadIdx.x] + c[offset] * alpha;
else
c[offset] = a[offset] * bv[threadIdx.x];
}
/*
tensor multiplication of a tensor and a colum vector
c = a * b + \alpha * c
where a is a tensor and b is a colum vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c (i.e., the size of b)
>> colNum - number of columns of a and c
>> blockNum - size of a block (matrix), i.e., rowNum * colNum
>> blockNum - number of matrics
>> alpha - the scaling factor
*/
template <class T, bool alphaFired>
__global__
void KernelMultiplyWithCol(T * a, T * b, T * c, int rowNum, int colNum, int blockSize, int blockNum, T alpha)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int colIndex = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
int col = colIndex % colNum;
int block = colIndex / colNum;
if (row >= rowNum || block >= blockNum)
return;
if (threadIdx.x == 0)
bv[threadIdx.y] = b[row];
__syncthreads();
int offset = block * blockSize + row * colNum + col;
if (alphaFired)
c[offset] = a[offset] * bv[threadIdx.y] + c[offset] * alpha;
else
c[offset] = a[offset] * bv[threadIdx.y];
}
/*
tensor multiplication
c = a * b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a * b + \alpha * c. we save it in a if c is NULL
>> n - the dimension index
>> alpha - the scaling factor
*/
void _CudaMultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in multiplication!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in multiplication!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in multiplication!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for (int i = a->order - 1; i >= 0; i--) {
if (i > n)
stride *= a->dimSize[i];
else if (i < n)
blockNum *= a->dimSize[i];
}
int cudaGrids[3];
int cudaBlocks[3];
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
if (stride > 1) {
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
if(alpha == (DTYPE)0.0F)
KernelMultiplyWithCol<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, alpha);
else
KernelMultiplyWithCol<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, alpha);
}
else if (stride == 1) {
GDevs.GetCudaThread2D(a->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
if(alpha == (DTYPE)0.0F)
KernelMultiplyWithRow<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, alpha);
else
KernelMultiplyWithRow<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, alpha);
}
else {
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2018-08-14
*/
#ifndef __MULTIPLYDIM_CUH__
#define __MULTIPLYDIM_CUH__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* tensor summation a * b + \alpha * c where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting (cuda version) */
void _CudaMultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha = 0);
#endif
} // namespace nts(NiuTrans.Tensor)
#endif // __MULTIPLYDIM_CUH__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2018-08-14
*/
#ifndef __MULTIPLYDIM_H__
#define __MULTIPLYDIM_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* tensor multiplication c = a * b + \alpha * c where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting */
void _MultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha = 0.0);
/* tensor multiplication a = a * b + \alpha * c where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting. we keep the result in the input tensor a and return nothing */
void _MultiplyDimMe(XTensor * a, const XTensor * b, int n, DTYPE alpha = 0.0);
/* tensor multiplication c = a * b + \alpha * c where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting. We make a new tensor c to keep the result and return it */
XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha = 0.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __MULTIPLYDIM_H__
......@@ -24,6 +24,7 @@
#include "../../XUtility.h"
#include "Sub.h"
#include "Sub.cuh"
#include "SubDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -123,7 +124,34 @@ void _SubMe(XTensor * a, const XTensor * b, DTYPE beta)
{
_Sub(a, b, a, beta);
}
/*
return a dimension if the subtraction is performed as SubDim (in more details in SubDim.h)
>> a - a tensor
>> b - another tensor for subtraction
*/
int GetSubDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/*
tensor subtraction c = a - b * \beta (return a XTensor structure)
make a new tensor c to keep the result and return it
......@@ -138,12 +166,28 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
XTensor c(&a);
c.SetTMP();
/* call _Sub function */
_Sub(&a, &b, &c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
int n = GetSubDimIndex(a, b);
if(n == -1){
/* call _Sub function */
_Sub(&a, &b, &c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
}
else if(n >= 0 && n < a.order){
/* call _SubDim function */
_SubDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
else{
ShowNTErrors("Something is wrong!");
}
return c;
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#include "Sub.h"
#include "SubDim.h"
#include "SubDim.cuh"
#include "../../XName.h"
#include "../movement/CopyValues.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor subtraction
c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a-b*\beta. we save it in a if c is NULL
>> n - the dimension index
>> beta - the scaling factor
*/
void _SubDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in subtraction!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in subtraction!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in subtraction!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
if (beta == 0) {
_CopyValues(a, c);
return;
}
if (XTensor::IsSameShaped(a, b)) {
_Sub(a, b, c, beta);
return;
}
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
_CudaSubDim(a, b, c, n, beta);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
}
else {
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for (int i = a->order - 1; i >= 0; i--) {
if (i > n)
stride *= a->dimSize[i];
else if (i < n)
blockNum *= a->dimSize[i];
}
if (a->dataType == DEFAULT_DTYPE) {
int num = a->unitNum;
if (stride > 1) {
for (int i = 0, j = 0; i < num; i += stride, j++) {
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE bv = *((DTYPE*)b->data + j % blockSize) * beta;
DTYPE * cp = (DTYPE*)c->data + i;
for (int k = 0; k < stride; k++)
cp[k] = ap[k] - bv;
}
}
else if (stride == 1) {
DTYPE * bp = (DTYPE*)b->data;
for (int i = 0; i < num; i += blockSize) {
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE * cp = (DTYPE*)c->data + i;
if (beta == 1.0F) {
for (int j = 0; j < blockSize; j++)
cp[j] = ap[j] - bp[j];
}
else {
for (int j = 0; j < blockSize; j++)
cp[j] = ap[j] - bp[j] * beta;
}
}
}
else {
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
}
}
/*
tensor subtraction (do it on site)
keep the result in the input tensor and return nothing
c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> beta - the scaling factor
*/
void _SubDim(XTensor * a, const XTensor * b, int n, DTYPE beta)
{
_SubDim(a, b, a, n, beta);
}
/*
tensor subtraction (return a XTensor structure and make tensor connections)
make a new tensor to keep the result and return it
c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> beta - the scaling factor
<< return - the result tensor by tensor subtraction
*/
XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
/* call _Sub function */
_SubDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
return c;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#include "SubDim.cuh"
#include "../../XDevice.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
tensor subtraction of a tensor and a row vector
c = a - b * \beta
where a is a tensor and b is a row vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c
>> colNum - number of columns of a and c (i.e., the size of b)
>> beta - the scaling factor
*/
template <class T, bool betaFired>
__global__
void KernelSubWithRow(T * a, T * b, T * c, int rowNum, int colNum, T beta)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int col = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
if (col >= colNum || row >= rowNum)
return;
if (threadIdx.y == 0)
bv[threadIdx.x] = b[col];
__syncthreads();
int offset = colNum * row + col;
if (betaFired)
c[offset] = a[offset] - bv[threadIdx.x] * beta;
else
c[offset] = a[offset] - bv[threadIdx.x];
}
/*
tensor subtraction of a tensor and a colum vector
c = a - b * \beta
where a is a tensor and b is a colum vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c (i.e., the size of b)
>> colNum - number of columns of a and c
>> blockNum - size of a block (matrix), i.e., rowNum * colNum
>> blockNum - number of matrics
>> beta - the scaling factor
*/
template <class T, bool betaFired>
__global__
void KernelSubWithCol(T * a, T * b, T * c, int rowNum, int colNum, int blockSize, int blockNum, T beta)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int colIndex = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
int col = colIndex % colNum;
int block = colIndex / colNum;
if (row >= rowNum || block >= blockNum)
return;
if (threadIdx.x == 0)
bv[threadIdx.y] = b[row];
__syncthreads();
int offset = block * blockSize + row * colNum + col;
if (betaFired)
c[offset] = a[offset] - bv[threadIdx.y] * beta;
else
c[offset] = a[offset] - bv[threadIdx.y];
}
/*
tensor subtraction (cuda version)
c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a+b*\beta. we save it in a if c is NULL
>> n - the dimension index
>> beta - the scaling factor
*/
void _CudaSubDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in subtraction!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in subtraction!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in subtraction!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for (int i = a->order - 1; i >= 0; i--) {
if (i > n)
stride *= a->dimSize[i];
else if (i < n)
blockNum *= a->dimSize[i];
}
int cudaGrids[3];
int cudaBlocks[3];
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
if (stride > 1) {
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
if (beta == (DTYPE)1.0F)
KernelSubWithCol<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, beta);
else
KernelSubWithCol<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, beta);
}
else if (stride == 1) {
GDevs.GetCudaThread2D(a->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
if (beta == (DTYPE)1.0F)
KernelSubWithRow<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, beta);
else
KernelSubWithRow<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, beta);
}
else {
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#ifndef __SUBDIM_CUH__
#define __SUBDIM_CUH__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* tensor subtraction c = a - b * \beta where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting (cuda version) */
void _CudaSubDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta = (DTYPE)1.0);
#endif
} // namespace nts(NiuTrans.Tensor)
#endif // __SUBDIM_CUH__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#ifndef __SUBDIM_H__
#define __SUBDIM_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* tensor subtraction c = a - b * \beta where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting*/
void _SubDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta = (DTYPE)1.0);
/* tensor subtraction c = a - b * \beta where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting. we keep the result in the input tensor a and return nothing */
void _SubDim(XTensor * a, const XTensor * b, int n, DTYPE beta = (DTYPE)1.0);
/* tensor subtraction c = a - b * \beta where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting. We make a new tensor c to keep the result and return it */
XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __SUBDIM_H__
......@@ -131,7 +131,7 @@ void _SumMe(XTensor * a, const XTensor * b, DTYPE beta)
}
/*
return a dimension if the sum is performed as SumDim (in more details in SumDim.h
return a dimension if the sum is performed as SumDim (in more details in SumDim.h)
>> a - a tensor
>> b - another tensor for sum
*/
......@@ -182,7 +182,7 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
XLink::AddParamToHead(&c, beta);
}
else if(n >= 0 && n < a.order){
/* call _Sum function */
/* call _SumDim function */
_SumDim(&a, &b, &c, n, beta);
/* tensor connections */
......
......@@ -49,7 +49,7 @@ void _SelectRange(const XTensor * a, XTensor * c, int dim, int low, int high)
for(int i = 0; i < a->order; i++){
if(i == dim){
CheckNTErrors(low > 0 && low < a->dimSize[dim], "Illegal range specified!");
CheckNTErrors(low >= 0 && low < a->dimSize[dim], "Illegal range specified!");
CheckNTErrors(high > 0 && high <= a->dimSize[dim], "Illegal range specified!");
}
else{
......@@ -101,7 +101,7 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high)
for(int i = 0; i < a.order; i++){
if(i == dim){
CheckNTErrors(low > 0 && low < a.dimSize[dim], "Illegal range specified!");
CheckNTErrors(low >= 0 && low < a.dimSize[dim], "Illegal range specified!");
CheckNTErrors(high > 0 && high <= a.dimSize[dim], "Illegal range specified!");
dimSize[i] = high - low;
}
......@@ -118,6 +118,7 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high)
/* tensor connection */
XLink::MakeLink(&a, NULL, &c, GETANDSET_SELECT);
XLink::AddParamToHeadInt(&c, dim);
XLink::AddParamToHeadInt(&c, low);
XLink::AddParamToHeadInt(&c, high);
......
......@@ -376,7 +376,7 @@ void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
DTYPE variance = upper - lower;
if (tensor->dataType == X_FLOAT)
KernelSetDataRandFloat <<<blocks, threads >>>((float*)tensor->data, tensor->unitNum, lower, variance);
KernelSetDataRandFloat <<<blocks, threads >>>((float*) tensor->data, tensor->unitNum, lower, variance);
else if (tensor->dataType == X_DOUBLE)
KernelSetDataRandDouble <<<blocks, threads >>>((double*)tensor->data, tensor->unitNum, lower, variance);
......
......@@ -55,7 +55,7 @@ void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift);
void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
/* generate data items with a normal distribution with specified mean and standard deviation */
void _SetDataRandN(XTensor * tensor, DTYPE mean, DTYPE standardDeviation);
void _SetDataRandN(XTensor * tensor, DTYPE mean = 0.0F, DTYPE standardDeviation = 1.0F);
} // namespace nts(NiuTrans.Tensor)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-14
*/
#include "TDivDim.h"
#include "../core/arithmetic/DivDim.h"
#include "../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: tensor division c = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting.
In this case, (2, 4) / (2) = (2, 4), n = 0, alpha = 0.0.
*/
bool TestDivDim1()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2) */
int bOrder = 1;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2] = {1.0F, -1.0F};
DTYPE answer[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{-4.0F, -5.0F, -6.0F, -7.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call DivDim function */
_DivDim(a, b, c, 0);
_DivDim(cMe, b, 0);
cUser = DivDim(*a, *b, 0);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call sum function */
_DivDim(aGPU, bGPU, cGPU, 0);
_DivDim(cMeGPU, bGPU, 0);
cUserGPU = DivDim(*aGPU, *bGPU, 0);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 2: tensor division c = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting.
In this case, (2, 4) / (2, 2) = (2, 4), n = 1.
*/
bool TestDivDim2()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2, 2) */
int bOrder = 2;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
bDimSize[1] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2][2] = { {1.0F, -1.0F},
{-1.0F, 1.0F} };
DTYPE answer[2][4] = { {0.0F, -1.0F, -2.0F, 3.0F},
{4.0F, -5.0F, -6.0F, 7.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call DivDim function */
_DivDim(a, b, c, 1);
_DivDim(cMe, b, 1);
cUser = DivDim(*a, *b, 1);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call sum function */
_DivDim(aGPU, bGPU, cGPU, 1);
_DivDim(cMeGPU, bGPU, 1);
cUserGPU = DivDim(*aGPU, *bGPU, 1);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for DivDim Function */
bool TestDivDim()
{
XPRINT(0, stdout, "[TEST DIVDIM] tensor division c(i) = a/b + \alpha * c by broadcasting\n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestDivDim1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestDivDim2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-14
*/
#ifndef __TEST_DIVDIM_H__
#define __TEST_DIVDIM_H__
#include "../core/arithmetic/DivDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for DivDim Function */
bool TestDivDim();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_DIVDIM_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-30
*/
#include "TMultiplyDim.h"
#include "../core/arithmetic/MultiplyDim.h"
#include "../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: tensor multiplication c = a * b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting
In this case, (2, 4) * (2) = (2, 4), n = 0.
*/
bool TestMultiplyDim1()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2) */
int bOrder = 1;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2] = {1.0F, -1.0F};
DTYPE answer[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{-4.0F, -5.0F, -6.0F, -7.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call MultiplyDim function */
_MultiplyDim(a, b, c, 0);
_MultiplyDimMe(cMe, b, 0);
cUser = MultiplyDim(*a, *b, 0);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call MultiplyDim function */
_MultiplyDim(aGPU, bGPU, cGPU, 0);
_MultiplyDimMe(cMeGPU, bGPU, 0);
cUserGPU = MultiplyDim(*aGPU, *bGPU, 0);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 2: tensor multiplication c = a*b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting.
In this case, (2, 4) * (4) = (2, 4), n = 1.
*/
bool TestMultiplyDim2()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (4) */
int bOrder = 1;
int * bDimSize = new int[bOrder];
bDimSize[0] = 4;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[4] = {1.0F, -1.0F , 1.0F, -1.0F};
DTYPE answer[2][4] = { {0.0F, -1.0F, 2.0F, -3.0F},
{4.0F, -5.0F, 6.0F, -7.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call MultiplyDim function */
_MultiplyDim(a, b, c, 1);
_MultiplyDimMe(cMe, b, 1);
cUser = MultiplyDim(*a, *b, 1);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call MultiplyDim function */
_MultiplyDim(aGPU, bGPU, cGPU, 1);
_MultiplyDimMe(cMeGPU, bGPU, 1);
cUserGPU = MultiplyDim(*aGPU, *bGPU, 1);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* test for MultiplyDim Function */
bool TestMultiplyDim()
{
XPRINT(0, stdout, "[TEST MULTIPLYDIM] tensor multiplication c = a * b + \alpha * c by broadcasting\n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestMultiplyDim1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestMultiplyDim2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-30
*/
#ifndef __TEST_MULTIPLYDIM_H__
#define __TEST_MULTIPLYDIM_H__
#include "../core/arithmetic/MultiplyDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for MultiplyDim Function */
bool TestMultiplyDim();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_MULTIPLYDIM_H__
\ No newline at end of file
......@@ -24,7 +24,8 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: sum the items along a dimension of the tensor.
case 1: test ReduceSum function.
Sum the items along a dimension of the tensor.
In this case,
(2, 4) -> (4), dim = 0
(2, 4) -> (2), dim = 1
......@@ -90,8 +91,8 @@ bool TestReduceSum1()
tUser2 = ReduceSum(*s, 1, *shift2);
/* check results */
cpuTest = t1->CheckData(answer1, tUnitNum1) && tUser1.CheckData(answer1, tUnitNum1)
&& t2->CheckData(answer2, tUnitNum2) && tUser2.CheckData(answer2, tUnitNum2);
cpuTest = t1->CheckData(answer1, tUnitNum1) && tUser1.CheckData(answer1, tUnitNum1) &&
t2->CheckData(answer2, tUnitNum2) && tUser2.CheckData(answer2, tUnitNum2);
#ifdef USE_CUDA
/* GPU test */
......@@ -120,8 +121,8 @@ bool TestReduceSum1()
tUserGPU2 = ReduceSum(*sGPU, 1, *shiftGPU2);
/* check results */
gpuTest = tGPU1->CheckData(answer1, tUnitNum1) && tUserGPU1.CheckData(answer1, tUnitNum1)
&& tGPU2->CheckData(answer2, tUnitNum2) && tUserGPU2.CheckData(answer2, tUnitNum2);
gpuTest = tGPU1->CheckData(answer1, tUnitNum1) && tUserGPU1.CheckData(answer1, tUnitNum1) &&
tGPU2->CheckData(answer2, tUnitNum2) && tUserGPU2.CheckData(answer2, tUnitNum2);
/* destroy variables */
delete s;
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#include "TSubDim.h"
#include "../core/arithmetic/SubDim.h"
#include "../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: tensor subtraction c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
*/
bool TestSubDim1()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2) */
int bOrder = 1;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2] = {1.0F, -1.0F};
DTYPE answer[2][4] = { {-1.0F, 0.0F, 1.0F, 2.0F},
{5.0F, 6.0F, 7.0F, 8.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call SubDim function */
_SubDim(a, b, c, 0);
_SubDim(cMe, b, 0);
cUser = SubDim(*a, *b, 0);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call sub function */
_SubDim(aGPU, bGPU, cGPU, 0);
_SubDim(cMeGPU, bGPU, 0);
cUserGPU = SubDim(*aGPU, *bGPU, 0);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 2: tensor subtraction c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
*/
bool TestSubDim2()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2, 2) */
int bOrder = 2;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
bDimSize[1] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2][2] = { {1.0F, -1.0F},
{-1.0F, 1.0F} };
DTYPE answer[2][4] = { {-1.0F, 2.0F, 3.0F, 2.0F},
{3.0F, 6.0F, 7.0F, 6.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call SubDim function */
_SubDim(a, b, c, 1);
_SubDim(cMe, b, 1);
cUser = SubDim(*a, *b, 1);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call sub function */
_SubDim(aGPU, bGPU, cGPU, 1);
_SubDim(cMeGPU, bGPU, 1);
cUserGPU = SubDim(*aGPU, *bGPU, 1);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for SubDim Function */
bool TestSubDim()
{
XPRINT(0, stdout, "[TEST SUBDIM] tensor subtraction c = a - b * beta by broadcasting\n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestSubDim1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestSubDim2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#ifndef __TEST_SUBDIM_H__
#define __TEST_SUBDIM_H__
#include "../core/arithmetic/SubDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for SubDim Function */
bool TestSubDim();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_SUBDIM_H__
......@@ -28,7 +28,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: tensor summation c = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
i.e., a is summed with b by broadcasting.
In this case, (2, 4) + (2) = (2, 4), n = 0.
*/
bool TestSumDim1()
{
......@@ -79,9 +80,9 @@ bool TestSumDim1()
cUser = SumDim(*a, *b, 0);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum)
&& cMe->CheckData(answer, aUnitNum)
&& cUser.CheckData(answer, aUnitNum);
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
......@@ -106,9 +107,9 @@ bool TestSumDim1()
cUserGPU = SumDim(*aGPU, *bGPU, 0);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum)
&& cMeGPU->CheckData(answer, aUnitNum)
&& cUserGPU.CheckData(answer, aUnitNum);
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
......@@ -139,7 +140,8 @@ bool TestSumDim1()
/*
case 2: tensor summation c = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
i.e., a is summed with b by broadcasting.
In this case, (2, 4) + (2, 2) = (2, 4), n = 1.
*/
bool TestSumDim2()
{
......@@ -192,9 +194,9 @@ bool TestSumDim2()
cUser = SumDim(*a, *b, 1);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum)
&& cMe->CheckData(answer, aUnitNum)
&& cUser.CheckData(answer, aUnitNum);
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
......@@ -219,9 +221,9 @@ bool TestSumDim2()
cUserGPU = SumDim(*aGPU, *bGPU, 1);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum)
&& cMeGPU->CheckData(answer, aUnitNum)
&& cUserGPU.CheckData(answer, aUnitNum);
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-30
*/
#include "TTmp.h"
#include "../XTensor.h"
#include "../../xc/ultility.h"
#include "../../xc/myCode.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
bool TestTmp1()
{
XTensor a;
XTensor b;
InitTensor4D(&a, 8, 32, 106, 106, X_FLOAT, -1, NULL);
FILE * fReadOrigin = fopen("V:/temp/input.dump", "rb");
a.Read(fReadOrigin, "a-plus-bias");
fclose(fReadOrigin);
b = Softmax(a, 3);
XTensor c;
InitTensor4D(&c, 8, 32, 106, 106, X_FLOAT, -1, NULL);
FILE * fReadResult = fopen("V:/temp/input.dump.result", "rb");
c.Read(fReadResult, "");
fclose(fReadResult);
printf("\n\nThis is CPU!\n");
b.Dump(stderr, "b", 100);
printf("\n\n");
c.Dump(stderr, "c", 100);
bool cpuTest;
cpuTest = b.CheckData(c.data, b.unitNum, 1e-6F);
if(cpuTest == true)
printf("CPU Yeah!");
else
printf("CPU ops..");
exit(1);
#ifdef USE_CUDA
XTensor aGPU;
XTensor bGPU;
InitTensor4D(&aGPU, 8, 32, 106, 106, X_FLOAT, 0, NULL);
InitTensor4D(&bGPU, 8, 32, 106, 106, X_FLOAT, 0, NULL);
fReadOrigin = fopen("V:/temp/input.dump", "rb");
aGPU.Read(fReadOrigin, "a-plus-bias");
fclose(fReadOrigin);
//bGPU = Softmax(aGPU, 3);
_Softmax(&aGPU, &bGPU, 3);
printf("\n\nThis is GPU\n");
bGPU.Dump(stderr, "bGPU", 100);
bool gpuTest;
gpuTest = bGPU.CheckData(c.data, bGPU.unitNum, 1e-4F);
if(gpuTest == true)
printf("GPU Yeah!");
else
printf("GPU ops..");
#endif // USE_CUDA
exit(1);
return 0;
}
bool TestTmp2()
{
XTensor a;
XTensor b;
InitTensor4D(&a, 8, 32, 106, 106, X_FLOAT, -1, NULL);
InitTensor4D(&b, 8, 32, 106, 106, X_FLOAT, -1, NULL);
//FILE * fReadResultGold = fopen("V:/temp/input.dump.gold", "rb");
//a.Read(fReadResultGold, "input");
//fclose(fReadResultGold);
FILE * fReadResult = fopen("V:/temp/input.dump", "rb");
b.Read(fReadResult, "a-plus-bias");
fclose(fReadResult);
ShowData(&b, "");
bool flag = CheckTensorData(a, b, 1e-3F);
if (flag)
printf("yeah");
else
printf("ops.");
exit(1);
return 0;
}
/* other cases */
/*
TODO!!
*/
/* test for Tmp Function */
bool TestTmp()
{
XPRINT(0, stdout, "[TEST Temp] temporary test\n");
bool returnFlag = true, caseFlag = true;
///* case 1 test */
//caseFlag = TestTmp1();
//if (!caseFlag) {
// returnFlag = false;
// XPRINT(0, stdout, ">> case 1 failed!\n");
//}
//else
// XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestTmp2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-13
*/
#ifndef __TEST_TMP_H__
#define __TEST_TMP_H__
#include "../core/CHeader.h"
#include "../function/FHeader.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
bool TestTmp();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_TMP_H__
......@@ -99,8 +99,8 @@ bool TestUnsqueeze1()
tUser2 = Unsqueeze(*s, 2, 2);
/* check results */
cpuTest = t1->CheckData(answer1, tUnitNum1) && tUser1.CheckData(answer1, tUnitNum1)
&& t2->CheckData(answer2, tUnitNum2) && tUser2.CheckData(answer2, tUnitNum2);
cpuTest = t1->CheckData(answer1, tUnitNum1) && tUser1.CheckData(answer1, tUnitNum1) &&
t2->CheckData(answer2, tUnitNum2) && tUser2.CheckData(answer2, tUnitNum2);
#ifdef USE_CUDA
/* GPU test */
......
......@@ -29,6 +29,8 @@ bool Test()
bool wrong = false;
XPRINT(0, stdout, "Testing the XTensor utilites ... \n\n");
//wrong = !TestTmp() || wrong;
wrong = !TestAbsolute() || wrong;
wrong = !TestClip() || wrong;
wrong = !TestConcatenate() || wrong;
......@@ -38,6 +40,7 @@ bool Test()
wrong = !TestCopyIndexed() || wrong;
wrong = !TestCopyValues() || wrong;
wrong = !TestDiv() || wrong;
wrong = !TestDivDim() || wrong;
wrong = !TestExp() || wrong;
wrong = !TestLog() || wrong;
wrong = !TestMatrixMul() || wrong;
......@@ -46,6 +49,7 @@ bool Test()
wrong = !TestMatrixMulBatched() || wrong;
wrong = !TestMerge() || wrong;
wrong = !TestMultiply() || wrong;
wrong = !TestMultiplyDim() || wrong;
wrong = !TestNegate() || wrong;
wrong = !TestNormalize() || wrong;
wrong = !TestPower() || wrong;
......
......@@ -22,6 +22,8 @@
#ifndef __TEST_H__
#define __TEST_H__
#include "TTmp.h"
#include "TAbsolute.h"
#include "TClip.h"
#include "TConcatenate.h"
......@@ -31,6 +33,7 @@
#include "TCopyIndexed.h"
#include "TCopyValues.h"
#include "TDiv.h"
#include "TDivDim.h"
#include "TExp.h"
#include "TLog.h"
#include "TMatrixMul.h"
......@@ -39,6 +42,7 @@
#include "TMatrixMulBatched.h"
#include "TMerge.h"
#include "TMultiply.h"
#include "TMultiplyDim.h"
#include "TNegate.h"
#include "TNormalize.h"
#include "TPower.h"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论