Commit f7f33b29 by xuchen

code optimization

parent e223c59c
...@@ -40,23 +40,50 @@ public: ...@@ -40,23 +40,50 @@ public:
bool IsMathOP(XTensor * node); bool IsMathOP(XTensor * node);
private: private:
/* gradient for sum: c = a + b * \beta */
/* gradient for absolute */
static static
void GradSum(XTensor * node); void GradAbsolute(XTensor * node);
/* gradient for sum with one dimension: c = a + b * \beta /* gradient for cos */
where the size of b is equal to that of one dimension of a */
static static
void GradSumDim(XTensor * node); void GradCos(XTensor * node);
/* gradient for multiply (dot production): c = a * b * \alpha */ /* gradient for exp */
static static
void GradMultiply(XTensor * node); void GradExp(XTensor * node);
/* gradient for multiply one dimension: c = a * b * \alpha /* gradient for log: c = log(a) */
where the size of b is equal to that of one dimension of a */
static static
void GradMultiplyDim(XTensor * node); void GradLog(XTensor * node);
/* gradient for round */
static
void GradRound(XTensor * node);
/* gradient for sign */
static
void GradSign(XTensor * node);
/* gradient for sin */
static
void GradSin(XTensor * node);
/* gradient for tan */
static
void GradTan(XTensor * node);
/* gradient for clip */
static
void GradClip(XTensor * node);
/* gradient for Divide */
static
void GradDiv(XTensor * node);
/* gradient for DivideDim */
static
void GradDivDim(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */ /* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static static
...@@ -73,18 +100,27 @@ private: ...@@ -73,18 +100,27 @@ private:
static static
void GradMatrixMulBatched(XTensor * node); void GradMatrixMulBatched(XTensor * node);
/* gradient for log: c = log(a) */ /* gradient for multiply (dot production): c = a * b * \alpha */
static static
void GradLog(XTensor * node); void GradMultiply(XTensor * node);
/* gradient for power */ /* gradient for multiply one dimension: c = a * b * \alpha
where the size of b is equal to that of one dimension of a */
static static
void GradPower(XTensor * node); void GradMultiplyDim(XTensor * node);
/* gradient for negate */ /* gradient for negate */
static static
void GradNegate(XTensor * node); void GradNegate(XTensor * node);
/* gradient for normalize */
static
void GradNormalize(XTensor * node);
/* gradient for power */
static
void GradPower(XTensor * node);
/* gradient for ScaleAndShift */ /* gradient for ScaleAndShift */
static static
void GradScaleAndShift(XTensor * node); void GradScaleAndShift(XTensor * node);
...@@ -93,13 +129,19 @@ private: ...@@ -93,13 +129,19 @@ private:
static static
void GradSub(XTensor * node); void GradSub(XTensor * node);
/* gradient for Divide */ /* gradient for sub with one dimension: c = a - b * \beta
where the size of b is equal to that of one dimension of a */
static static
void GradDiv(XTensor * node); void GradSubDim(XTensor * node);
/* gradient for DivideDim */ /* gradient for sum: c = a + b * \beta */
static static
void GradDivDim(XTensor * node); void GradSum(XTensor * node);
/* gradient for sum with one dimension: c = a + b * \beta
where the size of b is equal to that of one dimension of a */
static
void GradSumDim(XTensor * node);
/* gradient for reduceMean */ /* gradient for reduceMean */
static static
...@@ -116,42 +158,6 @@ private: ...@@ -116,42 +158,6 @@ private:
/* gradient for reduceVariance */ /* gradient for reduceVariance */
static static
void GradReduceVariance(XTensor * node); void GradReduceVariance(XTensor * node);
/* gradient for sin */
static
void GradSin(XTensor * node);
/* gradient for cos */
static
void GradCos(XTensor * node);
/* gradient for tan */
static
void GradTan(XTensor * node);
/* gradient for exp */
static
void GradExp(XTensor * node);
/* gradient for normalize */
static
void GradNormalize(XTensor * node);
/* gradient for absolute */
static
void GradAbsolute(XTensor * node);
/* gradient for sign */
static
void GradSign(XTensor * node);
/* gradient for clip */
static
void GradClip(XTensor * node);
/* gradient for round */
static
void GradRound(XTensor * node);
}; };
} }
......
...@@ -137,8 +137,6 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss) ...@@ -137,8 +137,6 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
XTensor * x = income.tails[0]; XTensor * x = income.tails[0];
XNoder::MakeGrad(x); XNoder::MakeGrad(x);
lossGrad.Compute(gold, root, x, NULL, x->grad, funcID, params, loss); lossGrad.Compute(gold, root, x, NULL, x->grad, funcID, params, loss);
//XNoder::MakeGrad(root);
//lossGrad.Compute(gold, root, x, root->grad, x->grad, funcID, params, loss);
root->visitMark = NODE_FINISHED; root->visitMark = NODE_FINISHED;
} }
/* we compuate dE/dy (y is the output) if no predefined activation function is used */ /* we compuate dE/dy (y is the output) if no predefined activation function is used */
......
...@@ -35,6 +35,8 @@ T2TAttention::T2TAttention() ...@@ -35,6 +35,8 @@ T2TAttention::T2TAttention()
dk = -1; dk = -1;
dv = -1; dv = -1;
d = -1; d = -1;
isMasked = false;
ignored = 0;
} }
/* deconstructor */ /* deconstructor */
...@@ -46,13 +48,19 @@ T2TAttention::~T2TAttention() ...@@ -46,13 +48,19 @@ T2TAttention::~T2TAttention()
initialize the model initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
>> myIgnored - number of position ignored in attention (from the begining)
>> myIsMasked - indicates whether the attention is with a mask
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool >> myMem - the memory pool
*/ */
void T2TAttention::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem) void T2TAttention::InitModel(int argc, const char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
{ {
devID = myDevID; devID = myDevID;
mem = myMem; mem = myMem;
isMasked = myIsMasked;
ignored = myIgnored;
float minmax = 0; float minmax = 0;
...@@ -82,9 +90,10 @@ make the network ...@@ -82,9 +90,10 @@ make the network
and H = vector size of each position and H = vector size of each position
>> q - queries >> q - queries
>> v - values >> v - values
>> maske - as it is
<< return - multi-attention result << return - multi-attention result
*/ */
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v) XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
{ {
XTensor k2; XTensor k2;
XTensor q2; XTensor q2;
...@@ -105,10 +114,18 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v) ...@@ -105,10 +114,18 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v)
vheads = Split(v2, v2.order - 1, nhead); vheads = Split(v2, v2.order - 1, nhead);
XTensor att; XTensor att;
XTensor dot;
XTensor scalar; XTensor scalar;
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */ /* scalar = softmax(Q * K^T / sqrt(dk)) * V */
scalar = Softmax(Linear(BMMul(qheads, X_NOTRANS, kheads, X_TRANS), 1/(float)sqrt((float)dk)), -1); dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
if(isMasked)
dot = dot + mask;
scalar = Softmax(Linear(dot, 1/(float)sqrt((float)dk)), -1);
if(ignored > 0)
_SetDataDim(&scalar, 0, ignored, scalar.order - 2, 1e-9F);
att = BMMul(scalar, vheads); att = BMMul(scalar, vheads);
/* concatenate the heads */ /* concatenate the heads */
......
...@@ -66,6 +66,13 @@ public: ...@@ -66,6 +66,13 @@ public:
/* size of input Q, K and V */ /* size of input Q, K and V */
int d; int d;
/* indicates whether the attention is masked */
bool isMasked;
/* some positions can be ignored in attention. this is useful in lm where the first position needs
special design for the attention model. */
int ignored;
public: public:
/* constructor */ /* constructor */
T2TAttention(); T2TAttention();
...@@ -74,10 +81,12 @@ public: ...@@ -74,10 +81,12 @@ public:
~T2TAttention(); ~T2TAttention();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL); void InitModel(int argc, const char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
/* make the network */ /* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v); XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask);
}; };
} }
......
...@@ -136,7 +136,7 @@ XTensor T2TEmbedder::Make(XTensor &input) ...@@ -136,7 +136,7 @@ XTensor T2TEmbedder::Make(XTensor &input)
wordEmbedding = Linear(MMul(input, w), (float)sqrt((float)d)); wordEmbedding = Linear(MMul(input, w), (float)sqrt((float)d));
/* we sum over the two embeddings */ /* we sum over the two embeddings */
return wordEmbedding +posEmbedding; return wordEmbedding + posEmbedding;
} }
} }
...@@ -46,13 +46,18 @@ AttEncoder::~AttEncoder() ...@@ -46,13 +46,18 @@ AttEncoder::~AttEncoder()
initialize the model initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
>> myIsMasked - indicates whether the masked attention is employed
>> myIgnored - number of positions ignored in attention (from the start)
>> myDevID - device id >> myDevID - device id
>> myMem - the memory pool >> myMem - the memory pool
*/ */
void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem) void AttEncoder::InitModel(int argc, const char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
{ {
devID = myDevID; devID = myDevID;
mem = myMem; mem = myMem;
ignored = myIgnored;
LoadParamInt(argc, argv, "nlayer", &nlayer, 6); LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE);
...@@ -72,7 +77,7 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM ...@@ -72,7 +77,7 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM
/* initialize the stacked layers */ /* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){ for(int i = 0; i < nlayer; i++){
attentions[i].InitModel(argc, argv, myDevID, myMem); attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
fnns[i].InitModel(argc, argv, myDevID, myMem); fnns[i].InitModel(argc, argv, myDevID, myMem);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem); attLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem); fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
...@@ -82,9 +87,11 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM ...@@ -82,9 +87,11 @@ void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myM
/* /*
make the encoding network make the encoding network
>> input - the input tensor of the encoder >> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> skipInputRes - indicates whether we skip the residual connection of the first layer
<< return - the output tensor of the encoder << return - the output tensor of the encoder
*/ */
XTensor AttEncoder::Make(XTensor &input) XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes)
{ {
XTensor x; XTensor x;
...@@ -96,8 +103,18 @@ XTensor AttEncoder::Make(XTensor &input) ...@@ -96,8 +103,18 @@ XTensor AttEncoder::Make(XTensor &input)
XTensor fnn; XTensor fnn;
XTensor res; XTensor res;
if(skipInputRes && i == 0){
/* self attention */ /* self attention */
att = attentions[i].Make(x, x, x); att = attentions[i].Make(x, x, x, mask);
/* TODO: dropout */
/* layer normalization */
x = attLayerNorms[i].Make(att);
}
else{
/* self attention */
att = attentions[i].Make(x, x, x, mask);
/* residual connection */ /* residual connection */
res = Sum(att, x); res = Sum(att, x);
...@@ -106,6 +123,7 @@ XTensor AttEncoder::Make(XTensor &input) ...@@ -106,6 +123,7 @@ XTensor AttEncoder::Make(XTensor &input)
/* layer normalization */ /* layer normalization */
x = attLayerNorms[i].Make(res); x = attLayerNorms[i].Make(res);
}
/* fnn */ /* fnn */
fnn = fnns[i].Make(x); fnn = fnns[i].Make(x);
......
...@@ -40,7 +40,7 @@ class T2TEncoder ...@@ -40,7 +40,7 @@ class T2TEncoder
{ {
public: public:
virtual virtual
XTensor Make(XTensor &input) = 0; XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes) = 0;
}; };
/* /*
...@@ -49,7 +49,7 @@ the encoder based on RNN ...@@ -49,7 +49,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder class RNNEncoder : T2TEncoder
{ {
public: public:
XTensor Make(XTensor &input); XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes);
}; };
...@@ -77,6 +77,10 @@ public: ...@@ -77,6 +77,10 @@ public:
/* vocabulary size */ /* vocabulary size */
int vSize; int vSize;
/* some positions can be ignored in attention. this is useful in lm where the first position needs
special design for the attention model. */
int ignored;
/* embedding of word at each position */ /* embedding of word at each position */
T2TEmbedder embedder; T2TEmbedder embedder;
...@@ -106,10 +110,12 @@ public: ...@@ -106,10 +110,12 @@ public:
~AttEncoder(); ~AttEncoder();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL); void InitModel(int argc, const char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */ /* make the encoding network */
XTensor Make(XTensor &input); XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes);
}; };
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/ */
#include <math.h>
#include "T2TLayerNormal.h" #include "T2TLayerNormal.h"
#include "T2TUtility.h" #include "T2TUtility.h"
#include "T2TEmbedding.h" #include "T2TEmbedding.h"
...@@ -89,14 +90,13 @@ XTensor T2TLN::Make(XTensor &input) ...@@ -89,14 +90,13 @@ XTensor T2TLN::Make(XTensor &input)
/* standard = sqrt(variance) */ /* standard = sqrt(variance) */
standard = Power(variance, 0.5F); standard = Power(variance, 0.5F);
/* unsqueeze mean and standard deviation to fit them into /* unsqueeze mean and standard deviation to fit them into
the same shape of x */ the same shape of x */
meanFilled = Unsqueeze(mean, x.order - 1, x.GetDim(-1)); meanFilled = Unsqueeze(mean, x.order - 1, x.GetDim(-1));
standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1)); standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1));
/* x' = (x - \mu)/standard */ /* x' = (x - \mu)/standard */
xn = (x - meanFilled)/standardFilled ; xn = (x - meanFilled)/standardFilled;
/* result = x' * w + b */ /* result = x' * w + b */
return MMul(xn, w) + b; return MMul(xn, w) + b;
......
...@@ -34,6 +34,7 @@ T2TModel::T2TModel() ...@@ -34,6 +34,7 @@ T2TModel::T2TModel()
mem = NULL; mem = NULL;
isLM = false; isLM = false;
isMT = false; isMT = false;
nhead = 1;
} }
/* de-constructor */ /* de-constructor */
...@@ -55,24 +56,27 @@ void T2TModel::InitModel(int argc, const char ** argv) ...@@ -55,24 +56,27 @@ void T2TModel::InitModel(int argc, const char ** argv)
LoadParamBool(argc, argv, "mem", &useMem, useMem); LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamBool(argc, argv, "lm", &isLM, true); LoadParamBool(argc, argv, "lm", &isLM, true);
LoadParamBool(argc, argv, "mt", &isMT, false); LoadParamBool(argc, argv, "mt", &isMT, false);
LoadParamInt(argc, argv, "nhead", &nhead, 8);
if(useMem){ if(useMem){
delete mem; delete mem;
mem = new XMem(devID); mem = new XMem(devID);
} }
encoder.InitModel(argc, argv, devID, mem); encoder.InitModel(argc, argv, isLM, isLM ? 1 : 0, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem); outputLayer.InitModel(argc, argv, devID, mem);
} }
/* /*
make the encoding network make the encoding network
>> input - input tensor >> input - input tensor
>> mask - the mask for positions that are/not involved in computation
>> skipInputRes - indicates whether we skip the residual connection of the first layer
<< return - encoding result << return - encoding result
*/ */
XTensor T2TModel::MakeEncoding(XTensor &input) XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes)
{ {
return encoder.Make(input); return encoder.Make(input, mask, skipInputRes);
} }
/* /*
...@@ -85,8 +89,23 @@ void T2TModel::Make(XTensor &input, XTensor &output) ...@@ -85,8 +89,23 @@ void T2TModel::Make(XTensor &input, XTensor &output)
XTensor encoding; XTensor encoding;
if(isLM){ if(isLM){
encoding = MakeEncoding(input); /* generate mask to see "previous" words only */
int len = input.GetDim(input.order - 2);
int * dims = new int[input.order + 1];
for(int i = 0; i < input.order; i++)
dims[i + 1] = input.GetDim(i);
dims[0] = nhead;
dims[input.order] = len;
XTensor mask(input.order + 1, dims, X_FLOAT, 1.0F, input.devID, input.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9 */
_SetDataLowTri(&mask, 1e9F, -1);
_ScaleAndShiftMe(&mask, 1.0F, -1e9F);
encoding = MakeEncoding(input, mask, true);
outputLayer.Make(encoding, output); outputLayer.Make(encoding, output);
delete[] dims;
} }
else{ else{
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
......
...@@ -55,6 +55,9 @@ public: ...@@ -55,6 +55,9 @@ public:
/* indicates whether the model is running for machine translation */ /* indicates whether the model is running for machine translation */
bool isMT; bool isMT;
/* number of heads in the attention model */
int nhead;
public: public:
/* constructor */ /* constructor */
T2TModel(); T2TModel();
...@@ -66,7 +69,7 @@ public: ...@@ -66,7 +69,7 @@ public:
void InitModel(int argc, const char ** argv); void InitModel(int argc, const char ** argv);
/* make the encoding network */ /* make the encoding network */
XTensor MakeEncoding(XTensor &input); XTensor MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes);
/* make the entire network (with the output softmax layer) */ /* make the entire network (with the output softmax layer) */
void Make(XTensor &input, XTensor &output); void Make(XTensor &input, XTensor &output);
......
...@@ -100,7 +100,9 @@ void ShowParams(int argc, const char ** argv) ...@@ -100,7 +100,9 @@ void ShowParams(int argc, const char ** argv)
{ {
fprintf(stderr, "args:\n"); fprintf(stderr, "args:\n");
for(int i = 0; i < argc; i++){ for(int i = 0; i < argc; i++){
if(argv[i][0] == '-'){ if(argv[i][1] == 0)
continue;
if(argv[i][0] == '-' && (argv[i][1] < '1' || argv[i][1] > '9')){
if(i + 1 < argc && argv[i + 1][0] != '-') if(i + 1 < argc && argv[i + 1][0] != '-')
fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]); fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]);
else else
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "../../XName.h" #include "../../XName.h"
#include "Div.h" #include "Div.h"
#include "Div.cuh" #include "Div.cuh"
#include "DivDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
...@@ -138,6 +139,33 @@ void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim) ...@@ -138,6 +139,33 @@ void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim)
} }
/* /*
return a dimension if the division is performed as DivDim (in more details in DivDim.h)
>> a - a tensor
>> b - another tensor for division
*/
int GetDivDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/*
element-wise division of two tensors (return a XTensor structure) element-wise division of two tensors (return a XTensor structure)
make a new tensor c to keep the result and return it make a new tensor c to keep the result and return it
...@@ -146,22 +174,40 @@ where i is the index of the item ...@@ -146,22 +174,40 @@ where i is the index of the item
>> a - tensor a >> a - tensor a
>> b - tensor b >> b - tensor b
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting >> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors << return - the product of the tensors
*/ */
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim) XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
{ {
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XTensor c(&a); XTensor c(&a);
c.SetTMP(); c.SetTMP();
/* call _Multiply function */ int n = GetDivDimIndex(a, b);
_Div(&a, &b, &c, 0, leadingDim);
if(n == -1){
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* call _Div function */
_Div(&a, &b, &c, alpha, leadingDim);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV); XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim); XLink::AddParamToHeadInt(&c, leadingDim);
}
else if(n >= 0 && n < a.order){
/* call _DivDim function */
_DivDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadInt(&c, alpha);
}
else{
ShowNTErrors("Something is wrong!");
}
return c; return c;
} }
......
...@@ -47,7 +47,7 @@ make a new tensor to keep the result and return it ...@@ -47,7 +47,7 @@ make a new tensor to keep the result and return it
c(i) = a(i)/b(i) c(i) = a(i)/b(i)
where i is the index of the element where i is the index of the element
*/ */
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim = 0); XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha = 0.0, int leadingDim = 0);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "../../XName.h" #include "../../XName.h"
#include "Multiply.h" #include "Multiply.h"
#include "Multiply.cuh" #include "Multiply.cuh"
#include "MultiplyDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
...@@ -139,6 +140,33 @@ void _MultiplyMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim) ...@@ -139,6 +140,33 @@ void _MultiplyMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim)
} }
/* /*
return a dimension if the multiplication is performed as MultiplyDim (in more details in MultiplyDim.h)
>> a - a tensor
>> b - another tensor for multiplication
*/
int GetMultiplyDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/*
element-wise product of two tensors (return a XTensor structure) element-wise product of two tensors (return a XTensor structure)
make a new tensor c to keep the result and return it make a new tensor c to keep the result and return it
...@@ -150,19 +178,37 @@ where i is the index of the item ...@@ -150,19 +178,37 @@ where i is the index of the item
>> leadingDim - the dimension along which we perform broadcasting >> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors << return - the product of the tensors
*/ */
XTensor Multiply(const XTensor &a, const XTensor &b, int leadingDim) XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
{ {
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XTensor c(&a); XTensor c(&a);
c.SetTMP(); c.SetTMP();
int n = GetMultiplyDimIndex(a, b);
if(n == -1){
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* call _Multiply function */ /* call _Multiply function */
_Multiply(&a, &b, &c, 0, leadingDim); _Multiply(&a, &b, &c, 0, leadingDim);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim); XLink::AddParamToHeadInt(&c, leadingDim);
}
else if(n >= 0 && n < a.order){
/* call _MultiplyDim function */
_MultiplyDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadInt(&c, alpha);
}
else{
ShowNTErrors("Something is wrong!");
}
return c; return c;
} }
......
...@@ -47,7 +47,7 @@ make a new tensor to keep the result and return it ...@@ -47,7 +47,7 @@ make a new tensor to keep the result and return it
c(i) = a(i)*b(i) c(i) = a(i)*b(i)
where i is the index of the element where i is the index of the element
*/ */
XTensor Multiply(const XTensor &a, const XTensor &b, int leadingDim = 0); XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha = 0.0, int leadingDim = 0);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -135,15 +135,14 @@ int GetSubDimIndex(const XTensor &a, const XTensor &b) ...@@ -135,15 +135,14 @@ int GetSubDimIndex(const XTensor &a, const XTensor &b)
if(a.order < b.order) if(a.order < b.order)
return -1; return -1;
if(XTensor::IsSameShaped(&a, &b))
return -1;
int hitCount = 0; int hitCount = 0;
int hitDim = -1; int hitDim = -1;
for(int i = 0; i < a.order; i++){ for(int i = 0; i < b.order; i++){
if(a.dimSize[i] == b.unitNum){ if(b.dimSize[b.order - 1 - i] == 1)
hitDim = i; continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++; hitCount++;
hitDim = a.order - b.order + i;
} }
} }
...@@ -173,7 +172,6 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta) ...@@ -173,7 +172,6 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
/* call _Sub function */ /* call _Sub function */
_Sub(&a, &b, &c, beta); _Sub(&a, &b, &c, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB); XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
......
...@@ -137,37 +137,17 @@ return a dimension if the sum is performed as SumDim (in more details in SumDim. ...@@ -137,37 +137,17 @@ return a dimension if the sum is performed as SumDim (in more details in SumDim.
*/ */
int GetSumDimIndex(const XTensor &a, const XTensor &b) int GetSumDimIndex(const XTensor &a, const XTensor &b)
{ {
//if(a.order < b.order)
// return -1;
//int hitCount = 0;
//int hitDim = -1;
//for(int i = 0; i < b.order; i++){
// if(b.dimSize[b.order - 1 - i] == 1)
// continue;
// else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
// hitCount++;
// hitDim = a.order - b.order + i;
// }
//}
//if(hitCount == 1)
// return hitDim;
//else
// return -1;
if(a.order < b.order) if(a.order < b.order)
return -1; return -1;
if(XTensor::IsSameShaped(&a, &b))
return -1;
int hitCount = 0; int hitCount = 0;
int hitDim = -1; int hitDim = -1;
for(int i = 0; i < a.order; i++){ for(int i = 0; i < b.order; i++){
if(a.dimSize[i] == b.unitNum){ if(b.dimSize[b.order - 1 - i] == 1)
hitDim = i; continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++; hitCount++;
hitDim = a.order - b.order + i;
} }
} }
......
...@@ -49,7 +49,7 @@ void _SelectRange(const XTensor * a, XTensor * c, int dim, int low, int high) ...@@ -49,7 +49,7 @@ void _SelectRange(const XTensor * a, XTensor * c, int dim, int low, int high)
for(int i = 0; i < a->order; i++){ for(int i = 0; i < a->order; i++){
if(i == dim){ if(i == dim){
CheckNTErrors(low > 0 && low < a->dimSize[dim], "Illegal range specified!"); CheckNTErrors(low >= 0 && low < a->dimSize[dim], "Illegal range specified!");
CheckNTErrors(high > 0 && high <= a->dimSize[dim], "Illegal range specified!"); CheckNTErrors(high > 0 && high <= a->dimSize[dim], "Illegal range specified!");
} }
else{ else{
...@@ -101,7 +101,7 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high) ...@@ -101,7 +101,7 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high)
for(int i = 0; i < a.order; i++){ for(int i = 0; i < a.order; i++){
if(i == dim){ if(i == dim){
CheckNTErrors(low > 0 && low < a.dimSize[dim], "Illegal range specified!"); CheckNTErrors(low >= 0 && low < a.dimSize[dim], "Illegal range specified!");
CheckNTErrors(high > 0 && high <= a.dimSize[dim], "Illegal range specified!"); CheckNTErrors(high > 0 && high <= a.dimSize[dim], "Illegal range specified!");
dimSize[i] = high - low; dimSize[i] = high - low;
} }
......
...@@ -214,6 +214,106 @@ void _SetDataFixedDouble(XTensor * tensor, double p) ...@@ -214,6 +214,106 @@ void _SetDataFixedDouble(XTensor * tensor, double p)
} }
/* /*
set data items along with a given dimension (and keep the remaining items unchanged)
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> len - length along with the given dimension
>> dim - the dimension along which we set the data
e.g., given a 3 * 3 tensor
1 2 3
4 5 6
7 8 9
when beg = 1, len = 1, dim = 0 and p = 0, we have
1 2 3
0 0 0
7 8 9
i.e., we set all entries of row 1 to 0
*/
void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim > 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
if(tensor->devID < 0){
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for(int i = n - 1; i > dim; i--){
stride *= tensor->GetDim(i);
}
blockSize = stride * tensor->GetDim(dim);
blockNum = tensor->unitNum / blockSize;
int l = len * stride;
for(int i = 0; i < blockNum; i++){
DTYPE * d = (DTYPE*)tensor->data + blockSize * i + beg * stride;
for(int j = 0; j < l; j++)
d[j] = p;
}
}
else{
#ifdef USE_CUDA
_CudaSetDataDim(tensor, beg, len, dim, p);
#endif
}
}
/*
generate data as lower triangular matrics for last two dimensions
>> tensor - the tensor whose data to be set
>> p - the value for each entry of the lower triangular matrics
>> shift - the offset from diagonal
e.g., for a 3* 3 tensor,
when p = 1 ans shift = 0, we have
1 0 0
1 1 0
1 1 1
when p = 2 and shift = -1, we have
0 0 0
2 0 0
2 2 0
*/
void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(n >= 2, "The tensor must have a order no less than 2!");
CheckNTErrors(tensor->GetDim(n - 1) == tensor->GetDim(n - 2),
"The last two dimensions must be of the same size!");
if(tensor->devID < 0){
int l = tensor->GetDim(-1);
int blockNum = 1;
int blockSize = l * l;
for(int i = 0; i < n - 2; i++)
blockNum *= tensor->GetDim(i);
for(int i = 0; i < blockNum; i++){
DTYPE * d = (DTYPE*)tensor->data + i * blockSize;
for(int row = 0; row < l; row++){
for(int col = 0; col <= row + shift; col++){
d[row * l + col] = p;
}
for(int col = MAX(0, row + shift + 1); col < l; col++){
d[row * l + col] = 0;
}
}
}
}
else{
#ifdef USE_CUDA
_CudaSetDataLowTri(tensor, p, shift);
#endif
}
}
/*
generate data items with a uniform distribution in [lower, upper] generate data items with a uniform distribution in [lower, upper]
>> tensor - the tensor whose data array would be initialized >> tensor - the tensor whose data array would be initialized
>> lower - lower value of the range >> lower - lower value of the range
......
...@@ -185,6 +185,169 @@ void KernelSetDataRandDouble(double * d, int size, DTYPE lower, DTYPE variance) ...@@ -185,6 +185,169 @@ void KernelSetDataRandDouble(double * d, int size, DTYPE lower, DTYPE variance)
} }
/* /*
set data items along with a given dimension (and keep the remaining items unchanged) - kernel version
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> len - length of the segment to be set
>> blockSize - size of a data block
>> blockNum - number of data blocks
*/
__global__
void KernelSetDataDim(DTYPE * d, int beg, int len, int blockSize, int blockNum, DTYPE p)
{
/* offset in each block */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* block id */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= blockSize || j > blockNum)
return;
if(i < beg || i >= beg + len)
return;
d[blockSize * j + i] = p;
}
/*
set data items along with a given dimension (and keep the remaining items unchanged) - cuda version
>> tensor - the tensor whose data array would be initialized
>> beg - the beginning position
>> len - length along with the given dimension
>> dim - the dimension along which we set the data
e.g., given a 3 * 3 tensor
1 2 3
4 5 6
7 8 9
when beg = 1, len = 1, dim = 0 and p = 0, we have
1 2 3
0 0 0
7 8 9
i.e., we set all entries of row 1 to 0
*/
void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim > 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for(int i = n - 1; i > dim; i--){
stride *= tensor->GetDim(i);
}
blockSize = stride * tensor->GetDim(dim);
blockNum = tensor->unitNum / blockSize;
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(tensor->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataDim<<<blocks, threads >>>((DTYPE*)tensor->data, beg * stride, len * stride, blockSize, blockNum, p);
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
set lower triangular matrics for each block
>> d - pointer to the data array
>> l - row number (or column number) of each block, i.e,
a block is l * l matrix
>> blockSize - size of each block (blockSize = l * l)
>> blockNum - number of the blocks
>> p - the value for each entry of the lower triangular matrics
>> shift - the offset from diagonal
e.g., for a 3* 3 tensor,
when p = 1 ans shift = 0, we have
1 0 0
1 1 0
1 1 1
when p = 2 and shift = -1, we have
0 0 0
2 0 0
2 2 0
*/
__global__
void _KernelSetDataLowTri(DTYPE * d, int l, int blockSize, int blockNum, DTYPE p, int shift)
{
/* offset in each block */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* block id */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= blockSize || j > blockNum)
return;
int row = i / l;
int col = i % l;
DTYPE * d2 = d + blockSize * j + row * l + col;
if(col <= row + shift)
*d2 = p;
else
*d2 = 0;
}
/*
generate data as lower triangular matrics for last two dimensions (cuda version)
>> tensor - the tensor whose data to be set
>> p - the value for each entry of the lower triangular matrics
>> shift - the offset from diagonal
e.g., for a 3* 3 tensor,
when p = 1 ans shift = 0, we have
1 0 0
1 1 0
1 1 1
when p = 2 and shift = -1, we have
0 0 0
2 0 0
2 2 0
*/
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift)
{
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(n >= 2, "The tensor must have a order no less than 2!");
CheckNTErrors(tensor->GetDim(n - 1) == tensor->GetDim(n - 2),
"The last two dimensions must be of the same size!");
int l = tensor->GetDim(-1);
int blockNum = 1;
int blockSize = l * l;
for(int i = 0; i < n - 2; i++)
blockNum *= tensor->GetDim(i);
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(tensor->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
_KernelSetDataLowTri<<<blocks, threads >>>((DTYPE*)tensor->data, l, blockSize, blockNum, p, shift);
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
generate data items with a uniform distribution in [lower, upper] generate data items with a uniform distribution in [lower, upper]
>> tensor - the tensor whose data array would be initialized >> tensor - the tensor whose data array would be initialized
>> lower - lower value of the range >> lower - lower value of the range
......
...@@ -37,6 +37,12 @@ void _CudaSetDataFixedFloat(XTensor * tensor, float p); ...@@ -37,6 +37,12 @@ void _CudaSetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */ /* generate data items with a fixed value p (in double) */
void _CudaSetDataFixedDouble(XTensor * tensor, double p); void _CudaSetDataFixedDouble(XTensor * tensor, double p);
/* set data items along with a given dimension (and keep the remaining items unchanged) */
void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p);
/* generate data as lower triangular matrics for last two dimensions (cuda version) */
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift);
/* generate data items with a uniform distribution in [lower, upper] */ /* generate data items with a uniform distribution in [lower, upper] */
void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper); void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
......
...@@ -45,6 +45,12 @@ void _SetDataFixedFloat(XTensor * tensor, float p); ...@@ -45,6 +45,12 @@ void _SetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */ /* generate data items with a fixed value p (in double) */
void _SetDataFixedDouble(XTensor * tensor, double p); void _SetDataFixedDouble(XTensor * tensor, double p);
/* set data items along with a given dimension (and keep the remaining items unchanged) */
void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p);
/* generate data as lower triangular matrics for last two dimensions */
void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift);
/* generate data items with a uniform distribution in [lower, upper] */ /* generate data items with a uniform distribution in [lower, upper] */
void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper); void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#include "TSubDim.h"
#include "../core/arithmetic/SubDim.h"
#include "../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: tensor subtraction c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
*/
bool TestSubDim1()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2) */
int bOrder = 1;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2] = {1.0F, -1.0F};
DTYPE answer[2][4] = { {-1.0F, 0.0F, 1.0F, 2.0F},
{5.0F, 6.0F, 7.0F, 8.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call SubDim function */
_SubDim(a, b, c, 0);
_SubDim(cMe, b, 0);
cUser = SubDim(*a, *b, 0);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call sub function */
_SubDim(aGPU, bGPU, cGPU, 0);
_SubDim(cMeGPU, bGPU, 0);
cUserGPU = SubDim(*aGPU, *bGPU, 0);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 2: tensor subtraction c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
*/
bool TestSubDim2()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2, 2) */
int bOrder = 2;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
bDimSize[1] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2][2] = { {1.0F, -1.0F},
{-1.0F, 1.0F} };
DTYPE answer[2][4] = { {-1.0F, 2.0F, 3.0F, 2.0F},
{3.0F, 6.0F, 7.0F, 6.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call SubDim function */
_SubDim(a, b, c, 1);
_SubDim(cMe, b, 1);
cUser = SubDim(*a, *b, 1);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum) &&
cMe->CheckData(answer, aUnitNum) &&
cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call sub function */
_SubDim(aGPU, bGPU, cGPU, 1);
_SubDim(cMeGPU, bGPU, 1);
cUserGPU = SubDim(*aGPU, *bGPU, 1);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum) &&
cMeGPU->CheckData(answer, aUnitNum) &&
cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for SubDim Function */
bool TestSubDim()
{
XPRINT(0, stdout, "[TEST SUBDIM] tensor subtraction c = a - b * beta by broadcasting\n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestSubDim1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestSubDim2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-13
*/
#ifndef __TEST_SUBDIM_H__
#define __TEST_SUBDIM_H__
#include "../core/arithmetic/SubDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for SubDim Function */
bool TestSubDim();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_SUBDIM_H__
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论