Commit ceb5b101 by xuchen

1. add gather function 2. add cross entropy forward computation and backward…

1. add gather function 2. add cross entropy forward computation and backward computation 3. code optimization 4. merge with xiaotong-working branch
parent 102db468
......@@ -29,10 +29,8 @@
namespace nts{
/* compute dE/dx of a node */
void XFuncGrad::MakeGrad(XTensor * node)
void XFuncGrad::MakeGrad(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
int operID = income.typeID;
......
......@@ -35,7 +35,7 @@ class XFuncGrad
public:
/* compute dE/dx of a node */
static
void MakeGrad(XTensor * node);
void MakeGrad(XTensor * node, bool isEfficient);
/* indicates whether the node is for an activation function */
static
......
......@@ -33,7 +33,7 @@ class XMathGrad
public:
/* compute dE/dx of a node */
static
void MakeGrad(XTensor * node);
void MakeGrad(XTensor * node, bool isEfficient);
/* indicates whether the node is for a math operation */
static
......@@ -43,121 +43,121 @@ private:
/* gradient for absolute */
static
void GradAbsolute(XTensor * node);
void GradAbsolute(XTensor * node, bool isEfficient);
/* gradient for cos */
static
void GradCos(XTensor * node);
void GradCos(XTensor * node, bool isEfficient);
/* gradient for exp */
static
void GradExp(XTensor * node);
void GradExp(XTensor * node, bool isEfficient);
/* gradient for log: c = log(a) */
static
void GradLog(XTensor * node);
void GradLog(XTensor * node, bool isEfficient);
/* gradient for round */
static
void GradRound(XTensor * node);
void GradRound(XTensor * node, bool isEfficient);
/* gradient for sign */
static
void GradSign(XTensor * node);
void GradSign(XTensor * node, bool isEfficient);
/* gradient for sin */
static
void GradSin(XTensor * node);
void GradSin(XTensor * node, bool isEfficient);
/* gradient for tan */
static
void GradTan(XTensor * node);
void GradTan(XTensor * node, bool isEfficient);
/* gradient for clip */
static
void GradClip(XTensor * node);
void GradClip(XTensor * node, bool isEfficient);
/* gradient for Divide */
static
void GradDiv(XTensor * node);
void GradDiv(XTensor * node, bool isEfficient);
/* gradient for DivideDim */
static
void GradDivDim(XTensor * node);
void GradDivDim(XTensor * node, bool isEfficient);
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static
void GradMatrixMul(XTensor * node);
void GradMatrixMul(XTensor * node, bool isEfficient);
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static
void GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE transA,
XTensor * b, XTensor * dedb, MATRIX_TRANS_TYPE transB,
XTensor * dedc, DTYPE alpha);
XTensor * dedc, DTYPE alpha, bool isEfficient);
/* gradient for matrix multiply in batch mode.
for each batch: c_i = matmul(a_i, b_i) * \alpha */
static
void GradMatrixMulBatched(XTensor * node);
void GradMatrixMulBatched(XTensor * node, bool isEfficient);
/* gradient for multiply (dot production): c = a * b * \alpha */
static
void GradMultiply(XTensor * node);
void GradMultiply(XTensor * node, bool isEfficient);
/* gradient for multiply one dimension: c = a * b * \alpha
where the size of b is equal to that of one dimension of a */
static
void GradMultiplyDim(XTensor * node);
void GradMultiplyDim(XTensor * node, bool isEfficient);
/* gradient for negate */
static
void GradNegate(XTensor * node);
void GradNegate(XTensor * node, bool isEfficient);
/* gradient for normalize */
static
void GradNormalize(XTensor * node);
void GradNormalize(XTensor * node, bool isEfficient);
/* gradient for power */
static
void GradPower(XTensor * node);
void GradPower(XTensor * node, bool isEfficient);
/* gradient for ScaleAndShift */
static
void GradScaleAndShift(XTensor * node);
void GradScaleAndShift(XTensor * node, bool isEfficient);
/* gradient for Minus */
static
void GradSub(XTensor * node);
void GradSub(XTensor * node, bool isEfficient);
/* gradient for sub with one dimension: c = a - b * \beta
where the size of b is equal to that of one dimension of a */
static
void GradSubDim(XTensor * node);
void GradSubDim(XTensor * node, bool isEfficient);
/* gradient for sum: c = a + b * \beta */
static
void GradSum(XTensor * node);
void GradSum(XTensor * node, bool isEfficient);
/* gradient for sum with one dimension: c = a + b * \beta
where the size of b is equal to that of one dimension of a */
static
void GradSumDim(XTensor * node);
void GradSumDim(XTensor * node, bool isEfficient);
/* gradient for reduceMean */
static
void GradReduceMean(XTensor * node);
void GradReduceMean(XTensor * node, bool isEfficient);
/* gradient for reduceSum */
static
void GradReduceSum(XTensor * node);
void GradReduceSum(XTensor * node, bool isEfficient);
/* gradient for reduceSumSquared */
static
void GradReduceSumSquared(XTensor * node);
void GradReduceSumSquared(XTensor * node, bool isEfficient);
/* gradient for reduceVariance */
static
void GradReduceVariance(XTensor * node);
void GradReduceVariance(XTensor * node, bool isEfficient);
};
}
......
......@@ -34,7 +34,7 @@ class XShapeGrad
public:
/* compute dE/dx of a node */
static
void MakeGrad(XTensor * node);
void MakeGrad(XTensor * node, bool isEfficent);
/* indicates whether the node is for a shaping operation */
static
......@@ -42,39 +42,48 @@ public:
/* post processing of a node */
static
void PostProcessing(XTensor * node, int typeId);
void PostProcessing(XTensor * node, int typeId, bool isEfficent);
private:
/* gradient computation for copying indexed sub-tensors: b = copyindexed(a, srcIndex, indexSize, tgtIndex, copyNum) */
static
void GradCopyIndexed(XTensor * node, bool isEfficent);
/* gradient computation for merge: c = merge(a, b, ...) */
static
void GradMerge(XTensor * node);
void GradMerge(XTensor * node, bool isEfficent);
/* gradient computation for merging a list of tensors : c = merge(list(a, b, ...)) */
static
void GradMergeList(XTensor * node);
void GradMergeList(XTensor * node, bool isEfficent);
/* gradient computation for transposing a tensor : b = transpose(a) */
static
void GradTranspose(XTensor * node, bool isEfficent);
/* gradient computation for reshaping a tensor: c = reshape(a) */
static
void GradReshape(XTensor * node, bool isEfficent);
/* gradient computation for split: c = split(a) */
static
void GradSplit(XTensor * node);
void GradSplit(XTensor * node, bool isEfficent);
/* gradient computation for spliting. we return the list of the splits : list(c_1, ...) = split(a) */
static
void GradSplitList(XTensor * node);
void GradSplitList(XTensor * node, bool isEfficent);
/* gradient computation for spliting. we return the list of the splits : list(c_1, ...) = split(a).
this method is called only when all nodes of spliting have been processed. We do this in a post-processing
manner because we can fuze multiple memory copy jobs one time. This is good for system speed up. */
static
void GradSplitListPost(XTensor * node);
void GradSplitListPost(XTensor * node, bool isEfficent);
/* gradient computation for unsqueezing a tensor : c = unsqueeze(a) */
static
void GradUnsqueeze(XTensor * node);
void GradUnsqueeze(XTensor * node, bool isEfficent);
/* gradient computation for unsqueezing a tensor : c = unsqueeze(a) */
static
void GradTranspose(XTensor * node);
};
}
......
......@@ -55,6 +55,7 @@ void XNetClearAll()
XNet::XNet()
{
nodes.Clear();
isGradEfficient = true;
}
/* de-constructor */
......@@ -115,6 +116,10 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
{
Traverse(roots);
/* label tensors where the backward computation is neccessary */
if(isGradEfficient)
MakeEfficientNet();
for(int i = 0; i < nodes.count; i++){
XTensor * node = (XTensor*)nodes.Get(i);
node->visitMark = NODE_UNFINISHED;
......@@ -154,10 +159,19 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
CheckNTErrors(node->mem->bufUsed < BUF_PITCH, "Illegal access of buffer!");
}
if(node->visitMark == NODE_FINISHED)
continue;
if(node->visitMark != NODE_FINISHED)
BackwardNode(node, isGradEfficient);
if(isGradEfficient){
XLink & outgo = node->outgo;
for(int i = 0; i < outgo.tailNum; i++){
XTensor * parent = outgo.tails[i];
ClearGrad(parent);
}
BackwardNode(node);
if(XNoder::IsLeaf(node))
ClearGrad(node);
}
}
}
......@@ -179,27 +193,32 @@ void XNet::Backward(XList &roots, LOSS_FUNCTION_NAME loss)
/*
backward computation for a given node
>> node - the node keeps the result of an operation (e.g., activation function)
>> isEfficient - indicates whether the back-propagation is compuated in an
efficient manner
*/
void XNet::BackwardNode(XTensor * node)
void XNet::BackwardNode(XTensor * node, bool isEfficent)
{
if(node == NULL || node->visitMark == NODE_FINISHED)
return;
if(!XNoder::IsLeaf(node)){
/* post processing for parent nodes */
BackwardNodePost(node);
BackwardNodePost(node, isEfficent);
/* process the current node */
if(XMathGrad::IsMathOP(node))
XMathGrad::MakeGrad(node);
XMathGrad::MakeGrad(node, isEfficent);
else if(XFuncGrad::IsFunc(node))
XFuncGrad::MakeGrad(node);
XFuncGrad::MakeGrad(node, isEfficent);
else if(XShapeGrad::IsShapeOP(node))
XShapeGrad::MakeGrad(node);
XShapeGrad::MakeGrad(node, isEfficent);
else{
ShowNTErrors("Wrong node type!");
}
}
else{
node->visitMark = NODE_FINISHED;
}
}
/*
......@@ -207,7 +226,7 @@ backward computation (in post processing) for a given node
>> node - the node whose parent nodes are not processed yet. So
we do the job at the child node.
*/
void XNet::BackwardNodePost(XTensor * node)
void XNet::BackwardNodePost(XTensor * node, bool isEfficent)
{
bool isSplitList = false;
XLink &outgo = node->outgo;
......@@ -217,7 +236,7 @@ void XNet::BackwardNodePost(XTensor * node)
}
if(isSplitList)
XShapeGrad::PostProcessing(node, SHAPE_SPLIT_LIST);
XShapeGrad::PostProcessing(node, SHAPE_SPLIT_LIST, isEfficent);
}
/*
......@@ -284,6 +303,8 @@ void XNet::TarjanVisit(XTensor * node, XList &orders, const unsigned int code)
node->visitMark = code + 2;
orders.Add(node);
}
else if(node->visitMark == code + 2){
}
}
/*
......@@ -304,4 +325,62 @@ void XNet::Dump(FILE * file)
}
}
/*
set the flag of gradient-efficient
>> flag - the flag
*/
void XNet::SetGradEfficientFlag(bool flag)
{
isGradEfficient = flag;
}
/* generate the gradient-efficient flag for every node */
void XNet::MakeEfficientNet()
{
/* back-propagation from output to input */
for(int i = 0; i < nodes.count; i++){
XTensor * node = (XTensor*)nodes.Get(i);
XLink &income = node->income;
for(int j = 0; j < income.tailNum; j++){
XTensor * child = income.tails[j];
if(child->isGrad || child->isVar){
node->SetGradFlag(true);
break;
}
}
}
}
/*
clear the graident information if the node is no use
>> node - the node that we want to clear
*/
void XNet::ClearGrad(XTensor * node)
{
if(node->isVar)
return;
if(node->grad == NULL)
return;
if(node->visitMark != NODE_FINISHED)
return;
XLink & income = node->income;
bool finished = true;
for(int i = 0; i < income.tailNum; i++){
XTensor * child = income.tails[i];
if(child->visitMark != NODE_FINISHED){
finished = false;
break;
}
}
if(finished){
//fprintf(stderr, "del %d %ld\n", node->id, node->grad->unitNum);
delete node->grad;
node->grad = NULL;
}
}
}
\ No newline at end of file
......@@ -47,6 +47,9 @@ struct XNet
/* input nodes of the network */
XList inputs;
/* indicates whether the network just keeps the gradient for parameter tensors */
bool isGradEfficient;
/* constructor */
XNet();
......@@ -71,10 +74,10 @@ struct XNet
void Backward(XList &roots, LOSS_FUNCTION_NAME loss = NOLOSS);
/* backward computation for a given node */
void BackwardNode(XTensor * node);
void BackwardNode(XTensor * node, bool isEfficent = false);
/* backward computation (in post processing) for a given node */
void BackwardNodePost(XTensor * node);
void BackwardNodePost(XTensor * node, bool isEfficent = false);
/* traverse the net and find the topological order by
depth-first search (Tarjan's algorithm) */
......@@ -89,6 +92,15 @@ struct XNet
/* dump network information */
void Dump(FILE * file);
/* set the flag of gradient-efficient */
void SetGradEfficientFlag(bool flag = true);
/* generate the gradient-efficient flag for every node */
void MakeEfficientNet();
/* clear the graident information if the node is no use */
void ClearGrad(XTensor * node);
};
/* we make a unique id for every tensor */
......
......@@ -74,6 +74,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net);
void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NAME loss,
FNNModel &model, FNNModel &grad, FNNNet &net);
void ForwardAutoDiff(XTensor inputs[], XTensor &output, FNNModel &model);
void ForwardAutoDiff(NGram * ngrams, int batch, XTensor &output, FNNModel &model);
/*
entry of the program
......@@ -476,7 +477,12 @@ void Train(const char * train, bool isShuffled, FNNModel &model)
Clear(model, true);
/* forward + backward process */
ForwardAutoDiff(inputs, output, model);
/* this is implemented by gather function */
ForwardAutoDiff(ngrams, ngramNum, output, model);
/* this is implemented by multiply function */
//ForwardAutoDiff(inputs, output, model);
/* automatic differentiation */
autoDiffer.Backward(output, gold, CROSSENTROPY);
......@@ -975,7 +981,55 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
}
/*
forward process (with tensor connections)
forward process (with tensor connections) (this is implemented by gather function)
>> ngrams - the loaded ngrams
>> batch - the tensor encoding a batch of words
>> output - output probability
>> model - the fnn model
*/
void ForwardAutoDiff(NGram * ngrams, int batch, XTensor &output, FNNModel &model)
{
int n = model.n;
int depth = model.hDepth;
XTensor words;
XTensor embeddingBig;
XTensor hidden;
XTensor b;
int size = batch * (n-1);
int * index = new int[size];
for(int i = 0; i < batch; i++){
for (int j = 0; j < n-1; j++){
int a = i * (n - 1) + j;
index[a] = ngrams[i].words[j];
}
}
XTensor embedding;
embedding = Gather(model.embeddingW, 0, index, size);
delete[] index;
int dimSize[2];
dimSize[0] = embedding.GetDim(0) / (n - 1);
dimSize[1] = embedding.GetDim(1) * (n - 1);
hidden = Reshape(embedding, embedding.order, dimSize);
/* hidden layers */
for(int i = 0; i < depth; i++)
hidden = MMul(hidden, model.hiddenW[i]) + model.hiddenB[i];
/* output layer */
output = LogSoftmax(MMul(hidden, model.outputW) + model.outputB, 1);
//XLink::ShowNetwork(stderr, &output);
}
/*
forward process (with tensor connections) (this is implemented by multiply function)
>> inputs - input word representations
>> output - output probability
>> model - the fnn model
......@@ -1122,8 +1176,12 @@ void Test(const char * test, const char * result, FNNModel &model)
/* forward computation */
Forward(inputs, output, model, net);
}
else {
ForwardAutoDiff(inputs, output, model);
else {
/* this is implemented by gather function */
ForwardAutoDiff(ngrams, ngramNum, output, model);
/* this is implemented by multiply function */
//ForwardAutoDiff(inputs, output, model);
}
/* prediction probabilities */
......
......@@ -53,7 +53,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TAttention::InitModel(int argc, const char ** argv,
void T2TAttention::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
{
......@@ -69,18 +69,22 @@ void T2TAttention::InitModel(int argc, const char ** argv,
LoadParamInt(argc, argv, "d", &dv, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
LoadParamFloat(argc, argv, "attminmax", &minmax, 0.1F);
LoadParamFloat(argc, argv, "dropoutatt", &dropoutP, 0);
InitTensor2D(&wk, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wq, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wv, d, dv, X_FLOAT, devID, mem);
InitTensor2D(&wa, d, d, X_FLOAT, devID, mem);
float scale = 1.0F;
float finfoutk = (float)sqrt(6.0F * scale/(d + dk));
float finfoutv = (float)sqrt(6.0F * scale/(d + dv));
float finfouta = (float)sqrt(6.0F * scale / (d + d));
wk.SetDataRand(-finfoutk, finfoutk);
wq.SetDataRand(-finfoutk, finfoutk);
wv.SetDataRand(-finfoutv, finfoutv);
wa.SetDataRand(-finfouta, finfouta);
}
/*
......@@ -90,10 +94,11 @@ make the network
and H = vector size of each position
>> q - queries
>> v - values
>> maske - as it is
>> mask - as it is
>> isTraining - indicates whether the model is used for training
<< return - multi-attention result
*/
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining)
{
XTensor k2;
XTensor q2;
......@@ -123,14 +128,17 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
if(isMasked)
dot = dot + mask;
dot = Linear(dot, 1.0F/(float)sqrt((float)dk));
dot = Linear(dot, 1.0F/(float)sqrt((float)dk/nhead));
scalar = Softmax(dot, -1);
if(isTraining && dropoutP > 0)
scalar = Dropout(scalar, dropoutP);
att = BMMul(scalar, vheads);
/* concatenate the heads */
return Merge(att, att.order - 1);
return MMul(Merge(att, att.order - 1), wa);
}
}
......@@ -57,6 +57,9 @@ public:
/* transformation matrix for V */
XTensor wv;
/* transformation after dot-product attention */
XTensor wa;
/* size of transformed Q and K */
int dk;
......@@ -75,6 +78,9 @@ public:
/* indicates whether the model is used for training */
bool isTraining;
/* dropout probability */
DTYPE dropoutP;
public:
/* constructor */
......@@ -84,12 +90,12 @@ public:
~T2TAttention();
/* initialize the model */
void InitModel(int argc, const char ** argv,
void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask);
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining);
};
}
......
......@@ -34,7 +34,7 @@ class AttDecoder : T2TDecoder
{
public:
/* initialize the model */
void InitModel(int argc, const char ** argv);
void InitModel(int argc, char ** argv);
};
}
......
......@@ -48,7 +48,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TEmbedder::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
......@@ -60,7 +60,8 @@ void T2TEmbedder::InitModel(int argc, const char ** argv, int myDevID, XMem * my
InitTensor2D(&w, vSize, eSize, X_FLOAT, devID, mem);
w.SetDataRandn(0, 1.0F/(float)sqrt((float)eSize));
DTYPE v = 1.0F/(float)sqrt((float)eSize);
w.SetDataRand(-v, v);
/* create the positional embedding matrix */
MakePosEmbedding(eSize, d, maxLength);
......@@ -79,6 +80,17 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
for(int pos = 0; pos < length; pos++){
float * dp = data + pos * eSize;
int channelSize = eSize / 2;
int offset = 0;
for(int i = 0; i < channelSize; i++){
dp[offset++] = (float)sin(pos/pow(10000.0F, 2.0F*i/(d - 2)));
}
for(int i = 0; i < channelSize; i++){
dp[offset++] = (float)cos(pos/pow(10000.0F, 2.0F*i/(d - 2)));
}
/*
for(int k = 0; k < eSize; k++){
if(k % 2 == 0){
int i = k/2;
......@@ -89,6 +101,7 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
dp[k] = (float)cos(pos/pow(10000.0F, 2.0F*i/d));
}
}
*/
}
posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum);
......@@ -135,7 +148,7 @@ XTensor T2TEmbedder::Make(XTensor &input)
}
/* then we make word embeddings */
wordEmbedding = Linear(MMul(input, w), (float)sqrt((float)d));
wordEmbedding = Linear(MMul(input, w), (float)sqrt((float)eSize));
/* we sum over the two embeddings */
return wordEmbedding + posEmbedding;
......
......@@ -71,7 +71,7 @@ public:
~T2TEmbedder();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make positional embeddings */
void MakePosEmbedding(int eSize, int d, int length);
......
......@@ -51,7 +51,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void AttEncoder::InitModel(int argc, const char ** argv,
void AttEncoder::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
{
......@@ -89,16 +89,17 @@ void AttEncoder::InitModel(int argc, const char ** argv,
make the encoding network
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> skipInputRes - indicates whether we skip the residual connection of the first layer
>> isTraining - indicates whether the model is for training
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining)
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining)
{
XTensor x;
x = embedder.Make(input);
//x.Dump(tmpFILE, "embedding: ");
/* dropout */
if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
......@@ -109,37 +110,21 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool
XTensor fnn;
XTensor res;
/* we skip the residual connection for the first layer if
the encoder is used in language modeling. */
if(skipInputRes && i == 0){
/* self attention */
att = attentions[i].Make(x, x, x, mask);
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
/* layer normalization */
x = attLayerNorms[i].Make(att);
}
else{
/* self attention */
att = attentions[i].Make(x, x, x, mask, isTraining);
/* self attention */
att = attentions[i].Make(x, x, x, mask);
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
/* residual connection */
res = Sum(att, x);
/* residual connection */
res = Sum(att, x);
/* layer normalization */
x = attLayerNorms[i].Make(res);
}
/* layer normalization */
x = attLayerNorms[i].Make(res);
/* fnn */
fnn = fnns[i].Make(x);
fnn = fnns[i].Make(x, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
......@@ -150,9 +135,6 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool
/* layer normalization */
x = fnnLayerNorms[i].Make(res);
if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
}
return x;
......
......@@ -40,7 +40,7 @@ class T2TEncoder
{
public:
virtual
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining) = 0;
XTensor Make(XTensor &input, XTensor &mask, bool isTraining) = 0;
};
/*
......@@ -49,7 +49,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder
{
public:
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining);
XTensor Make(XTensor &input, XTensor &mask, bool isTraining);
};
......@@ -113,12 +113,12 @@ public:
~AttEncoder();
/* initialize the model */
void InitModel(int argc, const char ** argv,
void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining);
XTensor Make(XTensor &input, XTensor &mask, bool isTraining);
};
......
......@@ -49,7 +49,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
void T2TFNN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
......@@ -58,8 +58,9 @@ void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &outSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, DEFAULT_EMBEDDING_SIZE * 4);
LoadParamInt(argc, argv, "fnnh", &hSize, outSize * 4);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.1F);
LoadParamFloat(argc, argv, "dropoutfnn", &dropoutP, 0);
InitTensor2D(&w1, inSize, hSize, X_FLOAT, devID, mem);
InitTensor1D(&b1, hSize, X_FLOAT, devID, mem);
......@@ -83,12 +84,15 @@ y = max(0, x * w1 + b1) * w2 + b2
>> input - the input tensor
>> return - the output tensor
*/
XTensor T2TFNN::Make(XTensor &input)
XTensor T2TFNN::Make(XTensor &input, bool isTraining)
{
XTensor t1;
/* t1 = max(0, x * w1 + b1) */
t1 = Rectify(MMul(input, w1) + b1);
if(isTraining && dropoutP > 0)
t1 = Dropout(t1, dropoutP);
/* result = t1 * w2 + b2 */
return MMul(t1, w2) + b2;
......
......@@ -59,6 +59,9 @@ public:
/* bias of transformation 2 */
XTensor b2;
/* dropout probability */
DTYPE dropoutP;
public:
......@@ -69,10 +72,10 @@ public:
~T2TFNN();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
XTensor Make(XTensor &input, bool isTraining);
};
......
......@@ -32,7 +32,8 @@ namespace transformer
T2TLN::T2TLN()
{
devID = -1;
mem = NULL;
mem = NULL;
d = 0;
}
/* de-constructor */
......@@ -47,28 +48,28 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TLN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
void T2TLN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
int d = 0;
d = 0;
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
InitTensor2D(&w, d, d, X_FLOAT, devID, mem);
InitTensor1D(&w, d, X_FLOAT, devID, mem);
InitTensor1D(&b, d, X_FLOAT, devID, mem);
float scale = 1.0F;
float finfout = (float)sqrt(6.0F * scale / (d + d));
float finfout = (float)sqrt(6.0F * scale / d);
w.SetDataRand(-finfout, finfout);
b.SetZeroAll();
}
/*
make the network
make the network
for each layer representation x, we have
y =
y =
>> input - the input tensor
>> return - layer normalization output
*/
......@@ -90,16 +91,17 @@ XTensor T2TLN::Make(XTensor &input)
/* standard = sqrt(variance) */
standard = Power(variance, 0.5F);
/* unsqueeze mean and standard deviation to fit them into
/* unsqueeze mean and standard deviation to fit them into
the same shape of x */
meanFilled = Unsqueeze(mean, x.order - 1, x.GetDim(-1));
standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1));
/* x' = (x - \mu)/standard */
xn = (x - meanFilled)/standardFilled;
xn = (x - meanFilled) / standardFilled;
/* result = x' * w + b */
return MMul(xn, w) + b;
return xn * w + b;
}
}
......@@ -45,6 +45,9 @@ public:
/* the bias term b */
XTensor b;
/* dimension size of the model */
int d;
public:
/* constructor */
......@@ -54,7 +57,7 @@ public:
~T2TLN();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
......
......@@ -48,7 +48,7 @@ initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TModel::InitModel(int argc, const char ** argv)
void T2TModel::InitModel(int argc, char ** argv)
{
bool useMem = false;
int memSize = 0;
......@@ -64,25 +64,32 @@ void T2TModel::InitModel(int argc, const char ** argv)
if(useMem){
delete mem;
mem = new XMem(devID, isMemFreeOTF ? FREE_ON_THE_FLY : UNI_FREE, (MTYPE)MILLION * 256, 1024, MILLION * 128);
mem = new XMem(devID, FREE_ON_THE_FLY, (MTYPE)MILLION * 256, 1024, MILLION * 128);
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
}
encoder.InitModel(argc, argv, isLM, 0, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem);
XList params(10);
GetParams(params);
for(int i = 0; i < params.count; i++){
XTensor * param = (XTensor*)params.Get(i);
param->SetVarFlag();
}
}
/*
make the encoding network
>> input - input tensor
>> mask - the mask for positions that are/not involved in computation
>> skipInputRes - indicates whether we skip the residual connection of the first layer
>> isTraining - indicates whether we are training the model
<< return - encoding result
*/
XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining)
XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask, bool isTraining)
{
return encoder.Make(input, mask, skipInputRes, isTraining);
return encoder.Make(input, mask, isTraining);
}
/*
......@@ -134,9 +141,9 @@ void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding, bool isTr
_ScaleAndShiftMe(padding3, 1e9F, -1e9F);
//_Sum(&mask, padding3, &mask);
_Sum(&mask, padding3, &mask);
encoding = MakeEncoding(input, mask, true, isTraining);
encoding = MakeEncoding(input, mask, isTraining);
outputLayer.Make(encoding, output);
delete[] dims;
......@@ -167,6 +174,7 @@ void T2TModel::GetParams(XList &list)
list.Add(&encoder.attentions[i].wk);
list.Add(&encoder.attentions[i].wq);
list.Add(&encoder.attentions[i].wv);
list.Add(&encoder.attentions[i].wa);
list.Add(&encoder.fnnLayerNorms[i].w);
list.Add(&encoder.fnnLayerNorms[i].b);
list.Add(&encoder.attLayerNorms[i].w);
......
......@@ -66,10 +66,10 @@ public:
~T2TModel();
/* initialize the model */
void InitModel(int argc, const char ** argv);
void InitModel(int argc, char ** argv);
/* make the encoding network */
XTensor MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining);
XTensor MakeEncoding(XTensor &input, XTensor &mask, bool isTraining);
/* make the entire network (with the output softmax layer) */
void Make(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
......
......@@ -49,7 +49,7 @@ initialize the model
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TOutput::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
void T2TOutput::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
......
......@@ -59,7 +59,7 @@ public:
~T2TOutput();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
......
......@@ -37,15 +37,27 @@ namespace transformer
class T2TTrainer
{
public:
/* paramter number */
int argNum;
/* parameter array */
char ** argArray;
/* buffer for loading words */
int * buf;
/* another buffer */
int * buf2;
/* buffer size */
int bufSize;
/* length of each sequence */
int * seqLen;
/* another array */
int * seqLen2;
/* offset of the first word for each sequence */
int * seqOffset;
......@@ -101,6 +113,24 @@ public:
/* list of the 2nd order moment of the parameter matrics */
XList moments2nd;
/* indicates whether the data file is shuffled for training */
bool isShuffled;
/* the factor of label smoothing */
DTYPE labelSmoothingP;
/* number of steps after which we make a checkpoint */
int nStepCheckpoint;
/* indicates whether we make a checkpoint after each traing epoch */
bool useEpochCheckpoint;
/* number of batches on which we do model update */
int updateStep;
/* indicates whether we double the </s> symble for the output of lms */
bool isDoubledEnd;
public:
/* constructor */
T2TTrainer();
......@@ -109,14 +139,17 @@ public:
~T2TTrainer();
/* initialize the trainer */
void Init(int argc, const char ** argv);
void Init(int argc, char ** argv);
/* train the model */
void Train(const char * fn, T2TModel * model);
void Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model);
/* test the model */
void Test(const char * fn, const char * ofn, T2TModel * model);
/* make a checkpoint */
void MakeCheckpoint(T2TModel * model, const char * validFN, const char * modelFN, const char * label, int id);
/* load data to buffer */
int LoadBuf(FILE * file, bool isSorted, int step);
......@@ -130,6 +163,9 @@ public:
int step, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem);
/* shuffle the data file */
void Shuffle(const char * srcFile, const char * tgtFile);
/* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
......@@ -141,7 +177,13 @@ public:
void PrepareModel(T2TModel * model);
/* do padding on the output */
void PadOutput(XTensor * output, XTensor * padding);
void PadOutput(XTensor * output, XTensor * gold, XTensor * padding);
/* recale the output and gold tensors for normalized loss */
void RescaleOutput(XTensor * output, XTensor * gold, XTensor * padding);
/* perform label smoothing */
void LabelSmooth(XTensor * gold, XTensor * smoothed, DTYPE p);
};
......
......@@ -30,7 +30,7 @@ FILE * tmpFILE;
int llnum = 0;
FILE * tf = NULL;
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP)
void LoadParamString(int argc, char ** argv, const char * name, char * p, const char * defaultP)
{
char vname[128];
vname[0] = '-';
......@@ -47,7 +47,7 @@ void LoadParamString(int argc, const char ** argv, const char * name, char * p,
strcpy(p, defaultP);
}
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP)
void LoadParamInt(int argc, char ** argv, const char * name, int * p, int defaultP)
{
char vname[128];
vname[0] = '-';
......@@ -64,7 +64,7 @@ void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int
*p = defaultP;
}
void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP)
void LoadParamBool(int argc, char ** argv, const char * name, bool * p, bool defaultP)
{
char vname[128];
vname[0] = '-';
......@@ -81,7 +81,7 @@ void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bo
*p = defaultP;
}
void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP)
void LoadParamFloat(int argc, char ** argv, const char * name, float * p, float defaultP)
{
char vname[128];
vname[0] = '-';
......@@ -98,7 +98,7 @@ void LoadParamFloat(int argc, const char ** argv, const char * name, float * p,
*p = defaultP;
}
void ShowParams(int argc, const char ** argv)
void ShowParams(int argc, char ** argv)
{
fprintf(stderr, "args:\n");
for(int i = 0; i < argc; i++){
......
......@@ -30,13 +30,13 @@ namespace transformer
extern FILE * tmpFILE;
/* load arguments */
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP);
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP);
void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP);
void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP);
void LoadParamString(int argc, char ** argv, const char * name, char * p, const char * defaultP);
void LoadParamInt(int argc, char ** argv, const char * name, int * p, int defaultP);
void LoadParamBool(int argc, char ** argv, const char * name, bool * p, bool defaultP);
void LoadParamFloat(int argc, char ** argv, const char * name, float * p, float defaultP);
/* show arguments */
void ShowParams(int argc, const char ** argv);
void ShowParams(int argc, char ** argv);
extern int llnum;
extern FILE * tf;
......
......@@ -19,6 +19,7 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "Transformer.h"
#include "T2TModel.h"
#include "T2TUtility.h"
......@@ -32,31 +33,39 @@ int TransformerMain(int argc, const char ** argv)
{
if(argc == 0)
return 1;
fprintf(stderr, "%e\n", log(1e-8F));
char ** args = new char*[argc];
for(int i = 0; i < argc; i++){
args[i] = new char[strlen(argv[i]) + 1];
strcpy(args[i], argv[i]);
}
tmpFILE = fopen("tmp.txt", "wb");
ShowParams(argc, argv);
ShowParams(argc, args);
char * trainFN = new char[MAX_LINE_LENGTH];
char * modelFN = new char[MAX_LINE_LENGTH];
char * testFN = new char[MAX_LINE_LENGTH];
char * outputFN = new char[MAX_LINE_LENGTH];
LoadParamString(argc, argv, "train", trainFN, "");
LoadParamString(argc, argv, "model", modelFN, "");
LoadParamString(argc, argv, "test", testFN, "");
LoadParamString(argc, argv, "output", outputFN, "");
LoadParamString(argc, args, "train", trainFN, "");
LoadParamString(argc, args, "model", modelFN, "");
LoadParamString(argc, args, "test", testFN, "");
LoadParamString(argc, args, "output", outputFN, "");
T2TTrainer trainer;
trainer.Init(argc, argv);
trainer.Init(argc, args);
T2TModel model;
model.InitModel(argc, argv);
model.InitModel(argc, args);
/* learn model parameters */
if(strcmp(trainFN, ""))
trainer.Train(trainFN, &model);
trainer.Train(trainFN, testFN, strcmp(modelFN, "") ? modelFN : "checkpoint.model", &model);
/* save the final model */
if(strcmp(modelFN, "") && strcmp(trainFN, ""))
......@@ -66,18 +75,25 @@ int TransformerMain(int argc, const char ** argv)
if(strcmp(modelFN, ""))
model.Read(modelFN);
T2TTrainer tester;
tester.Init(argc, args);
/* test the model on the new data */
if(strcmp(testFN, "") && strcmp(outputFN, ""))
trainer.Test(testFN, outputFN, &model);
tester.Test(testFN, outputFN, &model);
delete[] trainFN;
delete[] modelFN;
delete[] testFN;
delete[] outputFN;
for(int i = 0; i < argc; i++)
delete[] args[i];
delete[] args;
fclose(tmpFILE);
return 0;
}
}
\ No newline at end of file
}
......@@ -55,6 +55,9 @@ namespace nts {
#define DTYPE_MIN (DTYPE)-3.40E+38
#endif
#define LOGPROB_MIN (DTYPE)-2E+1
#define GRAD_MAX (DTYPE)1E+5
#if WIN32
#define DELIMITER '\\'
#else
......@@ -148,6 +151,7 @@ extern bool useCUDA;
#define XPRINT5(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5);FFLUSH(FILEH);}}
#define XPRINT6(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6);FFLUSH(FILEH);}}
#define XPRINT7(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7);FFLUSH(FILEH);}}
#define XPRINT8(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7,ARG8) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7,ARG8);FFLUSH(FILEH);}}
#define B2I(V) V==0?false:true
......
......@@ -263,6 +263,18 @@ int XLink::GetParamInt(int i)
char * p = (char*)params + i * paramSize;
return *(int*)p;
}
/*
get a paramter in integer
>> i - id of the parameter
<< return - the parameter in integer
*/
void * XLink::GetParamPointer(int i)
{
CheckNTErrors(params != NULL, "parameter array cannot be empty!");
char * p = (char*)params + i * paramSize;
return *(int **)p;
}
/*
get a parameter in MATRIX_TRANS_TYPE
......@@ -401,8 +413,7 @@ add a boolean parameter
*/
void XLink::AddParamToHeadBool(XTensor * h, bool param)
{
if(h != NULL)
return;
CheckNTErrors(h != NULL, "head tensor cannot be empty!");
h->income.AddParam(&param, sizeof(bool));
}
......@@ -413,8 +424,7 @@ add a pointer parameter
*/
void XLink::AddParamToHeadPointer(XTensor * h, void * param)
{
if(h != NULL)
return;
CheckNTErrors(h != NULL, "head tensor cannot be empty!");
h->income.AddParam(&param, sizeof(param));
}
......@@ -589,9 +599,24 @@ show the network encoded in a root node (tensor)
*/
void XLink::ShowNetwork(FILE * file, XTensor * root)
{
fprintf(file, "node %d - ", root->id);
XLink &income = root->income;
for(int i = 0; i < income.tailNum; i++){
XTensor * child = income.tails[i];
ShowNetwork(file, child);
}
}
/*
show a node
>> file - file to dump information
>> root - pointer to the node
*/
void XLink::ShowNode(FILE * file, XTensor * node)
{
fprintf(file, "node %d - ", node->id);
XLink &income = node->income;
if(income.head == NULL){
fprintf(file, "income[%d]: null ", income.tailNum);
}
......@@ -607,7 +632,7 @@ void XLink::ShowNetwork(FILE * file, XTensor * root)
}
fprintf(stderr, ", ");
XLink &outgo = root->outgo;
XLink &outgo = node->outgo;
if(outgo.head == NULL || outgo.tailNum == 0){
fprintf(file, "outgo[%d]: null ", outgo.tailNum);
}
......@@ -623,11 +648,6 @@ void XLink::ShowNetwork(FILE * file, XTensor * root)
}
fprintf(stderr, "\n");
for(int i = 0; i < income.tailNum; i++){
XTensor * child = income.tails[i];
ShowNetwork(file, child);
}
}
} // namespace nts(NiuTrans.Tensor)
......
......@@ -127,6 +127,9 @@ struct XLink
/* get a paramter in integer */
int GetParamInt(int i);
/* get a paramter in pointer */
void * GetParamPointer(int i);
/* get a parameter in MATRIX_TRANS_TYPE */
MATRIX_TRANS_TYPE GetParamTrans(int i);
......@@ -178,6 +181,10 @@ struct XLink
/* show the network encoded in a root node (tensor) */
static
void ShowNetwork(FILE * file, XTensor * root);
/* show a node */
static
void ShowNode(FILE * file, XTensor * node);
};
} // namespace nts(NiuTrans.Tensor)
......
......@@ -600,7 +600,7 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
void * result = NULL;
/* search for the memory piece avialable for the allocation */
for(int i = indexEntryNum; i > index; i--){
for(int i = index; i <= indexEntryNum; i++){
if(i == indexEntryNum){
entry = memIndex + index;
CheckNTErrors(mySize >= minSizeIndex[index], "Wrong index!");
......@@ -667,7 +667,7 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
hit->size = mySize;
hit->head.state = 2;
hit->pReal = beg;
blocks[hit->head.blockID].used += mySize;
blocks[hit->head.blockID].used += head->size;
RemoveFreeIndexNode(hit);
AddAllocIndexNode(hit);
......@@ -690,7 +690,7 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
continue;
if (block->mem == NULL) {
block->size = MAX(maxBlockSize, mySize + 2 * MY_PITCH);
block->size = MAX(block->sizeDesired, mySize + 2 * MY_PITCH);
if (myDevID < 0) {
block->mem = new char[block->size];
memset(block->mem, 0, block->size);
......@@ -719,8 +719,9 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
newNode->head.indexNode = newNode;
newNode->p = block->mem;
newNode->pReal = NULL;
newNode->size = (char*)block->mem + mySize -
(char*)GetPitchedAddress(block->mem, MY_PITCH);
//newNode->size = (char*)block->mem + block->size -
// (char*)GetPitchedAddress(block->mem, MY_PITCH);
newNode->size = mySize;
AddFreeIndexNode(newNode);
......@@ -1041,9 +1042,14 @@ void XMem::RebuildIndex()
/* make a new index node */
MPieceNode * newNode = memIndex2 + nodeNumUsed2++;
newNode->p = p;
newNode->size = node->size;
//newNode->size = (char*)p + head->size -
// ( head->state == 1 ? (char*)GetPitchedAddress((char*)p, MY_PITCH) : (char*)head->indexNode->pReal);
if(head->state == 1){
newNode->size = (char*)p + head->size -
( head->state == 1 ? (char*)GetPitchedAddress((char*)p, MY_PITCH) : (char*)head->indexNode->pReal);
}
else
newNode->size = node->size;
newNode->pre = NULL;
newNode->next = NULL;
......
......@@ -35,6 +35,8 @@ const char * GetOPName(int type)
return "M_EXP";
else if (type == MATH_FLOOR)
return "M_FLOOR";
else if (type == MATH_ISZERO)
return "M_ISZERO";
else if (type == MATH_LOG)
return "M_LOG";
else if (type == MATH_SQRT)
......@@ -107,10 +109,14 @@ const char * GetOPName(int type)
return "S_MERGE_LIST";
else if (type == SHAPE_PERMUTE)
return "S_PERMUTE";
else if (type == SHAPE_RESHAPE)
return "S_RESHAPE";
else if (type == SHAPE_SPLIT)
return "S_SPLIT";
else if (type == SHAPE_SPLIT_LIST)
return "S_SPLIT_LIST";
else if (type == SHAPE_SQUEEZE)
return "S_SQUEEZE";
else if (type == SHAPE_TRANSPOSE)
return "S_TRANSPOSE";
else if (type == SHAPE_UNSQUEEZE)
......
......@@ -35,7 +35,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_CEIL MATH_ABSOLUTE + 1
#define MATH_EXP MATH_CEIL + 1
#define MATH_FLOOR MATH_EXP + 1
#define MATH_LOG MATH_FLOOR + 1
#define MATH_ISZERO MATH_FLOOR + 1
#define MATH_LOG MATH_ISZERO + 1
#define MATH_SQRT MATH_LOG + 1
#define MATH_SQUARE MATH_SQRT + 1
#define MATH_SIN MATH_SQUARE + 1
......@@ -81,9 +82,11 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define SHAPE_MERGE SHAPE_CONCATENATE + 1
#define SHAPE_MERGE_LIST SHAPE_MERGE + 1
#define SHAPE_PERMUTE SHAPE_MERGE_LIST + 1
#define SHAPE_SPLIT SHAPE_PERMUTE + 1
#define SHAPE_RESHAPE SHAPE_PERMUTE + 1
#define SHAPE_SPLIT SHAPE_RESHAPE + 1
#define SHAPE_SPLIT_LIST SHAPE_SPLIT + 1
#define SHAPE_TRANSPOSE SHAPE_SPLIT_LIST + 1
#define SHAPE_SQUEEZE SHAPE_SPLIT_LIST + 1
#define SHAPE_TRANSPOSE SHAPE_SQUEEZE + 1
#define SHAPE_UNSQUEEZE SHAPE_TRANSPOSE + 1
#define SORT SHAPE_UNSQUEEZE + 1
......
......@@ -38,6 +38,7 @@
#include "XMem.h"
#include "XHeap.h"
#include "XBLAS.h"
#include "XName.h"
#include "core/shape/MergeBlockLists.h"
#include "core/movement/CopyValues.h"
#include "core/arithmetic/Sum.h"
......@@ -45,6 +46,7 @@
#include "core/arithmetic/Sub.h"
#include "core/arithmetic/Div.h"
#include "core/math/ScaleAndShift.h"
#include "function/Identity.h"
#ifdef USE_CUDA
......@@ -202,7 +204,7 @@ XTensor::~XTensor()
dims[0] = -dims[0];
XTensor * newTensor = new XTensor(order, dims, dataType, denseRatio, devID, mem);
newTensor->SetTMP();
newTensor->SetTMPFlag();
newTensor->data = data;
data = NULL;
......@@ -244,6 +246,7 @@ void XTensor::Init()
isInit = false;
isTmp = false;
isGrad = false;
isVar = false;
visitMark = 0;
grad = NULL;
}
......@@ -289,7 +292,8 @@ void XTensor::ShallowCopy(const XTensor &tensor)
/* overloading of the equal-sign */
XTensor& XTensor::operator= (const XTensor& tensor)
{
/* we must make a hard copy of the tensor if it is the input
/* we must make a hard copy of the tensor if it is the input
of another node. */
if(outgo.tailNum > 0){
int dims[MAX_TENSOR_DIM_NUM];
......@@ -297,7 +301,7 @@ XTensor& XTensor::operator= (const XTensor& tensor)
dims[0] = -dims[0];
XTensor * newTensor = new XTensor(order, dims, dataType, denseRatio, devID, mem);
newTensor->SetTMP();
newTensor->SetTMPFlag();
newTensor->data = data;
newTensor->dataHost = dataHost;
newTensor->signature = tensor.signature;
......@@ -311,38 +315,54 @@ XTensor& XTensor::operator= (const XTensor& tensor)
dataHost = NULL;
}
/* hard copy of the data array */
int size = unitNum * unitSize;
if( isInit && !isSparse && !tensor.isSparse &&
size == tensor.unitNum * tensor.unitSize &&
((devID < 0 && tensor.devID < 0) && devID == tensor.devID) &&
data != NULL)
{
XMemCopy(data, devID, tensor.data, tensor.devID, size);
if(dataHost != NULL && tensor.dataHost != NULL)
XMemCopy(dataHost, -1, tensor.dataHost, tensor.devID, size);
if(false && !tensor.isTmp){
/* NOTE: this might lead to additional data copy on Mac machines */
/* we make an identity transformation here */
if(outgo.tailNum > 0)
XLink::ClearOutgoing(this);
XLink::ClearIncoming(this);
if(!IsSameShaped(this, &tensor))
Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
_Identity(&tensor, this);
XLink::MakeLink(&tensor, NULL, this, FUNC_IDENTITY);
}
else{
DestroyData();
if(!isInit){
devID = tensor.devID;
mem = tensor.mem;
/* hard copy of the data array */
int size = unitNum * unitSize;
if( isInit && !isSparse && !tensor.isSparse &&
size == tensor.unitNum * tensor.unitSize &&
((devID < 0 && tensor.devID < 0) && devID == tensor.devID) &&
data != NULL)
{
XMemCopy(data, devID, tensor.data, tensor.devID, size);
if(dataHost != NULL && tensor.dataHost != NULL)
XMemCopy(dataHost, -1, tensor.dataHost, tensor.devID, size);
}
else{
DestroyData();
if(!isInit){
devID = tensor.devID;
mem = tensor.mem;
}
Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
_CopyValues(&tensor, this);
}
Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
_CopyValues(&tensor, this);
}
/* copy member variables */
ShallowCopy(tensor);
/* copy member variables */
ShallowCopy(tensor);
isInit = true;
isTmp = false;
isInit = true;
isTmp = false;
CheckNTErrors(outgo.tailNum == 0, "The node has outgoing edge to other nodes!");
CheckNTErrors(outgo.tailNum == 0, "The node has outgoing edge to other nodes!");
/* create tensor links for the new tensor */
XLink::Replace(&tensor, this);
/* create tensor links for the new tensor */
XLink::Replace(&tensor, this);
}
return *this;
}
......@@ -353,24 +373,48 @@ XTensor XTensor::operator+ (const XTensor& tensor)
return Sum(*this, tensor);
}
/* overloading of the plus-sign */
XTensor XTensor::operator+ (const DTYPE shift)
{
return ScaleAndShift(*this, 1, shift);
}
/* overloading of the multiply-sign */
XTensor XTensor::operator* (const XTensor& tensor)
{
return Multiply(*this, tensor);
}
/* overloading of the multiply-sign */
XTensor XTensor::operator* (const DTYPE scale)
{
return ScaleAndShift(*this, scale, 0);
}
/* overloading of the minus-sign */
XTensor XTensor::operator- (const XTensor& tensor)
{
return Sub(*this, tensor);
}
/* overloading of the minus-sign */
XTensor XTensor::operator- (const DTYPE shift)
{
return ScaleAndShift(*this, 1, -shift);
}
/* overloading of the division-sign */
XTensor XTensor::operator/ (const XTensor& tensor)
{
return Div(*this, tensor);
}
/* overloading of the division-sign */
XTensor XTensor::operator/ (const DTYPE scale)
{
return ScaleAndShift(*this, (DTYPE)1/scale, 0);
}
/*
linear transformation b = a * \scale + \shift
>> scale - the slope
......@@ -419,7 +463,7 @@ judge whether the three matrices are in the same type and size
>> c - a tensor again
<< return - whether the two input tensors are identical
*/
bool XTensor::IsSameShaped(XTensor * a, XTensor * b, XTensor * c)
bool XTensor::IsSameShaped(const XTensor * a, const XTensor * b, const XTensor * c)
{
return IsSameShaped(a, b) && IsSameShaped(a, c);
}
......@@ -440,7 +484,7 @@ void XTensor::SetDim(int * myDimSize)
get the size of a given dimension
>> dim - the given dim we are looking at
*/
int XTensor::GetDim(const int dim)
int XTensor::GetDim(const int dim) const
{
CheckNTErrors(dim < order, "dimenision is out of range!");
......@@ -746,6 +790,20 @@ void XTensor::SetDataPointer()
dataP = &data;
}
/* compare two number */
bool IsFloatEqual(DTYPE a, DTYPE b, float absError, float relError)
{
if(a == b)
return true;
if(fabs(a - b) < absError)
return true;
if(fabs(a) < fabs(b))
return (fabs(a - b) / b < relError) ? true : false;
else
return (fabs(a - b) / a < relError) ? true : false;
}
/* check whether the data array is the same as the answer */
bool XTensor::CheckData(const void * d, int num, float tolerance, int beg)
{
if (data == NULL || d == NULL)
......@@ -759,7 +817,7 @@ bool XTensor::CheckData(const void * d, int num, float tolerance, int beg)
DTYPE * answerPrt = (DTYPE*)d;
for (int i = beg; i < num; i++) {
value = ToCPU(devID, valuePrt);
if (fabs(value - *answerPrt) > tolerance)
if(IsFloatEqual(value, *answerPrt, tolerance, 1e-4F) == false)
return false;
valuePrt++;
answerPrt++;
......@@ -1125,7 +1183,7 @@ int XTensor::GetNonzeroSize()
set the tensor as "temporary"
>> myIsTMP - the flag
*/
void XTensor::SetTMP(bool myIsTmp)
void XTensor::SetTMPFlag(bool myIsTmp)
{
isTmp = myIsTmp;
}
......@@ -1134,12 +1192,23 @@ void XTensor::SetTMP(bool myIsTmp)
set the tensor as "keep-gradient"
>> myIsGrad - the flag
*/
void XTensor::SetGrad(bool myIsGrad)
void XTensor::SetGradFlag(bool myIsGrad)
{
isGrad = myIsGrad;
}
/*
set the tensor as "variable"
>> myIsVar - the flag
*/
void XTensor::SetVarFlag(bool myIsVar)
{
isVar = myIsVar;
if(isVar)
SetGradFlag(true);
}
/*
resize a tensor with a specified tensor size
>> myOrder - order of the tensor
>> myDimSize - the size of each dimension
......@@ -1415,9 +1484,18 @@ void XTensor::Dump(FILE * file, const char * label, const int n, const int beg,
}
}
else {
ShowNTErrors("TODO!");
else if(dataType == X_INT) {
int end = MIN(n > 0 ? beg + n : beg + unitNum, unitNum);
for(int i = beg; i < end; i++){
int f = ((int*)d)[i];
if(i == beg)
fprintf(file, "%d", f);
else
fprintf(file, " %d", f);
}
}
else
ShowNTErrors("TODO!");
}
else {
int num = this->unitNumNonZero > 0 ? *(int*)d : 0;
......
......@@ -145,6 +145,9 @@ public:
/* indicates whether the tensor keeps the gradient when used as model parameters */
bool isGrad;
/* indicates whether the tensor is used as paramters (or variables) */
bool isVar;
/* mark for traversing the gragh */
unsigned int visitMark;
......@@ -201,15 +204,27 @@ public:
/* overloading of the plus-sign */
XTensor operator+ (const XTensor &tensor);
/* overloading of the plus-sign */
XTensor operator+ (const DTYPE shift);
/* overloading of the multiply-sign */
XTensor operator* (const XTensor &tensor);
/* overloading of the multiply-sign */
XTensor operator* (const DTYPE scale);
/* overloading of the minus-sign */
XTensor operator- (const XTensor &tensor);
/* overloading of the minus-sign */
XTensor operator- (const DTYPE shift);
/* overloading of the division-sign */
XTensor operator/ (const XTensor &tensor);
/* overloading of the division-sign */
XTensor operator/ (const DTYPE scale);
/* linear transformation */
XTensor Lin(DTYPE scale, DTYPE shift = 0);
......@@ -220,13 +235,13 @@ public:
/* judge whether the three matrices are in the same type and size */
static
bool IsSameShaped(XTensor * a, XTensor * b, XTensor * c);
bool IsSameShaped(const XTensor * a, const XTensor * b, const XTensor * c);
/* set the size of each dimension */
void SetDim(int * myDimSize);
/* get the size of a given dimension */
int GetDim(const int dim);
int GetDim(const int dim) const;
/* reshape the tensor */
void Reshape(const int order, const int * myDimSize);
......@@ -319,10 +334,13 @@ public:
int GetNonzeroSize();
/* set the tensor as "temporary" */
void SetTMP(bool myIsTmp = true);
void SetTMPFlag(bool myIsTmp = true);
/* set the tensor as "keep-gradient" */
void SetGrad(bool myIsGrad = true);
void SetGradFlag(bool myIsGrad = true);
/* set the tensor as "variable" */
void SetVarFlag(bool myIsVar = true);
/* resize a matrix with a specified matrix size */
bool Resize(const int myOrder, const int * myDimSize,
......
......@@ -63,11 +63,14 @@
#include "movement/CopyIndexed.h"
#include "movement/CopyInGrid.h"
#include "movement/CopyValues.h"
#include "movement/Gather.h"
#include "movement/Spread.h"
#include "reduce/ReduceMax.h"
#include "reduce/ReduceMean.h"
#include "reduce/ReduceStandardVariance.h"
#include "reduce/ReduceSum.h"
#include "reduce/ReduceSumAll.h"
#include "reduce/ReduceSumSquared.h"
#include "reduce/ReduceVariance.h"
......@@ -77,8 +80,10 @@
#include "shape/MakeSplitBlockIndex.h"
#include "shape/Merge.h"
#include "shape/MergeBlockLists.h"
#include "shape/Reshape.h"
#include "shape/Permute.h"
#include "shape/Split.h"
#include "shape/Squeeze.h"
#include "shape/Transpose.h"
#include "shape/Unsqueeze.h"
......
......@@ -147,6 +147,8 @@ int GetDivDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
if(XTensor::IsSameShaped(&a, &b))
return -1;
int hitCount = 0;
int hitDim = -1;
......@@ -181,7 +183,7 @@ where i is the index of the item
XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
{
XTensor c(&a);
c.SetTMP();
c.SetTMPFlag();
int n = GetDivDimIndex(a, b);
......
......@@ -150,7 +150,7 @@ i.e., a is divided with b by broadcasting
XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
{
XTensor c(&a);
c.SetTMP();
c.SetTMPFlag();
/* call _Div function */
_DivDim(&a, &b, &c, n, alpha);
......
......@@ -249,7 +249,7 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
float dr = (!a.isSparse || !b.isSparse) ? 1.0F : MAX(a.denseRatio, b.denseRatio);
XTensor c(order, dimSize, a.dataType, dr, a.devID, a.mem);
c.SetTMP();
c.SetTMPFlag();
/* call _MatrixMul function */
_MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner);
......@@ -299,7 +299,7 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b,
float dr = (!a.isSparse || !b.isSparse) ? 1.0F : MAX(a.denseRatio, b.denseRatio);
XTensor c(order, dimSize, a.dataType, dr, a.devID, a.mem);
c.SetTMP();
c.SetTMPFlag();
/* call _MatrixMul function */
_MatrixMul(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
......
......@@ -314,7 +314,7 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const
float dr = (!a.isSparse || !b.isSparse) ? 1.0F : MAX(a.denseRatio, b.denseRatio);
XTensor c(order, dimSize, a.dataType, dr, a.devID, a.mem);
c.SetTMP();
c.SetTMPFlag();
/*call _MatrixMulBatched function */
_MatrixMulBatched(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner);
......@@ -370,7 +370,7 @@ XTensor MatrixMulBatched(const XTensor &a, const XTensor &b,
float dr = (!a.isSparse || !b.isSparse) ? 1.0F : MAX(a.denseRatio, b.denseRatio);
XTensor c(order, dimSize, a.dataType, dr, a.devID, a.mem);
c.SetTMP();
c.SetTMPFlag();
/*call _MatrixMulBatched function */
_MatrixMulBatched(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
......
......@@ -148,6 +148,8 @@ int GetMultiplyDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
if(XTensor::IsSameShaped(&a, &b))
return -1;
int hitCount = 0;
int hitDim = -1;
......@@ -182,7 +184,7 @@ XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim
{
XTensor c(&a);
c.SetTMP();
c.SetTMPFlag();
int n = GetMultiplyDimIndex(a, b);
......
......@@ -148,7 +148,7 @@ i.e., a is multiplied with b by broadcasting
XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
{
XTensor c(&a);
c.SetTMP();
c.SetTMPFlag();
/* call _Multiply function */
_MultiplyDim(&a, &b, &c, n, alpha);
......
......@@ -68,7 +68,7 @@ make a new tensor to keep the result and return it
XTensor Negate(const XTensor & a)
{
XTensor b(&a);
b.SetTMP();
b.SetTMPFlag();
/* call _Negate function */
_Negate(&a, &b);
......
......@@ -74,7 +74,7 @@ make a new tensor to keep the result and return it
XTensor Sign(const XTensor & a)
{
XTensor b(&a);
b.SetTMP();
b.SetTMPFlag();
/* call _Sign function */
_Sign(&a, &b);
......
......@@ -134,6 +134,8 @@ int GetSubDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
if(XTensor::IsSameShaped(&a, &b))
return -1;
int hitCount = 0;
int hitDim = -1;
......@@ -164,7 +166,7 @@ make a new tensor c to keep the result and return it
XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
c.SetTMPFlag();
int n = GetSubDimIndex(a, b);
......
......@@ -150,7 +150,7 @@ i.e., a is subtracted with b by broadcasting
XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
c.SetTMPFlag();
/* call _Sub function */
_SubDim(&a, &b, &c, n, beta);
......
......@@ -139,6 +139,8 @@ int GetSumDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
if(XTensor::IsSameShaped(&a, &b))
return -1;
int hitCount = 0;
int hitDim = -1;
......@@ -169,7 +171,7 @@ make a new tensor c to keep the result and return it
XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
c.SetTMPFlag();
int n = GetSumDimIndex(a, b);
......
......@@ -150,7 +150,7 @@ i.e., a is summed with b by broadcasting
XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
c.SetTMPFlag();
/* call _Sum function */
_SumDim(&a, &b, &c, n, beta);
......
......@@ -111,7 +111,7 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high)
float dr = (!a.isSparse) ? 1.0F : a.denseRatio;
XTensor c(order, dimSize, a.dataType, dr, a.devID, a.mem);
c.SetTMP();
c.SetTMPFlag();
/* call _SelectRange function */
_SelectRange(&a, &c, dim, low, high);
......
......@@ -234,7 +234,7 @@ void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim > 0, "Illegal dimension!");
CheckNTErrors(dim < n && dim >= 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
......@@ -264,11 +264,78 @@ void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
}
/*
modify data items along with a given index and dimension (and keep the remaining items unchanged)
>> source - the tensor whose data array would be modified
>> modify - the tensor whose data array would be used to modify the source tensor
>> dim - the dimension along which we modify the tensor
>> index - index of the given dimension
e.g., given a source tensor (3, 3)
1 2 3
4 5 6
7 8 9
given a modified tensor (3)
1 2 3
when dim = 0, index = 1, we have
1 2 3
1 2 3
7 8 9
i.e., we set entries of row 1 to {1, 2, 3}
*/
void _SetDataIndexed(XTensor * source, XTensor * modify, int dim, int index)
{
int order = source->order;
int size = source->GetDim(dim);
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
CheckNTErrors(index >= 0 && index < size, "Illegal index!");
for(int i = 0; i < order - 1; i++){
if(i < dim){
CheckNTErrors(modify->GetDim(i) == source->GetDim(i), "Illegal dimension!");
}
else if(i >= dim){
CheckNTErrors(modify->GetDim(i) == source->GetDim(i+1), "Illegal dimension!");
}
}
if(source->devID < 0 && modify->devID < 0){
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for(int i = order - 1; i > dim; i--){
stride *= source->GetDim(i);
}
blockSize = stride * source->GetDim(dim);
blockNum = source->unitNum / blockSize;
for(int i = 0; i < blockNum; i++){
DTYPE * d = (DTYPE*)source->data + blockSize * i + index * stride;
DTYPE * p = (DTYPE*)modify->data + stride * i;
for(int j = 0; j < stride; j++)
d[j] = p[j];
}
}
else if(source->devID >= 0 && modify->devID >= 0) {
#ifdef USE_CUDA
_CudaSetDataIndexed(source, modify, dim, index);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
}
else{
ShowNTErrors("TODO!");
}
}
/*
generate data as lower triangular matrics for last two dimensions
>> tensor - the tensor whose data to be set
>> p - the value for each entry of the lower triangular matrics
>> shift - the offset from diagonal
e.g., for a 3* 3 tensor,
e.g., for a 3 * 3 tensor,
when p = 1 ans shift = 0, we have
1 0 0
1 1 0
......@@ -363,7 +430,6 @@ void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
}
}
/*
generate data items with a normal distribution with specified mean and standard deviation
>> mean - mean or expectation of the distribution
......
......@@ -231,7 +231,7 @@ void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim > 0, "Illegal dimension!");
CheckNTErrors(dim < n && dim >= 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
......@@ -255,12 +255,95 @@ void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataDim<<<blocks, threads >>>((DTYPE*)tensor->data, beg * stride, len * stride, blockSize, blockNum, p);
KernelSetDataDim<<<blocks, threads >>>((DTYPE*)tensor->data, beg * stride,
len * stride, blockSize, blockNum, p);
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
modify data items along with a given index and dimension
(and keep the remaining items unchanged) - kernel version
>> s - the pointer whose data would be modified
>> m - the pointer whose data would be used to modify the data pointed by s
>> blockNum - number of data blocks
>> blockSize - size of a data block
>> stride - stride of a data block
*/
__global__
void KernelSetDataIndexed(DTYPE * s, DTYPE * m, int blockNum, int blockSize, int stride)
{
/* offset in each block */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* block id */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= stride || j >= blockNum)
return;
int x = blockSize * j + i;
int y = stride * j + i;
s[x] = m[y];
}
/*
modify data items along with a given index and dimension (and keep the remaining items unchanged)
>> source - the tensor whose data array would be modified
>> modify - the tensor whose data array would be used to modify the source tensor
>> dim - the dimension along which we modify the tensor
>> index - index of the given dimension
e.g., given a source tensor (3, 3)
1 2 3
4 5 6
7 8 9
given a modified tensor (3)
1 2 3
when dim = 0, index = 1, we have
1 2 3
1 2 3
7 8 9
i.e., we set entries of row 1 to {1, 2, 3}
*/
void _CudaSetDataIndexed(XTensor * source, XTensor * modify, int dim, int index)
{
int order = source->order;
int size = source->GetDim(dim);
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
CheckNTErrors(index >= 0 && index < size, "Illegal index!");
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for(int i = order - 1; i > dim; i--){
stride *= source->GetDim(i);
}
blockSize = stride * source->GetDim(dim);
blockNum = source->unitNum / blockSize;
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(source->devID, stride, blockNum, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup);
KernelSetDataIndexed<<<blocks, threads >>>((DTYPE*)source->data + index * stride, (DTYPE*)modify->data,
blockNum, blockSize, stride);
BacktoCudaDev(source->devID, devIDBackup);
}
/*
set lower triangular matrics for each block
>> d - pointer to the data array
>> l - row number (or column number) of each block, i.e,
......
......@@ -40,6 +40,9 @@ void _CudaSetDataFixedDouble(XTensor * tensor, double p);
/* set data items along with a given dimension (and keep the remaining items unchanged) */
void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p);
/* modify data items along with a given index and dimension (and keep the remaining items unchanged) */
void _CudaSetDataIndexed(XTensor * source, XTensor * modify, int dim, int index);
/* generate data as lower triangular matrics for last two dimensions (cuda version) */
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift);
......
......@@ -48,6 +48,9 @@ void _SetDataFixedDouble(XTensor * tensor, double p);
/* set data items along with a given dimension (and keep the remaining items unchanged) */
void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p);
/* modify data items along with a given index and dimension (and keep the remaining items unchanged) */
void _SetDataIndexed(XTensor * source, XTensor * modify, int dim, int index);
/* generate data as lower triangular matrics for last two dimensions */
void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift);
......
......@@ -81,7 +81,7 @@ make a new tensor to keep the result and return it
XTensor Clip(const XTensor & a, DTYPE lower, DTYPE upper)
{
XTensor b(&a);
b.SetTMP();
b.SetTMPFlag();
/* call _Clip function */
_Clip(&a, &b, lower, upper);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-03
*/
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-03
*/
#ifndef __CLIP_H__
#define __CLIP_H__
......@@ -29,16 +30,12 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its clip value */
void _Clip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper);
/*
set every entry to its clip value (do it on site)
keep the result in the input tensor a and return nothing
*/
/* set every entry to its clip value (do it on site)
keep the result in the input tensor a and return nothing */
void _ClipMe(XTensor * a, DTYPE lower, DTYPE upper);
/*
set every entry to its clip value (return a XTensor structure)
make a new tensor to keep the result and return it
*/
/* set every entry to its clip value (return a XTensor structure)
make a new tensor to keep the result and return it */
XTensor Clip(const XTensor & a, DTYPE lower, DTYPE upper);
/*
......
......@@ -132,7 +132,7 @@ where a and b are the scalar and bias respectively, and \epsilon is the adjustme
XTensor Normalize(const XTensor &input, int dim, const XTensor &mean, const XTensor &var, const XTensor &a, const XTensor &b, DTYPE epsilon)
{
XTensor output(&input);
output.SetTMP();
output.SetTMPFlag();
/* call _Normalize function */
_Normalize(&input, &output, dim, &mean, &var, &a, &b, epsilon);
......
......@@ -90,7 +90,7 @@ make a new tensor to keep the result and return it
XTensor Power(const XTensor & a, DTYPE p)
{
XTensor b(&a);
b.SetTMP();
b.SetTMPFlag();
/* call _Power function */
_Power(&a, &b, p);
......
......@@ -105,7 +105,7 @@ b = a * scale + shift
XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift)
{
XTensor b(&a);
b.SetTMP();
b.SetTMPFlag();
/* call _ScaleAndShift function */
_ScaleAndShift(&a, &b, scale, shift);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#include <math.h>
#include "../../XName.h"
......@@ -36,6 +37,11 @@ DTYPE round(DTYPE r)
return (r > 0.0) ? (DTYPE)floor(r + 0.5) : (DTYPE)ceil(r - 0.5);
}
DTYPE iszero(DTYPE r)
{
return (r == 0.0) ? (DTYPE)1.0 : (DTYPE)0.0;
}
#ifdef USE_CUDA
/* define three marco separately, specify the respective function names (GPU mode) */
#define _SIMPLE_UNARY_FUNCTION(_funcName, _cudaFuncName, origFunc) \
......@@ -65,7 +71,7 @@ void _funcNameMe(XTensor * a) \
XTensor funcName(const XTensor &a) \
{ \
XTensor b(&a); \
b.SetTMP(); \
b.SetTMPFlag(); \
_funcName(&a, &b); \
XLink::MakeLink(&a, NULL, &b, operationId); \
return b; \
......@@ -87,6 +93,10 @@ _SIMPLE_UNARY_FUNCTION(_Floor, _CudaFloor, floor)
_SIMPLE_UNARY_FUNCTION_ME(_FloorMe, _Floor)
SIMPLE_UNARY_FUNCTION(Floor, _Floor, MATH_FLOOR)
_SIMPLE_UNARY_FUNCTION(_IsZero, _CudaIsZero, iszero)
_SIMPLE_UNARY_FUNCTION_ME(_IsZeroMe, _IsZero)
SIMPLE_UNARY_FUNCTION(IsZero, _IsZero, MATH_ISZERO)
_SIMPLE_UNARY_FUNCTION(_Log, _CudaLog, log)
_SIMPLE_UNARY_FUNCTION_ME(_LogMe, _Log)
SIMPLE_UNARY_FUNCTION(Log, _Log, MATH_LOG)
......@@ -140,7 +150,7 @@ void _funcNameMe(XTensor * a) \
XTensor funcName(const XTensor &a) \
{ \
XTensor b(&a); \
b.SetTMP(); \
b.SetTMPFlag(); \
_funcName(&a, &b); \
XLink::MakeLink(&a, NULL, &b, operationId); \
return b; \
......@@ -163,6 +173,10 @@ _SIMPLE_UNARY_FUNCTION(_Floor, floor)
_SIMPLE_UNARY_FUNCTION_ME(_FloorMe, _Floor)
SIMPLE_UNARY_FUNCTION(Floor, _Floor, MATH_FLOOR)
_SIMPLE_UNARY_FUNCTION(_IsZero, iszero)
_SIMPLE_UNARY_FUNCTION_ME(_IsZeroMe, _IsZero)
SIMPLE_UNARY_FUNCTION(IsZero, _IsZero, MATH_ISZERO)
_SIMPLE_UNARY_FUNCTION(_Log, log)
_SIMPLE_UNARY_FUNCTION_ME(_LogMe, _Log)
SIMPLE_UNARY_FUNCTION(Log, _Log, MATH_LOG)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#include <math.h>
#include "../../XDevice.h"
......@@ -28,17 +29,23 @@
namespace nts {
__device__
DTYPE CudaSquare(DTYPE x)
DTYPE cudasquare(DTYPE x)
{
return x * x;
}
__device__
DTYPE CudaRound(DTYPE r)
DTYPE cudaround(DTYPE r)
{
return (r > 0.0) ? (DTYPE)floor(r + 0.5) : (DTYPE)ceil(r - 0.5);
}
__device__
DTYPE cudaiszero(DTYPE r)
{
return (r == 0.0) ? (DTYPE)1.0 : (DTYPE)0.0;
}
#define SIMPLE_UNARY_FUNCTION_GPU(funcName, origFunc) \
__global__ \
void Kernel##funcName(DTYPE * a, DTYPE * b, int size) \
......@@ -89,10 +96,11 @@ SIMPLE_UNARY_FUNCTION_GPU(Absolute, fabs)
SIMPLE_UNARY_FUNCTION_GPU(Ceil, ceil)
SIMPLE_UNARY_FUNCTION_GPU(Exp, exp)
SIMPLE_UNARY_FUNCTION_GPU(Floor, floor)
SIMPLE_UNARY_FUNCTION_GPU(IsZero, cudaiszero)
SIMPLE_UNARY_FUNCTION_GPU(Log, log)
SIMPLE_UNARY_FUNCTION_GPU(Round, CudaRound)
SIMPLE_UNARY_FUNCTION_GPU(Round, cudaround)
SIMPLE_UNARY_FUNCTION_GPU(Sqrt, sqrt)
SIMPLE_UNARY_FUNCTION_GPU(Square, CudaSquare)
SIMPLE_UNARY_FUNCTION_GPU(Square, cudasquare)
SIMPLE_UNARY_FUNCTION_GPU(Sin, sin)
SIMPLE_UNARY_FUNCTION_GPU(Cos, cos)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#ifndef __UNARY_CUH__
#define __UNARY_CUH__
......@@ -65,6 +66,15 @@ void KernelFloor(__half * a, __half * b, int size);
/* set each entry to its floor value */
void _CudaFloor(const XTensor * a, XTensor * b);
/* if source entry is zero, set target entry to be one, otherwise zero (CUDA Kernel) */
__global__
void KernelIsZero(DTYPE * a, DTYPE * b, int size);
/* if source entry is zero, set target entry to be one, otherwise zero (CUDA Kernel) with float16 data type*/
__global__
void KernelIsZero(__half * a, __half * b, int size);
/* if source entry is zero, set target entry to be one, otherwise zero */
void _CudaIsZero(const XTensor * a, XTensor * b);
/* set each entry to its logarithm value (CUDA Kernel) */
__global__
void KernelLog(DTYPE * a, DTYPE * b, int size);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#ifndef __UNARY_H__
#define __UNARY_H__
......@@ -62,6 +63,15 @@ void _FloorMe(XTensor * a);
make a new tensor to keep the result and return it */
XTensor Floor(const XTensor & a);
/* if source entry is zero, set target entry to be one, otherwise zero */
void _IsZero(const XTensor *a, XTensor *b);
/* if source entry is zero, set target entry to be one, otherwise zero (do it on site)
keep the result in the input tensor a and return nothing */
void _IsZeroMe(XTensor *a);
/* if source entry is zero, set target entry to be one, otherwise zero (return a XTensor structure)
make a new tensor to keep the result and return it */
XTensor IsZero(const XTensor &a);
/* set every entry to its logarithm value */
void _Log(const XTensor * a, XTensor * b);
/* set every entry to its logarithm value (do it on site)
......
......@@ -32,7 +32,7 @@ copy indexed sub-tensors
>> t - the target tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3,2)
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
>> tgtIndex - index of the target sub-tensors
......@@ -130,22 +130,30 @@ XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, in
float dr = (!s.isSparse) ? 1.0F : s.denseRatio;
XTensor t(order, dimSize, s.dataType, dr, s.devID, s.mem);
t.SetTMP();
t.SetTMPFlag();
/* call _CopyIndexed function */
_CopyIndexed(&s, &t, dim, srcIndex, indexSize, tgtIndex, copyNum);
/* care: we must malloc a new array for save index,
because the source indexs may be freed. */
int * saveSrcIndex = new int[indexSize];
memcpy(saveSrcIndex, srcIndex, indexSize * sizeof(int));
int * saveTgtIndex = new int[indexSize];
memcpy(saveTgtIndex, tgtIndex, indexSize * sizeof(int));
/* tensor connection */
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYINDEXED);
XLink::AddParamToHeadInt(&t, dim);
XLink::AddParamToHeadPointer(&t, srcIndex);
XLink::AddParamToHeadPointer(&t, saveSrcIndex);
XLink::AddParamToHeadInt(&t, indexSize);
XLink::AddParamToHeadPointer(&t, tgtIndex);
XLink::AddParamToHeadPointer(&t, saveTgtIndex);
XLink::AddParamToHeadInt(&t, copyNum);
/* destroy variables */
delete[] dimSize;
return t;
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __COPYINDEXED_H__
#define __COPYINDEXED_H__
......
......@@ -108,7 +108,7 @@ make a new tensor to keep the result and return it
XTensor CopyValues(const XTensor &s, XStream * stream)
{
XTensor t(&s);
t.SetTMP();
t.SetTMPFlag();
/* call _CopyValues function */
_CopyValues(&s, &t, stream);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-18
*/
#include "Gather.h"
#include "CopyIndexed.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
gather indexed sub-tensors
>> s - the source tensor
>> t - the target tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
*/
void _Gather(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize)
{
int * tgtIndex = new int[indexSize];
for(int i = 0; i < indexSize; i++)
tgtIndex[i] = i;
_CopyIndexed(s, t, dim, srcIndex, indexSize, tgtIndex, 1);
delete[] tgtIndex;
}
/*
gather indexed sub-tensors (return a XTensor structure)
make a new tensor to keep the result and return it
>> s - the source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
<< return - the result of copying indexed sub-tensors
Notice: the index must be on the CPU!!!
*/
XTensor Gather(const XTensor &s, int dim, int * srcIndex, int indexSize)
{
int * tgtIndex = new int[indexSize];
for(int i = 0; i < indexSize; i++)
tgtIndex[i] = i;
/* call CopyIndexed function */
XTensor result;
result = CopyIndexed(s, dim, srcIndex, indexSize, tgtIndex, 1);
delete[] tgtIndex;
return result;
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-18
*/
#ifndef __GATHER_H__
#define __GATHER_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* gather selected sub-tensors */
void _Gather(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize);
/* gather selected sub-tensors (return a XTensor structure)
make a new tensor to keep the result and return it */
XTensor Gather(const XTensor &s, int dim, int * srcIndex, int indexSize);
} // namespace nts(NiuTrans.Tensor)
#endif // __GATHER_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#include "Spread.h"
#include "Spread.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
This is core assignment for spread function.
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - number of data blocks
>> blockSizeSrc - size of source data block
>> blockSizeColl - size of source data block
>> stride - stride of a data block
*/
void _Assignment(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride)
{
for (int i = 0; i < blockNum; i++) {
DTYPE * s = sData + blockSizeSrc * i;
DTYPE * c = cData + blockSizeColl * i;
for(int j = 0; j < stride; j++)
s[j] = c[j];
}
}
/*
spread a collection tensor to source tensor.
This is a inverse operation compared to gather.
>> source - the source tensor whose data would be modified
>> collection - the collection whose data would be spread to source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and collIndex)
>> collIndex - index of the gathered sub-tensors
*/
void _Spread(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex)
{
int order = source->order;
int size = source->GetDim(dim);
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
for(int i = 0; i < order; i++){
if(i < dim){
CheckNTErrors(collection->GetDim(i) == source->GetDim(i), "Illegal dimension!");
}
else if(i > dim){
CheckNTErrors(collection->GetDim(i) == source->GetDim(i), "Illegal dimension!");
}
else{
CheckNTErrors(collection->GetDim(i) == indexSize, "Illegal dimension!");
}
}
#ifdef USE_CUDA
if(source->devID >= 0 && collection->devID >= 0) {
_CudaSpread(source, collection, dim, srcIndex, indexSize, collIndex);
return;
}
#endif
int blockSizeSrc = 1;
int blockSizeColl = 1;
int blockNum = 1;
int stride = 1;
for (int i = dim + 1; i < order; i++) {
stride *= source->GetDim(i);
}
blockSizeSrc = stride * source->GetDim(dim);
blockSizeColl = stride * collection->GetDim(dim);
blockNum = source->unitNum / blockSizeSrc;
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++){
int src = srcIndex[i];
int tgt = collIndex[i];
DTYPE * s = sData + src * stride;
DTYPE * c = cData + tgt * stride;
_Assignment(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
}
}
/*
This is core assignment for backward computation of gather function.
Care of the operator "+=" instead of "=".
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - number of data blocks
>> blockSizeSrc - size of source data block
>> blockSizeColl - size of source data block
>> stride - stride of a data block
*/
void _AssignmentForGather(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride)
{
for (int i = 0; i < blockNum; i++) {
DTYPE * s = sData + blockSizeSrc * i;
DTYPE * c = cData + blockSizeColl * i;
for(int j = 0; j < stride; j++)
s[j] += c[j];
}
}
/*
spread a collection tensor to source tensor.
And this is a special spread function for backward computation of gather function.
>> source - the source tensor whose data would be modified
>> collection - the collection whose data would be spread to source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and collIndex)
>> collIndex - index of the gathered sub-tensors
*/
void _SpreadForGather(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex)
{
int order = source->order;
int size = source->GetDim(dim);
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
for(int i = 0; i < order; i++){
if(i < dim){
CheckNTErrors(collection->GetDim(i) == source->GetDim(i), "Illegal dimension!");
}
else if(i > dim){
CheckNTErrors(collection->GetDim(i) == source->GetDim(i), "Illegal dimension!");
}
else{
CheckNTErrors(collection->GetDim(i) == indexSize, "Illegal dimension!");
}
}
#ifdef USE_CUDA
if(source->devID >= 0 && collection->devID >= 0) {
_CudaSpreadForGather(source, collection, dim, srcIndex, indexSize, collIndex);
return;
}
#endif
int blockSizeSrc = 1;
int blockSizeColl = 1;
int blockNum = 1;
int stride = 1;
for (int i = dim + 1; i < order; i++) {
stride *= source->GetDim(i);
}
blockSizeSrc = stride * source->GetDim(dim);
blockSizeColl = stride * collection->GetDim(dim);
blockNum = source->unitNum / blockSizeSrc;
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++){
int src = srcIndex[i];
int tgt = collIndex[i];
DTYPE * s = sData + src * stride;
DTYPE * c = cData + tgt * stride;
_AssignmentForGather(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#ifndef __SPREAD_CUH__
#define __SPREAD_CUH__
#include "../../XTensor.h"
#include "../../XDevice.h"
#include "Spread.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
This is core assignment for spread function.
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - the number of data blocks
>> blockSizeSrc - the size of source data block
>> blockSizeColl - the size of source data block
>> stride - the stride of a data block
*/
__global__
void KernelSpread(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride)
{
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* offset in each block */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= blockNum || j >= stride)
return;
DTYPE * s = sData + blockSizeSrc * i;
DTYPE * c = cData + blockSizeColl * i;
s[j] = c[j];
}
/*
spread a collection tensor to source tensor (cuda version).
This is a inverse operation compared to gather.
>> source - the source tensor whose data would be modified
>> collection - the collection whose data would be spread to source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and collIndex)
>> collIndex - index of the gathered sub-tensors
*/
void _CudaSpread(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex)
{
int order = source->order;
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
int blockSizeSrc = 1;
int blockSizeColl = 1;
int blockNum = 1;
int stride = 1;
for (int i = dim + 1; i < order; i++) {
stride *= source->GetDim(i);
}
blockSizeSrc = stride * source->GetDim(dim);
blockSizeColl = stride * collection->GetDim(dim);
blockNum = source->unitNum / blockSizeSrc;
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(source->devID, blockNum, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup);
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++) {
int src = srcIndex[i];
int tgt = collIndex[i];
DTYPE * s = sData + src * stride;
DTYPE * c = cData + tgt * stride;
KernelSpread<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
}
BacktoCudaDev(source->devID, devIDBackup);
}
/*
This is core assignment for backward computation of gather function.
Care of the operator "+=" instead of "=".
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - number of data blocks
>> blockSizeSrc - size of source data block
>> blockSizeColl - size of source data block
>> stride - stride of a data block
*/
__global__
void KernelSpreadForGather(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride)
{
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* offset in each block */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= blockNum || j >= stride)
return;
DTYPE * s = sData + blockSizeSrc * i;
DTYPE * c = cData + blockSizeColl * i;
s[j] += c[j];
}
/*
spread a collection tensor to source tensor (cuda version).
And this is a special spread function for backward computation of gather function.
>> source - the source tensor whose data would be modified
>> collection - the collection whose data would be spread to source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and collIndex)
>> collIndex - index of the gathered sub-tensors
*/
void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex)
{
int order = source->order;
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
int blockSizeSrc = 1;
int blockSizeColl = 1;
int blockNum = 1;
int stride = 1;
for (int i = dim + 1; i < order; i++) {
stride *= source->GetDim(i);
}
blockSizeSrc = stride * source->GetDim(dim);
blockSizeColl = stride * collection->GetDim(dim);
blockNum = source->unitNum / blockSizeSrc;
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(source->devID, blockNum, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup);
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++) {
int src = srcIndex[i];
int tgt = collIndex[i];
DTYPE * s = sData + src * stride;
DTYPE * c = cData + tgt * stride;
KernelSpreadForGather<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
}
BacktoCudaDev(source->devID, devIDBackup);
}
} // namespace nts(NiuTrans.Tensor)
#endif // __SPREAD_CUH__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#ifndef __SPREAD_CUH__
#define __SPREAD_CUH__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* spread a collection tensor to source tensor (cuda version) */
void _CudaSpread(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex);
/* special spread function for backward computation of gather function (cuda version) */
void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex);
} // namespace nts(NiuTrans.Tensor)
#endif // __SPREAD_CUH__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#ifndef __SPREAD_H__
#define __SPREAD_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* spread a collection tensor to source tensor */
void _Spread(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex);
/* spread a collection tensor to source tensor (return a XTensor structure)
make a new tensor to keep the result and return it */
void Spread(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex);
/* special spread function for backward computation of gather function */
void _SpreadForGather(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex);
} // namespace nts(NiuTrans.Tensor)
#endif // __SPREAD_H__
\ No newline at end of file
......@@ -114,7 +114,7 @@ XTensor ReduceMax(const XTensor &input, int dim)
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
XTensor output(order, dimSize, input.dataType, dr, input.devID, input.mem);
output.SetTMP();
output.SetTMPFlag();
/* call _ReduceMax function */
_ReduceMax(&input, &output, dim);
......
......@@ -71,7 +71,7 @@ XTensor ReduceMean(const XTensor &input, int dim)
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
XTensor output(order, dimSize, input.dataType, dr, input.devID, input.mem);
output.SetTMP();
output.SetTMPFlag();
/* call _ReduceMean function */
_ReduceMean(&input, &output, dim);
......
......@@ -225,7 +225,7 @@ XTensor ReduceSum(const XTensor &input, int dim, const XTensor &shift, DTYPE pow
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
XTensor output(order, dimSize, input.dataType, dr, input.devID, input.mem);
output.SetTMP();
output.SetTMPFlag();
/* call _ReduceSum function */
_ReduceSum(&input, &output, dim, &shift, power, isExp);
......@@ -271,7 +271,7 @@ XTensor ReduceSum(const XTensor &input, int dim, DTYPE power, bool isExp)
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
XTensor output(order, dimSize, input.dataType, dr, input.devID, input.mem);
output.SetTMP();
output.SetTMPFlag();
/* call _ReduceSum function */
_ReduceSum(&input, &output, dim, NULL, power, isExp);
......
......@@ -33,7 +33,7 @@ sum = \sum_i (a_i - shift) if isExp == false
sum = \sum_i exp(a_i - shift) if isExp == true
*/
void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor * shift = NULL,
DTYPE power = (DTYPE)1.0F, bool isExp = false);
DTYPE power = (DTYPE)1.0F, bool isExp = false);
/*
sum the items along a dimension of the tensor (return a XTensor structure)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-27
*/
#include "ReduceSumAll.h"
#include "ReduceSum.h"
#include "../movement/CopyValues.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
int * getDimSize(const XTensor * tensor, int n)
{
int order = tensor->order;
int * dimSize = new int[order - 1];
for (int i = 0; i < order; i++) {
if(i < n)
dimSize[i] = tensor->dimSize[i];
else if(i > n)
dimSize[i - 1] = tensor->dimSize[i];
}
return dimSize;
}
/*
sum all the items of the tensor (It should be optimized!)
>> source - the inpute tensor
<< return - the total summation
*/
DTYPE _ReduceSumAll(XTensor * source)
{
int order = source->order;
DTYPE summation;
XTensor * big = NewTensor(source);
_CopyValues(source, big);
for(int i = 0; i < order; i++) {
if(i == order - 1)
big->Reshape(big->unitNum, 1);
int * dimSize;
dimSize = getDimSize(big, 0);
XTensor * little = NewTensor(big->order - 1, dimSize, source->dataType, source->denseRatio, source->devID, source->mem);
_ReduceSum(big, little, 0);
delete big;
delete dimSize;
big = NewTensor(little);
_CopyValues(little, big);
delete little;
}
summation = big->Get1D(0);
delete big;
return summation;
}
/*
sum all the items of the tensor
>> source - the inpute tensor
<< return - the total summation
*/
DTYPE ReduceSumAll(XTensor & source)
{
return _ReduceSumAll(&source);
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-27
*/
#ifndef __REDUCESUMALL_H__
#define __REDUCESUMALL_H__
#include "../../XTensor.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/* sum all the items of the tensor */
DTYPE _ReduceSumAll(XTensor * source);
/* sum all the items of the tensor */
DTYPE ReduceSumAll(XTensor & source);
} // namespace nts(NiuTrans.Tensor)
#endif // __REDUCESUMALL_H__
\ No newline at end of file
......@@ -67,7 +67,7 @@ XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift)
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
XTensor output(order, dimSize, input.dataType, dr, input.devID, input.mem);
output.SetTMP();
output.SetTMPFlag();
/* call _ReduceSumSquared function */
_ReduceSumSquared(&input, &output, dim, &shift);
......
......@@ -70,7 +70,7 @@ XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean)
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
XTensor output(order, dimSize, input.dataType, dr, input.devID, input.mem);
output.SetTMP();
output.SetTMPFlag();
/* call _ReduceVariance function */
_ReduceVariance(&input, &output, dim, &mean);
......
......@@ -93,7 +93,7 @@ XTensor Concatenate(const XList &smalls, int dim)
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetTMP();
big.SetTMPFlag();
/* call _Merge function */
_Merge(&smalls, &big, dim);
......@@ -121,7 +121,7 @@ XTensor Concatenate(const XList &smalls, int dim)
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetTMP();
big.SetTMPFlag();
/* call _ConcatenateSolely function */
_ConcatenateSolely(&smalls, &big, dim);
......@@ -194,7 +194,7 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetTMP();
big.SetTMPFlag();
/* call _Merge function */
_Merge(&smalls, &big, dim);
......@@ -222,7 +222,7 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetTMP();
big.SetTMPFlag();
/* call _ConcatenateSolely function */
_ConcatenateSolely(&smalls, &big, dim);
......
......@@ -183,7 +183,7 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
float dr = (!s.isSparse) ? 1.0F : s.denseRatio;
XTensor t(order, dimSize, s.dataType, dr, s.devID, s.mem);
t.SetTMP();
t.SetTMPFlag();
/* call _Merge function */
_Merge(&s, &t, whereToMerge, leadingDim);
......@@ -334,7 +334,7 @@ XTensor Merge(const XList &smalls, int whereToMerge)
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
XTensor big(order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
big.SetTMP();
big.SetTMPFlag();
/* call _Merge function */
_Merge(&smalls, &big, whereToMerge);
......@@ -371,7 +371,7 @@ XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge)
float dr = (!smallA.isSparse) ? 1.0F : smallA.denseRatio;
XTensor big(order, dimSize, smallA.dataType, dr, smallA.devID, smallA.mem);
big.SetTMP();
big.SetTMPFlag();
XList smalls(2);
smalls.Add(&smallA);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "../movement/CopyValues.h"
#include "Reshape.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
reshape the tensor
>> s - the input tensor
>> order - order of the tensor
>> dimSize - the size of each dimension
<< return - the output tensor
*/
XTensor Reshape(XTensor &s, int order, int * dimSize)
{
XTensor t(&s);
t.SetTMPFlag();
_CopyValues(&s, &t);
int oriOrder = s.order;
int * oriDimSize = new int[order];
memcpy(oriDimSize, s.dimSize, sizeof(int) * order);
/* call Reshape function */
t.Reshape(order, dimSize);
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE);
XLink::AddParamToHeadInt(&t, oriOrder);
XLink::AddParamToHeadPointer(&t, oriDimSize);
return t;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#ifndef __RESHAPE_H__
#define __RESHAPE_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* reshape the tensor */
XTensor Reshape(XTensor &s, int order, int * dimSize);
} // namespace nts(NiuTrans.Tensor)
#endif // __RESHAPE_H__
......@@ -184,7 +184,7 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
float dr = (!s.isSparse) ? 1.0F : s.denseRatio;
XTensor t(order, dimSize, s.dataType, dr, s.devID, s.mem);
t.SetTMP();
t.SetTMPFlag();
/* call _Split function */
_Split(&s, &t, whereToSplit, splitNum);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-27
*/
#include "Squeeze.h"
#include "../movement/CopyValues.h"
#include "../../XName.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
squeeze the tensor along the specified dimension
>> source - the input tensor
>> target - the output tensor
>> leadingDim - the dimension that we would squeeze
if leadingDim = -1, squeeze all dimensions that are 1
else, squeeze the specified dimension
*/
void _Squeeze(XTensor * source, XTensor * target, int leadingDim)
{
int order = target->order;
CheckNTErrors(XTensor::IsSameShaped(source, target),
"The source and target tensor must be of the same size!");
CheckNTErrors(leadingDim >= -1 && leadingDim < order,
"Wrong leading dimension");
_CopyValues(source, target);
if(leadingDim < 0) {
int * newDimSize = new int[order];
int newOrder = 0;
for(int i = 0; i < order; i++) {
int dim = source->GetDim(i);
if(dim > 1) {
newDimSize[newOrder] = dim;
newOrder += 1;
}
}
target->Reshape(newOrder, newDimSize);
delete[] newDimSize;
}
else {
if(source->GetDim(leadingDim) > 1)
return;
int newOrder = order - 1;
int * newDimSize = new int[newOrder];
for(int i = 0; i < order; i++)
if(i < leadingDim)
newDimSize[i] = source->GetDim(i);
else if(i > leadingDim)
newDimSize[i - 1] = source->GetDim(i);
target->Reshape(newOrder, newDimSize);
delete[] newDimSize;
}
}
/*
squeeze the tensor along the specified dimension (do it on site)
keep the result in the input tensor a and return nothing
>> source - the input tensor
>> leadingDim - the dimension that we would squeeze
if leadingDim = -1, squeeze all dimensions that are 1
else, squeeze the specified dimension
*/
void _SqueezeMe(XTensor * source, int leadingDim)
{
_Squeeze(source, source, leadingDim);
}
/*
squeeze the tensor along the specified dimension (return a XTensor structure)
make a new tensor to keep the result and return it
>> source - the input tensor
>> leadingDim - the dimension that we would squeeze
if leadingDim = -1, squeeze all dimensions that are 1
else, squeeze the specified dimension
<< return - the output tensor after squeeze operation
*/
XTensor Squeeze(XTensor & source, int leadingDim)
{
XTensor target(&source);
target.SetTMPFlag();
/* call _Squeeze function */
_Squeeze(&source, &target, leadingDim);
/* tensor connections */
XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE);
return target;
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-27
*/
#ifndef __SQUEEZE_H__
#define __SQUEEZE_H__
#include "../../XTensor.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/* squeeze the tensor along the specified dimension */
void _Squeeze(XTensor * source, XTensor * target, int leadingDim = -1);
/* squeeze the tensor along the specified dimension (do it on site)
keep the result in the input tensor a and return nothing */
void _SqueezeMe(XTensor * source, int leadingDim = -1);
/* squeeze the tensor along the specified dimension (return a XTensor structure)
make a new tensor to keep the result and return it */
XTensor Squeeze(XTensor & source, int leadingDim = -1);
} // namespace nts(NiuTrans.Tensor)
#endif // __SQUEEZE_H__
\ No newline at end of file
......@@ -138,7 +138,7 @@ XTensor Transpose(const XTensor &a, const int i, const int j)
float dr = (!a.isSparse) ? 1.0F : a.denseRatio;
XTensor b(order, dimSize, a.dataType, dr, a.devID, a.mem);
b.SetTMP();
b.SetTMPFlag();
/* call _Transpose function */
_Transpose(&a, &b, i, j);
......
......@@ -122,7 +122,7 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize)
float dr = (!a.isSparse) ? 1.0F : a.denseRatio;
XTensor b(order, dimSize, a.dataType, dr, a.devID, a.mem);
b.SetTMP();
b.SetTMPFlag();
/* call _Unsqueeze function */
_Unsqueeze(&a, &b, dim, dSize);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __UNSQUEEZE_H__
#define __UNSQUEEZE_H__
......@@ -26,14 +26,13 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* insert a dimension by copying the blocks for x times (where x is the size of the inerted dimension) */
/* insert a dimension by copying the blocks for x times
(where x is the size of the inerted dimension) */
void _Unsqueeze(const XTensor * a, XTensor * b, int dim, int dSize);
/*
insert a dimension by copying the blocks for x times
(where x is the size of the inerted dimension) (return a XTensor structure)
make a new tensor to keep the result and return it
*/
/* insert a dimension by copying the blocks for x times
(where x is the size of the inerted dimension) (return a XTensor structure)
make a new tensor to keep the result and return it */
XTensor Unsqueeze(const XTensor &a, int dim, int dSize);
} // namespace nts(NiuTrans.Tensor)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-17
*/
#ifndef __CROSSENTROPY_CUH__
#define __CROSSENTROPY_CUH__
#include "../XTensor.h"
#include "CrossEntropy.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/* compute the cross entropy loss (tensor version) */
void _CudaCrossEntropyManual(const XTensor * output, const XTensor * gold,
XTensor * loss, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* compute the cross entropy loss (scalar version) */
DTYPE _CudaCrossEntropyManual(const XTensor * output, const XTensor * gold,
LOSS_COMPUTE_WAY reduceWay, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* backward computation of cross entropy function */
void _CudaCrossEntropyBackward(XTensor * dedy, const XTensor * output, const XTensor * gold,
const XTensor * weight = NULL, XTensor * padding = NULL,
int leadingDim = -1);
} // namespace nts(NiuTrans.Tensor)
#endif // __CROSSENTROPY_CUH__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-17
*/
#ifndef __CROSSENTROPY_H__
#define __CROSSENTROPY_H__
#include "../XTensor.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
enum LOSS_COMPUTE_WAY{
REDUCE_SUM,
REDUCE_MEAN
};
/* compute the cross entropy loss (tensor version) */
void _CrossEntropy(const XTensor * output, const XTensor * gold,
XTensor * loss, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* compute the cross entropy loss (tensor version) */
void _CrossEntropyManual(const XTensor * output, const XTensor * gold,
XTensor * loss, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* compute the cross entropy loss (scalar version) */
DTYPE _CrossEntropy(const XTensor * output, const XTensor * gold,
LOSS_COMPUTE_WAY reduceWay, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* compute the cross entropy loss (scalar version) */
DTYPE _CrossEntropyManual(const XTensor * output, const XTensor * gold,
LOSS_COMPUTE_WAY reduceWay = REDUCE_MEAN, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* backward computation of cross entropy function */
void _CrossEntropyBackward(XTensor * dedy, const XTensor * output, const XTensor * gold,
const XTensor * weight = NULL, XTensor * padding = NULL,
int leadingDim = -1);
} // namespace nts(NiuTrans.Tensor)
#endif // __CROSSENTROPY_H__
\ No newline at end of file
......@@ -20,7 +20,6 @@
*/
#include "../XName.h"
#include <math.h>
#include <time.h>
#include "Dropout.h"
#include "Dropout.cuh"
......
......@@ -23,7 +23,6 @@
#define __DROPOUT_H__
#include "../XTensor.h"
#include "Loss.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......
......@@ -26,6 +26,7 @@
#include "../XTensor.h"
#include "CrossEntropy.h"
#include "Dropout.h"
#include "HardTanH.h"
#include "Identity.h"
......
......@@ -23,6 +23,7 @@
#include "../XName.h"
#include "HardTanH.h"
#include "HardTanH.cuh"
#include "CrossEntropy.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -72,7 +73,7 @@ y = 1 if x > 1
XTensor HardTanH(const XTensor &x)
{
XTensor y(&x);
y.SetTMP();
y.SetTMPFlag();
/* call _HardTanH function */
_HardTanH(&x, &y);
......@@ -118,7 +119,9 @@ void _HardTanHBackward(XTensor * gold, XTensor * y, XTensor * x,
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
/* calculate dE/dy */
if(lossName != NOLOSS)
if(lossName == CROSSENTROPY)
_CrossEntropyBackward(dedy, y, gold);
else if(lossName != NOLOSS)
_LossBackward(dedy, gold, y, lossName);
DTYPE * dedyp = (DTYPE*)dedy->data;
......
......@@ -22,6 +22,7 @@
#include "HardTanH.h"
#include "HardTanH.cuh"
#include "Loss.cuh"
#include "CrossEntropy.cuh"
#include "../XDevice.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -136,8 +137,10 @@ void _CudaHardTanHBackward(XTensor * gold, XTensor * y, XTensor * x,
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
/* calculate dE/dy */
if(lossName != NOLOSS)
_LossBackward(dedy, gold, y, lossName);
if(lossName == CROSSENTROPY)
_CudaCrossEntropyBackward(dedy, y, gold);
else if(lossName != NOLOSS)
_CudaLossBackward(dedy, gold, y, lossName);
int gridSize[3], blockSize[3];
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论