Commit fa4d308e by xuchen

Merge branch 'xuchen'

parents b8bcae5b 98db6f24
......@@ -2,7 +2,7 @@
## 注意事项
CUDA最新版本9.2尚且不支持VS2017最新版本,因此建议使用CUDA版本为9.0或9.1,建议使用VS版本为VS2015,或使用VS2017时安装v140工具集。
CUDA最新版本9.2尚且不支持VS2017最新版本,因此建议使用CUDA版本为9.0或9.1,建议使用VS版本为VS2015,或使用VS2017时安装v140工具集,解决方案平台设置为×64
## CUDA配置
......@@ -29,7 +29,7 @@ CUDA最新版本9.2尚且不支持VS2017最新版本,因此建议使用CUDA版
**C/C++->预处理器->预处理器定义** 中,添加
>USE_CUDA;USE_BLAS;WIN32;MKL;DEBUG;CRT_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS_
>USE_CUDA;USE_BLAS;WIN32;MKL;_DEBUG;_CRT_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS_
CONSOLE;
**链接器->系统->子系统**,设置为控制台。
......
......@@ -24,6 +24,7 @@
#include "../tensor/XUtility.h"
#include "../tensor/function/FHeader.h"
#include "../tensor/core/CHeader.h"
#include "../tensor/test/Test.h"
#include "../sample/fnnlm/FNNLM.h"
#include "../sample/transformer/Transformer.h"
......@@ -31,18 +32,24 @@
//#include <stdlib.h>
//#include <crtdbg.h>
void BackwardTest();
void TransposeTest();
void SumDimTest();
using namespace nts;
using namespace fnnlm;
using namespace transformer;
using namespace GAN;
int main( int argc, const char ** argv )
{
//_CrtSetBreakAlloc(896);
//BackwardTest();
//return 0;
if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
if(argc > 1 && !strcmp(argv[1], "-test"))
Test();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
FNNLMMain(argc - 1, argv + 1);
else if(argc > 1 && !strcmp(argv[1], "-t2t"))
TransformerMain(argc - 1, argv + 1);
......@@ -58,6 +65,41 @@ int main( int argc, const char ** argv )
return 0;
}
void BackwardTest()
{
XNet net;
XTensor a;
XTensor b;
XTensor c;
XTensor mean;
XTensor origin;
InitTensor2D(&a, 2, 3);
InitTensor1D(&b, 2);
a.SetZeroAll();
b.SetZeroAll();
a.Set2D(1.0F, 0, 0);
a.Set2D(2.0F, 0, 1);
a.Set2D(3.0F, 0, 2);
a.Set2D(4.0F, 1, 0);
a.Set2D(5.0F, 1, 1);
a.Set2D(6.0F, 1, 2);
b.Set1D(2.0F, 0);
b.Set1D(1.0F, 1);
c = DivDim(a, b, 0);
c.Dump(stderr, "c:");
XLink::ShowNetwork(stderr, &c);
net.Backward(c);
net.Dump(stderr);
}
void TransposeTest()
{
#ifdef USE_CUDA
......
......@@ -50,6 +50,7 @@ void XFuncGrad::MakeGrad(XTensor * node)
_IdentityBackward(NULL, output, input, output->grad, input->grad, NOLOSS);
else if(operID == FUNC_LOGSOFTMAX){
int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in logsoftmax!");
_LogSoftmaxBackward(NULL, output, input, output->grad, input->grad, leadDim, NOLOSS);
}
else if(operID == FUNC_RECTIFY)
......@@ -58,6 +59,7 @@ void XFuncGrad::MakeGrad(XTensor * node)
_SigmoidBackward(NULL, output, input, output->grad, input->grad, NOLOSS);
else if(operID == FUNC_SOFTMAX){
int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in softmax!");
_SoftmaxBackward(NULL, output, input, output->grad, input->grad, leadDim, NOLOSS);
}
else{
......
......@@ -22,7 +22,11 @@
#include "XBackwardLoss.h"
#include "../tensor/XName.h"
#include "../tensor/function/HardTanH.h"
#include "../tensor/function/Identity.h"
#include "../tensor/function/LogSoftmax.h"
#include "../tensor/function/Rectify.h"
#include "../tensor/function/Sigmoid.h"
#include "../tensor/function/Softmax.h"
namespace nts{
......@@ -49,10 +53,22 @@ void XLossGrad::Compute(XTensor * gold, XTensor * y, XTensor * x,
if(funcID == FUNC_HARDTANH){
_HardTanHBackward(gold, y, x, dedy, dedx, lossName);
}
else if(funcID == FUNC_IDENTITY){
_IdentityBackward(gold, y, x, dedy, dedx, lossName);
}
else if(funcID == FUNC_LOGSOFTMAX){
int leadDim = *(int*)params;
_LogSoftmaxBackward(gold, y, x, dedy, dedx, leadDim, lossName);
}
else if(funcID == FUNC_RECTIFY){
_RectifyBackward(gold, y, x, dedy, dedx, lossName);
}
else if(funcID == FUNC_SIGMOID){
_SigmoidBackward(gold, y, x, dedy, dedx, lossName);
}else if(funcID == FUNC_SOFTMAX){
int leadDim = *(int*)params;
_SoftmaxBackward(gold, y, x, dedy, dedx, leadDim, lossName);
}
else{
ShowNTErrors("wrong function found when call the backward process!");
}
......
......@@ -40,18 +40,50 @@ public:
bool IsMathOP(XTensor * node);
private:
/* gradient for sum: c = a + b * \beta */
/* gradient for absolute */
static
void GradSum(XTensor * node);
void GradAbsolute(XTensor * node);
/* gradient for cos */
static
void GradCos(XTensor * node);
/* gradient for exp */
static
void GradExp(XTensor * node);
/* gradient for sum with one dimension: c = a + b * \beta
where the size of b is equal to that of one dimension of a */
/* gradient for log: c = log(a) */
static
void GradSumDim(XTensor * node);
void GradLog(XTensor * node);
/* gradient for round */
static
void GradRound(XTensor * node);
/* gradient for sign */
static
void GradSign(XTensor * node);
/* gradient for multiply (dot production): c = a * b * \alpha */
/* gradient for sin */
static
void GradMultiply(XTensor * node);
void GradSin(XTensor * node);
/* gradient for tan */
static
void GradTan(XTensor * node);
/* gradient for clip */
static
void GradClip(XTensor * node);
/* gradient for Divide */
static
void GradDiv(XTensor * node);
/* gradient for DivideDim */
static
void GradDivDim(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static
......@@ -68,17 +100,26 @@ private:
static
void GradMatrixMulBatched(XTensor * node);
/* gradient for log: c = log(a) */
/* gradient for multiply (dot production): c = a * b * \alpha */
static
void GradLog(XTensor * node);
void GradMultiply(XTensor * node);
/* gradient for power */
/* gradient for multiply one dimension: c = a * b * \alpha
where the size of b is equal to that of one dimension of a */
static
void GradPower(XTensor * node);
void GradMultiplyDim(XTensor * node);
/* gradient for negate */
static
void GradNegate(XTensor * node);
/* gradient for normalize */
static
void GradNormalize(XTensor * node);
/* gradient for power */
static
void GradPower(XTensor * node);
/* gradient for ScaleAndShift */
static
......@@ -87,10 +128,20 @@ private:
/* gradient for Minus */
static
void GradSub(XTensor * node);
/* gradient for sub with one dimension: c = a - b * \beta
where the size of b is equal to that of one dimension of a */
static
void GradSubDim(XTensor * node);
/* gradient for Divide */
/* gradient for sum: c = a + b * \beta */
static
void GradDiv(XTensor * node);
void GradSum(XTensor * node);
/* gradient for sum with one dimension: c = a + b * \beta
where the size of b is equal to that of one dimension of a */
static
void GradSumDim(XTensor * node);
/* gradient for reduceMean */
static
......@@ -107,42 +158,6 @@ private:
/* gradient for reduceVariance */
static
void GradReduceVariance(XTensor * node);
/* gradient for sin */
static
void GradSin(XTensor * node);
/* gradient for cos */
static
void GradCos(XTensor * node);
/* gradient for tan */
static
void GradTan(XTensor * node);
/* gradient for exp */
static
void GradExp(XTensor * node);
/* gradient for normalize */
static
void GradNormalize(XTensor * node);
/* gradient for absolute */
static
void GradAbsolute(XTensor * node);
/* gradient for sign */
static
void GradSign(XTensor * node);
/* gradient for clip */
static
void GradClip(XTensor * node);
/* gradient for round */
static
void GradRound(XTensor * node);
};
}
......
......@@ -99,7 +99,7 @@ arguments:
(how many words)
-shuffle: shuffle the training data
-devid D: the id of the device used
-1: GPU, >=0: GPUs
-1: CPU, >=0: GPUs
-mempool: use memory pools for memory management
-autodiff: use automatic differentiation for training
......
......@@ -35,6 +35,8 @@ T2TAttention::T2TAttention()
dk = -1;
dv = -1;
d = -1;
isMasked = false;
ignored = 0;
}
/* deconstructor */
......@@ -46,29 +48,39 @@ T2TAttention::~T2TAttention()
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myIgnored - number of position ignored in attention (from the begining)
>> myIsMasked - indicates whether the attention is with a mask
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TAttention::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
void T2TAttention::InitModel(int argc, const char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
isMasked = myIsMasked;
ignored = myIgnored;
float minmax = 0;
LoadParamInt(argc, argv, "nhead", &nhead, 8);
LoadParamInt(argc, argv, "d", &dk, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &dv, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_BEDDING_SIZE);
LoadParamFloat(argc, argv, "attminmax", &minmax, 0.08F);
LoadParamInt(argc, argv, "d", &dk, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &dv, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
LoadParamFloat(argc, argv, "attminmax", &minmax, 0.1F);
InitTensor2D(&wk, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wq, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wv, d, dv, X_FLOAT, devID, mem);
float scale = 1.0F;
float finfoutk = (float)sqrt(6.0F * scale/(d + dk));
float finfoutv = (float)sqrt(6.0F * scale/(d + dv));
wk.SetDataRand(-minmax, minmax);
wq.SetDataRand(-minmax, minmax);
wv.SetDataRand(-minmax, minmax);
wk.SetDataRand(-finfoutk, finfoutk);
wq.SetDataRand(-finfoutk, finfoutk);
wv.SetDataRand(-finfoutv, finfoutv);
}
/*
......@@ -78,9 +90,10 @@ make the network
and H = vector size of each position
>> q - queries
>> v - values
>> maske - as it is
<< return - multi-attention result
*/
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v)
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
{
XTensor k2;
XTensor q2;
......@@ -101,10 +114,18 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v)
vheads = Split(v2, v2.order - 1, nhead);
XTensor att;
XTensor dot;
XTensor scalar;
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */
scalar = Softmax(Linear(BMMul(qheads, X_NOTRANS, kheads, X_TRANS), 1/sqrt((float)dk)), -1);
dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
if(isMasked)
dot = dot + mask;
scalar = Softmax(Linear(dot, 1/(float)sqrt((float)dk)), -1);
if(ignored > 0)
_SetDataDim(&scalar, 0, ignored, scalar.order - 2, 1e-9F);
att = BMMul(scalar, vheads);
/* concatenate the heads */
......
......@@ -66,6 +66,13 @@ public:
/* size of input Q, K and V */
int d;
/* indicates whether the attention is masked */
bool isMasked;
/* some positions can be ignored in attention. this is useful in lm where the first position needs
special design for the attention model. */
int ignored;
public:
/* constructor */
T2TAttention();
......@@ -74,10 +81,12 @@ public:
~T2TAttention();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, const char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v);
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask);
};
}
......
......@@ -53,16 +53,14 @@ void T2TEmbedder::InitModel(int argc, const char ** argv, int myDevID, XMem * my
devID = myDevID;
mem = myMem;
int d = 0;
LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamInt(argc, argv, "maxlen", &maxLength, 256);
LoadParamInt(argc, argv, "d", &eSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "maxlen", &maxLength, 512);
LoadParamInt(argc, argv, "d", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
InitTensor2D(&w, vSize, eSize, X_FLOAT, devID, mem);
w.SetDataRandn(0, sqrt((float)eSize));
w.SetDataRandn(0, 1.0F/(float)sqrt((float)eSize));
/* create the positional embedding matrix */
MakePosEmbedding(eSize, d, maxLength);
......@@ -84,11 +82,11 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
for(int k = 0; k < eSize; k++){
if(k % 2 == 0){
int i = k/2;
dp[k] = sin(pos/pow(10000.0F, 2.0F*i/d));
dp[k] = (float)sin(pos/pow(10000.0F, 2.0F*i/d));
}
else{
int i = (k - 1)/2;
dp[k] = cos(pos/pow(10000.0F, 2.0F*i/d));
dp[k] = (float)cos(pos/pow(10000.0F, 2.0F*i/d));
}
}
}
......@@ -135,7 +133,7 @@ XTensor T2TEmbedder::Make(XTensor &input)
XTensor wordEmbedding;
/* then we make word embeddings */
wordEmbedding = MMul(&input, w);
wordEmbedding = Linear(MMul(input, w), (float)sqrt((float)d));
/* we sum over the two embeddings */
return wordEmbedding + posEmbedding;
......
......@@ -29,7 +29,7 @@ using namespace nts;
namespace transformer
{
#define DEFAULT_BEDDING_SIZE 512
#define DEFAULT_EMBEDDING_SIZE 512
/*
embedding (of word at position i):
......@@ -53,6 +53,9 @@ public:
/* maximum length of the sequence */
int maxLength;
/* dimension size of the hidden layers in the t2t model */
int d;
/* word embedding matrix */
XTensor w;
......
......@@ -38,28 +38,33 @@ AttEncoder::~AttEncoder()
{
delete[] attentions;
delete[] fnns;
delete[] layerNorms;
delete[] attLayerNorms;
delete[] fnnLayerNorms;
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myIsMasked - indicates whether the masked attention is employed
>> myIgnored - number of positions ignored in attention (from the start)
>> myDevID - device id
>> myMem - the memory pool
*/
void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
void AttEncoder::InitModel(int argc, const char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
ignored = myIgnored;
LoadParamInt(argc, argv, "nstack", &nlayer, 6);
LoadParamInt(argc, argv, "hsize", &hSize, 512);
LoadParamInt(argc, argv, "esize", &eSize, 512);
LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "esize", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "vsize", &vSize, -1);
CheckNTErrors(nlayer > 1, "We have one encoding layer at least!");
CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!");
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsize\"");
/* embedding model */
......@@ -67,22 +72,26 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM
attentions = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer];
layerNorms = new T2TLN[nlayer];
attLayerNorms = new T2TLN[nlayer];
fnnLayerNorms = new T2TLN[nlayer];
/* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){
attentions[i].InitModel(argc, argv, myDevID, myMem);
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
fnns[i].InitModel(argc, argv, myDevID, myMem);
layerNorms[i].InitModel(argc, argv, myDevID, myMem);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
}
}
/*
make the encoding network
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> skipInputRes - indicates whether we skip the residual connection of the first layer
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input)
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes)
{
XTensor x;
......@@ -94,19 +103,27 @@ XTensor AttEncoder::Make(XTensor &input)
XTensor fnn;
XTensor res;
/* self attention */
att = attentions[i].Make(x, x, x);
if(skipInputRes && i == 0){
/* self attention */
att = attentions[i].Make(x, x, x, mask);
/* residual connection */
res = Sum(att, x);
/* TODO: dropout */
/* TODO: dropout */
/* layer normalization */
x = attLayerNorms[i].Make(att);
}
else{
/* self attention */
att = attentions[i].Make(x, x, x, mask);
/* layer normalization */
ln = layerNorms[i].Make(res);
/* residual connection */
res = Sum(att, x);
/* TODO: dropout */
/* input of next layer */
x = ln;
/* layer normalization */
x = attLayerNorms[i].Make(res);
}
/* fnn */
fnn = fnns[i].Make(x);
......@@ -117,10 +134,7 @@ XTensor AttEncoder::Make(XTensor &input)
/* TODO: dropout */
/* layer normalization */
ln = layerNorms[i].Make(res);
/* input of next layer */
x = ln;
x = fnnLayerNorms[i].Make(res);
}
return x;
......
......@@ -40,7 +40,7 @@ class T2TEncoder
{
public:
virtual
XTensor Make(XTensor &input) = 0;
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes) = 0;
};
/*
......@@ -49,7 +49,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder
{
public:
XTensor Make(XTensor &input);
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes);
};
......@@ -77,6 +77,10 @@ public:
/* vocabulary size */
int vSize;
/* some positions can be ignored in attention. this is useful in lm where the first position needs
special design for the attention model. */
int ignored;
/* embedding of word at each position */
T2TEmbedder embedder;
......@@ -86,8 +90,11 @@ public:
/* attention model of each layer */
T2TAttention * attentions;
/* layer normalization */
T2TLN * layerNorms;
/* layer normalization for fnn */
T2TLN * fnnLayerNorms;
/* layer normalization for attention */
T2TLN * attLayerNorms;
/* input tensor of the encoder */
XTensor * input;
......@@ -103,10 +110,12 @@ public:
~AttEncoder();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, const char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */
XTensor Make(XTensor &input);
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes);
};
......
......@@ -19,6 +19,7 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "T2TFNN.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
......@@ -55,10 +56,10 @@ void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
float minmax = 0;
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &outSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, DEFAULT_BEDDING_SIZE);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.08F);
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &outSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.1F);
InitTensor2D(&w1, inSize, hSize, X_FLOAT, devID, mem);
InitTensor1D(&b1, hSize, X_FLOAT, devID, mem);
......@@ -66,10 +67,14 @@ void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
InitTensor2D(&w2, hSize, outSize, X_FLOAT, devID, mem);
InitTensor1D(&b2, outSize, X_FLOAT, devID, mem);
w1.SetDataRand(-minmax, minmax);
b1.SetDataRand(-minmax, minmax);
w2.SetDataRand(-minmax, minmax);
b2.SetDataRand(-minmax, minmax);
float scale = 1.0F;
float finfout1 = (float)sqrt(6.0F * scale/(inSize + hSize));
float finfout2 = (float)sqrt(6.0F * scale/(hSize + outSize));
w1.SetDataRand(-finfout1, finfout1);
b1.SetZeroAll();
w2.SetDataRand(-finfout2, finfout2);
b2.SetZeroAll();
}
/*
......@@ -83,10 +88,10 @@ XTensor T2TFNN::Make(XTensor &input)
XTensor t1;
/* t1 = max(0, x * w1 + b1) */
t1 = Rectify(MMul(input, X_NOTRANS, w1, X_NOTRANS) + b1);
t1 = Rectify(MMul(input, w1) + b1);
/* result = t1 * w2 + b2 */
return MMul(t1, X_NOTRANS, w2, X_NOTRANS) + b2;
return MMul(t1, w2) + b2;
}
......
......@@ -19,7 +19,10 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "T2TLayerNormal.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
......@@ -48,6 +51,18 @@ void T2TLN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
int d = 0;
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
InitTensor2D(&w, d, d, X_FLOAT, devID, mem);
InitTensor1D(&b, d, X_FLOAT, devID, mem);
float scale = 1.0F;
float finfout = (float)sqrt(6.0F * scale / (d + d));
w.SetDataRand(-finfout, finfout);
b.SetZeroAll();
}
/*
......@@ -60,6 +75,7 @@ y =
XTensor T2TLN::Make(XTensor &input)
{
XTensor &x = input;
XTensor xn;
XTensor mean;
XTensor variance;
XTensor standard;
......@@ -67,21 +83,23 @@ XTensor T2TLN::Make(XTensor &input)
XTensor standardFilled;
/* \mu = (sum_i x_i)/m */
mean = ReduceSum(x, x.order - 1);
mean = ReduceMean(x, x.order - 1);
/* \sigma = (sum_i (x_i - \mu)^2)/m */
variance = ReduceVariance(x, x.order - 1, mean);
/* standard = sqrt(variance) */
standard = Power(variance, 0.5F);
/* unsqueeze mean and standard deviation to fit them into
the same size of x */
the same shape of x */
meanFilled = Unsqueeze(mean, x.order - 1, x.GetDim(-1));
standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1));
/* x' = (x - \mu)/standard */
return (x - meanFilled)/standardFilled;
xn = (x - meanFilled)/standardFilled;
/* result = x' * w + b */
return MMul(xn, w) + b;
}
}
......@@ -29,6 +29,8 @@ using namespace nts;
namespace transformer
{
/* layer normalization: y = norm(x) * w + b
where norm(x) = (x - mean)/standardDeviation */
class T2TLN
{
public:
......@@ -37,6 +39,12 @@ public:
/* memory pool */
XMem * mem;
/* the transformation matrix w */
XTensor w;
/* the bias term b */
XTensor b;
public:
/* constructor */
......
......@@ -22,6 +22,7 @@
#include "T2TModel.h"
#include "T2TUtility.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
......@@ -33,6 +34,7 @@ T2TModel::T2TModel()
mem = NULL;
isLM = false;
isMT = false;
nhead = 1;
}
/* de-constructor */
......@@ -54,24 +56,27 @@ void T2TModel::InitModel(int argc, const char ** argv)
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamBool(argc, argv, "lm", &isLM, true);
LoadParamBool(argc, argv, "mt", &isMT, false);
LoadParamInt(argc, argv, "nhead", &nhead, 8);
if(useMem){
delete mem;
mem = new XMem(devID);
}
encoder.InitModel(argc, argv, devID, mem);
encoder.InitModel(argc, argv, isLM, isLM ? 1 : 0, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem);
}
/*
make the encoding network
>> input - input tensor
>> mask - the mask for positions that are/not involved in computation
>> skipInputRes - indicates whether we skip the residual connection of the first layer
<< return - encoding result
*/
XTensor T2TModel::MakeEncoding(XTensor &input)
XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes)
{
return encoder.Make(input);
return encoder.Make(input, mask, skipInputRes);
}
/*
......@@ -81,15 +86,30 @@ make the entire network (with the output softmax layer)
*/
void T2TModel::Make(XTensor &input, XTensor &output)
{
XTensor encoding;
if(isLM){
XTensor encoding;
/* generate mask to see "previous" words only */
int len = input.GetDim(input.order - 2);
int * dims = new int[input.order + 1];
for(int i = 0; i < input.order; i++)
dims[i + 1] = input.GetDim(i);
dims[0] = nhead;
dims[input.order] = len;
XTensor mask(input.order + 1, dims, X_FLOAT, 1.0F, input.devID, input.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9 */
_SetDataLowTri(&mask, 1e9F, -1);
_ScaleAndShiftMe(&mask, 1.0F, -1e9F);
encoding = MakeEncoding(input);
encoding = MakeEncoding(input, mask, true);
outputLayer.Make(encoding, output);
delete[] dims;
}
else{
ShowNTErrors("TODO!");
}
}
}
\ No newline at end of file
}
......@@ -55,6 +55,9 @@ public:
/* indicates whether the model is running for machine translation */
bool isMT;
/* number of heads in the attention model */
int nhead;
public:
/* constructor */
T2TModel();
......@@ -66,7 +69,7 @@ public:
void InitModel(int argc, const char ** argv);
/* make the encoding network */
XTensor MakeEncoding(XTensor &input);
XTensor MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes);
/* make the entire network (with the output softmax layer) */
void Make(XTensor &input, XTensor &output);
......
......@@ -19,6 +19,7 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "T2TOutput.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
......@@ -56,12 +57,15 @@ void T2TOutput::InitModel(int argc, const char ** argv, int myDevID, XMem * myMe
float minmax = 0;
LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &hSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamFloat(argc, argv, "outputminmax", &minmax, 0.08F);
InitTensor2D(&w, hSize, vSize, X_FLOAT, devID, mem);
w.SetDataRand(-minmax, minmax);
float scale = 1.0F;
float finfout = (float)sqrt(6.0F * scale/(hSize + vSize));
w.SetDataRand(-finfout, finfout);
}
/*
......@@ -89,4 +93,4 @@ void T2TOutput::Make(XTensor &input, XTensor &output)
output = LogSoftmax(MMul(x, w), -1);
}
}
\ No newline at end of file
}
......@@ -59,6 +59,8 @@ void T2TTrainer::Init(int argc, const char ** argv)
LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1);
LoadParamInt(argc, argv, "nepoch", &nepoch, 1);
LoadParamInt(argc, argv, "nstep", &nstep, 1);
LoadParamInt(argc, argv, "d", &d, 512);
LoadParamInt(argc, argv, "nwarmup", &nwarmup, 4000);
LoadParamInt(argc, argv, "vsize", &vSize, 1);
LoadParamBool(argc, argv, "sorted", &isLenSorted, false);
LoadParamInt(argc, argv, "bufsize", &bufSize, 50000);
......@@ -82,6 +84,7 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
int wordCountTotal = 0;
bool isEnd = false;
float loss = 0;
float lr = 0;
XNet net;
......@@ -108,8 +111,12 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
/* back-propagation for obtaining gradients */
net.Backward(output, batch, CROSSENTROPY);
/* learning rate */
lr = (1 / (float)sqrt((float)d)) * (float)MIN(pow(step + 1, -0.5), (step + 1) * pow(nwarmup, -1.5));
//lr = 0.00005F;
/* update the parameters */
Update(model);
Update(model, lr);
/* get probabilities */
float prob = GetProb(&output, &batch, NULL);
......@@ -125,18 +132,21 @@ void T2TTrainer::Train(const char * fn, T2TModel * model)
if (step % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT5(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, ngram=%d, ppl=%.3f\n",
elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n",
lr, elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
}
}
fclose(file);
if (isEnd)
break;
}
double elapsed = GetClockSec() - startT;
XPRINT5(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, ngram=%d, ppl=%.3f\n",
elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n",
lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
elapsed, step, epoch);
}
......@@ -317,10 +327,14 @@ float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs)
}
/*
update the model by delta rule
update the model by delta rule
\theta_new = \theta - \lrate * grad
where
\lrate = d^-0.5 * min(stepNum^-0.5, stepNum * warmupStepNum^-1.5)
>> model - the t2t model
>> lr - learning rate
*/
void T2TTrainer::Update(T2TModel * model)
void T2TTrainer::Update(T2TModel * model, const float lr)
{
XList ws(100);
......@@ -331,6 +345,13 @@ void T2TTrainer::Update(T2TModel * model)
ws.Add(&model->encoder.fnns[i].b1);
ws.Add(&model->encoder.fnns[i].w2);
ws.Add(&model->encoder.fnns[i].b2);
ws.Add(&model->encoder.attentions[i].wk);
ws.Add(&model->encoder.attentions[i].wq);
ws.Add(&model->encoder.attentions[i].wv);
ws.Add(&model->encoder.fnnLayerNorms[i].w);
ws.Add(&model->encoder.fnnLayerNorms[i].b);
ws.Add(&model->encoder.attLayerNorms[i].w);
ws.Add(&model->encoder.attLayerNorms[i].b);
}
ws.Add(&model->encoder.embedder.w);
......@@ -339,11 +360,37 @@ void T2TTrainer::Update(T2TModel * model)
XTensor * para = (XTensor*)ws.Get(i);
XTensor * paraGrad = para->grad;
if (para == NULL || paraGrad == NULL)
continue;
CheckNTErrors(para != NULL, "NULL parameter tensor!");
CheckNTErrors(paraGrad != NULL, "NULL gradient tensor!");
/*
DTYPE * d = new DTYPE[para->unitNum * para->unitSize];
DTYPE * g = new DTYPE[para->unitNum * para->unitSize];
XMemCopy(d, -1, para->data, para->devID, para->unitNum * para->unitSize);
XMemCopy(g, -1, paraGrad->data, paraGrad->devID, para->unitNum * para->unitSize);
for (int i = 0; i < para->unitNum; i++) {
if (IsNAN(d[i]) || IsINF(d[i])) {
int nnn = 0;
}
if (IsNAN(g[i]) || IsINF(g[i])) {
int nnn = 0;
}
}
delete[] d;
delete[] g;
*/
/* the delta rule */
_Sum(para, paraGrad, para, -lrate);
_Sum(para, paraGrad, para, -lr);
/* clear gradient */
paraGrad->SetZeroAll();
}
}
......
......@@ -63,6 +63,12 @@ public:
/* indicates whether the sequence is sorted by length */
bool isLenSorted;
/* dimension size of each inner layer */
int d;
/* step number of warm-up for training */
int nwarmup;
/* vocabulary size of the source side */
int vSize;
......@@ -105,7 +111,7 @@ public:
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
/* update the model by delta rule */
void Update(T2TModel * model);
void Update(T2TModel * model, const float lr);
};
......
......@@ -26,6 +26,8 @@
namespace transformer
{
FILE * tmpFILE;
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP)
{
char vname[128];
......@@ -98,7 +100,9 @@ void ShowParams(int argc, const char ** argv)
{
fprintf(stderr, "args:\n");
for(int i = 0; i < argc; i++){
if(argv[i][0] == '-'){
if(argv[i][1] == 0)
continue;
if(argv[i][0] == '-' && (argv[i][1] < '1' || argv[i][1] > '9')){
if(i + 1 < argc && argv[i + 1][0] != '-')
fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]);
else
......
......@@ -27,6 +27,8 @@
namespace transformer
{
extern FILE * tmpFILE;
/* load arguments */
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP);
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP);
......
......@@ -33,6 +33,8 @@ int TransformerMain(int argc, const char ** argv)
if(argc == 0)
return 1;
tmpFILE = fopen("tmp.txt", "wb");
ShowParams(argc, argv);
char * trainFN = new char[MAX_LINE_LENGTH];
......@@ -51,6 +53,8 @@ int TransformerMain(int argc, const char ** argv)
delete[] trainFN;
fclose(tmpFILE);
return 0;
}
......
......@@ -45,7 +45,7 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(123);
/* a tiny test */
SmallTest();
//SmallTest();
//_CrtDumpMemoryLeaks();
//return 0;
......
......@@ -43,7 +43,7 @@
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
#define _XINLINE_ inline
#define _XINLINE_
//#define DOUBELPRICSION
......
......@@ -45,12 +45,16 @@ const char * GetOPName(int type)
return "M_CLIP";
else if (type == MATH_DIV)
return "M_DIV";
else if (type == MATH_DIVDIM)
return "M_DIVDIM";
else if (type == MATH_MATRIXMUL)
return "M_MATRIXMUL";
else if (type == MATH_MATRIXMULBATCHED)
return "M_MATRIXMULBATCHED";
else if (type == MATH_MULTIPLY)
return "M_MULTIPLY";
else if (type == MATH_MULTIPLYDIM)
return "M_MULTIPLYDIM";
else if (type == MATH_NEGATE)
return "M_NEGATE";
else if (type == MATH_NORMALIZE)
......@@ -61,10 +65,12 @@ const char * GetOPName(int type)
return "M_SCALEANDSHIFT";
else if (type == MATH_SIGN)
return "M_SIGN";
else if (type == MATH_SUM)
return "M_SUM";
else if (type == MATH_SUB)
return "M_SUB";
else if (type == MATH_SUBDIM)
return "M_SUBDIM";
else if (type == MATH_SUM)
return "M_SUM";
else if (type == MATH_SUMDIM)
return "M_SUMDIM";
else if (type == REDUCE_REDUCEMAX)
......
......@@ -41,17 +41,20 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_CLIP MATH_ROUND + 1
#define MATH_DIV MATH_CLIP + 1
#define MATH_MATRIXMUL MATH_DIV + 1
#define MATH_DIVDIM MATH_DIV + 1
#define MATH_MATRIXMUL MATH_DIVDIM + 1
#define MATH_MATRIXMULBATCHED MATH_MATRIXMUL + 1
#define MATH_MULTIPLY MATH_MATRIXMULBATCHED + 1
#define MATH_NEGATE MATH_MULTIPLY + 1
#define MATH_MULTIPLYDIM MATH_MULTIPLY + 1
#define MATH_NEGATE MATH_MULTIPLYDIM + 1
#define MATH_NORMALIZE MATH_NEGATE + 1
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
#define MATH_SIGN MATH_SCALEANDSHIFT + 1
#define MATH_SUM MATH_SIGN + 1
#define MATH_SUB MATH_SUM + 1
#define MATH_SUMDIM MATH_SUB + 1
#define MATH_SUB MATH_SIGN + 1
#define MATH_SUBDIM MATH_SUB + 1
#define MATH_SUM MATH_SUBDIM + 1
#define MATH_SUMDIM MATH_SUM + 1
#define REDUCE MATH_SUMDIM + 1
#define REDUCE_REDUCEMAX REDUCE + 1
......
......@@ -1046,7 +1046,7 @@ bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2)
CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 1 is out of range!");
CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
int dims[3] = {d0, d1, d1};
int dims[3] = {d0, d1, d2};
return SetToDevice(devID, GetCell(dims, 3), value);
}
......
......@@ -27,15 +27,18 @@
#include "../XTensor.h"
#include "arithmetic/Div.h"
#include "arithmetic/DivDim.h"
#include "arithmetic/MatrixMul.h"
#include "arithmetic/MatrixMul2D.h"
#include "arithmetic/MatrixMul2DMultiTheading.h"
#include "arithmetic/MatrixMul2DParallel.h"
#include "arithmetic/MatrixMulBatched.h"
#include "arithmetic/Multiply.h"
#include "arithmetic/MultiplyDim.h"
#include "arithmetic/Negate.h"
#include "arithmetic/Sign.h"
#include "arithmetic/Sub.h"
#include "arithmetic/SubDim.h"
#include "arithmetic/Sum.h"
#include "arithmetic/SumByColumnTV.h"
#include "arithmetic/SumByColumnVT.h"
......
......@@ -23,6 +23,7 @@
#include "../../XName.h"
#include "Div.h"
#include "Div.cuh"
#include "DivDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -137,6 +138,33 @@ void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim)
_Div(a, b, a, alpha, leadingDim);
}
/*
return a dimension if the division is performed as DivDim (in more details in DivDim.h)
>> a - a tensor
>> b - another tensor for division
*/
int GetDivDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/*
element-wise division of two tensors (return a XTensor structure)
make a new tensor c to keep the result and return it
......@@ -146,23 +174,41 @@ where i is the index of the item
>> a - tensor a
>> b - tensor b
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors
*/
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim)
XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
{
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XTensor c(&a);
c.SetTMP();
int n = GetDivDimIndex(a, b);
if(n == -1){
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* call _Div function */
_Div(&a, &b, &c, alpha, leadingDim);
/* call _Multiply function */
_Div(&a, &b, &c, 0, leadingDim);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHeadInt(&c, leadingDim);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
else if(n >= 0 && n < a.order){
/* call _DivDim function */
_DivDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadInt(&c, alpha);
}
else{
ShowNTErrors("Something is wrong!");
}
return c;
}
......
......@@ -31,7 +31,7 @@ element-wise division of two tensors:
c(i) = a(i)/b(i) + \alpha * c(i)
where i is the index of the element
*/
void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0.0, int leadingDim = 0);
/*
element-wise division of two tensors (do it on site)
......@@ -39,7 +39,7 @@ keep the result in the input tensor a and return nothing
a(i) = a(i)/b(i) + \alpha * a(i)
where i is the index of the element
*/
void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha = 0, int leadingDim = 0);
void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha = 0.0, int leadingDim = 0);
/*
element-wise division of two tensors (return a XTensor structure)
......@@ -47,7 +47,7 @@ make a new tensor to keep the result and return it
c(i) = a(i)/b(i)
where i is the index of the element
*/
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim = 0);
XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha = 0.0, int leadingDim = 0);
} // namespace nts(NiuTrans.Tensor)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-15
*/
#include "Div.h"
#include "DivDim.h"
#include "DivDim.cuh"
#include "../../XName.h"
#include "../movement/CopyValues.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor division
c = a / b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put result. we save it in a if c is NULL
>> n - the dimension index
>> alpha - the scaling factor
*/
void _DivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in division!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in addition!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in division!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
if(XTensor::IsSameShaped(a, b)){
_Div(a, b, c, alpha);
return;
}
if(a->devID >= 0 || b->devID >= 0 || c->devID >= 0){
#ifdef USE_CUDA
_CudaDivDim(a, b, c, n, alpha);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
}
else{
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for(int i = a->order - 1; i >= 0; i--){
if(i > n)
stride *= a->dimSize[i];
else if(i < n)
blockNum *= a->dimSize[i];
}
if (a->dataType == DEFAULT_DTYPE){
int num = a->unitNum;
if(stride > 1){
for(int i = 0, j = 0; i < num; i += stride, j++){
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE bv = *((DTYPE*)b->data + j % blockSize);
DTYPE * cp = (DTYPE*)c->data + i;
for(int k = 0; k < stride; k++){
if(alpha == 0.0F)
cp[k] = ap[k] / bv;
else
cp[k] = ap[k] / bv + alpha * cp[k];
}
}
}
else if(stride == 1){
DTYPE * bp = (DTYPE*)b->data;
for(int i = 0; i < num; i += blockSize){
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE * cp = (DTYPE*)c->data + i;
if(alpha == 0.0F){
for(int j = 0; j < blockSize; j++)
cp[j] = ap[j] / bp[j];
}
else{
for(int j = 0; j < blockSize; j++)
cp[j] = ap[j] / bp[j] + alpha * cp[j];
}
}
}
else{
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
}
}
/*
tensor division of two tensors (do it on site)
keep the result in the input tensor and return nothing
a = a/b + \alpha * a
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> alpha - the scaling factor
*/
void _DivDim(XTensor * a, const XTensor * b, int n, DTYPE alpha)
{
_DivDim(a, b, a, n, alpha);
}
/*
tensor division of two tensors (return a XTensor structure and make tensor connections)
make a new tensor to keep the result and return it
c = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> alpha - the scaling factor
<< return - the result tensor by tensor division
*/
XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
{
XTensor c(&a);
c.SetTMP();
/* call _Div function */
_DivDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
return c;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-15
*/
#include "DivDim.cuh"
#include "../../XDevice.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
tensor division of a tensor and a row vector
c = a / b + alpha * c
where a is a tensor and b is a row vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c
>> colNum - number of columns of a and c (i.e., the size of b)
>> alpha - the scaling factor
*/
template <class T, bool alphaFired>
__global__
void KernelDivWithRow(T * a, T * b, T * c, int rowNum, int colNum, T alpha)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int col = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
if(col >= colNum || row >= rowNum)
return;
if(threadIdx.y == 0)
bv[threadIdx.x] = b[col];
__syncthreads();
int offset = colNum * row + col;
if(alphaFired)
c[offset] = a[offset] / bv[threadIdx.x] + c[offset] * alpha;
else
c[offset] = a[offset] / bv[threadIdx.x];
}
/*
tensor division of a tensor and a colum vector
c = a / b + alpha * c
where a is a tensor and b is a colum vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c (i.e., the size of b)
>> colNum - number of columns of a and c
>> blockNum - size of a block (matrix), i.e., rowNum * colNum
>> blockNum - number of matrics
>> alpha - the scaling factor
*/
template <class T, bool alphaFired>
__global__
void KernelDivWithCol(T * a, T * b, T * c, int rowNum, int colNum, int blockSize, int blockNum, T alpha)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int colIndex = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
int col = colIndex % colNum;
int block = colIndex / colNum;
if(row >= rowNum || block >= blockNum)
return;
if(threadIdx.x == 0)
bv[threadIdx.y] = b[row];
__syncthreads();
int offset = block * blockSize + row * colNum + col;
if(alphaFired)
c[offset] = a[offset] / bv[threadIdx.y] + c[offset] * alpha;
else
c[offset] = a[offset] / bv[threadIdx.y];
}
/*
tensor division
c = a / b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a / b + \alpha * c. we save it in a if c is NULL
>> n - the dimension index
>> alpha - the scaling factor
*/
void _CudaDivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in division!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in division!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in division!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for(int i = a->order - 1; i >= 0; i--){
if(i > n)
stride *= a->dimSize[i];
else if(i < n)
blockNum *= a->dimSize[i];
}
int cudaGrids[3];
int cudaBlocks[3];
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE){
if(stride > 1){
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
if(alpha == (DTYPE)0.0F)
KernelDivWithCol<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, alpha);
else
KernelDivWithCol<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, alpha);
}
else if(stride == 1){
GDevs.GetCudaThread2D(a->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
if(alpha == (DTYPE)0.0F)
KernelDivWithRow<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, alpha);
else
KernelDivWithRow<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, alpha);
}
else{
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-15
*/
#ifndef __DIVDIM_CUH__
#define __DIVDIM_CUH__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
tensor division
c(i) = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting (cuda version)
*/
void _CudaDivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha = (DTYPE)0.0);
#endif
} // namespace nts(NiuTrans.Tensor)
#endif // __DIVDIM_CUH__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-15
*/
#ifndef __DIVDIM_H__
#define __DIVDIM_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor division of two tensors:
c(i) = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
*/
void _DivDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha = (DTYPE)0.0);
/*
tensor division of two tensors:
c(i) = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
we keep the result in the input tensor a and return nothing
*/
void _DivDim(XTensor * a, const XTensor * b, int n, DTYPE alpha = (DTYPE)0.0);
/*
tensor division of two tensors:
c(i) = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
we make a new tensor c to keep the result and return it
*/
XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha = (DTYPE)0.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __DIVDIM_H__
......@@ -23,6 +23,7 @@
#include "../../XName.h"
#include "Multiply.h"
#include "Multiply.cuh"
#include "MultiplyDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -138,6 +139,33 @@ void _MultiplyMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim)
_Multiply(a, b, a, alpha, leadingDim);
}
/*
return a dimension if the multiplication is performed as MultiplyDim (in more details in MultiplyDim.h)
>> a - a tensor
>> b - another tensor for multiplication
*/
int GetMultiplyDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/*
element-wise product of two tensors (return a XTensor structure)
make a new tensor c to keep the result and return it
......@@ -150,20 +178,38 @@ where i is the index of the item
>> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors
*/
XTensor Multiply(const XTensor &a, const XTensor &b, int leadingDim)
XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
{
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XTensor c(&a);
c.SetTMP();
/* call _Multiply function */
_Multiply(&a, &b, &c, 0, leadingDim);
int n = GetMultiplyDimIndex(a, b);
if(n == -1){
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHeadInt(&c, leadingDim);
/* call _Multiply function */
_Multiply(&a, &b, &c, 0, leadingDim);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
else if(n >= 0 && n < a.order){
/* call _MultiplyDim function */
_MultiplyDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadInt(&c, alpha);
}
else{
ShowNTErrors("Something is wrong!");
}
return c;
}
......
......@@ -31,7 +31,7 @@ element-wise product of two tensors:
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the element
*/
void _Multiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
void _Multiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0.0, int leadingDim = 0);
/*
element-wise product of two tensors (do it on site)
......@@ -39,7 +39,7 @@ keep the result in the input tensor a and return nothing
a(i) = a(i)*b(i) + \alpha * a(i)
where i is the index of the element
*/
void _MultiplyMe(XTensor * a, const XTensor * b, DTYPE alpha = 0, int leadingDim = 0);
void _MultiplyMe(XTensor * a, const XTensor * b, DTYPE alpha = 0.0, int leadingDim = 0);
/*
element-wise product of two tensors (return a XTensor structure)
......@@ -47,7 +47,7 @@ make a new tensor to keep the result and return it
c(i) = a(i)*b(i)
where i is the index of the element
*/
XTensor Multiply(const XTensor &a, const XTensor &b, int leadingDim = 0);
XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha = 0.0, int leadingDim = 0);
} // namespace nts(NiuTrans.Tensor)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2018-08-14
*/
#include "Multiply.h"
#include "MultiplyDim.h"
#include "MultiplyDim.cuh"
#include "../../XName.h"
#include "../movement/CopyValues.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor multiplication
c = a * b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a * b + \alpha * c. we save it in a if c is NULL
>> n - the dimension index
>> alpha - the scaling factor
*/
void _MultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha) {
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in multiplication!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in multiplication!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in multiplication!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
if(XTensor::IsSameShaped(a, b)){
_Multiply(a, b, c, alpha);
return;
}
if(a->devID >= 0 || b->devID >= 0 || c->devID >= 0){
#ifdef USE_CUDA
_CudaMultiplyDim(a, b, c, n, alpha);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
}
else{
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for(int i = a->order - 1; i >= 0; i--){
if(i > n)
stride *= a->dimSize[i];
else if(i < n)
blockNum *= a->dimSize[i];
}
if(a->dataType == DEFAULT_DTYPE){
int num = a->unitNum;
if(stride > 1){
for(int i = 0, j = 0; i < num; i += stride, j++){
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE bv = *((DTYPE*)b->data + j % blockSize);
DTYPE * cp = (DTYPE*)c->data + i;
for(int k = 0; k < stride; k++)
if(alpha == 0.0F)
cp[k] = ap[k] * bv;
else
cp[k] = ap[k] * bv + alpha * cp[k];
}
}
else if(stride == 1){
DTYPE * bp = (DTYPE*)b->data;
for(int i = 0; i < num; i += blockSize){
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE * cp = (DTYPE*)c->data + i;
if(alpha == 0.0F){
for(int j = 0; j < blockSize; j++)
cp[j] = ap[j] * bp[j];
}
else{
for(int j = 0; j < blockSize; j++)
cp[j] = ap[j] * bp[j] + alpha * cp[j];
}
}
}
else{
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
}
}
/*
tensor multiplication(do it on site)
make a new tensor to keep the result and return it
c = a * b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> alpha - the scaling factor
*/
void _MultiplyDimMe(XTensor * a, const XTensor * b, int n, DTYPE alpha)
{
_MultiplyDim(a, b, a, n, alpha);
}
/*
tensor multiplication (return a XTensor structure and make tensor connections)
make a new tensor to keep the result and return it
c = a * b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> alpha - the scaling factor
<< return - the result tensor by tensor multiplication
*/
XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
{
XTensor c(&a);
c.SetTMP();
/* call _Multiply function */
_MultiplyDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
return c;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2018-08-14
*/
#include "../../XDevice.h"
#include "../../XUtility.h"
#include "MultiplyDim.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
tensor multiplication of a tensor and a row vector
c = a * b + \alpha * c
where a is a tensor and b is a row vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c
>> colNum - number of columns of a and c (i.e., the size of b)
>> alpha - the scaling factor
*/
template <class T, bool alphaFired>
__global__
void KernelMultiplyWithRow(T * a, T * b, T * c, int rowNum, int colNum, T alpha)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int col = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
if (col >= colNum || row >= rowNum)
return;
if (threadIdx.y == 0)
bv[threadIdx.x] = b[col];
__syncthreads();
int offset = colNum * row + col;
if (alphaFired)
c[offset] = a[offset] * bv[threadIdx.x] + c[offset] * alpha;
else
c[offset] = a[offset] * bv[threadIdx.x];
}
/*
tensor multiplication of a tensor and a colum vector
c = a * b + \alpha * c
where a is a tensor and b is a colum vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c (i.e., the size of b)
>> colNum - number of columns of a and c
>> blockNum - size of a block (matrix), i.e., rowNum * colNum
>> blockNum - number of matrics
>> alpha - the scaling factor
*/
template <class T, bool alphaFired>
__global__
void KernelMultiplyWithCol(T * a, T * b, T * c, int rowNum, int colNum, int blockSize, int blockNum, T alpha)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int colIndex = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
int col = colIndex % colNum;
int block = colIndex / colNum;
if (row >= rowNum || block >= blockNum)
return;
if (threadIdx.x == 0)
bv[threadIdx.y] = b[row];
__syncthreads();
int offset = block * blockSize + row * colNum + col;
if (alphaFired)
c[offset] = a[offset] * bv[threadIdx.y] + c[offset] * alpha;
else
c[offset] = a[offset] * bv[threadIdx.y];
}
/*
tensor multiplication
c = a * b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a * b + \alpha * c. we save it in a if c is NULL
>> n - the dimension index
>> alpha - the scaling factor
*/
void _CudaMultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in multiplication!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in multiplication!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in multiplication!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for (int i = a->order - 1; i >= 0; i--) {
if (i > n)
stride *= a->dimSize[i];
else if (i < n)
blockNum *= a->dimSize[i];
}
int cudaGrids[3];
int cudaBlocks[3];
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
if (stride > 1) {
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
if(alpha == (DTYPE)0.0F)
KernelMultiplyWithCol<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, alpha);
else
KernelMultiplyWithCol<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, alpha);
}
else if (stride == 1) {
GDevs.GetCudaThread2D(a->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
if(alpha == (DTYPE)0.0F)
KernelMultiplyWithRow<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, alpha);
else
KernelMultiplyWithRow<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, alpha);
}
else {
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2018-08-14
*/
#ifndef __MULTIPLYDIM_CUH__
#define __MULTIPLYDIM_CUH__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* tensor summation a * b + \alpha * c where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting (cuda version) */
void _CudaMultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha = 0);
#endif
} // namespace nts(NiuTrans.Tensor)
#endif // __MULTIPLYDIM_CUH__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2018-08-14
*/
#ifndef __MULTIPLYDIM_H__
#define __MULTIPLYDIM_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* tensor multiplication c = a * b + \alpha * c where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting */
void _MultiplyDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE alpha = 0.0);
/* tensor multiplication a = a * b + \alpha * c where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting. we keep the result in the input tensor a and return nothing */
void _MultiplyDimMe(XTensor * a, const XTensor * b, int n, DTYPE alpha = 0.0);
/* tensor multiplication c = a * b + \alpha * c where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting. We make a new tensor c to keep the result and return it */
XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha = 0.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __MULTIPLYDIM_H__
......@@ -24,6 +24,7 @@
#include "../../XUtility.h"
#include "Sub.h"
#include "Sub.cuh"
#include "SubDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -123,7 +124,34 @@ void _SubMe(XTensor * a, const XTensor * b, DTYPE beta)
{
_Sub(a, b, a, beta);
}
/*
return a dimension if the subtraction is performed as SubDim (in more details in SubDim.h)
>> a - a tensor
>> b - another tensor for subtraction
*/
int GetSubDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/*
tensor subtraction c = a - b * \beta (return a XTensor structure)
make a new tensor c to keep the result and return it
......@@ -138,12 +166,28 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
XTensor c(&a);
c.SetTMP();
/* call _Sub function */
_Sub(&a, &b, &c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
int n = GetSubDimIndex(a, b);
if(n == -1){
/* call _Sub function */
_Sub(&a, &b, &c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
}
else if(n >= 0 && n < a.order){
/* call _SubDim function */
_SubDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
else{
ShowNTErrors("Something is wrong!");
}
return c;
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#include "Sub.h"
#include "SubDim.h"
#include "SubDim.cuh"
#include "../../XName.h"
#include "../movement/CopyValues.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor subtraction
c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a-b*\beta. we save it in a if c is NULL
>> n - the dimension index
>> beta - the scaling factor
*/
void _SubDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in subtraction!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in subtraction!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in subtraction!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
if (beta == 0) {
_CopyValues(a, c);
return;
}
if (XTensor::IsSameShaped(a, b)) {
_Sub(a, b, c, beta);
return;
}
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
_CudaSubDim(a, b, c, n, beta);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
}
else {
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for (int i = a->order - 1; i >= 0; i--) {
if (i > n)
stride *= a->dimSize[i];
else if (i < n)
blockNum *= a->dimSize[i];
}
if (a->dataType == DEFAULT_DTYPE) {
int num = a->unitNum;
if (stride > 1) {
for (int i = 0, j = 0; i < num; i += stride, j++) {
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE bv = *((DTYPE*)b->data + j % blockSize) * beta;
DTYPE * cp = (DTYPE*)c->data + i;
for (int k = 0; k < stride; k++)
cp[k] = ap[k] - bv;
}
}
else if (stride == 1) {
DTYPE * bp = (DTYPE*)b->data;
for (int i = 0; i < num; i += blockSize) {
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE * cp = (DTYPE*)c->data + i;
if (beta == 1.0F) {
for (int j = 0; j < blockSize; j++)
cp[j] = ap[j] - bp[j];
}
else {
for (int j = 0; j < blockSize; j++)
cp[j] = ap[j] - bp[j] * beta;
}
}
}
else {
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
}
}
/*
tensor subtraction (do it on site)
keep the result in the input tensor and return nothing
c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> beta - the scaling factor
*/
void _SubDim(XTensor * a, const XTensor * b, int n, DTYPE beta)
{
_SubDim(a, b, a, n, beta);
}
/*
tensor subtraction (return a XTensor structure and make tensor connections)
make a new tensor to keep the result and return it
c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> beta - the scaling factor
<< return - the result tensor by tensor subtraction
*/
XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
/* call _Sub function */
_SubDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
return c;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#include "SubDim.cuh"
#include "../../XDevice.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
tensor subtraction of a tensor and a row vector
c = a - b * \beta
where a is a tensor and b is a row vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c
>> colNum - number of columns of a and c (i.e., the size of b)
>> beta - the scaling factor
*/
template <class T, bool betaFired>
__global__
void KernelSubWithRow(T * a, T * b, T * c, int rowNum, int colNum, T beta)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int col = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
if (col >= colNum || row >= rowNum)
return;
if (threadIdx.y == 0)
bv[threadIdx.x] = b[col];
__syncthreads();
int offset = colNum * row + col;
if (betaFired)
c[offset] = a[offset] - bv[threadIdx.x] * beta;
else
c[offset] = a[offset] - bv[threadIdx.x];
}
/*
tensor subtraction of a tensor and a colum vector
c = a - b * \beta
where a is a tensor and b is a colum vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c (i.e., the size of b)
>> colNum - number of columns of a and c
>> blockNum - size of a block (matrix), i.e., rowNum * colNum
>> blockNum - number of matrics
>> beta - the scaling factor
*/
template <class T, bool betaFired>
__global__
void KernelSubWithCol(T * a, T * b, T * c, int rowNum, int colNum, int blockSize, int blockNum, T beta)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int colIndex = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
int col = colIndex % colNum;
int block = colIndex / colNum;
if (row >= rowNum || block >= blockNum)
return;
if (threadIdx.x == 0)
bv[threadIdx.y] = b[row];
__syncthreads();
int offset = block * blockSize + row * colNum + col;
if (betaFired)
c[offset] = a[offset] - bv[threadIdx.y] * beta;
else
c[offset] = a[offset] - bv[threadIdx.y];
}
/*
tensor subtraction (cuda version)
c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a+b*\beta. we save it in a if c is NULL
>> n - the dimension index
>> beta - the scaling factor
*/
void _CudaSubDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in subtraction!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in subtraction!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in subtraction!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for (int i = a->order - 1; i >= 0; i--) {
if (i > n)
stride *= a->dimSize[i];
else if (i < n)
blockNum *= a->dimSize[i];
}
int cudaGrids[3];
int cudaBlocks[3];
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
if (stride > 1) {
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
if (beta == (DTYPE)1.0F)
KernelSubWithCol<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, beta);
else
KernelSubWithCol<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, beta);
}
else if (stride == 1) {
GDevs.GetCudaThread2D(a->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
if (beta == (DTYPE)1.0F)
KernelSubWithRow<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, beta);
else
KernelSubWithRow<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, beta);
}
else {
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#ifndef __SUBDIM_CUH__
#define __SUBDIM_CUH__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* tensor subtraction c = a - b * \beta where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting (cuda version) */
void _CudaSubDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta = (DTYPE)1.0);
#endif
} // namespace nts(NiuTrans.Tensor)
#endif // __SUBDIM_CUH__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#ifndef __SUBDIM_H__
#define __SUBDIM_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* tensor subtraction c = a - b * \beta where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting*/
void _SubDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta = (DTYPE)1.0);
/* tensor subtraction c = a - b * \beta where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting. we keep the result in the input tensor a and return nothing */
void _SubDim(XTensor * a, const XTensor * b, int n, DTYPE beta = (DTYPE)1.0);
/* tensor subtraction c = a - b * \beta where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting. We make a new tensor c to keep the result and return it */
XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __SUBDIM_H__
......@@ -131,7 +131,7 @@ void _SumMe(XTensor * a, const XTensor * b, DTYPE beta)
}
/*
return a dimension if the sum is performed as SumDim (in more details in SumDim.h
return a dimension if the sum is performed as SumDim (in more details in SumDim.h)
>> a - a tensor
>> b - another tensor for sum
*/
......@@ -182,7 +182,7 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
XLink::AddParamToHead(&c, beta);
}
else if(n >= 0 && n < a.order){
/* call _Sum function */
/* call _SumDim function */
_SumDim(&a, &b, &c, n, beta);
/* tensor connections */
......
......@@ -49,7 +49,7 @@ void _SelectRange(const XTensor * a, XTensor * c, int dim, int low, int high)
for(int i = 0; i < a->order; i++){
if(i == dim){
CheckNTErrors(low > 0 && low < a->dimSize[dim], "Illegal range specified!");
CheckNTErrors(low >= 0 && low < a->dimSize[dim], "Illegal range specified!");
CheckNTErrors(high > 0 && high <= a->dimSize[dim], "Illegal range specified!");
}
else{
......@@ -101,7 +101,7 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high)
for(int i = 0; i < a.order; i++){
if(i == dim){
CheckNTErrors(low > 0 && low < a.dimSize[dim], "Illegal range specified!");
CheckNTErrors(low >= 0 && low < a.dimSize[dim], "Illegal range specified!");
CheckNTErrors(high > 0 && high <= a.dimSize[dim], "Illegal range specified!");
dimSize[i] = high - low;
}
......@@ -118,6 +118,7 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high)
/* tensor connection */
XLink::MakeLink(&a, NULL, &c, GETANDSET_SELECT);
XLink::AddParamToHeadInt(&c, dim);
XLink::AddParamToHeadInt(&c, low);
XLink::AddParamToHeadInt(&c, high);
......
......@@ -70,8 +70,8 @@ void _SetDataFanInOut(XTensor * tensor, DTYPE gain)
fanOut = numOutputFmaps * receptiveFieldSize;
}
DTYPE std = gain * sqrt(2.0/(fanIn + fanOut));
DTYPE a = sqrt(3.0) * std;
DTYPE std = gain * (float)sqrt(2.0/(fanIn + fanOut));
DTYPE a = (DTYPE)sqrt(3.0) * std;
_SetDataRand(tensor, -a, a);
}
......@@ -213,6 +213,106 @@ void _SetDataFixedDouble(XTensor * tensor, double p)
_SetDataFixed(tensor, &p);
}
/*
set data items along with a given dimension (and keep the remaining items unchanged)
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> len - length along with the given dimension
>> dim - the dimension along which we set the data
e.g., given a 3 * 3 tensor
1 2 3
4 5 6
7 8 9
when beg = 1, len = 1, dim = 0 and p = 0, we have
1 2 3
0 0 0
7 8 9
i.e., we set all entries of row 1 to 0
*/
void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim > 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
if(tensor->devID < 0){
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for(int i = n - 1; i > dim; i--){
stride *= tensor->GetDim(i);
}
blockSize = stride * tensor->GetDim(dim);
blockNum = tensor->unitNum / blockSize;
int l = len * stride;
for(int i = 0; i < blockNum; i++){
DTYPE * d = (DTYPE*)tensor->data + blockSize * i + beg * stride;
for(int j = 0; j < l; j++)
d[j] = p;
}
}
else{
#ifdef USE_CUDA
_CudaSetDataDim(tensor, beg, len, dim, p);
#endif
}
}
/*
generate data as lower triangular matrics for last two dimensions
>> tensor - the tensor whose data to be set
>> p - the value for each entry of the lower triangular matrics
>> shift - the offset from diagonal
e.g., for a 3* 3 tensor,
when p = 1 ans shift = 0, we have
1 0 0
1 1 0
1 1 1
when p = 2 and shift = -1, we have
0 0 0
2 0 0
2 2 0
*/
void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(n >= 2, "The tensor must have a order no less than 2!");
CheckNTErrors(tensor->GetDim(n - 1) == tensor->GetDim(n - 2),
"The last two dimensions must be of the same size!");
if(tensor->devID < 0){
int l = tensor->GetDim(-1);
int blockNum = 1;
int blockSize = l * l;
for(int i = 0; i < n - 2; i++)
blockNum *= tensor->GetDim(i);
for(int i = 0; i < blockNum; i++){
DTYPE * d = (DTYPE*)tensor->data + i * blockSize;
for(int row = 0; row < l; row++){
for(int col = 0; col <= row + shift; col++){
d[row * l + col] = p;
}
for(int col = MAX(0, row + shift + 1); col < l; col++){
d[row * l + col] = 0;
}
}
}
}
else{
#ifdef USE_CUDA
_CudaSetDataLowTri(tensor, p, shift);
#endif
}
}
/*
generate data items with a uniform distribution in [lower, upper]
>> tensor - the tensor whose data array would be initialized
......
......@@ -184,6 +184,169 @@ void KernelSetDataRandDouble(double * d, int size, DTYPE lower, DTYPE variance)
}
}
/*
set data items along with a given dimension (and keep the remaining items unchanged) - kernel version
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> len - length of the segment to be set
>> blockSize - size of a data block
>> blockNum - number of data blocks
*/
__global__
void KernelSetDataDim(DTYPE * d, int beg, int len, int blockSize, int blockNum, DTYPE p)
{
/* offset in each block */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* block id */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= blockSize || j > blockNum)
return;
if(i < beg || i >= beg + len)
return;
d[blockSize * j + i] = p;
}
/*
set data items along with a given dimension (and keep the remaining items unchanged) - cuda version
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> len - length along with the given dimension
>> dim - the dimension along which we set the data
e.g., given a 3 * 3 tensor
1 2 3
4 5 6
7 8 9
when beg = 1, len = 1, dim = 0 and p = 0, we have
1 2 3
0 0 0
7 8 9
i.e., we set all entries of row 1 to 0
*/
void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim > 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for(int i = n - 1; i > dim; i--){
stride *= tensor->GetDim(i);
}
blockSize = stride * tensor->GetDim(dim);
blockNum = tensor->unitNum / blockSize;
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(tensor->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataDim<<<blocks, threads >>>((DTYPE*)tensor->data, beg * stride, len * stride, blockSize, blockNum, p);
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
set lower triangular matrics for each block
>> d - pointer to the data array
>> l - row number (or column number) of each block, i.e,
a block is l * l matrix
>> blockSize - size of each block (blockSize = l * l)
>> blockNum - number of the blocks
>> p - the value for each entry of the lower triangular matrics
>> shift - the offset from diagonal
e.g., for a 3* 3 tensor,
when p = 1 ans shift = 0, we have
1 0 0
1 1 0
1 1 1
when p = 2 and shift = -1, we have
0 0 0
2 0 0
2 2 0
*/
__global__
void _KernelSetDataLowTri(DTYPE * d, int l, int blockSize, int blockNum, DTYPE p, int shift)
{
/* offset in each block */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* block id */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= blockSize || j > blockNum)
return;
int row = i / l;
int col = i % l;
DTYPE * d2 = d + blockSize * j + row * l + col;
if(col <= row + shift)
*d2 = p;
else
*d2 = 0;
}
/*
generate data as lower triangular matrics for last two dimensions (cuda version)
>> tensor - the tensor whose data to be set
>> p - the value for each entry of the lower triangular matrics
>> shift - the offset from diagonal
e.g., for a 3* 3 tensor,
when p = 1 ans shift = 0, we have
1 0 0
1 1 0
1 1 1
when p = 2 and shift = -1, we have
0 0 0
2 0 0
2 2 0
*/
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(n >= 2, "The tensor must have a order no less than 2!");
CheckNTErrors(tensor->GetDim(n - 1) == tensor->GetDim(n - 2),
"The last two dimensions must be of the same size!");
int l = tensor->GetDim(-1);
int blockNum = 1;
int blockSize = l * l;
for(int i = 0; i < n - 2; i++)
blockNum *= tensor->GetDim(i);
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(tensor->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
_KernelSetDataLowTri<<<blocks, threads >>>((DTYPE*)tensor->data, l, blockSize, blockNum, p, shift);
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
generate data items with a uniform distribution in [lower, upper]
>> tensor - the tensor whose data array would be initialized
......@@ -213,7 +376,7 @@ void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
DTYPE variance = upper - lower;
if (tensor->dataType == X_FLOAT)
KernelSetDataRandFloat <<<blocks, threads >>>((float*)tensor->data, tensor->unitNum, lower, variance);
KernelSetDataRandFloat <<<blocks, threads >>>((float*) tensor->data, tensor->unitNum, lower, variance);
else if (tensor->dataType == X_DOUBLE)
KernelSetDataRandDouble <<<blocks, threads >>>((double*)tensor->data, tensor->unitNum, lower, variance);
......
......@@ -37,6 +37,12 @@ void _CudaSetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */
void _CudaSetDataFixedDouble(XTensor * tensor, double p);
/* set data items along with a given dimension (and keep the remaining items unchanged) */
void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p);
/* generate data as lower triangular matrics for last two dimensions (cuda version) */
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift);
/* generate data items with a uniform distribution in [lower, upper] */
void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
......
......@@ -45,11 +45,17 @@ void _SetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */
void _SetDataFixedDouble(XTensor * tensor, double p);
/* set data items along with a given dimension (and keep the remaining items unchanged) */
void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p);
/* generate data as lower triangular matrics for last two dimensions */
void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift);
/* generate data items with a uniform distribution in [lower, upper] */
void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
/* generate data items with a normal distribution with specified mean and standard deviation */
void _SetDataRandN(XTensor * tensor, DTYPE mean, DTYPE standardDeviation);
void _SetDataRandN(XTensor * tensor, DTYPE mean = 0.0F, DTYPE standardDeviation = 1.0F);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -60,8 +60,12 @@ void _Power(const XTensor * a, XTensor * b, DTYPE p)
bData[i] = aData[i] * aData[i];
}
else {
for (int i = 0; i < a->unitNum; i++)
bData[i] = (DTYPE)pow(aData[i], p);
for (int i = 0; i < a->unitNum; i++) {
if (p < 0 && aData[i] == 0)
bData[i] = 1e20F;
else
bData[i] = (DTYPE)pow(aData[i], p);
}
}
}
......
......@@ -77,8 +77,13 @@ void KernelPower(DTYPE * a, DTYPE * b, DTYPE p, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
b[i] = pow(a[i], p);
if (i < size) {
DTYPE v = a[i];
if (p < 0 && v == 0)
b[i] = 1e20;
else
b[i] = pow(a[i], p);
}
}
/*
......@@ -94,8 +99,13 @@ void KernelPower(__half * a, __half * b, __half p, int size)
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
#else
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
b[i] = __float2half(pow(__half2float(a[i]), __half2float(p)));
if (i < size) {
float v = __half2float(a[i]);
if (__half2float(p) < 0 && v == 0)
b[i] = __float2half(1e20);
else
b[i] = __float2half(pow(__half2float(a[i]), __half2float(p)));
}
#endif
}
......
......@@ -176,15 +176,19 @@ make a new tensor to keep the result and return it
*/
XTensor LogSoftmax(const XTensor &x, int leadDim)
{
int ld = leadDim;
if (ld < 0)
ld = x.order - 1;
XTensor y(&x);
y.SetTMP();
/* call _LogSoftmax function */
_LogSoftmax(&x, &y, leadDim);
_LogSoftmax(&x, &y, ld);
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX);
XLink::AddParamToHeadInt(&y, leadDim);
XLink::AddParamToHeadInt(&y, ld);
return y;
}
......@@ -248,7 +252,7 @@ There are two ways to implement this process.
Method 1. we compute dE/dy and dy/dx resepectively, and then reach dE/dx by dE/dx = dE/dy * dy/dx
(or more precisely dE/dx_j = \sum_{i} {dE/dy_i * dy_i/dx_j})
Method 2. we compute dE/dx (or dE/dx_j) in a single step, rather than resorting to the
sub-models dE/dy and dy/dx. We can do this by using dE/dx_j = -gold_j + exp(y_j)
sub-models of dE/dy and dy/dx. We can do this by using dE/dx_j = -gold_j + exp(y_j)
Here we choose Method 2, i.e., we straightforwardly compute dE/dx_j by
......@@ -257,12 +261,12 @@ dE/dx_j = -gold_j + exp(y_j)
(or dE/dx_j = -\delta(i,j) + exp(y_j) for a Maximum A Posteriori Estimation (MAP))
Method 1 is also fine but is more time consuming due to the summation over dimensions.
Note that this method is not good for the standard version softmax when working with
the cross entropy loss. Because it is numerical unstable. When we use a usual method to
Note that this method is not good for the standard version softmax when we work with
the cross entropy loss because it is numerical unstable. When we use a usual method to
define softmax, we have softmax: y_i = log(e^{x_i} / \sum_{k} e^{x_k}). It is trivial to
know that dy_i/dx_j = y_i * \delta(i,j) - y_i * y_j. As y_i and y_j could be a small number,
y_i * y_i would result in a much smaller one with a risk of lossing precision. This is even
worse we multiply dy_i/dx_j with dE/dy_i. So it is in general to use log softmax instead for
know that dy_i/dx_j = y_i * \delta(i,j) - y_i * y_j. As y_i and y_j could be small numbers,
y_i * y_i would result in a much smaller value with a risk of lossing precision. This is even
worse we multiply dy_i/dx_j with dE/dy_i. So it is in general to use log softmax for
better numerical stability.
>> gold - gold standard to measure error (or loss)
......
......@@ -103,10 +103,10 @@ void _Softmax(const XTensor * x, XTensor * y, int leadDim)
else{
for(int i = 0; i < n; i++){
DTYPE r = (DTYPE)exp(ip[i * m + j] - mp[j])/sp[j];
if(IsNAN(r))
r = DTYPE_MIN;
if(IsINF(r))
r = DTYPE_MIN;
if (r > (DTYPE)1.0F)
r = (DTYPE)1.0F;
else if (r < 0)
r = 0;
op[i * m + j] = r;
}
}
......@@ -143,14 +143,19 @@ make a new tensor to keep the result and return it
*/
XTensor Softmax(const XTensor &x, int leadDim)
{
int ld = leadDim;
if (ld < 0)
ld = x.order - 1;
XTensor y(&x);
y.SetTMP();
/* call _Softmax function */
_Softmax(&x, &y, leadDim);
_Softmax(&x, &y, ld);
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
return y;
}
......
......@@ -85,7 +85,13 @@ void KernelSoftmaxComputeTensor(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y,
if(i < strideSizeTotal && j < strideNum){
int offset = int(i / stride) * blockSize + j * stride + i2[threadIdx.x];
y[offset] = exp(x[offset] - xMax[threadIdx.x])/xSum[threadIdx.x];
DTYPE r = exp(x[offset] - xMax[threadIdx.x])/xSum[threadIdx.x];
if (r >(DTYPE)1.0F)
r = (DTYPE)1.0F;
else if (r < 0)
r = 0;
y[offset] = r;
}
}
......@@ -194,7 +200,12 @@ void KernelSoftmaxComputeTensorUseBroadcast(DTYPE * input, DTYPE * max, DTYPE *
maxData = broadcast(maxData);
if (i < strideNum){
int offset = int(j / stride) * blockSize + i * stride + i2;
output[offset] = exp(input[offset] - maxData) / sumData;
DTYPE r = exp(input[offset] - maxData) / sumData;
if (r > (DTYPE)1.0F)
r = (DTYPE)1.0F;
else if (r < 0)
r = 0;
output[offset] = r;
}
}
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-14
*/
#include "TDivDim.h"
#include "../core/arithmetic/DivDim.h"
#include "../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: tensor division c = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting.
In this case, (2, 4) / (2) = (2, 4), n = 0, alpha = 0.0.
*/
bool TestDivDim1()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2) */
int bOrder = 1;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2] = {1.0F, -1.0F};
DTYPE answer[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{-4.0F, -5.0F, -6.0F, -7.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call DivDim function */
_DivDim(a, b, c, 0);
_DivDim(cMe, b, 0);
cUser = DivDim(*a, *b, 0);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call sum function */
_DivDim(aGPU, bGPU, cGPU, 0);
_DivDim(cMeGPU, bGPU, 0);
cUserGPU = DivDim(*aGPU, *bGPU, 0);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 2: tensor division c = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting.
In this case, (2, 4) / (2, 2) = (2, 4), n = 1.
*/
bool TestDivDim2()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2, 2) */
int bOrder = 2;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
bDimSize[1] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2][2] = { {1.0F, -1.0F},
{-1.0F, 1.0F} };
DTYPE answer[2][4] = { {0.0F, -1.0F, -2.0F, 3.0F},
{4.0F, -5.0F, -6.0F, 7.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call DivDim function */
_DivDim(a, b, c, 1);
_DivDim(cMe, b, 1);
cUser = DivDim(*a, *b, 1);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call sum function */
_DivDim(aGPU, bGPU, cGPU, 1);
_DivDim(cMeGPU, bGPU, 1);
cUserGPU = DivDim(*aGPU, *bGPU, 1);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for DivDim Function */
bool TestDivDim()
{
XPRINT(0, stdout, "[TEST DIVDIM] tensor division c(i) = a/b + \alpha * c by broadcasting\n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestDivDim1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestDivDim2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-14
*/
#ifndef __TEST_DIVDIM_H__
#define __TEST_DIVDIM_H__
#include "../core/arithmetic/DivDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for DivDim Function */
bool TestDivDim();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_DIVDIM_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-30
*/
#include "TMultiplyDim.h"
#include "../core/arithmetic/MultiplyDim.h"
#include "../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: tensor multiplication c = a * b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting
In this case, (2, 4) * (2) = (2, 4), n = 0.
*/
bool TestMultiplyDim1()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2) */
int bOrder = 1;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2] = {1.0F, -1.0F};
DTYPE answer[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{-4.0F, -5.0F, -6.0F, -7.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call MultiplyDim function */
_MultiplyDim(a, b, c, 0);
_MultiplyDimMe(cMe, b, 0);
cUser = MultiplyDim(*a, *b, 0);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call MultiplyDim function */
_MultiplyDim(aGPU, bGPU, cGPU, 0);
_MultiplyDimMe(cMeGPU, bGPU, 0);
cUserGPU = MultiplyDim(*aGPU, *bGPU, 0);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 2: tensor multiplication c = a*b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting.
In this case, (2, 4) * (4) = (2, 4), n = 1.
*/
bool TestMultiplyDim2()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (4) */
int bOrder = 1;
int * bDimSize = new int[bOrder];
bDimSize[0] = 4;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[4] = {1.0F, -1.0F , 1.0F, -1.0F};
DTYPE answer[2][4] = { {0.0F, -1.0F, 2.0F, -3.0F},
{4.0F, -5.0F, 6.0F, -7.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call MultiplyDim function */
_MultiplyDim(a, b, c, 1);
_MultiplyDimMe(cMe, b, 1);
cUser = MultiplyDim(*a, *b, 1);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call MultiplyDim function */
_MultiplyDim(aGPU, bGPU, cGPU, 1);
_MultiplyDimMe(cMeGPU, bGPU, 1);
cUserGPU = MultiplyDim(*aGPU, *bGPU, 1);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* test for MultiplyDim Function */
bool TestMultiplyDim()
{
XPRINT(0, stdout, "[TEST MULTIPLYDIM] tensor multiplication c = a * b + \alpha * c by broadcasting\n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestMultiplyDim1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestMultiplyDim2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-30
*/
#ifndef __TEST_MULTIPLYDIM_H__
#define __TEST_MULTIPLYDIM_H__
#include "../core/arithmetic/MultiplyDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for MultiplyDim Function */
bool TestMultiplyDim();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_MULTIPLYDIM_H__
\ No newline at end of file
......@@ -24,7 +24,8 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: sum the items along a dimension of the tensor.
case 1: test ReduceSum function.
Sum the items along a dimension of the tensor.
In this case,
(2, 4) -> (4), dim = 0
(2, 4) -> (2), dim = 1
......@@ -90,8 +91,8 @@ bool TestReduceSum1()
tUser2 = ReduceSum(*s, 1, *shift2);
/* check results */
cpuTest = t1->CheckData(answer1, tUnitNum1) && tUser1.CheckData(answer1, tUnitNum1)
&& t2->CheckData(answer2, tUnitNum2) && tUser2.CheckData(answer2, tUnitNum2);
cpuTest = t1->CheckData(answer1, tUnitNum1) && tUser1.CheckData(answer1, tUnitNum1) &&
t2->CheckData(answer2, tUnitNum2) && tUser2.CheckData(answer2, tUnitNum2);
#ifdef USE_CUDA
/* GPU test */
......@@ -120,8 +121,8 @@ bool TestReduceSum1()
tUserGPU2 = ReduceSum(*sGPU, 1, *shiftGPU2);
/* check results */
gpuTest = tGPU1->CheckData(answer1, tUnitNum1) && tUserGPU1.CheckData(answer1, tUnitNum1)
&& tGPU2->CheckData(answer2, tUnitNum2) && tUserGPU2.CheckData(answer2, tUnitNum2);
gpuTest = tGPU1->CheckData(answer1, tUnitNum1) && tUserGPU1.CheckData(answer1, tUnitNum1) &&
tGPU2->CheckData(answer2, tUnitNum2) && tUserGPU2.CheckData(answer2, tUnitNum2);
/* destroy variables */
delete s;
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#include "TSubDim.h"
#include "../core/arithmetic/SubDim.h"
#include "../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: tensor subtraction c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
*/
bool TestSubDim1()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2) */
int bOrder = 1;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2] = {1.0F, -1.0F};
DTYPE answer[2][4] = { {-1.0F, 0.0F, 1.0F, 2.0F},
{5.0F, 6.0F, 7.0F, 8.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call SubDim function */
_SubDim(a, b, c, 0);
_SubDim(cMe, b, 0);
cUser = SubDim(*a, *b, 0);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call sub function */
_SubDim(aGPU, bGPU, cGPU, 0);
_SubDim(cMeGPU, bGPU, 0);
cUserGPU = SubDim(*aGPU, *bGPU, 0);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 2: tensor subtraction c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
*/
bool TestSubDim2()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2, 2) */
int bOrder = 2;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
bDimSize[1] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2][2] = { {1.0F, -1.0F},
{-1.0F, 1.0F} };
DTYPE answer[2][4] = { {-1.0F, 2.0F, 3.0F, 2.0F},
{3.0F, 6.0F, 7.0F, 6.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call SubDim function */
_SubDim(a, b, c, 1);
_SubDim(cMe, b, 1);
cUser = SubDim(*a, *b, 1);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call sub function */
_SubDim(aGPU, bGPU, cGPU, 1);
_SubDim(cMeGPU, bGPU, 1);
cUserGPU = SubDim(*aGPU, *bGPU, 1);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for SubDim Function */
bool TestSubDim()
{
XPRINT(0, stdout, "[TEST SUBDIM] tensor subtraction c = a - b * beta by broadcasting\n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestSubDim1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestSubDim2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#ifndef __TEST_SUBDIM_H__
#define __TEST_SUBDIM_H__
#include "../core/arithmetic/SubDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for SubDim Function */
bool TestSubDim();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_SUBDIM_H__
......@@ -28,7 +28,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: tensor summation c = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
i.e., a is summed with b by broadcasting.
In this case, (2, 4) + (2) = (2, 4), n = 0.
*/
bool TestSumDim1()
{
......@@ -79,9 +80,9 @@ bool TestSumDim1()
cUser = SumDim(*a, *b, 0);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum)
&& cMe->CheckData(answer, aUnitNum)
&& cUser.CheckData(answer, aUnitNum);
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
......@@ -106,9 +107,9 @@ bool TestSumDim1()
cUserGPU = SumDim(*aGPU, *bGPU, 0);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum)
&& cMeGPU->CheckData(answer, aUnitNum)
&& cUserGPU.CheckData(answer, aUnitNum);
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
......@@ -139,7 +140,8 @@ bool TestSumDim1()
/*
case 2: tensor summation c = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
i.e., a is summed with b by broadcasting.
In this case, (2, 4) + (2, 2) = (2, 4), n = 1.
*/
bool TestSumDim2()
{
......@@ -192,9 +194,9 @@ bool TestSumDim2()
cUser = SumDim(*a, *b, 1);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum)
&& cMe->CheckData(answer, aUnitNum)
&& cUser.CheckData(answer, aUnitNum);
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
......@@ -219,9 +221,9 @@ bool TestSumDim2()
cUserGPU = SumDim(*aGPU, *bGPU, 1);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum)
&& cMeGPU->CheckData(answer, aUnitNum)
&& cUserGPU.CheckData(answer, aUnitNum);
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
......
......@@ -99,8 +99,8 @@ bool TestUnsqueeze1()
tUser2 = Unsqueeze(*s, 2, 2);
/* check results */
cpuTest = t1->CheckData(answer1, tUnitNum1) && tUser1.CheckData(answer1, tUnitNum1)
&& t2->CheckData(answer2, tUnitNum2) && tUser2.CheckData(answer2, tUnitNum2);
cpuTest = t1->CheckData(answer1, tUnitNum1) && tUser1.CheckData(answer1, tUnitNum1) &&
t2->CheckData(answer2, tUnitNum2) && tUser2.CheckData(answer2, tUnitNum2);
#ifdef USE_CUDA
/* GPU test */
......
......@@ -38,6 +38,7 @@ bool Test()
wrong = !TestCopyIndexed() || wrong;
wrong = !TestCopyValues() || wrong;
wrong = !TestDiv() || wrong;
wrong = !TestDivDim() || wrong;
wrong = !TestExp() || wrong;
wrong = !TestLog() || wrong;
wrong = !TestMatrixMul() || wrong;
......@@ -46,6 +47,7 @@ bool Test()
wrong = !TestMatrixMulBatched() || wrong;
wrong = !TestMerge() || wrong;
wrong = !TestMultiply() || wrong;
wrong = !TestMultiplyDim() || wrong;
wrong = !TestNegate() || wrong;
wrong = !TestNormalize() || wrong;
wrong = !TestPower() || wrong;
......
......@@ -31,6 +31,7 @@
#include "TCopyIndexed.h"
#include "TCopyValues.h"
#include "TDiv.h"
#include "TDivDim.h"
#include "TExp.h"
#include "TLog.h"
#include "TMatrixMul.h"
......@@ -39,6 +40,7 @@
#include "TMatrixMulBatched.h"
#include "TMerge.h"
#include "TMultiply.h"
#include "TMultiplyDim.h"
#include "TNegate.h"
#include "TNormalize.h"
#include "TPower.h"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论