Commit c6f2dbdf by huchi

replace all requireLink with enableGrad that allows gradient computation for a tensor

parent 0a7c2d15
......@@ -42,20 +42,20 @@ using namespace transformer;
int main( int argc, const char ** argv )
{
//_CrtSetBreakAlloc(896);
//BackwardTest();
//return 0;
if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
FNNLMMain(argc - 1, argv + 1);
else if(argc > 1 && !strcmp(argv[1], "-t2t"))
TransformerMain(argc - 1, argv + 1);
else{
fprintf(stderr, "Thanks for using NiuTrans.Network! This is a library for building\n");
fprintf(stderr, "neural networks in an easy way. \n\n");
fprintf(stderr, "Run this program with \"-test\" for unit test!\n");
fprintf(stderr, "Or run this program with \"-fnnlm\" for sample FNNLM!\n");
}
//_CrtSetDbgFlag(_CrtSetDbgFlag(_CRTDBG_REPORT_FLAG) | _CRTDBG_LEAK_CHECK_DF);
//_CrtSetBreakAlloc(2708);
//if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
// FNNLMMain(argc - 1, argv + 1);
//else if(argc > 1 && !strcmp(argv[1], "-t2t"))
// TransformerMain(argc - 1, argv + 1);
//else{
// fprintf(stderr, "Thanks for using NiuTrans.Network! This is a library for building\n");
// fprintf(stderr, "neural networks in an easy way. \n\n");
// fprintf(stderr, "Run this program with \"-test\" for unit test!\n");
// fprintf(stderr, "Or run this program with \"-fnnlm\" for sample FNNLM!\n");
//}
BackwardTest();
//_CrtDumpMemoryLeaks();
......@@ -69,6 +69,9 @@ void BackwardTest()
XTensor a;
XTensor b;
XTensor c;
a.enableGrad = true;
b.enableGrad = false;
c.enableGrad = false;
XTensor mean;
XTensor origin;
InitTensor2D(&a, 2, 3);
......@@ -86,14 +89,15 @@ void BackwardTest()
b.Set1D(2.0F, 0);
b.Set1D(1.0F, 1);
c = DivDim(a, b, 0);
DivDim(a, b, c, 0);
c.Dump(stderr, "c:");
auto loss = CrossEntropy(c, a);
//XLink::ShowNetwork(stderr, &c);
net.Backward(c);
net.Backward(loss);
net.Dump(stderr);
a.grad->Dump(stderr);
}
......
......@@ -20,7 +20,9 @@
*/
#include "XBackwardLoss.h"
#include "XNoder.h"
#include "../tensor/XName.h"
#include "../tensor/function/FHeader.h"
#include "../tensor/core/getandset/SetData.h"
#include "../tensor/function/HardTanH.h"
#include "../tensor/function/Identity.h"
......@@ -31,6 +33,60 @@
namespace nts{
/* compute dE/dx of a node */
void XLossGrad::MakeGrad(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
int operID = income.typeID;
CheckNTErrors(income.tailNum >= 1, "Wrong number of tensors for loss computation!");
XTensor * output = income.tails[0];
XTensor * gold = NULL;
XTensor * weight = NULL;
XTensor * padding = NULL;
int leadingDim;
XNoder::MakeGrad(output);
XTensor * dedy = output->grad;
if (income.tailNum == 1) {
if(dedy->dataType == X_FLOAT)
_SetDataFixedFloat(dedy, 1.0F);
else if(dedy->dataType == X_DOUBLE)
_SetDataFixedDouble(dedy, 1.0);
else if(dedy->dataType == X_INT)
_SetDataFixedInt(dedy, 1);
else
ShowNTErrors("TODO");
return;
}
gold = income.tails[1];
if(operID == LOSS_CROSSENTROPY) {
if (income.tailNum == 3)
padding = income.tails[2];
leadingDim = income.GetParamInt(0);
CheckNTErrors(leadingDim >= 0 && leadingDim < output->order, "wrong leading dimension in logsoftmax!");
_CrossEntropyBackward(dedy, output, gold, weight, padding, leadingDim);
}
else{
ShowNTErrors("Wrong activation function type!");
}
node->visitMark = NODE_FINISHED;
}
/* indicates whether the node is for a loss computation */
bool XLossGrad::IsLossOP(XTensor * node)
{
XLink &income = node->income;
return (income.typeID & LOSS_BASE) != 0;
}
/*
compute dE/dx for a given function y = f(x)
>> gold - gold standard to measure error (or loss)
......
......@@ -23,6 +23,7 @@
#include "../tensor/XTensor.h"
#include "../tensor/function/FHeader.h"
#include "../tensor/loss/LHeader.h"
#ifndef __XBACKWARDLOSS_H__
#define __XBACKWARDLOSS_H__
......@@ -34,6 +35,14 @@ namespace nts{
class XLossGrad
{
public:
/* compute dE/dx of a node */
static
void MakeGrad(XTensor * node, bool isEfficient);
/* indicates whether the node is for a Loss computation */
static
bool IsLossOP(XTensor * node);
/* compute dE/dx for a given function y = f(x) */
void Compute(XTensor * gold, XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx, XTensor * padding,
......
......@@ -81,6 +81,12 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient)
GradPower(node, isEfficient);
else if(operID == MATH_SCALEANDSHIFT)
GradScaleAndShift(node, isEfficient);
else if(operID == MATH_SCALE)
GradScale(node, isEfficient);
else if(operID == MATH_DESCALE)
GradDescale(node, isEfficient);
else if(operID == MATH_SHIFT)
GradShift(node, isEfficient);
else if(operID == MATH_SUB)
GradSub(node, isEfficient);
else if(operID == MATH_SUBDIM)
......@@ -719,12 +725,18 @@ void XMathGrad::GradMultiply(XTensor * node, bool isEfficient)
XTensor * a = income.tails[0];
XTensor * b = income.tails[1];
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
CheckNTErrors(XTensor::IsSameShaped(a, b), "Wrong sized input tensors!");
if (!isEfficient || a->isGrad) {
XNoder::MakeGrad(a);
_Multiply(node->grad, b, a->grad, 1.0F);
}
if (!isEfficient || b->isGrad) {
XNoder::MakeGrad(b);
_Multiply(node->grad, a, b->grad, 1.0F);
}
node->visitMark = NODE_FINISHED;
}
......@@ -888,88 +900,8 @@ gradient for normalize
*/
void XMathGrad::GradNormalize(XTensor * node, bool isEfficient)
{
ShowNTErrors("This is really a bad piece of code!!!");
XLink &income = node->income;
CheckNTErrors(income.tailNum == 5, "Wrong input tensor number for NORMALIZE!");
XTensor * input = income.tails[0];
XTensor * mean = income.tails[1];
XTensor * var = income.tails[2];
XTensor * a = income.tails[3];
XTensor * b = income.tails[4];
XTensor * c = NewTensor(var);
XTensor * d = NewTensor(a);
XTensor * e = NewTensor(a);
XTensor * f = NewTensor(a);
XTensor * g = NewTensor(a);
XTensor * h = NewTensor(a);
XTensor * i = NewTensor(a);
XTensor * j = NewTensor(a);
XTensor * k = NewTensor(var);
XTensor * p = NewTensor(var);
XTensor * q = NewTensor(var);
XTensor * r = NewTensor(a);
XTensor * x = NewTensor(mean);
XTensor * y = NewTensor(mean);
XTensor * z = NewTensor(mean);
DTYPE epsilon = income.GetParam(1);
int dim = income.GetParamInt(0);
int n = a->GetDim(dim);
XNoder::MakeGrad(input);
XNoder::MakeGrad(mean);
XNoder::MakeGrad(var);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
/* dEdinput */
_ScaleAndShift(var, c, 1.0F, epsilon);
_Unsqueeze(c, d, dim, n);
_Power(d, e, -0.5F);
_Multiply(a, e, f);
_Multiply(node->grad, f, input->grad, 1.0F);
/* dEdmean */
_ScaleAndShift(f, g, -1.0F);
_ReduceSum(g, x, dim);
_ReduceSum(node->grad, y, dim);
_Multiply(y, x, mean->grad, 1.0F);
/* dEdvar */
_Unsqueeze(mean, h, dim, n);
_Sub(input, h, i);
_Multiply(a, i, j);
_Power(var, k, -1.5F);
_ScaleAndShift(k, p, -0.5F);
_ReduceSum(j, z, dim);
_Multiply(z, p, q);
_Multiply(y, q, var->grad, 1.0F);
/* dEda */
_Multiply(i, e, r);
_Multiply(node->grad, r, a->grad, 1.0F);
/* dEdb */
_Sum(b->grad, node->grad, b->grad);
node->visitMark = NODE_FINISHED;
ShowNTErrors("TODO!");
delete c;
delete d;
delete e;
delete f;
delete g;
delete h;
delete i;
delete j;
delete k;
delete p;
delete q;
delete r;
delete x;
delete y;
delete z;
}
/*
......@@ -1030,6 +962,82 @@ void XMathGrad::GradScaleAndShift(XTensor * node, bool isEfficient)
}
/*
gradient for Scale
for
c = a * scale
we have
dE/da = dE/dc * scale
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XMathGrad::GradScale(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SCALE!");
XTensor * a = income.tails[0];
DTYPE scale = income.GetParam(0);
XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad, scale);
node->visitMark = NODE_FINISHED;
}
/*
gradient for Descale
for
c = a / descale
we have
dE/da = dE/dc / descale
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XMathGrad::GradDescale(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for DESCALE!");
XTensor * a = income.tails[0];
DTYPE descale = income.GetParam(0);
XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad, 1/descale);
node->visitMark = NODE_FINISHED;
}
/*
gradient for Shift
for
c = a + shift
we have
dE/da = dE/dc
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XMathGrad::GradShift(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SHIFT!");
XTensor * a = income.tails[0];
XNoder::MakeGrad(a);
_Sum(a->grad, node->grad, a->grad);
node->visitMark = NODE_FINISHED;
}
/*
gradient for minus
for
c = a - b * \beta
......
......@@ -130,6 +130,18 @@ private:
static
void GradScaleAndShift(XTensor * node, bool isEfficient);
/* gradient for Scale */
static
void GradScale(XTensor * node, bool isEfficient);
/* gradient for Shift */
static
void GradShift(XTensor * node, bool isEfficient);
/* gradient for Descale */
static
void GradDescale(XTensor * node, bool isEfficient);
/* gradient for Minus */
static
void GradSub(XTensor * node, bool isEfficient);
......
......@@ -43,6 +43,8 @@ void XShapeGrad::MakeGrad(XTensor * node, bool isEfficent)
GradCopyIndexed(node, isEfficent);
else if(operID == MOVEMENT_GATHER)
GradGather(node, isEfficent);
else if (operID == MOVEMENT_DROPOUTWITHINDEX)
GradDropoutWithIndex(node, isEfficent);
else if(operID == SHAPE_MERGE)
GradMerge(node, isEfficent);
else if(operID == SHAPE_MERGE_LIST)
......@@ -115,7 +117,7 @@ dE/da = spreadforgather(b)
void XShapeGrad::GradGather(XTensor * node, bool isEfficent)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum > 0, "Wrong input tensor number for CopyIndexed!");
CheckNTErrors(income.tailNum > 0, "Wrong input tensor number for Gather!");
XTensor * input = income.tails[0];
XTensor * index = income.tails[1];
......@@ -127,6 +129,43 @@ void XShapeGrad::GradGather(XTensor * node, bool isEfficent)
}
/*
gradient computation for DropoutWithIndex function
*/
void XShapeGrad::GradDropoutWithIndex(XTensor * node, bool isEfficent)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum > 0, "Wrong input tensor number for DropoutWithIndex!");
XTensor * input = income.tails[0];
XTensor * index = income.tails[1];
DTYPE scale = income.GetParam(0);
XNoder::MakeGrad(input);
//_Identity(node->grad, input->grad);
_CopyValues(node->grad, input->grad);
int order = node->grad->order;
int * dimSize = new int[order];
for (int i = 0; i < order; i++) {
dimSize[i] = node->grad->dimSize[i];
}
int order1 = 1;
int * dimSize1 = new int[order1];
dimSize1[0] = input->grad->unitNum;
input->grad->Reshape(order1, dimSize1);
_DropoutWithIndex(node->grad, index, input->grad);
_ScaleAndShiftMe(input->grad, scale);
input->grad->Reshape(order, dimSize);
node->visitMark = NODE_FINISHED;
}
/*
gradient for merge
for
c = merge(a_0, a_1, ...)
......
......@@ -54,6 +54,10 @@ private:
static
void GradGather(XTensor * node, bool isEfficent);
/* gradient computation for dropout with index: b = dropoutwithindex(a, index) */
static
void GradDropoutWithIndex(XTensor * node, bool isEfficent);
/* gradient computation for merge: c = merge(a, b, ...) */
static
void GradMerge(XTensor * node, bool isEfficent);
......
......@@ -55,7 +55,7 @@ void XNetClearAll()
XNet::XNet()
{
nodes.Clear();
isGradEfficient = true;
isGradEfficient = false;
}
/* de-constructor */
......@@ -190,18 +190,18 @@ void XNet::Backward(TensorList &roots, TensorList &golds, TensorList &paddings,
XLossGrad lossGrad;
/* we start with the gradient with respect to the loss for output layers */
for(int i = 0; i < roots.count; i++){
/*for(int i = 0; i < roots.count; i++){
XTensor * root = (XTensor*)roots.Get(i);
XTensor * gold = (XTensor*)golds.Get(i);
XTensor * padding = (XTensor*)paddings.Get(i);
XLink &income = root->income;
int funcID = income.typeID;
void * params = income.params;
void * params = income.params;*/
/* we compute dE/dx if the output is generated by an activation function y = f(x).
Note that we do not need to obtain dE/dy here because it is no use in the
folloing process of back-propagation */
if(gold != NULL && income.tailNum == 1 && (funcID & FUNCTION_BASE)){
/*if(gold != NULL && income.tailNum == 1 && (funcID & FUNCTION_BASE)){
if(funcID == FUNC_LOGSOFTMAX || funcID == FUNC_SOFTMAX) {
XTensor * x = income.tails[0];
XNoder::MakeGrad(x);
......@@ -212,13 +212,13 @@ void XNet::Backward(TensorList &roots, TensorList &golds, TensorList &paddings,
XNoder::MakeGrad(root);
lossGrad.Compute(gold, root, root->grad, padding, loss);
}
}
}*/
/* we compuate dE/dy (y is the output) if no predefined activation function is used */
else{
/*else{
XNoder::MakeGrad(root);
lossGrad.Compute(gold, root, root->grad, NULL, loss);
}
}
}*/
/* back-propagation from output to input */
for(int i = nodes.count - 1; i >= 0; i--){
......@@ -266,6 +266,8 @@ void XNet::BackwardNode(XTensor * node, bool isEfficent)
XFuncGrad::MakeGrad(node, isEfficent);
else if(XShapeGrad::IsShapeOP(node))
XShapeGrad::MakeGrad(node, isEfficent);
else if(XLossGrad::IsLossOP(node))
XLossGrad::MakeGrad(node, isEfficent);
else{
ShowNTErrors("Wrong node type!");
}
......@@ -464,9 +466,9 @@ search for a node in a top-down manner by its name
>> top - the top most node
<< return - the node we found
*/
XTensor * XNet::SearchNode(XTensor * top, const char * name)
{
return XLink::SearchNode(top, name);
}
//XTensor * XNet::SearchNode(XTensor * top, const char * name)
//{
//return XLink::SearchNode(top, name);
//}
}
......@@ -23,6 +23,7 @@
#include "../tensor/XTensor.h"
#include "../tensor/function/FHeader.h"
#include "../tensor/loss/LHeader.h"
#ifndef __XNET_H__
#define __XNET_H__
......@@ -113,8 +114,8 @@ struct XNet
void ShowNetwork(FILE * file, XTensor * node);
/* search a node in a top-down manner by its name */
static
XTensor * SearchNode(XTensor * top, const char * name);
//static
//XTensor * SearchNode(XTensor * top, const char * name);
};
/* we make a unique id for every tensor */
......
......@@ -20,7 +20,7 @@
* This is a simple impelementation of the feed-forward network-baesd language
* model (FNNLM). See more details about FNNLM in
* "A Neural Probabilistic Language Model" by Bengio et al.
* Journal of Machine Learning Research 3 (2003) 1137C1155
* Journal of Machine Learning Research 3 (2003) 1137C1155
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-06-22
*/
......@@ -247,13 +247,13 @@ void Check(FNNModel &model)
/* make a hard copy of the fnn model */
void Copy(FNNModel &tgt, FNNModel &src)
{
InitTensor(&tgt.embeddingW, &src.embeddingW);
InitTensorV2(&tgt.embeddingW, &src.embeddingW);
for(int i = 0; i < MAX_HIDDEN_NUM; i++){
InitTensor(&tgt.hiddenW[i], &src.hiddenW[i]);
InitTensor(&tgt.hiddenB[i], &src.hiddenB[i]);
InitTensorV2(&tgt.hiddenW[i], &src.hiddenW[i]);
InitTensorV2(&tgt.hiddenB[i], &src.hiddenB[i]);
}
InitTensor(&tgt.outputW, &src.outputW);
InitTensor(&tgt.outputB, &src.outputB);
InitTensorV2(&tgt.outputW, &src.outputW);
InitTensorV2(&tgt.outputB, &src.outputB);
tgt.n = src.n;
tgt.eSize = src.eSize;
......@@ -310,7 +310,7 @@ initialize a 1d tensor using the fnn model setting
*/
void InitModelTensor1D(XTensor &tensor, int num, FNNModel &model)
{
InitTensor1D(&tensor, num, X_FLOAT, model.devID, model.mem);
InitTensor1DV2(&tensor, num, X_FLOAT, model.devID);
}
/*
......@@ -322,7 +322,7 @@ initialize a 2d tensor using the fnn model setting
*/
void InitModelTensor2D(XTensor &tensor, int rowNum, int colNum, FNNModel &model)
{
InitTensor2D(&tensor, rowNum, colNum, X_FLOAT, model.devID, model.mem);
InitTensor2DV2(&tensor, rowNum, colNum, X_FLOAT, model.devID);
}
......@@ -449,6 +449,9 @@ void Train(const char * train, bool isShuffled, FNNModel &model)
/* the gold standard */
XTensor gold;
/* the loss tensor */
XTensor lossTensor;
/* make the input tensor for position i */
for(int i = 0; i < model.n - 1; i++)
MakeWordBatch(inputs[i], ngrams, ngramNum, i, model.vSize, model.devID, model.mem);
......@@ -466,6 +469,8 @@ void Train(const char * train, bool isShuffled, FNNModel &model)
/* forward computation */
Forward(inputs, output, model, net);
/* backward computation to obtain gradients */
Backward(inputs, output, gold, CROSSENTROPY, model, grad, net);
......@@ -483,9 +488,11 @@ void Train(const char * train, bool isShuffled, FNNModel &model)
/* this is implemented by multiply function */
//ForwardAutoDiff(inputs, output, model);
lossTensor = CrossEntropy(output, gold);
/* automatic differentiation */
autoDiffer.Backward(output, gold, CROSSENTROPY);
autoDiffer.Backward(lossTensor);
//autoDiffer.Backward(output, gold, CROSSENTROPY);
/* update model parameters */
Update(model, grad, learningRate, true);
......@@ -494,7 +501,9 @@ void Train(const char * train, bool isShuffled, FNNModel &model)
/* get probabilities */
float prob = GetProb(output, gold);
loss += -prob;
prob = ReduceSumAll(lossTensor);
loss += prob;
wordCount += ngramNum;
wordCountTotal += ngramNum;
......@@ -595,14 +604,14 @@ get prediction probabilites of the gold words
float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs)
{
XTensor probs;
InitTensor(&probs, &output);
InitTensorV2(&probs, &output);
/* probs[i,j] = output[i,j] * gold[i,j] */
_Multiply(&output, &gold, &probs);
/* probability of each word */
XTensor wprobs;
InitTensor1D(&wprobs, output.GetDim(0), output.dataType, output.devID, output.mem);
InitTensor1DV2(&wprobs, output.GetDim(0), output.dataType, output.devID);
_ReduceSum(&probs, &wprobs, 1);
if(wordProbs != NULL)
_CopyValues(&wprobs, wordProbs);
......@@ -616,7 +625,7 @@ float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs)
/* probability for the batch */
XTensor result;
InitTensor1D(&result, 1, X_FLOAT, output.devID, output.mem);
InitTensor1DV2(&result, 1, X_FLOAT, output.devID);
_ReduceSum(&probs, &result, 1);
return result.Get1D(0);
......@@ -718,7 +727,7 @@ The indexed cell is set to 1, and 0 otherwise.
void InitZeroOneTensor2D(XTensor &tensor, int rowNum, int colNum, int * rows, int * cols,
int itemNum, int devID, XMem * mem)
{
InitTensor2D(&tensor, rowNum, colNum, X_FLOAT, devID, mem);
InitTensor2DV2(&tensor, rowNum, colNum, X_FLOAT, devID);
tensor.SetZeroAll();
......@@ -811,7 +820,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
/* make a 2d tensor for the bias term */
XTensor b2D;
InitTensor(&b2D, &s);
InitTensorV2(&b2D, &s);
_Unsqueeze(&b, &b2D, 0, batchSize);
/* introduce bias term:
......@@ -843,7 +852,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
_MatrixMul(&h_last, X_NOTRANS, &w, X_NOTRANS, &s);
XTensor b2D;
InitTensor(&b2D, &s);
InitTensorV2(&b2D, &s);
_Unsqueeze(&b, &b2D, 0, batchSize);
_Sum(&s, &b2D, &s);
......@@ -908,8 +917,8 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
XTensor dedsHidden;
XTensor dedxBottom;
if (depth > 0)
InitTensor(&dedsHidden, &dedx);
InitTensor(&dedxBottom, &net.embeddingCat);
InitTensorV2(&dedsHidden, &dedx);
InitTensorV2(&dedxBottom, &net.embeddingCat);
/* back-propagation from top to bottom in the stack of hidden layers
for each layer, h = f(s)
......@@ -947,7 +956,7 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
/* back-propagation for the embedding layer */
for (int i = 0; i < n - 1; i++) {
XTensor * dedy = NewTensor2D(batchSize, model.eSize, X_FLOAT, model.devID, model.mem);
XTensor * dedy = NewTensor2DV2(batchSize, model.eSize, X_FLOAT, model.devID);
eList.Add(dedy);
}
......@@ -999,7 +1008,7 @@ void ForwardAutoDiff(NGram * ngrams, int batch, XTensor &output, FNNModel &model
}
}
InitTensor1D(&words, size, X_INT, model.devID, model.mem);
InitTensor1DV2(&words, size, X_INT, model.devID);
words.SetData(index, size);
embeddingBig = Gather(model.embeddingW, words);
......@@ -1017,7 +1026,8 @@ void ForwardAutoDiff(NGram * ngrams, int batch, XTensor &output, FNNModel &model
hidden = HardTanH(MMul(hidden, model.hiddenW[i]) + model.hiddenB[i]);
/* output layer */
output = LogSoftmax(MMul(hidden, model.outputW) + model.outputB, 1);
//output = LogSoftmax(MMul(hidden, model.outputW) + model.outputB, 1);
output = Softmax(MMul(hidden, model.outputW) + model.outputB, 1);
}
/*
......@@ -1177,7 +1187,7 @@ void Test(const char * test, const char * result, FNNModel &model)
/* prediction probabilities */
XTensor probs;
InitTensor1D(&probs, ngramNum);
InitTensor1DV2(&probs, ngramNum);
/* get probabilities */
float prob = GetProb(output, gold, &probs);
......
#include "../../source/tensor/data/DataSet.h"
#include <fstream>
#include <iostream>
#include <string>
#include "../tensor/core/arithmetic/MatrixMul.h"
using namespace nts;
void TestDataManager() {
DataSet dataSet("src.txt", 2, 100);
XTensor src, tgt;
enum FIELD {
srcField = 0,
tgtField = 1,
};
const int indices[] = { 0, 1 };
dataSet.LoadBatch(src, indices, sizeof(indices) / sizeof(*indices), srcField);
dataSet.LoadBatch(tgt, indices, sizeof(indices) / sizeof(*indices), tgtField);
IntList str(10);
for (int i = 9; i > 0; --i) {
str.Add(i);
}
str.Add('\0');
for (int i = 0; i < str.count; ++i)
cout << str.Get(i);
cout << endl;
str.Sort(10);
for (int i = 0; i < str.count; ++i)
cout << str.Get(i);
cout << endl;
}
int main()
{
TestDataManager();
return 0;
}
\ No newline at end of file
......@@ -253,6 +253,15 @@ void T2TBatchLoader::ClearBuf()
}
/*
set the random batch flag
>> flag - as it is
*/
void T2TBatchLoader::SetRandomBatch(bool flag)
{
isRandomBatch = flag;
}
/*
load a batch of sequences
>> file - the handle to the data file
>> isLM - indicates whether the data is used for training lms
......@@ -580,7 +589,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
int * batchEncValues = new int[batchEnc->unitNum];
int * batchDecValues = new int[batchDec->unitNum];
int * labelValues = new int[label->unitNum];
//MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2];
MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2];
MTYPE * paddingDecOffsets = new MTYPE[sc * maxDec / 2];
//MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2];
......@@ -595,17 +604,18 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
for(int w = 0; w < len; w++){
int num = buf[seqOffset[s] + w];
batchEncValues[batchEnc->GetOffset2D(sent, w)] = num;
//paddingEncOffsets[wCountEnc] = paddingEnc->GetOffset2D(sent, w);
paddingEncOffsets[wCountEnc] = paddingEnc->GetOffset2D(sent, w);
wCountEnc++;
}
}
ws = wCountEnc;
batchEnc->SetData(batchEncValues, batchEnc->unitNum);
//paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc);
XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
_ConvertDataType(batchEnc, tmp);
_NotEqual(tmp, paddingEnc, 0);
DelTensorBuf(tmp);
paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc);
//XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
//_ConvertDataType(batchEnc, tmp);
//tmp->Dump(stderr, "tmp:");
//_NotEqual(tmp, paddingEnc, 0);
//DelTensorBuf(tmp);
/* batch of the target-side sequences */
for(int s = seq + 1; s < seq + sc; s += 2){
......@@ -660,7 +670,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
delete[] batchEncValues;
delete[] batchDecValues;
delete[] labelValues;
//delete[] paddingEncOffsets;
delete[] paddingEncOffsets;
delete[] paddingDecOffsets;
//delete[] goldOffsets;
......
......@@ -120,6 +120,9 @@ public:
/* clear data buffer */
void ClearBuf();
/* set the random batch flag */
void SetRandomBatch(bool flag = true);
/* load a batch of sequences */
int LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc,
......
......@@ -31,6 +31,10 @@ namespace transformer
/* constructor */
AttDecoder::AttDecoder()
{
attentions = NULL;
fnns = NULL;
attLayerNorms = NULL;
fnnLayerNorms = NULL;
attentionsEnde = NULL;
attEndeLayerNorms = NULL;
}
......@@ -38,6 +42,10 @@ AttDecoder::AttDecoder()
/* de-constructor */
AttDecoder::~AttDecoder()
{
delete[] attentions;
delete[] fnns;
delete[] attLayerNorms;
delete[] fnnLayerNorms;
delete[] attentionsEnde;
delete[] attEndeLayerNorms;
}
......
......@@ -23,6 +23,7 @@
#include "T2TModel.h"
#include "T2TUtility.h"
#include "../../tensor/core/CHeader.h"
#include "../../tensor/XUtility.h"
namespace transformer
{
......@@ -366,8 +367,13 @@ void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
This matrix can be used to block the attention to current or following words in
a given sequence. */
_SetDataLowTri(&maskDec, 1e9F, 0);
//maskDec.Dump(stderr, "mask: ");
_ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
//maskDec.Dump(stderr, "mask: ");
/* encoder-decoder mask that prevents the attention to padding dummy words */
dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem);
......@@ -377,9 +383,18 @@ void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID, paddingEnc.mem);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
//paddingEnc.Dump(stderr, "paddingenc:");
//maskEncDecTMPEnc->Dump(stderr, "maskencdectmpenc:");
_ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
//maskEncDecTMPEnc->Dump(stderr, "maskencdectmpenc:");
_Unsqueeze(maskEncDecTMPEnc, &maskEncDec, 0, dims[0]);
//maskEncDecTMPEnc->Dump(stderr, "maskencdectmpenc:");
DelTensorBuf(maskEncDecTMPDec);
DelTensorBuf(maskEncDecTMPEnc);
delete[] dims;
......@@ -445,6 +460,8 @@ dump the parameters
*/
void T2TModel::Dump(const char * fn)
{
double startT = GetClockSec();
FILE * file = fopen(fn, "wb");
CheckNTErrors(file, "Cannot open the model file");
......@@ -459,12 +476,16 @@ void T2TModel::Dump(const char * fn)
fclose(file);
XPRINT(0, stderr, "[INFO] model saved\n");
double elapsed = GetClockSec() - startT;
XPRINT1(0, stderr, "[INFO] model saved (took %.1fs)\n", elapsed);
}
/* read the parameters */
void T2TModel::Read(const char * fn)
{
double startT = GetClockSec();
FILE * file = fopen(fn, "rb");
CheckNTErrors(file, "Cannot open the model file");
......@@ -479,7 +500,9 @@ void T2TModel::Read(const char * fn)
fclose(file);
XPRINT(0, stderr, "[INFO] model loaded\n");
double elapsed = GetClockSec() - startT;
XPRINT1(0, stderr, "[INFO] model loaded (took %.1fs)\n", elapsed);
}
}
......@@ -93,9 +93,8 @@ void T2TOutput::Make(XTensor &input, XTensor &output)
{
XTensor &x = input;
output = LogSoftmax(MMul(x, w), -1);
//output = Softmax(MMul(x, w), -1);
//output = LogSoftmax(MMul(x, w), -1);
output = Softmax(MMul(x, w), -1);
output.SetName(OUTPUT_NAME);
}
......
......@@ -59,6 +59,7 @@ void T2TStateBundle::MakeStates(int num)
states[i].pid = T2T_PID_EMPTY;
states[i].isEnd = false;
states[i].isStart = false;
states[i].isCompleted = false;
states[i].prob = 0;
states[i].probPath = 0;
states[i].modelScore = 0;
......@@ -72,6 +73,7 @@ void T2TStateBundle::MakeStates(int num)
/* constructor */
T2TPredictor::T2TPredictor()
{
startSymbol = -1;
}
/* de-constructor */
......@@ -115,6 +117,15 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input
}
/*
set start symbol
>> symbol - the symbol (in integer)
*/
void T2TPredictor::SetStartSymbol(int symbol)
{
startSymbol = symbol;
}
/*
read a state
>> model - the t2t model that keeps the network created so far
>> state - a set of states. It keeps
......@@ -135,7 +146,8 @@ predict the next state
>> inputEnc - input of the encoder
>> paddingEnc - padding of the encoder
*/
void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc)
void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
XTensor * inputEnc, XTensor * paddingEnc)
{
int dims[MAX_TENSOR_DIM_NUM];
......@@ -148,30 +160,28 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
XTensor * inputLast = (XTensor*)s->layersDec.GetItem(0);
/* word indices of positions up to next state */
XTensor &inputDec = *NewTensor();
XTensor inputDec;
/* a dummy word that used to as a placeholder when we process the next work */
XTensor dummy;
/* the first token */
XTensor first;
CheckNTErrors(inputEnc->order >= 2, "Wrong order of the tensor!");
for(int i = 0; i < inputEnc->order - 1; i++)
dims[i] = inputEnc->GetDim(i);
dims[inputEnc->order - 1] = 1;
InitTensor(&dummy, inputEnc->order, dims, X_INT, 1.0F, inputEnc->devID, inputEnc->mem);
dummy.SetZeroAll();
InitTensor(&first, inputEnc->order, dims, X_INT, 1.0F, inputEnc->devID, inputEnc->mem);
_SetDataFixedInt(&first, startSymbol);
/* add a new word into the input sequence of the decoder side */
if(inputLast == NULL)
inputDec = Identity(dummy);
if (inputLast == NULL) {
inputDec = Identity(first);
}
else{
inputDec = GeneratePaths(s);
for(int i = 0; i < inputEnc->order - 1; i++)
dims[i] = inputEnc->GetDim(i);
dims[inputEnc->order - 1] = inputDec.GetDim(-1);
inputDec.Resize(inputEnc->order, dims, X_INT);
inputDec.SetDevice(inputEnc->devID, inputEnc->mem);
inputDec = Concatenate(inputDec, dummy, inputDec.order - 1);
inputDec = Concatenate(first, inputDec, inputDec.order - 1);
}
/* prediction probabilities */
......
......@@ -50,6 +50,9 @@ public:
/* indicates whether the state is the start */
bool isStart;
/* indicates whether the state is completed */
bool isCompleted;
/* probability of every prediction (last state of the path) */
float prob;
......@@ -132,6 +135,9 @@ private:
/* current state */
T2TStateBundle * s;
/* start symbol */
int startSymbol;
public:
/* constructor */
T2TPredictor();
......@@ -142,6 +148,9 @@ public:
/* create an initial state */
void Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state);
/* set the start symbol */
void SetStartSymbol(int symbol);
/* read a state */
void Read(T2TModel * model, T2TStateBundle * state);
......
......@@ -38,6 +38,7 @@ T2TSearch::T2TSearch()
endSymbolNum = 0;
fullHypos = NULL;
endSymbols = new int[32];
startSymbol = -1;
}
/* de-constructor */
......@@ -60,6 +61,7 @@ void T2TSearch::Init(int argc, char ** argv)
LoadParamInt(argc, argv, "batchsize", &batchSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F);
LoadParamInt(argc, argv, "endid", endSymbols, -1);
LoadParamInt(argc, argv, "startid", &startSymbol, -1);
if(endSymbols[0] >= 0)
endSymbolNum = 1;
......@@ -74,30 +76,46 @@ search for the most promising states
*/
void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output)
{
T2TPredictor predictor;
XTensor maskEnc;
XTensor encoding;
T2TPredictor predictor;
XTensor encodingBeam;
XTensor inputBeam;
XTensor paddingBeam;
CheckNTErrors(endSymbolNum > 0, "The search class is not initialized!");
CheckNTErrors(startSymbol >= 0, "The search class is not initialized!");
Prepare(input->unitNum/input->GetDim(-1), beamSize);
/* encoder mask */
model->MakeMTMaskEnc(*input, *padding, maskEnc);
//input->Dump(stderr, "input:");
//maskEnc.Dump(stderr, "maskenc:");
/* make the encoding network */
encoding = model->MakeEncoder(*input, maskEnc, false);
encoding.SetName(ENCODING_NAME);
encodingBeam = Unsqueeze(encoding, encoding.order - 2, beamSize);
inputBeam = Unsqueeze(*input, input->order - 1, beamSize);
paddingBeam = Unsqueeze(*padding, padding->order - 1, beamSize);
encodingBeam.ReshapeMerged(encodingBeam.order - 4);
inputBeam.ReshapeMerged(inputBeam.order - 3);
paddingBeam.ReshapeMerged(paddingBeam.order - 3);
/* max output-length = 2 * source-length */
maxLength = input->GetDim(-2) * 2;
maxLength = input->GetDim(-1) * 2;
CheckNTErrors(maxLength > 0, "no max length specified!");
T2TStateBundle * states = new T2TStateBundle[maxLength + 1];
T2TStateBundle * first = states;
/* create the first state */
predictor.Create(model, &encoding, input, beamSize, first);
predictor.Create(model, &encodingBeam, input, beamSize, first);
predictor.SetStartSymbol(startSymbol);
first->isStart = true;
......@@ -110,7 +128,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
predictor.Read(model, cur);
/* predict the next state */
predictor.Predict(next, &encoding, input, padding);
predictor.Predict(next, &encodingBeam, &inputBeam, &paddingBeam);
/* compute the model score (given the prediction probability) */
Score(cur, next);
......@@ -161,6 +179,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
{
XTensor &score = beam->modelScore;
XTensor &prob = beam->prob;
XTensor &probPath = beam->probPath;
XTensor &probPathPrev = prev->probPath;
XTensor &lenPrev = prev->nstep;
XTensor &len = beam->nstep;
......@@ -174,13 +193,16 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
dims[i] = prob.GetDim(i);
InitTensor(&score, &prob);
InitTensor(&probPath, &prob);
prob.Reshape(prob.unitNum/outputSize, outputSize);
score.Reshape(score.unitNum/outputSize, outputSize);
probPath.Reshape(score.unitNum / outputSize, outputSize);
probPathPrev.Reshape(probPathPrev.unitNum);
/* the log-scale probability of the entire sequence */
_SumDim(&prob, &probPathPrev, &score, 0);
_SumDim(&prob, &probPathPrev, &probPath, 0);
InitTensor(&len, &lenPrev);
InitTensor(&lp, &lenPrev);
......@@ -193,7 +215,15 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
lp.Reshape(lp.unitNum);
/* score = log-prob/lp */
_DivDim(&score, &lp, &score, 0);
_DivDim(&probPath, &lp, &score, 0);
if (prev->isStart) {
XTensor firstMask = MakeFirstMask(beam);
firstMask.Reshape(firstMask.unitNum);
/* mask the hypotheses in the beam expect the first one */
_SumDim(&score, &firstMask, &score, 0);
}
InitTensor(&mask,
prev->endMark.order, prev->endMark.dimSize, X_FLOAT, 1.0F,
......@@ -208,6 +238,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
prob.Reshape(order, dims);
score.Reshape(order, dims);
probPath.Reshape(order, dims);
probPathPrev.Reshape(order - 1, dims);
lp.Reshape(order - 1, dims);
mask.Reshape(order -1 , dims);
......@@ -228,10 +259,11 @@ void T2TSearch::Generate(T2TStateBundle * beam)
XTensor &index = beam->prediction;
XTensor &preID = beam->preID;
XTensor &probPath = beam->probPath;
XTensor &prob = beam->prob;
int order = score.order;
CheckNTErrors(order >= 2, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 2] % beamSize == 0, "Wrong dimension size!");
CheckNTErrors(order >= 3, "The tensor must be of order 2 or larger.");
CheckNTErrors(dimsBeam[order - 3] % beamSize == 0, "Wrong dimension size!");
for (int i = 0; i < order; i++) {
dims[i] = score.GetDim(i);
......@@ -239,9 +271,12 @@ void T2TSearch::Generate(T2TStateBundle * beam)
dimsTopK[i] = score.GetDim(i);
}
dimsBeam[order - 2] /= beamSize;
int sizeVocab = score.GetDim(-1);
int stride = score.GetDim(-1);
dimsBeam[order - 3] /= beamSize;
dimsBeam[order - 1] *= beamSize;
dimsTopK[order - 2] = dimsBeam[order - 2];
dimsTopK[order - 3] = dimsBeam[order - 3];
dimsTopK[order - 1] = beamSize;
InitTensor(&scoreTopK, order, dimsTopK, score.dataType,
......@@ -258,8 +293,6 @@ void T2TSearch::Generate(T2TStateBundle * beam)
CopyValues(index, preID);
int sizeVocab = score.GetDim(-1);
/* "preID" represents the id (or the offset) of previous state used to make the current
hypothesis. Note that we reshape the "score" tensor into a matrix where each
row means a previous state. The column number is size-of-beam * vocab-size. We,
......@@ -270,7 +303,7 @@ void T2TSearch::Generate(T2TStateBundle * beam)
/* Then, we do something similar to "preID". For the top-k predictions, we need
to know their indices in the vocabulary. We compute the offset of each prediction
in the vocabulary by dividing it with vocab-size and computing the remainder. */
Mod(index, sizeVocab);
_ModMe(index, sizeVocab);
score.Reshape(order, dims);
......@@ -283,9 +316,40 @@ void T2TSearch::Generate(T2TStateBundle * beam)
InitTensor(&indexCPU, index.order, index.dimSize, index.dataType, index.denseRatio, -1);
CopyValues(index, indexCPU);
for (int i = 0; i < indexCPU.unitNum; i++)
indexCPU.SetInt(i * stride + indexCPU.GetInt(i), i);
CheckNTErrors(XTensor::IsSameShaped(&prob, &probPath), "Wrong tensor shape!");
/* sequence probability of top-k candidates */
InitTensor(&probPath, &scoreTopK);
_Gather(&beam->prob, &probPath, probPath.order - 1, (int*)indexCPU.data, indexCPU.unitNum);
XTensor probPathTopK;
InitTensor(&probPathTopK, &scoreTopK);
XTensor probTopK;
InitTensor(&probTopK, &scoreTopK);
for (int i = 0; i < probPath.order; i++) {
dims[i] = probPath.GetDim(i);
dimsTopK[i] = probPathTopK.GetDim(i);
}
order = probPath.order;
probPath.Reshape(1, probPath.unitNum);
probPathTopK.Reshape(1, probPathTopK.unitNum);
prob.Reshape(1, prob.unitNum);
probTopK.Reshape(1, probTopK.unitNum);
_Gather(&probPath, &probPathTopK, probPathTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum);
_Gather(&prob, &probTopK, probTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum);
probPath.Reshape(order, dims);
probPathTopK.Reshape(order, dimsTopK);
prob.Reshape(order, dims);
probTopK.Reshape(order, dimsTopK);
probPath = probPathTopK;
prob = probTopK;
}
/*
......@@ -331,43 +395,52 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
CheckNTErrors(beam->stateNum == id.unitNum, "Errors occur in counting!");
/* Related variables are kept on the states of the graph. All these are
maintained on CPUs to ease the implementation of requent access and
maintained on CPUs to ease the implementation of frequent access and
modification of the states. An alternative is to do this on GPUs but
it needs much more coding work and the speed-up is not obvious. */
for(int i = 0; i < beam->stateNum; i++){
T2TState & state = states[i];
for(int i = 0; i < beam->stateNum; i += beamSize){
for (int j = 0; j < beamSize; j++) {
int k = i + j;
T2TState & state = states[k];
int offset = id.GetInt(i);
T2TState * last = prev->states + offset;
int offset = id.GetInt(k);
int pid = i / beamSize;
T2TState * last = prev->states + pid * beamSize + offset;
CheckNTErrors(offset >= 0, "Wrong state index!");
/* pointer to the previous state */
if (prev->isStart) {
state.last = NULL;
state.pid = offset;
state.pid = pid;
state.nstep = 0;
state.isCompleted = false;
}
else{
else {
state.last = last;
state.pid = state.last->pid;
state.nstep = last->nstep + 1;
state.isCompleted = last->isCompleted;
CheckNTErrors(offset < prev->stateNum, "Wrong state index!");
}
/* scores */
state.modelScore = modelScore.Get(i);
state.prob = prob.Get(i);
state.probPath = probPath.Get(i);
state.modelScore = modelScore.Get(k);
state.prob = prob.Get(k);
state.probPath = probPath.Get(k);
/* prediction */
state.prediction = prediction.GetInt(i);
state.prediction = prediction.GetInt(k);
CheckNTErrors(state.prediction >= 0, "Illegal prediction!");
/* check if it is the end of the sequence */
state.isEnd = IsEnd(state.prediction);
state.isCompleted = (state.isCompleted || state.isEnd);
/* set the ending mark */
endMarkCPU.SetInt(state.isEnd, i);
endMarkCPU.SetInt(state.isEnd, k);
}
}
/* copy the ending mark from CPU to the target device */
......@@ -396,7 +469,7 @@ void T2TSearch::Collect(T2TStateBundle * beam)
}
/*
fill the hypotheis heap with incomplete hypothses
fill the hypotheis heap with incomplete hypotheses
>> beam - the beam that keeps a number of states (final)
*/
void T2TSearch::FillHeap(T2TStateBundle * beam)
......@@ -443,16 +516,22 @@ void T2TSearch::Dump(XTensor * output)
T2TState * state = (T2TState *)heap.Pop().index;
int count = 0;
bool isCompleted = true;
/* we track the state from the end to the beginning */
while(state != NULL){
if (!state->isCompleted)
isCompleted = false;
if (isCompleted)
words[count++] = -1;
else
words[count++] = state->prediction;
state = state->last;
}
/* dump the sentence to the output tensor */
for(int w = 0; w < count; w++)
output->Set3DInt(words[count - w - 1], h, i, w);
output->Set3DInt(words[count - w - 1], h, beamSize - i - 1, w);
}
}
......@@ -495,4 +574,31 @@ void T2TSearch::SetEnd(const int * tokens, const int tokenNum)
endSymbolNum = tokenNum;
}
/*
make a mask to prevent duplicated entries in beam expansion for the first position
>> beam - the beam that keeps the searching states
*/
XTensor T2TSearch::MakeFirstMask(T2TStateBundle * beam)
{
XTensor &prob = beam->prob;
XTensor mask;
int order = prob.order;
int dims[MAX_TENSOR_DIM_NUM];
for (int i = 0; i < order - 1; i++)
dims[i] = prob.GetDim(i);
InitTensor(&mask, order - 1, dims, X_FLOAT);
mask.SetZeroAll();
for (int i = 0; i < mask.unitNum; i++) {
if (i % beamSize != 0)
mask.Set(-1e9, i);
}
mask.SetDevice(prob.devID, prob.mem);
return mask;
}
}
......@@ -59,6 +59,9 @@ private:
/* number of the end symbols */
int endSymbolNum;
/* start symbol */
int startSymbol;
public:
/* constructor */
T2TSearch();
......@@ -98,6 +101,9 @@ public:
/* set end symbols for search */
void SetEnd(const int * tokens, const int tokenNum);
/* make a mask to prevent duplicated entries in beam expansion for the first position */
XTensor MakeFirstMask(T2TStateBundle * beam);
};
}
......
......@@ -100,6 +100,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
/* an array that keeps the sequences */
int * seqs = new int[MILLION];
batchLoader.SetRandomBatch(false);
batchLoader.ClearBuf();
while(batchLoader.LoadBatch(file, model->isLM,
......@@ -114,7 +115,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
seacher.Search(model, &batchEnc, &paddingEnc, &output);
output.Dump(ofile, "output:");
Dump(ofile, &output);
float prob = 0;
......@@ -144,4 +145,25 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
elapsed,wordCountTotal, exp(loss/wordCount));
}
/*
dump the result into the file
>> file - data file
>> output - output tensor
*/
void T2TTester::Dump(FILE * file, XTensor * output)
{
int seqLength = output->GetDim(-1);
for (int i = 0; i < output->unitNum; i += seqLength) {
for (int j = 0; j < seqLength; j++) {
int w = output->GetInt(i + j);
fprintf(file, "%d ", w);
if (w < 0)
break;
}
fprintf(file, "\n");
}
}
}
......@@ -57,6 +57,9 @@ public:
/* test the model */
void Test(const char * fn, const char * ofn, T2TModel * model);
/* dump the result into the file */
void Dump(FILE * file, XTensor * output);
};
}
......
......@@ -24,6 +24,7 @@
#include "T2TUtility.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h"
#include "../../tensor/loss/LHeader.h"
#include "../../network/XNoder.h"
#ifndef WIN32
......@@ -209,13 +210,16 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
labelOnehot = IndexToOnehot(label, vSizeTgt, labelSmoothingP);
/* make paddings for the output */
if (output.GetDim(0) > 0)
PadOutput(&output, &labelOnehot, &paddingDec);
//if (output.GetDim(0) > 0)
//PadOutput(&output, &labelOnehot, &paddingDec);
/* get probabilities */
float prob = GetProb(&output, &labelOnehot, NULL);
//float prob = GetProb(&output, &labelOnehot, NULL);
XTensor lossTensor;
lossTensor = CrossEntropy(output, labelOnehot, paddingDec);
float prob = ReduceSumAll(lossTensor);
DTYPE lossLocal = -prob / wc;
DTYPE lossLocal = prob / wc;
bool doUpdate = (!IsNAN(lossLocal) && !IsINF(lossLocal) && lossLocal < 1e3F);
//XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold;
......@@ -223,14 +227,15 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if (doUpdate) {
/* recale the output for normalized loss */
RescaleOutput(&output, &labelOnehot, &paddingDec);
//RescaleOutput(&output, &labelOnehot, &paddingDec);
/* back-propagation */
net.Backward(output, labelOnehot, paddingDec, CROSSENTROPY);
net.Backward(lossTensor);
//net.Backward(output, labelOnehot, paddingDec, CROSSENTROPY);
//net.Backward(output, label, labelSmoothingP, CROSSENTROPY);
gradStep += 1;
loss += -prob;
loss += prob;
wordCount += wc;
wordCountTotal += wc;
......@@ -260,7 +265,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if (step % 100 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT8(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, tword=%d, sword=%d, loss=%.3f, ppl=%.3f, sppl=%.3f",
elapsed, step, epoch, wordCountTotal, wordCountBatch, loss/wordCount, exp(loss/wordCount), exp(-prob/wc));
elapsed, step, epoch, wordCountTotal, wordCountBatch, loss/wordCount, exp(loss/wordCount), exp(prob/wc));
if (!doUpdate)
XPRINT(0, stderr, " (no update)");
XPRINT(0, stderr, "\n");
......
......@@ -30,6 +30,7 @@
#include "XDevice.h"
#include "./test/Test.h"
#include "./core/CHeader.h"
#include "./loss/CrossEntropy.h"
//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>
......
......@@ -266,10 +266,6 @@ XDevManager::XDevManager()
{
Clear();
Init();
#ifndef USE_CPP11
fprintf(stderr, "Warning!!! c++ 11 is RECOMMENDED for compilation.\n");
#endif
}
/* de-constructor */
......
......@@ -43,17 +43,13 @@
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
#if (__cplusplus >= 201103L || _MSC_VER >= 1700)
#define USE_CPP11
#endif
#define _XINLINE_
//#define DOUBELPRICSION
//#define DOUBELPRICSION
#ifdef DOUBELPRICSION
#define DTYPE double
#define DTYPE_MIN (DTYPE)1.79E+308
#define DTYPE_MIN (DTYPE)-1.79E+308
#else
#define DTYPE float
#define DTYPE_MIN (DTYPE)-3.40E+38
......
......@@ -102,10 +102,24 @@ _XINLINE_ HeapNode<T> XHeap<hType, T>::End()
template<HeapType hType, typename T>
_XINLINE_ void XHeap<hType, T>::Push(HeapNode<T> node)
{
//CheckNTErrors((count < size), "Heap is full!");
if (count < size) {
items[count] = node;
Up(count);
count++;
}
else if(count == size){
HeapNode<T> & item0 = items[0];
if (hType == MIN_HEAP && item0.value >= node.value)
return;
else if (hType == MAX_HEAP && item0.value <= node.value)
return;
items[0] = node;
Down(0);
}
else {
ShowNTErrors("Overflow of the heap!");
}
}
/* replace the top-most item and update the heap */
......
......@@ -528,10 +528,90 @@ void XLink::Replace(const XTensor * oldOne, XTensor * newOne)
CheckNTErrors(hit, "No proper node found in parent.income edge!");
}
}
strcpy(newOne->name, oldOne->name);
}
/*
copy a node with another, i.e., we add the links to the new node
>> src - the node to be copied
>> tgt - the new node
*/
void XLink::Copy(const XTensor * reference, XTensor * target)
{
if (reference == NULL || target == NULL)
return;
XLink &newIncome = target->income;
XLink &newOutgo = target->outgo;
XLink::ClearOutgoing(target);
XLink::ClearIncoming(target);
/* incoming nodes */
if (reference->income.typeID != 0) {
if (newIncome.tailNum < reference->income.tailNum) {
delete[] newIncome.tails;
newIncome.tails = new XTensor*[reference->income.tailNum];
}
newIncome.SetType(reference->income.typeID);
newIncome.head = target;
newIncome.tailNum = reference->income.tailNum;
memcpy(newIncome.tails, reference->income.tails, sizeof(XTensor*) * newIncome.tailNum);
int paraArraySize = reference->income.paramNum * reference->income.paramSize;
newIncome.params = new char[paraArraySize];
memcpy(newIncome.params, reference->income.params, paraArraySize);
newIncome.paramNum = reference->income.paramNum;
/* update the link to each child node */
for (int i = 0; i < newIncome.tailNum; i++) {
XTensor * child = newIncome.tails[i];
XLink &childOutgo = child->outgo;
bool hit = false;
for (int j = 0; j < childOutgo.tailNum; j++) {
if (childOutgo.tails[j] == reference) {
//childOutgo.tails[j] = target;
childOutgo.AddTail(target);
hit = true;
break;
}
}
if (childOutgo.tailNum > 0) {
CheckNTErrors(hit, "No proper node found in child.outgo edge!");
}
}
}
if (newOutgo.tailNum < reference->outgo.tailNum) {
delete[] newOutgo.tails;
newOutgo.tails = new XTensor*[reference->outgo.tailNum];
}
/* outgoing nodes */
newOutgo.head = target;
newOutgo.tailNum = reference->outgo.tailNum;
memcpy(newOutgo.tails, reference->outgo.tails, sizeof(XTensor*) * newOutgo.tailNum);
/* update the link to each parent node */
for (int i = 0; i < newOutgo.tailNum; i++) {
XTensor * parent = newOutgo.tails[i];
XLink &parentIncome = parent->income;
bool hit = false;
for (int j = 0; j < parentIncome.tailNum; j++) {
if (parentIncome.tails[j] == reference) {
//parentIncome.tails[j] = target;
parentIncome.AddTail(target);
hit = true;
}
}
if (parentIncome.tailNum > 0) {
CheckNTErrors(hit, "No proper node found in parent.income edge!");
}
}
}
/*
copy incoming edges of a given node
>> reference - the node we copy from
......
......@@ -33,7 +33,7 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
/* cross reference */
struct XTensor;
#define MAX_OP_NAME_LENGTH 16
#define MAX_OP_NAME_LENGTH 64
#define PARAM_UNTI_SIZE 64
/*
......@@ -174,6 +174,10 @@ struct XLink
static
void Replace(const XTensor * oldOne, XTensor * newOne);
/* copy a node with another, i.e., we add the links to the new node */
static
void Copy(const XTensor * reference, XTensor * target);
/* copy links of a given node */
static
void CopyIncoming(const XTensor * reference, XTensor * target);
......
......@@ -34,6 +34,11 @@ namespace nts{
int testxmemid = 0;
void * recordp = NULL;
/*
for managing the memories
*/
XMemManager GMems;
XMem * GMem;
/* constructor */
......@@ -1488,4 +1493,158 @@ cublasHandle_t * XMem::GetCublasHandle()
#endif
/* constructor */
XMemManager::XMemManager()
{
Initialize();
}
/* de-constructor */
XMemManager::~XMemManager()
{
}
/* get memory size */
MTYPE XMemManager::GetAvailableMemory()
{
unsigned long freeMem = 0;
#ifndef WIN32
long pages = sysconf(_SC_AVPHYS_PAGES);
long page_size = sysconf(_SC_PAGE_SIZE);
freeMem = pages * page_size;
#else
MEMORYSTATUSEX memoryStatus;
memoryStatus.dwLength = sizeof(memoryStatus);
if (GlobalMemoryStatusEx(&memoryStatus)){
freeMem = memoryStatus.ullAvailPhys;
}
#endif
return (MTYPE)freeMem;
}
/* get GPU memory size */
MTYPE XMemManager::GetAvailableGPUMemory(int devID)
{
size_t freeMem = 0;
size_t totalMem = 0;
#ifdef USE_CUDA
cudaSetDevice(devID);
if (cudaMemGetInfo(&freeMem, &totalMem) != cudaSuccess){
XPRINT(0, stderr, "cannot get GPU memory information.");
exit(1);
}
#endif
return (MTYPE)freeMem;
}
/* get buffer size */
void XMemManager::GetBufferSize(MTYPE freeMem, MTYPE * myBufSize)
{
*myBufSize = 0;
if (freeMem >= MILLION * 128){
*myBufSize = MILLION * 32;
if (freeMem >= MILLION * 256){
*myBufSize = MILLION * 64;
if (freeMem >= MILLION * 512){
*myBufSize = MILLION * 128;
if (freeMem >= MILLION * 1024) {
*myBufSize = MILLION * 256;
if (freeMem >= MILLION * 2048)
*myBufSize = MILLION * 512;
}
}
}
}
}
/* initialize it and set the global memory information */
void XMemManager::Initialize()
{
srand((unsigned int)time(NULL));
Free();
/* CPUs (we actually do not care about how many CPUs are using) */
nCPUMem = 1;
MTYPE freeMem = GetAvailableMemory();
MTYPE myBufSize = 0;
GetBufferSize(freeMem, &myBufSize);
CPUMems[0].Initialize(-1, UNI_FREE, MIN_BLOCK_SIZE_FOR_MEMPOOL, MIN_BLOCK_NUM_FOR_MEMPOOL, myBufSize);
/* GPUs */
nGPUMem = 0;
#ifdef USE_CUDA
if (cudaGetDeviceCount(&nGPUMem) != cudaSuccess) {
XPRINT(0, stderr, "cannot get GPU information.");
exit(1);
}
for (int i = 0; i < nGPUMem; i++) {
MTYPE freeMem = GetAvailableGPUMemory(i);
MTYPE myBufSize = 0;
GetBufferSize(freeMem, &myBufSize);
GPUMems[i].Initialize(i, UNI_FREE, MIN_BLOCK_SIZE_FOR_MEMPOOL, MIN_BLOCK_NUM_FOR_MEMPOOL, myBufSize);
}
#endif
}
/* free it */
void XMemManager::Free()
{
for (int i = 0; i < MAX_CPU_NUM; i++)
CPUMems[i].Free();
for (int i = 0; i < MAX_GPU_NUM; i++)
GPUMems[i].Free();
}
/* get global memory pool */
XMem * XMemManager::GetMem(const int devID)
{
XMem * mem = NULL;
if (devID < 0)
mem = CPUMems;
else{
if (devID < nGPUMem)
mem = GPUMems + devID;
else
XPRINT1(0, stderr, "Cannot get the memory (%d). Please check your device id!", devID);
}
return mem;
}
/* get global memory size */
int XMemManager::GetMemSize(const int devID, MTYPE * myBlockSize, int * myBlockNum, MTYPE * myBufSize)
{
XMem * mem = GetMem(devID);
int result = 0;
if (mem != NULL){
*myBlockSize = mem->maxBlockSize;
*myBlockNum = mem->blockNum;
*myBufSize = mem->bufSize;
result = 1;
}
return result;
}
/* show memory information */
void XMemManager::ShowMemInfo()
{
XPRINT(1, stderr, "Memory Information:\n");
MTYPE myBlockSize, myBufSize;
int myBlockNum;
for(int i = 0; i < nCPUMem; i++){
GetMemSize(-1, &myBlockSize, &myBlockNum, &myBufSize);
XPRINT3(1, stderr, " - id:-1 CPU, blockSize:%d, blockNum:%d, bufSize:%d\n", myBlockSize, myBlockNum, myBufSize);
}
for(int i = 0; i < nGPUMem; i++){
GetMemSize(i, &myBlockSize, &myBlockNum, &myBufSize);
XPRINT4(1, stderr, " - id:%2d GPU, blockSize:%d, blockNum:%d, bufSize:%d\n", i, myBlockSize, myBlockNum, myBufSize);
}
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -39,6 +39,12 @@
#include <curand.h>
#endif
#ifndef WIN32
#include <unistd.h>
#else
#include <windows.h>
#endif
/* the nts (NiuTrans.Tensor) namespace */
namespace nts{
......@@ -53,6 +59,8 @@ typedef long long INT_64;
#define BUF_PITCH 256
#define MIN_BLOCK_SIZE_FOR_MEMPOOL 128 * 1024 * 1024
#define MIN_BLOCK_NUM_FOR_MEMPOOL 1024
#define MAX_CPU_NUM 16
#define MAX_GPU_NUM 16
/*
mode of runnig a memory pool
......@@ -413,6 +421,61 @@ public:
};
/*
a class for the management of memory
*/
class XMemManager
{
public:
/* cpu memory pool information */
XMem CPUMems[MAX_CPU_NUM];
/* number of cpu memory pools */
int nCPUMem;
/* gpu memory pool information */
XMem GPUMems[MAX_GPU_NUM];
/* number of gpu memory pools */
int nGPUMem;
public:
/* constructor */
XMemManager();
/* de-constructor */
~XMemManager();
/* get memory size */
MTYPE GetAvailableMemory();
/* get GPU memory size */
MTYPE GetAvailableGPUMemory(int devID);
/* get buffer size */
void GetBufferSize(MTYPE freeMem, MTYPE * myBufSize);
/* initialize it and set the global memory information */
void Initialize();
/* free it */
void Free();
/* get global memory pool */
XMem * GetMem(const int devID);
/* get global memory size */
int GetMemSize(const int devID, MTYPE * myBlockSize, int * myBlockNum, MTYPE * myBufSize);
/* show memory information */
void ShowMemInfo();
};
/* managing the memories */
extern XMemManager GMems;
extern XMem * GMem;
extern int testxmemid;
......
......@@ -77,6 +77,12 @@ const char * GetOPName(int type)
return "M_POWER";
else if (type == MATH_SCALEANDSHIFT)
return "M_SCALEANDSHIFT";
else if (type == MATH_SCALE)
return "M_SCALE";
else if (type == MATH_DESCALE)
return "M_DESCALE";
else if (type == MATH_SHIFT)
return "M_SHIFT";
else if (type == MATH_MULANDSHIFT)
return "M_OPERATION";
else if (type == MATH_SIGN)
......@@ -111,6 +117,8 @@ const char * GetOPName(int type)
return "M_COPYVALUES";
else if (type == MOVEMENT_GATHER)
return "M_GATHER";
else if (type == MOVEMENT_DROPOUTWITHINDEX)
return "M_DROPOUTWITHINDEX";
else if (type == SHAPE_CONCATENATE)
return "S_CONCATENATE";
else if (type == SHAPE_MERGE)
......@@ -152,6 +160,10 @@ const char * GetOPName(int type)
else if (type == FUNC_SOFTMAX)
return "F_SOFTMAX";
}
else if ((type & LOSS_BASE) != 0) {
if (type == LOSS_CROSSENTROPY)
return "L_CROSSENTROPY";
}
return "NULL";
}
......
......@@ -58,7 +58,11 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
#define MATH_MULANDSHIFT MATH_SCALEANDSHIFT + 1
#define MATH_SIGN MATH_MULANDSHIFT + 1
#define MATH_SCALE MATH_MULANDSHIFT + 1
#define MATH_DESCALE MATH_SCALE + 1
#define MATH_SHIFT MATH_DESCALE + 1
#define MATH_MOD MATH_SHIFT + 1
#define MATH_SIGN MATH_MOD + 1
#define MATH_SUB MATH_SIGN + 1
#define MATH_SUBDIM MATH_SUB + 1
#define MATH_SUM MATH_SUBDIM + 1
......@@ -81,8 +85,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MOVEMENT_COPYINDEXED MOVEMENT + 1
#define MOVEMENT_COPYVALUES MOVEMENT_COPYINDEXED + 1
#define MOVEMENT_GATHER MOVEMENT_COPYVALUES + 1
#define MOVEMENT_DROPOUTWITHINDEX MOVEMENT_GATHER + 1
#define SHAPE MOVEMENT_GATHER + 1
#define SHAPE MOVEMENT_DROPOUTWITHINDEX + 1
#define SHAPE_CONCATENATE SHAPE + 1
#define SHAPE_MERGE SHAPE_CONCATENATE + 1
#define SHAPE_MERGE_LIST SHAPE_MERGE + 1
......@@ -108,6 +113,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define FUNC_SIGMOID FUNC_RECTIFY + 1
#define FUNC_SOFTMAX FUNC_SIGMOID + 1
#define LOSS_BASE FUNCTION_BASE * 2
#define LOSS_CROSSENTROPY LOSS_BASE + 1
/* get operator name */
const char * GetOPName(int type);
......
/* NiuTrans.Tensor - an open-source tensor library
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
*
* This is an implementation of queue. Actually we intend to use it to maintain
* a priority job list
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2017-04-05
*
*/
#include <stdio.h>
#include <stdlib.h>
#include "XQueue.h"
#include "XDevice.h"
#include "XList.h"
#include "XUtility.h"
/* the nts (NiuTrans.Tensor) namespace */
namespace nts{
/**************************************
job item used in queues
*/
/* constructor */
JobQueueNode::JobQueueNode()
{
job = NULL;
args = new TensorList(1);
}
/* de-constructor */
JobQueueNode::~JobQueueNode()
{
delete args;
}
/**************************************
This class provides standard utilities of Queue.
*/
/* constuctor */
XQueue::XQueue(int mySize)
{
queue = new void*[mySize];
memset(queue, 0, sizeof(void*) * mySize);
size = mySize;
itemCount = 0;
head = 0;
tail = 0;
isJobQueue = false;
jobDequeuerArgs = new TensorList(1);
jobDequeuerBreak = false;
runningJobCount = 0;
jobStream = NULL;
jobStream1 = NULL;
jobStream2 = NULL;
MUTEX_INIT(enqueueMutex);
MUTEX_INIT(dequeueMutex);
COND_INIT(queueCond);
MUTEX_INIT(jobQueueMutex);
}
/* deconstructor */
XQueue::~XQueue()
{
delete[] queue;
delete jobDequeuerArgs;
delete jobStream;
delete jobStream1;
delete jobStream2;
//if(isJobQueue)
// StopJobConsumer();
MUTEX_DELE(enqueueMutex);
MUTEX_DELE(dequeueMutex);
COND_DELE(queueCond);
MUTEX_DELE(jobQueueMutex);
}
/*
put an item in the tail of the queue
>> item - the item we intend to add into the queue
*/
void XQueue::Enqueue(void * item)
{
MUTEX_LOCK(enqueueMutex);
MUTEX_LOCK(dequeueMutex);
CheckNTErrors((itemCount < size), "Put too many items into the queue!");
queue[tail] = item;
tail = (tail + 1) % size;
itemCount++;
COND_SIGNAL(queueCond);
MUTEX_UNLOCK(dequeueMutex);
MUTEX_UNLOCK(enqueueMutex);
}
/*
fetch an item from head of the queue
<< return - the head item of the queue
*/
void * XQueue::Dequeue()
{
MUTEX_LOCK(dequeueMutex);
while(itemCount == 0)
{
#ifdef WIN32
MUTEX_UNLOCK(dequeueMutex);
#endif
COND_WAIT(queueCond, dequeueMutex);
#ifdef WIN32
MUTEX_LOCK(dequeueMutex);
#endif
}
void * r = queue[head];
head = (head + 1) % size;
itemCount--;
MUTEX_UNLOCK(dequeueMutex);
return r;
}
/* return if the queue is empty */
bool XQueue::IsEmpty()
{
return itemCount == 0;
}
/* wait until the queue is empty */
void XQueue::WaitForEmptyJobQueue()
{
while(runningJobCount > 0){
XSleep(10);
}
if(jobStream != NULL){
CheckNTErrors((jobStream->IsFinished()), "None fineished jobs remain");
jobStream->Clear();
}
if(jobStream1 != NULL){
CheckNTErrors((jobStream1->IsFinished()), "None fineished jobs remain");
jobStream1->Clear();
}
if(jobStream2 != NULL){
CheckNTErrors((jobStream2->IsFinished()), "None fineished jobs remain");
jobStream2->Clear();
}
}
int devids[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
int cpuid = -1;
/*
run job consumer (in another thread)
>> jobDevID - id of the device for running the jobs
*/
void XQueue::RunJobConsumer(int jobDevID)
{
CheckNTErrors((jobDevID < 16), "device id is out of scope!");
isJobQueue = true;
jobDequeuerArgs->Clear();
jobDequeuerArgs->Add(this);
jobDequeuerArgs->Add(jobDevID >= 0 ? devids + jobDevID : &cpuid);
jobDequeuer.function = (TFunction)DequeueJobs;
jobDequeuer.argv = jobDequeuerArgs;
jobDequeuer.Start();
jobDequeuer.LetItGo();
}
/* stop the job consumer */
void XQueue::StopJobConsumer()
{
jobDequeuerBreak = true;
XSleep(10);
EnqueueJob(NULL, NULL);
jobDequeuer.End();
isJobQueue = false;
}
/* add a job item to process */
void XQueue::EnqueueJob(void * job, TensorList * jobArgs)
{
MUTEX_LOCK(jobQueueMutex);
runningJobCount++;
MUTEX_UNLOCK(jobQueueMutex);
JobQueueNode * node = new JobQueueNode();
node->job = job;
if(jobArgs != NULL)
node->args->AddList(jobArgs);
Enqueue(node);
}
/* job item consumer */
void XQueue::DequeueJobs(TensorList * args)
{
CheckNTErrors((args->count == 2), "Illegal arguments!");
XQueue * q = (XQueue*)args->GetItem(0);
int devID = *(int*)args->GetItem(1);
int devIDBackup = XDevice::GetGPUDevice();
if(devID >= 0)
XDevice::SetGPUDevice(devID);
while(1){
JobQueueNode * node = (JobQueueNode*)q->Dequeue();
if(q->GetJobBreak())
break;
CheckNTErrors((node != NULL), "Illegal job!");
/* process a job */
((TFunction)node->job)(node->args);
delete node;
MUTEX_LOCK(q->jobQueueMutex);
q->runningJobCount--;
MUTEX_UNLOCK(q->jobQueueMutex);
}
if(devID >= 0)
XDevice::SetGPUDevice(devIDBackup);
}
/* get the break flag */
bool XQueue::GetJobBreak()
{
return jobDequeuerBreak;
}
/* get job stream */
XStream * XQueue::GetJobStream(int n)
{
if(n == 0)
return jobStream;
else if(n == 1)
return jobStream1;
else if(n == 2)
return jobStream2;
else{
ShowNTErrors("invalid stream id!");
}
return NULL;
}
/* make job streams */
void XQueue::MakeJobStreams(int devID, int devID1, int devID2)
{
if(devID != INVALID_DEVICE_ID)
jobStream = new XStream(0, devID);
if(devID1 != INVALID_DEVICE_ID)
jobStream1 = new XStream(0, devID1);
if(devID2 != INVALID_DEVICE_ID)
jobStream2 = new XStream(0, devID2);
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -171,7 +171,7 @@ XTensor::XTensor(const XTensor &reference)
As "reference" is constant, we cannot reset reference.data
here. So we save the ADDRESS of reference.data in
reference.dataP, and do this work by updating "*reference.dataP".
This is VERY trick and might not be the best solution :) */
This is VERY tricky and might not be the best solution :) */
*reference.dataP = NULL;
}
else{
......@@ -190,10 +190,10 @@ XTensor::XTensor(const XTensor &reference)
isInit = true;
isTmp = reference.isTmp;
enableGrad = reference.enableGrad;
}
/* copy constructor (with right value reference) */
#ifdef USE_CPP11
XTensor::XTensor(const XTensor &&reference)
{
Init();
......@@ -212,15 +212,15 @@ XTensor::XTensor(const XTensor &&reference)
As "reference" is constant, we cannot reset reference.data
here. So we save the ADDRESS of reference.data in
reference.dataP, and do this work by updating "*reference.dataP".
This is VERY trick and might not be the best solution :) */
This is VERY tricky and might not be the best solution :) */
*reference.dataP = NULL;
XLink::Replace(&reference, this);
isInit = true;
isTmp = reference.isTmp;
enableGrad = reference.enableGrad;
}
#endif
/* de-constructor */
XTensor::~XTensor()
......@@ -285,6 +285,7 @@ void XTensor::Init()
isTmp = false;
isGrad = false;
isVar = false;
enableGrad = false;
visitMark = 0;
grad = NULL;
}
......@@ -402,7 +403,7 @@ XTensor& XTensor::operator= (const XTensor& tensor)
/* create tensor links for the new tensor */
XLink::Replace(&tensor, this);
}
enableGrad = tensor.enableGrad;
return *this;
}
......@@ -445,11 +446,11 @@ XTensor& XTensor::operator= (const XTensor&& tensor)
As "reference" is constant, we cannot reset reference.data
here. So we save the ADDRESS of reference.data in
reference.dataP, and do this work by updating "*reference.dataP".
This is VERY trick and might not be the best solution :) */
This is VERY tricky and might not be the best solution :) */
*tensor.dataP = NULL;
XLink::Replace(&tensor, this);
enableGrad = tensor.enableGrad;
return *this;
}
......@@ -520,6 +521,7 @@ void XTensor::SetDevice(int myDevId, XMem * myMem)
{
if (myMem != NULL) {
FlushToMem(myMem);
isInGlobalMem = false;
}
else {
ShowNTErrors("TODO!");
......@@ -557,6 +559,37 @@ bool XTensor::IsSameShaped(const XTensor * a, const XTensor * b)
return true;
}
bool XTensor::IsReduceShaped(const XTensor * a, const XTensor * b, int dim)
{
if (a == NULL || b == NULL)
return false;
if ((a->order - 1) != b->order)
return false;
for (int i = 0; i < b->order; i++) {
if (i < dim) {
if (a->dimSize[i] != b->dimSize[i])
return false;
}
else if (i >= dim) {
if (a->dimSize[i+1] != b->dimSize[i])
return false;
}
}
if(a->dataType != b->dataType)
return false;
if(a->denseRatio != b->denseRatio)
return false;
if(a->isSparse != b->isSparse)
return false;
return true;
}
/*
judge whether the three matrices are in the same type and size
>> a - input tensor
......@@ -642,6 +675,33 @@ void XTensor::Reshape(const int rowNum, const int colNum)
Reshape(2, dims);
}
/*
reshape the tensor by merging two consecutive dimensions
>> i - dimension i
>> j - i + 1
*/
void XTensor::ReshapeMerged(const int i, const int j)
{
if (i < 0)
return;
int di = i;
int dj = j < 0 ? i + 1 : j;
CheckNTErrors(di < order, "Wrong dimension index!");
int dims[MAX_TENSOR_DIM_NUM];
for (int k = 0; k < di; k++)
dims[k] = dimSize[k];
dims[di] = dimSize[di] * dimSize[dj];
for (int k = dj + 1; k < order; k++)
dims[k - 1] = dimSize[k];
Reshape(order - 1, dims);
}
/* get the number of items in the data array */
int XTensor::GetSize() const
{
......@@ -1268,6 +1328,21 @@ bool XTensor::Set(DTYPE value, int index[], int size)
}
/*
set the value of a cell with its offset in the array
>> value - the value we intend to set
>> offset - the offset in the array
*/
bool XTensor::Set(DTYPE value, int offset)
{
CheckNTErrors(offset >= 0 && offset < unitNum, "Invalid index!");
CheckNTErrors(data != NULL, "Cannot use an uninitialized tensor!");
DTYPE * d = (DTYPE*)data + offset;
return SetToDevice(devID, d, value);
}
/*
set the value of a cell in a 1d tensor
>> value - value we tend to set
>> i - item offset
......@@ -2086,6 +2161,48 @@ void InitTensor(XTensor * tensor,
}
/*
initialize a dense tensor V2
>> tensor - the tensor we intend to initialize
>> myOrder - order of the tensor
>> myDimSize - the size of each dimension
>> myDataType - unit size (e.g., int, float, and double)
>> myDenseRatio - how often an element has non-zero value
>> myDevID - when myMem is NULL, myDevID specifies the device
on which we allocate the data on site
*/
void InitTensorV2(XTensor * tensor,
const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType,
const int myDevID)
{
if(tensor->mem != NULL){
tensor->Resize(myOrder, myDimSize, myDataType, 1.0F);
}
else{
int dims[MAX_TENSOR_DIM_NUM];
memcpy(dims, myDimSize, sizeof(int) * myOrder);
bool allocated = true;
for (int i = 0; i < myOrder; i++) {
if (dims[i] < 0)
allocated = false;
}
dims[0] = -abs(dims[0]);
if (myDevID == CURRENT_GPU)
tensor->devID = XDevice::GetGPUDevice();
else
tensor->devID = myDevID;
tensor->Resize(myOrder, dims, myDataType, 1.0F);
if(allocated)
XTensor::AllocateData(tensor);
}
}
/*
initialize a dense tensor
>> tensor - the tensor we intend to initialize
>> num - number of elements
......@@ -2107,6 +2224,24 @@ void InitTensor1D(XTensor * tensor, const int num,
}
/*
initialize a dense tensor V2
>> tensor - the tensor we intend to initialize
>> num - number of elements
>> myDataType - unit size (e.g., int, float, and double)
>> myDevID - when myMem is NULL, myDevID specifies the device
on which we allocate the data on site
*/
void InitTensor1DV2(XTensor * tensor, const int num,
const TENSOR_DATA_TYPE myDataType, const int myDevID)
{
int dims[1];
dims[0] = num;
InitTensorV2(tensor, 1, dims, myDataType, myDevID);
}
/*
initialize a dense matrix
>> tensor - the tensor we intend to initialize
>> rowNum - number of rows
......@@ -2130,6 +2265,26 @@ void InitTensor2D(XTensor * tensor, const int rowNum, const int colNum,
}
/*
initialize a dense matrix V2
>> tensor - the tensor we intend to initialize
>> rowNum - number of rows
>> colNum - number of columns
>> myDataType - unit size (e.g., int, float, and double)
>> myDevID - when myMem is NULL, myDevID specifies the device
on which we allocate the data on site
*/
void InitTensor2DV2(XTensor * tensor, const int rowNum, const int colNum,
const TENSOR_DATA_TYPE myDataType, const int myDevID)
{
int dims[2];
dims[0] = rowNum;
dims[1] = colNum;
InitTensorV2(tensor, 2, dims, myDataType, myDevID);
}
/*
initialize a dense 3d tensor
>> tensor - the tensor we intend to initialize
>> d0 - size of dimension 0
......@@ -2155,6 +2310,28 @@ void InitTensor3D(XTensor * tensor, const int d0, const int d1, const int d2,
}
/*
initialize a dense 3d tensor V2
>> tensor - the tensor we intend to initialize
>> d0 - size of dimension 0
>> d1 - size of dimension 1
>> d2 - size of dimension 2
>> myDataType - unit size (e.g., int, float, and double)
>> myDevID - when myMem is NULL, myDevID specifies the device
on which we allocate the data on site
*/
void InitTensor3DV2(XTensor * tensor, const int d0, const int d1, const int d2,
const TENSOR_DATA_TYPE myDataType, const int myDevID)
{
int dims[3];
dims[0] = d0;
dims[1] = d1;
dims[2] = d2;
InitTensorV2(tensor, 3, dims, myDataType, myDevID);
}
/*
initialize a dense 4d tensor
>> tensor - the tensor we intend to initialize
>> d0 - size of dimension 0
......@@ -2182,6 +2359,30 @@ void InitTensor4D(XTensor * tensor, const int d0, const int d1, const int d2, co
}
/*
initialize a dense 4d tensor V2
>> tensor - the tensor we intend to initialize
>> d0 - size of dimension 0
>> d1 - size of dimension 1
>> d2 - size of dimension 2
>> d3 - size of dimension 3
>> myDataType - unit size (e.g., int, float, and double)
>> myDevID - when myMem is NULL, myDevID specifies the device
on which we allocate the data on site
*/
void InitTensor4DV2(XTensor * tensor, const int d0, const int d1, const int d2, const int d3,
const TENSOR_DATA_TYPE myDataType, const int myDevID)
{
int dims[4];
dims[0] = d0;
dims[1] = d1;
dims[2] = d2;
dims[3] = d3;
InitTensorV2(tensor, 4, dims, myDataType, myDevID);
}
/*
initialize a dense 5d tensor
>> tensor - the tensor we intend to initialize
>> d0 - size of dimension 0
......@@ -2211,6 +2412,32 @@ void InitTensor5D(XTensor * tensor, const int d0, const int d1, const int d2, co
}
/*
initialize a dense 5d tensor V2
>> tensor - the tensor we intend to initialize
>> d0 - size of dimension 0
>> d1 - size of dimension 1
>> d2 - size of dimension 2
>> d3 - size of dimension 3
>> d4 - size of dimension 4
>> myDataType - unit size (e.g., int, float, and double)
>> myDevID - when myMem is NULL, myDevID specifies the device
on which we allocate the data on site
*/
void InitTensor5DV2(XTensor * tensor, const int d0, const int d1, const int d2, const int d3, const int d4,
const TENSOR_DATA_TYPE myDataType, const int myDevID)
{
int dims[5];
dims[0] = d0;
dims[1] = d1;
dims[2] = d2;
dims[3] = d3;
dims[4] = d4;
InitTensorV2(tensor, 5, dims, myDataType, myDevID);
}
/*
initialize a tensor with a reference tensor
>> tensor - the tensor we intend to initialize
>> reference - the reference tensor
......@@ -2220,12 +2447,28 @@ void InitTensor(XTensor * tensor, const XTensor * reference)
if(reference->order < 0)
return;
tensor->enableGrad = reference->enableGrad;
InitTensor(tensor, reference->order, reference->dimSize,
reference->dataType, reference->denseRatio,
reference->devID, reference->mem);
}
/*
initialize a tensor with a reference tensor V2
>> tensor - the tensor we intend to initialize
>> reference - the reference tensor
*/
void InitTensorV2(XTensor * tensor, const XTensor * reference)
{
if(reference->order < 0)
return;
tensor->enableGrad = reference->enableGrad;
InitTensorV2(tensor, reference->order, reference->dimSize,
reference->dataType, reference->devID);
}
/*
initialize a tensor on the CPU with a reference tensor
>> tensor - the tensor we intend to initialize
>> reference - the reference tensor
......@@ -2235,6 +2478,7 @@ void InitTensorOnCPU(XTensor * tensor, const XTensor * reference)
if(reference->order < 0)
return;
tensor->enableGrad = reference->enableGrad;
InitTensor(tensor, reference->order, reference->dimSize,
reference->dataType, reference->denseRatio,
-1);
......@@ -2273,6 +2517,23 @@ XTensor * NewTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_
}
/*
generate a dense XTensor V2
>> myOrder - order of the tensor
>> myDimSize - the size of each dimension
>> myDataType - unit size (e.g., int, float, and double)
>> myDenseRatio - how often an element has non-zero value
>> myDevID - when myMem is NULL, myDevID specifies the device
on which we allocate the data on site.
*/
XTensor * NewTensorV2(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType,
const int myDevID)
{
XMem * myMem = GMems.GetMem(myDevID);
return new XTensor(myOrder, myDimSize, myDataType, 1.0F, myDevID, myMem);
}
/*
generate a XTensor which allocates data on the buffer
>> myOrder - order of the tensor
>> myDimSize - the size of each dimension
......@@ -2307,6 +2568,35 @@ XTensor * NewTensorBuf(const int myOrder, const int * myDimSize,
}
/*
generate a dense XTensor which allocates data on the buffer V2
>> myOrder - order of the tensor
>> myDimSize - the size of each dimension
>> devID - device id
>> myDataType - unit size (e.g., int, float, and double)
>> myDenseRatio - how often an element has non-zero value
*/
XTensor * NewTensorBufV2(const int myOrder, const int * myDimSize,
const TENSOR_DATA_TYPE myDataType, const int devID)
{
int dims[MAX_TENSOR_DIM_NUM];
memcpy(dims, myDimSize, sizeof(int) * myOrder);
dims[0] = -abs(dims[0]);
XTensor * tensor = NewTensor(myOrder, dims, myDataType, 1.0F, devID);
if (tensor->unitNum * tensor->unitSize == 176657664) {
tensor->Dump(stderr, "", 200);
}
XMem * myMem = GMems.GetMem(devID);
tensor->data = myMem->AllocBuf(myMem->devID, tensor->unitNum * tensor->unitSize);
return tensor;
}
/*
generate a XTensor which allocates data on the buffer
>> reference - reference tensor
>> devID - device id
......@@ -2322,6 +2612,17 @@ XTensor * NewTensorBuf(const XTensor * reference, int devID, XMem * myMem)
}
/*
generate a XTensor which allocates data on the buffer V2
>> reference - reference tensor
>> devID - device id
*/
XTensor * NewTensorBufV2(const XTensor * reference, int devID)
{
return NewTensorBufV2(reference->order, reference->dimSize,
reference->dataType, devID);
}
/*
generate a dense vector
>> num - number of entries
>> myDataType - unit size (e.g., int, float, and double)
......@@ -2342,6 +2643,23 @@ XTensor * NewTensor1D(const int num,
}
/*
generate a dense vector V2
>> num - number of entries
>> myDataType - unit size (e.g., int, float, and double)
>> myDevID - when myMem is NULL, myDevID specifies the device
on which we allocate the data on site.
*/
XTensor * NewTensor1DV2(const int num,
const TENSOR_DATA_TYPE myDataType, const int myDevID)
{
int dims[1];
dims[0] = num;
return NewTensorV2(1, dims, myDataType, myDevID);
}
/*
generate a dense matrix
>> rowNum - number of rows
>> colNum - number of colums
......@@ -2364,6 +2682,25 @@ XTensor * NewTensor2D(const int rowNum, const int colNum,
}
/*
generate a dense matrix V2
>> rowNum - number of rows
>> colNum - number of colums
>> myDataType - unit size (e.g., int, float, and double)
>> myDevID - when myMem is NULL, myDevID specifies the device
on which we allocate the data on site.
*/
XTensor * NewTensor2DV2(const int rowNum, const int colNum,
const TENSOR_DATA_TYPE myDataType, const int myDevID)
{
int dims[2];
dims[0] = rowNum;
dims[1] = colNum;
return NewTensorV2(2, dims, myDataType, myDevID);
}
/*
generate a dense 3d tensor
>> d0 - size of dimension 0
>> d1 - size of dimension 1
......@@ -2388,6 +2725,27 @@ XTensor * NewTensor3D(const int d0, const int d1, const int d2,
}
/*
generate a dense 3d tensor V2
>> d0 - size of dimension 0
>> d1 - size of dimension 1
>> d2 - size of dimension 2
>> myDataType - unit size (e.g., int, float, and double)
>> myDevID - when myMem is NULL, myDevID specifies the device
on which we allocate the data on site.
*/
XTensor * NewTensor3DV2(const int d0, const int d1, const int d2,
const TENSOR_DATA_TYPE myDataType, const int myDevID)
{
int dims[3];
dims[0] = d0;
dims[1] = d1;
dims[2] = d2;
return NewTensorV2(3, dims, myDataType, myDevID);
}
/*
generate a dense 4d tensor
>> d0 - size of dimension 0
>> d1 - size of dimension 1
......@@ -2414,6 +2772,29 @@ XTensor * NewTensor4D(const int d0, const int d1, const int d2, const int d3,
}
/*
generate a dense 4d tensor V2
>> d0 - size of dimension 0
>> d1 - size of dimension 1
>> d2 - size of dimension 2
>> d3 - size of dimension 3
>> myDataType - unit size (e.g., int, float, and double)
>> myDevID - when myMem is NULL, myDevID specifies the device
on which we allocate the data on site.
*/
XTensor * NewTensor4DV2(const int d0, const int d1, const int d2, const int d3,
const TENSOR_DATA_TYPE myDataType, const int myDevID)
{
int dims[4];
dims[0] = d0;
dims[1] = d1;
dims[2] = d2;
dims[3] = d3;
return NewTensorV2(4, dims, myDataType, myDevID);
}
/*
generate a dense 5d tensor
>> d0 - size of dimension 0
>> d1 - size of dimension 1
......@@ -2442,6 +2823,31 @@ XTensor * NewTensor5D(const int d0, const int d1, const int d2, const int d3, co
}
/*
generate a dense 5d tensor V2
>> d0 - size of dimension 0
>> d1 - size of dimension 1
>> d2 - size of dimension 2
>> d3 - size of dimension 3
>> d4 - size of dimension 4
>> myDataType - unit size (e.g., int, float, and double)
>> myDevID - when myMem is NULL, myDevID specifies the device
on which we allocate the data on site.
*/
XTensor * NewTensor5DV2(const int d0, const int d1, const int d2, const int d3, const int d4,
const TENSOR_DATA_TYPE myDataType, const int myDevID)
{
int dims[5];
dims[0] = d0;
dims[1] = d1;
dims[2] = d2;
dims[3] = d3;
dims[4] = d4;
return NewTensorV2(5, dims, myDataType, myDevID);
}
/*
generate a copy of XTensor
>> a - the tensor we copy from
>> isFilledData - indicates whether we allocate the data for
......
......@@ -151,6 +151,9 @@ public:
/* indicates whether the tensor keeps the gradient when used as model parameters */
bool isGrad;
/* indicates whether the gradient of the tensor should be computed */
bool enableGrad;
/* indicates whether the tensor is used as paramters (or variables) */
bool isVar;
......@@ -194,9 +197,7 @@ public:
XTensor(const XTensor &reference);
/* copy constructor (with right value reference) */
#ifdef USE_CPP11
XTensor(const XTensor &&reference);
#endif
/* de-constructor */
~XTensor();
......@@ -217,9 +218,7 @@ public:
XTensor& operator= (const XTensor &tensor);
/* overloading of the equal-sign (with right value reference) */
#ifdef USE_CPP11
XTensor& operator= (const XTensor &&tensor);
#endif
/* overloading of the plus-sign */
XTensor operator+ (const XTensor &tensor) const;
......@@ -259,6 +258,10 @@ public:
static
bool IsSameShaped(const XTensor * a, const XTensor * b, const XTensor * c);
/* judge whether b is the reduced shape of a ?? */
static
bool IsReduceShaped(const XTensor * a, const XTensor * b, int dim);
/* set the size of each dimension */
void SetDim(int * myDimSize);
......@@ -274,6 +277,9 @@ public:
/* reshape the tensor to a matrix */
void Reshape(const int rowNum, const int colNum);
/* reshape the tensor by merging two consecutive dimensions */
void ReshapeMerged(const int i, const int j = -1);
/* get the number of items in the data array */
int GetSize() const;
......@@ -358,6 +364,9 @@ public:
/* set the value of a cell */
bool Set(DTYPE value, int index[], int size = -1);
/* set the value of a cell with its offset in the array */
bool Set(DTYPE value, int offset);
/* set the value of a cell in a 1d tensor */
bool Set1D(DTYPE value, int i);
......@@ -445,29 +454,57 @@ void InitTensor(XTensor * tensor,
const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType = X_FLOAT,
const float myDenseRatio = 1.0F, const int myDevID = -1, XMem * myMem = NULL);
/* initialize a dense XTensor V2 */
void InitTensorV2(XTensor * tensor,
const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType = X_FLOAT,
const int myDevID = -1);
/* initialize a dense vector */
void InitTensor1D(XTensor * tensor, const int num,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1, XMem * myMem = NULL);
/* initialize a dense vector V2 */
void InitTensor1DV2(XTensor * tensor, const int num,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1);
/* initialize a dense matrix */
void InitTensor2D(XTensor * tensor, const int rowNum, const int colNum,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1, XMem * myMem = NULL);
/* initialize a dense matrix V2 */
void InitTensor2DV2(XTensor * tensor, const int rowNum, const int colNum,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1);
/* initialize a dense 3d tensor */
void InitTensor3D(XTensor * tensor, const int d0, const int d1, const int d2,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1, XMem * myMem = NULL);
/* initialize a dense 3d tensor V2 */
void InitTensor3DV2(XTensor * tensor, const int d0, const int d1, const int d2,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1);
/* initialize a dense 4d tensor */
void InitTensor4D(XTensor * tensor, const int d0, const int d1, const int d2, const int d3,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1, XMem * myMem = NULL);
/* initialize a dense 4d tensor V2 */
void InitTensor4DV2(XTensor * tensor, const int d0, const int d1, const int d2, const int d3,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1);
/* initialize a dense 5d tensor */
void InitTensor5D(XTensor * tensor, const int d0, const int d1, const int d2, const int d3, const int d4,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1, XMem * myMem = NULL);
/* initialize a dense 5d tensor V2 */
void InitTensor5DV2(XTensor * tensor, const int d0, const int d1, const int d2, const int d3, const int d4,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1);
/* initialize a tensor with a reference tensor */
void InitTensor(XTensor * tensor, const XTensor * reference);
/* initialize a tensor with a reference tensor */
void InitTensorV2(XTensor * tensor, const XTensor * reference);
/* initialize a tensor on the CPU with a reference tensor */
void InitTensorOnCPU(XTensor * tensor, const XTensor * reference);
......@@ -478,38 +515,72 @@ XTensor * NewTensor();
XTensor * NewTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType = X_FLOAT,
const float myDenseRatio = 1.0F, const int myDevID = -1, XMem * myMem = NULL);
/* generate a dense XTensor V2 */
XTensor * NewTensorV2(const int myOrder, const int * myDimSize, const TENSOR_DATA_TYPE myDataType = X_FLOAT,
const int myDevID = -1);
/* generate a XTensor which allocates data on the buffer */
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const float myDenseRatio = 1.0F,
const int myDevID = -1, XMem * myMem = NULL);
/* generate a dense XTensor which allocates data on the buffer V2 */
XTensor * NewTensorBufV2(const int myOrder, const int * myDimSize,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1);
/* generate a XTensor which allocates data on the buffer */
XTensor * NewTensorBuf(const XTensor * reference, int devID, XMem * myMem);
/* generate a XTensor which allocates data on the buffer V2 */
XTensor * NewTensorBufV2(const XTensor * reference, int devID);
/* generate a dense vector */
XTensor * NewTensor1D(const int num, const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1,
XMem * myMem = NULL);
/* generate a dense vector V2 */
XTensor * NewTensor1DV2(const int num, const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1);
/* generate a dense matrix */
XTensor * NewTensor2D(const int rowNum, const int colNum,
const TENSOR_DATA_TYPE myDataType = X_FLOAT,
const int myDevID = -1, XMem * myMem = NULL);
/* generate a dense matrix V2 */
XTensor * NewTensor2DV2(const int rowNum, const int colNum,
const TENSOR_DATA_TYPE myDataType = X_FLOAT,
const int myDevID = -1);
/* generate a dense 3d tensor */
XTensor * NewTensor3D(const int d0, const int d1, const int d2,
const TENSOR_DATA_TYPE myDataType = X_FLOAT,
const int myDevID = -1, XMem * myMem = NULL);
/* generate a dense 3d tensor V2 */
XTensor * NewTensor3DV2(const int d0, const int d1, const int d2,
const TENSOR_DATA_TYPE myDataType = X_FLOAT,
const int myDevID = -1);
/* generate a dense 4d tensor */
XTensor * NewTensor4D(const int d0, const int d1, const int d2, const int d3,
const TENSOR_DATA_TYPE myDataType = X_FLOAT,
const int myDevID = -1, XMem * myMem = NULL);
/* generate a dense 4d tensor V2 */
XTensor * NewTensor4DV2(const int d0, const int d1, const int d2, const int d3,
const TENSOR_DATA_TYPE myDataType = X_FLOAT,
const int myDevID = -1);
/* generate a dense 5d tensor */
XTensor * NewTensor5D(const int d0, const int d1, const int d2, const int d3, const int d4,
const TENSOR_DATA_TYPE myDataType = X_FLOAT,
const int myDevID = -1, XMem * myMem = NULL);
/* generate a dense 5d tensor V2 */
XTensor * NewTensor5DV2(const int d0, const int d1, const int d2, const int d3, const int d4,
const TENSOR_DATA_TYPE myDataType = X_FLOAT,
const int myDevID = -1);
/* generate a copy of XTensor (with a reference to a given tensor) */
XTensor * NewTensor(const XTensor * a, bool isFilledData = true);
......
......@@ -97,4 +97,5 @@
#include "utilities/XMatrixSegment.h"
#include "utilities/FlushToMem.h"
#include "../function/DropoutWithIndex.h"
#endif // __CHEADER_H__
......@@ -218,4 +218,55 @@ XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
return c;
}
/*
element-wise division of two tensors
c(i) = a(i)/b(i) + \alpha * c(i)
where i is the index of the item
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
>> requireLink - if add operation to network
*/
void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadingDim, bool requireLink)
{
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) {
InitTensor(&c, &a);
}
int n = GetDivDimIndex(a, b);
if (n == -1) {
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* call _Div function */
_Div(&a, &b, &c, 0, leadingDim);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
}
else if (n >= 0 && n < a.order) {
/* call _DivDim function */
_DivDim(&a, &b, &c, n, alpha);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
}
else {
ShowNTErrors("Something is wrong!");
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -49,6 +49,13 @@ where i is the index of the element
*/
XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha = 0.0, int leadingDim = 0);
/*
element-wise division of two tensors:
c(i) = a(i)/b(i) + \alpha * c(i)
where i is the index of the element
*/
void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha = 0.0, int leadingDim = 0, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __DIV_H__
\ No newline at end of file
......@@ -171,4 +171,36 @@ XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
return c;
}
/*
tensor division
c = a / b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put result. we save it in a if c is NULL
>> n - the dimension index
>> alpha - the scaling factor
>> requireLink - if add operation to network
*/
void DivDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE alpha, bool requireLink)
{
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) {
InitTensor(&c, &a);
}
/* call _Div function */
_DivDim(&a, &b, &c, n, alpha);
if (c.enableGrad == true) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
}
}
......@@ -53,6 +53,14 @@ we make a new tensor c to keep the result and return it
*/
XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha = (DTYPE)0.0);
/*
tensor division of two tensors:
c(i) = a/b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is divided with b by broadcasting
*/
void DivDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE alpha = (DTYPE)0.0, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __DIVDIM_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2019-04-24
* I'll attend several conferences and workshops in the following weeks -
* busy days :(
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "Mask.h"
#include "Mask.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
mask entries of a given tensor:
c(i) = a(i) if mask(i) is non-zero
c(i) = alpha if mask(i) = 0
where i is the index of the element
*/
void _Mask(const XTensor * a, const XTensor * mask, XTensor * c, DTYPE alpha)
{
CheckNTErrors(a && mask && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == mask->unitNum && a->unitNum == c->unitNum,
"Unmatched tensors in addition!");
CheckNTErrors(mask->dataType == X_INT, "The mask tensor must be in X_INT!")
//CheckNTErrors(a->dataType == mask->dataType && a->dataType == c->dataType,
// "Unmatched tensors in addition!");
if (a->devID >= 0 || mask->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
if (a == c) {
int P2PAccesible = 0;
#ifdef CUDA_UVA
cudaDeviceCanAccessPeer(&P2PAccesible, a->devID, b->devID);
#endif
if ((a->devID < 0 && mask->devID >= 0) ||
(a->devID >= 0 && mask->devID < 0) ||
(a->devID >= 0 && mask->devID >= 0 && a->devID != mask->devID && !P2PAccesible))
{
ShowNTErrors("Cannot run this method on multiple devices simultaneously!");
}
else
_CudaMask(a, mask, c, alpha);
}
else
_CudaMask(a, mask, c, alpha);
#endif
}
else {
if (!a->isSparse && !mask->isSparse) {
CheckNTErrors(!c->isSparse, "Illegal use of sparse tensor in addition!");
if (a->dataType == DEFAULT_DTYPE &&
mask->dataType == X_INT &&
c->dataType == DEFAULT_DTYPE)
{
DTYPE * ap = (DTYPE*)a->data;
int * maskp = (int*)mask->data;
DTYPE * cp = (DTYPE*)c->data;
/* unrolling */
int num = a->unitNum;
if (num % 2 == 0) {
for (int i = 0; i < num; i += 2) {
if (maskp[i] == 0) {
cp[i] = alpha;
}
else {
cp[i] = ap[i];
}
if (maskp[i + 1] == 0) {
cp[i + 1] = alpha;
}
else {
cp[i + 1] = ap[i + 1];
}
}
}
else {
for (int i = 0; i < num; i++) {
if (maskp[i] == 0) {
cp[i] = alpha;
}
else {
cp[i] = ap[i];
}
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
}
/*
mask entries of a given tensor (on site):
a(i) = a(i) if mask(i) is non-zero
a(i) = alpha if mask(i) = 0
where i is the index of the element
*/
void _MaskMe(XTensor * a, const XTensor * mask, DTYPE alpha)
{
_Mask(a, mask, a, alpha);
}
/*
mask entries of a given tensor (return an XTensor structure):
a(i) = a(i) if mask(i) is non-zero
a(i) = alpha if mask(i) = 0
where i is the index of the element
*/
XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha)
{
XTensor c(&a);
c.SetTMPFlag();
/* call _Sum function */
_Mask(&a, &mask, &c, alpha);
/* tensor connections */
//XLink::MakeLink(&a, &mask, &c, MATH_SUM);
//XLink::AddParamToHead(&c, alpha);
// TODO!!
ShowNTErrors("TODO!");
return c;
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2019-04-24
* I'll attend several conferences and workshops in the following weeks -
* busy days :(
*/
#include "../../XDevice.h"
#include "../../XUtility.h"
#include "Sub.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
mask entries of a given tensor (CUDA Kernel)
c = a - b * \beta
>> a - A matrix
>> mask - mask matrix
>> c - where we put masked a
>> size - the size of a/b/c
>> alpha - value
*/
__global__
void KernelMASK(DTYPE * a, int * mask, DTYPE * c, int size, DTYPE alpha)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
if (mask[i] == 0) {
c[i] = alpha;
}
else {
c[i] = a[i];
}
}
}
/*
mask entries of a given tensor (cuda version)
>> a - a tensor
>> mask - mask tensor
>> c - where we put masked a
>> alpha - value
*/
void _CudaMask(const XTensor * a, const XTensor * mask, XTensor * c, DTYPE alpha)
{
CheckNTErrors(a && mask && c, "Empty tensor input!");
CheckNTErrors((a->unitNum == mask->unitNum && a->unitNum == c->unitNum),
"Unmatched tensors in addition!");
CheckNTErrors(mask->dataType == X_INT, "The mask tensor must be in X_INT!")
//CheckNTErrors((a->dataType == mask->dataType && a->dataType == c->dataType),
// "Unmatched tensors in addition!");
CheckNTErrors((a->devID == mask->devID && a->devID == c->devID),
"The tensors must be on the same!");
int devIDBackup = XDevice::GetGPUDevice();
XDevice::SetGPUDevice(a->devID);
if (!a->isSparse && !mask->isSparse) {
CheckNTErrors(!c->isSparse, "Illegal use of sparse matrix in addition!");
if (a->dataType == DEFAULT_DTYPE &&
mask->dataType == X_INT &&
c->dataType == DEFAULT_DTYPE)
{
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
KernelMASK << <blocks, threads >> >((DTYPE*)a->data, (int *)mask->data, (DTYPE*)c->data, a->unitNum, alpha);
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
XDevice::SetGPUDevice(devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2019-04-24
* I'll attend several conferences and workshops in the following weeks -
* busy days :(
*/
#ifndef __MASK_CUH__
#define __MASK_CUH__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* mask entries of a given tensor (cuda version) */
void _CudaMask(const XTensor * a, const XTensor * mask, XTensor * c = NULL, DTYPE alpha = (DTYPE)1.0);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __MASK_CUH__
\ No newline at end of file
......@@ -202,6 +202,42 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
delete cList;
}
bool CheckMMulShape(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c)
{
if (!(a && b && c))
return false;
if(!(a->dataType == b->dataType && a->dataType == c->dataType))
return false;
if (!(a->order >= 2 && b->order >= 2 && c->order >= 2))
return false;
int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1];
int am = transposedA == X_TRANS ? a->dimSizeRDI[1] : a->dimSizeRDI[0];
int bn = transposedB == X_TRANS ? b->dimSizeRDI[0] : b->dimSizeRDI[1];
int bm = transposedB == X_TRANS ? b->dimSizeRDI[1] : b->dimSizeRDI[0];
CheckNTErrors(am == bn, "Unmatched tensors in multiplication!");
int order = a->order + b->order - 2;
int sub = 0;
int * dimSize = new int[order];
for (int i = 2; i < a->order; i++)
dimSize[sub++] = a->dimSizeRDI[a->order + 1 - i];
for (int i = 2; i < b->order; i++)
dimSize[sub++] = b->dimSizeRDI[b->order + 1 - i];
dimSize[sub++] = an;
dimSize[sub++] = bm;
for (int i = 0; i < order; i++) {
if (dimSize[i] != c->dimSize[i])
return false;
}
return true;
}
/*
matrix multiplication (return an XTensor structure) c = trans(a) * trans(b) * alpha
make a new tensor to keep the result and return it
......@@ -266,6 +302,53 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
return c;
}
void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
const XTensor &b, MATRIX_TRANS_TYPE transposedB, XTensor &c,
DTYPE alpha, XPRunner * parallelRunner, bool requireLink)
{
CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!");
CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!");
if (!c.isInit || !CheckMMulShape(&a, transposedA, &b, transposedB, &c)) {
int an = transposedA == X_TRANS ? a.dimSizeRDI[0] : a.dimSizeRDI[1];
int am = transposedA == X_TRANS ? a.dimSizeRDI[1] : a.dimSizeRDI[0];
int bn = transposedB == X_TRANS ? b.dimSizeRDI[0] : b.dimSizeRDI[1];
int bm = transposedB == X_TRANS ? b.dimSizeRDI[1] : b.dimSizeRDI[0];
CheckNTErrors(am == bn, "Unmatched tensors in multiplication!");
int order = a.order + b.order - 2;
int sub = 0;
int * dimSize = new int[order];
for (int i = 2; i < a.order; i++)
dimSize[sub++] = a.dimSizeRDI[a.order + 1 - i];
for (int i = 2; i < b.order; i++)
dimSize[sub++] = b.dimSizeRDI[b.order + 1 - i];
dimSize[sub++] = an;
dimSize[sub++] = bm;
float dr = (!a.isSparse || !b.isSparse) ? 1.0F : MAX(a.denseRatio, b.denseRatio);
InitTensor(&c, order, dimSize, a.dataType, dr, a.devID, a.mem);
/* destroy variables */
delete[] dimSize;
}
/* call _MatrixMul function */
_MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHeadTrans(&c, transposedB);
XLink::AddParamToHead(&c, alpha);
}
}
/*
matrix multiplication with no transposition c = a * b * alpha
>> a - tensor a
......@@ -316,6 +399,52 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b,
return c;
}
void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c,
DTYPE alpha, XPRunner * parallelRunner, bool requireLink)
{
CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!");
CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!");
if (!c.isInit || !CheckMMulShape(&a, X_NOTRANS, &b, X_NOTRANS, &c)) {
int an = a.dimSizeRDI[1];
int am = a.dimSizeRDI[0];
int bn = b.dimSizeRDI[1];
int bm = b.dimSizeRDI[0];
CheckNTErrors(am == bn, "Unmatched tensors in multiplication!");
int order = a.order + b.order - 2;
int sub = 0;
int * dimSize = new int[order];
for (int i = 2; i < a.order; i++)
dimSize[sub++] = a.dimSizeRDI[a.order + 1 - i];
for (int i = 2; i < b.order; i++)
dimSize[sub++] = b.dimSizeRDI[b.order + 1 - i];
dimSize[sub++] = an;
dimSize[sub++] = bm;
float dr = (!a.isSparse || !b.isSparse) ? 1.0F : MAX(a.denseRatio, b.denseRatio);
InitTensor(&c, order, dimSize, a.dataType, dr, a.devID, a.mem);
/* destroy variables */
delete[] dimSize;
}
/* call _MatrixMul function */
_MatrixMul(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha);
}
}
} // namespace nts(NiuTrans.Tensor)
......
......@@ -59,10 +59,16 @@ Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x
XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
XTensor &c, DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL, bool requireLink = false);
/* matrix multiplication with no transposition c = a * b * alpha*/
XTensor MatrixMul(const XTensor &a, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -117,7 +117,6 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
ShowNTErrors("Something is wrong!");
}
/* tensor connections */
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadInt(&c, n);
......
......@@ -219,4 +219,55 @@ XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim
return c;
}
/*
element-wise product of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the item
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
>> requireLink - if add operation to network
*/
void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadingDim, bool requireLink)
{
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) {
InitTensor(&c, &a);
}
int n = GetMultiplyDimIndex(a, b);
if (n == -1) {
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
/* call _Multiply function */
_Multiply(&a, &b, &c, 0, leadingDim);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
}
else if (n >= 0 && n < a.order) {
/* call _MultiplyDim function */
_MultiplyDim(&a, &b, &c, n, alpha);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
}
else {
ShowNTErrors("Something is wrong!");
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -49,6 +49,13 @@ where i is the index of the element
*/
XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha = 0.0, int leadingDim = 0);
/*
element-wise product of two tensors:
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the element
*/
void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha = 0.0, int leadingDim = 0, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __MULTIPLY_H__
\ No newline at end of file
......@@ -170,6 +170,36 @@ XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n)
}
/*
tensor multiplication
c = a * b + \alpha * c
where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a * b + \alpha * c. we save it in a if c is NULL
>> n - the dimension index
>> requireLink - if add operation to network
*/
void MultiplyDim(const XTensor &a, const XTensor &b, XTensor &c, int n, bool requireLink)
{
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) {
InitTensor(&c, &a);
}
/* call _Multiply function */
_MultiplyDim(&a, &b, &c, n, 0);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, 0);
}
}
/*
tensor broadcast multiplication
c = a * b + c * \beta
where some of dimensions of b can be of size 1
......@@ -309,4 +339,30 @@ XTensor MultiplyBroadcast(const XTensor &a, const XTensor &b)
return c;
}
/*
tensor broadcast multiplication
c = a * b + c * \beta
where some of dimensions of b can be of size 1
>> a - a tensor
>> b - another tensor that would be broadcasted
>> c - the resulting tensor
>> requireLink - if add operation to network
*/
void MultiplyBroadcast(const XTensor &a, const XTensor &b, XTensor &c, bool requireLink)
{
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) {
InitTensor(&c, &a);
}
/* call _SumBroadcast function */
_MultiplyBroadcast(&a, &b, &c, 0);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYBROADCAST);
XLink::AddParamToHead(&c, 0);
}
}
}
......@@ -38,6 +38,10 @@ void _MultiplyDimMe(XTensor * a, const XTensor * b, int n, DTYPE alpha = 0.0);
i.e., a is multiplied with b by broadcasting. We make a new tensor c to keep the result and return it */
XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n);
/* tensor multiplication c = a * b + \alpha * c where the size of b is equal to the n-th dimension of a,
i.e., a is multiplied with b by broadcasting */
void MultiplyDim(const XTensor &a, const XTensor &b, XTensor &c, int n, bool requireLink = false);
/* tensor multiplication summation c = a * b + c * \beta where some of dimensions of b can be of size 1 */
void _MultiplyBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
......@@ -45,6 +49,9 @@ void _MultiplyBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE
we return the resulting tensor here */
XTensor MultiplyBroadcast(const XTensor &a, const XTensor &b);
/* tensor multiplication summation c = a * b + c * \beta where some of dimensions of b can be of size 1 */
void MultiplyBroadcast(const XTensor &a, const XTensor &b, XTensor &c, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __MULTIPLYDIM_H__
......@@ -79,4 +79,25 @@ XTensor Negate(const XTensor & a)
return b;
}
/*
set every entry to its minus value
>> a - input tensor we are processing
>> b - output tensor we are processing
>> requireLink - if add operation to network
*/
void Negate(const XTensor & a, XTensor & b, bool requireLink)
{
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) {
InitTensor(&b, &a);
}
/* call _Negate function */
_Negate(&a, &b);
if (b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_NEGATE);
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -41,6 +41,9 @@ make a new tensor to keep the result and return it
*/
XTensor Negate(const XTensor & a);
/* set every entry to its minus value */
void Negate(const XTensor & a, XTensor & b, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __NEGATE_H__
......@@ -84,4 +84,25 @@ XTensor Sign(const XTensor & a)
return b;
}
/*
set every entry to its sign value
>> a - input tensor we are processing
>> b - output tensor we are processing
>> requireLink - if add operation to network
*/
void Sign(const XTensor & a, XTensor & b, bool requireLink)
{
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) {
InitTensor(&b, &a);
}
/* call _Sign function */
_Sign(&a, &b);
if (b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SIGN);
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -41,6 +41,9 @@ make a new tensor to keep the result and return it
*/
XTensor Sign(const XTensor & a);
/* set every entry to its sign value */
void Sign(const XTensor & a, XTensor & b, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __SIGN_H__
......@@ -196,4 +196,47 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
return c;
}
/*
tensor subtraction c = a - b * \beta
>> a - a tensor
>> b - another tensor
>> c - where we put a-b*\beta. we save it in a if c is NULL
>> beta - the scaling factor
>> requireLink - if add operation to network
*/
void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta, bool requireLink)
{
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) {
InitTensor(&c, &a);
}
int n = GetSubDimIndex(a, b);
if (n == -1) {
/* call _Sub function */
_Sub(&a, &b, &c, beta);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
}
}
else if (n >= 0 && n < a.order) {
/* call _SubDim function */
_SubDim(&a, &b, &c, n, beta);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
}
else {
ShowNTErrors("Something is wrong!");
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -42,6 +42,9 @@ make a new tensor c to keep the result and return it
*/
XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta = (DTYPE)1.0);
/* tensor subtraction c = a - b * \beta */
void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta = (DTYPE)1.0, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __SUB_H__
......@@ -171,4 +171,35 @@ XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
return c;
}
/*
tensor subtraction
c = a - b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a-b*\beta. we save it in a if c is NULL
>> n - the dimension index
>> beta - the scaling factor
>> requireLink - if add operation to network
*/
void SubDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta, bool requireLink)
{
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) {
InitTensor(&c, &a);
}
/* call _Sub function */
_SubDim(&a, &b, &c, n, beta);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
}
}
......@@ -38,6 +38,10 @@ void _SubDim(XTensor * a, const XTensor * b, int n, DTYPE beta = (DTYPE)1.0);
i.e., a is subtracted with b by broadcasting. We make a new tensor c to keep the result and return it */
XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta = (DTYPE)1.0);
/* tensor subtraction c = a - b * \beta where the size of b is equal to the n-th dimension of a,
i.e., a is subtracted with b by broadcasting*/
void SubDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta = (DTYPE)1.0, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __SUBDIM_H__
......@@ -201,4 +201,46 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
return c;
}
/*
tensor summation c = a + b * \beta
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
>> requireLink - if add operation to network
*/
void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta, bool requireLink)
{
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) {
InitTensor(&c, &a);
}
int n = GetSumDimIndex(a, b);
if (n == -1) {
/* call _Sum function */
_Sum(&a, &b, &c, beta);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta);
}
}
else if (n >= 0 && n < a.order) {
/* call _SumDim function */
_SumDim(&a, &b, &c, n, beta);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
}
else {
ShowNTErrors("Something is wrong!");
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -41,6 +41,9 @@ make a new tensor c to keep the result and return it
*/
XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta = (DTYPE)1.0);
/* tensor summation c = a + b * \beta */
void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta = (DTYPE)1.0, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __SUM_H__
......@@ -189,6 +189,37 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
}
/*
tensor summation
c = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a+b*\beta. we save it in a if c is NULL
>> n - the dimension index
>> beta - the scaling factor
>> requireLink - if add operation to network
*/
void SumDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta, bool requireLink)
{
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) {
InitTensor(&c, &a);
}
/* call _SumDim function */
_SumDim(&a, &b, &c, n, beta);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
}
/*
tensor broadcast summation c = a + b * \beta where some of dimensions of b can be of size 1
c = a + b * \beta
......@@ -329,4 +360,30 @@ XTensor SumBroadcast(const XTensor &a, const XTensor &b, DTYPE beta)
return c;
}
/*
tensor broadcast summation c = a + b * \beta where some of dimensions of b can be of size 1
c = a + b * \beta
>> a - a tensor
>> b - another tensor that would be broadcasted
>> c - the resulting tensor
>> beta - the scaling factor
>> requireLink - if add operation to network
*/
void SumBroadcast(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta, bool requireLink)
{
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) {
InitTensor(&c, &a);
}
/* call _SumBroadcast function */
_SumBroadcast(&a, &b, &c, beta);
if (c.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMBROADCAST);
XLink::AddParamToHead(&c, beta);
}
}
}
......@@ -42,6 +42,10 @@ void _SumDim(XTensor * a, const XTensor * b, int n, DTYPE beta = (DTYPE)1.0);
i.e., a is summed with b by broadcasting. We make a new tensor c to keep the result and return it */
XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta = (DTYPE)1.0);
/* tensor summation c = a + b * \beta where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting */
void SumDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta = (DTYPE)1.0, bool requireLink = false);
/* tensor broadcast summation c = a + b * \beta where some of dimensions of b can be of size 1 */
void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
......@@ -49,6 +53,9 @@ void _SumBroadcast(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta
we return the resulting tensor here */
XTensor SumBroadcast(const XTensor &a, const XTensor &b, DTYPE beta = (DTYPE)1.0);
/* tensor broadcast summation c = a + b * \beta where some of dimensions of b can be of size 1 */
void SumBroadcast(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta = (DTYPE)1.0, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __SUMDIM_H__
......@@ -111,9 +111,10 @@ void _IndexToOnehot(XTensor * index, XTensor * onehot, int size, float labelSmoo
onehot->SetZeroAll();
#ifdef USE_CUDA
float confidence = 1 - labelSmoothingP;
float lowconfidence = labelSmoothingP / size;
#ifdef USE_CUDA
if(onehot->devID >= 0 && index->devID >= 0) {
_CudaIndexToOnehot(index, onehot, size, confidence, lowconfidence);
return;
......@@ -129,8 +130,7 @@ void _IndexToOnehot(XTensor * index, XTensor * onehot, int size, float labelSmoo
for (int i = 0; i < blockNum; i++) {
int id = indexData[i];
DTYPE * od = onehotData + i * stride;
od[id] = 2;
//onehotData[i * stride + id] = 1;
od[id] = 1;
}
}
......
......@@ -31,16 +31,31 @@ int scale(int x, int scale)
return x * scale;
}
float scale(float x, float scale)
{
return x * scale;
}
int descale(int x, int descale)
{
return x / descale;
}
float descale(float x, float descale)
{
return x / descale;
}
int shift(int x, int shift)
{
return x + shift;
}
float shift(float x, float shift)
{
return x + shift;
}
int mod(int x, int mod)
{
return x % mod;
......@@ -48,7 +63,7 @@ int mod(int x, int mod)
#ifdef USE_CUDA
/* define three marco separately, specify the respective function names (GPU mode) */
#define _SIMPLE_BINARY_FUNCTION(_funcName, _cudaFuncName, origFunc) \
#define _SIMPLE_BINARY_FUNCTION_INT(_funcName, _cudaFuncName, origFunc) \
void _funcName(const XTensor * a, XTensor * b, int num) \
{ \
/* run it on GPUs */ \
......@@ -58,82 +73,188 @@ void _funcName(const XTensor * a, XTensor * b, int num) \
} \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same data type!"); \
CheckNTErrors((a->dataType == X_INT), "TODO!"); \
CheckNTErrors((a->dataType == X_INT&&b->dataType == X_INT), "TODO!"); \
int * d = (int*)a->data; \
int * db = (int*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (int)origFunc(d[i], num); \
} \
#define _SIMPLE_BINARY_FUNCTION(_funcName, _cudaFuncName, origFunc) \
void _funcName(const XTensor * a, XTensor * b, float num) \
{ \
/* run it on GPUs */ \
if (a->devID >= 0) { \
_cudaFuncName(a, b, num); \
return; \
} \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same data type!"); \
CheckNTErrors((a->dataType == X_FLOAT&&b->dataType == X_FLOAT), "TODO!");\
float * d = (float*)a->data; \
float * db = (float*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (float)origFunc(d[i], num); \
}
#define SIMPLE_BINARY_FUNCTION_ME(funcName, _funcName) \
#define SIMPLE_BINARY_FUNCTION_ME_INT(funcName, _funcName) \
void funcName(XTensor &a, int num) \
{ \
_funcName(&a, &a, num); \
}
} \
#define SIMPLE_BINARY_FUNCTION(funcName, _funcName) \
#define SIMPLE_BINARY_FUNCTION_ME(funcName, _funcName) \
void funcName(XTensor &a, float num) \
{ \
_funcName(&a, &a, num); \
} \
#define SIMPLE_BINARY_FUNCTION_INT(funcName, _funcName) \
void funcName(const XTensor &a, XTensor &b, int num) \
{ \
_funcName(&a, &b, num); \
}
} \
_SIMPLE_BINARY_FUNCTION(_Scale, _CudaScale, scale)
SIMPLE_BINARY_FUNCTION_ME(Scale, _Scale)
SIMPLE_BINARY_FUNCTION(Scale, _Scale)
#define SIMPLE_BINARY_FUNCTION(funcName, _funcName, operationId) \
XTensor funcName(const XTensor &a, float num) \
{ \
XTensor b(&a); \
b.SetTMPFlag(); \
_funcName(&a, &b, num); \
XLink::MakeLink(&a, NULL, &b, operationId); \
return b; \
} \
_SIMPLE_BINARY_FUNCTION(_Descale, _CudaDescale, descale)
SIMPLE_BINARY_FUNCTION_ME(Descale, _Descale)
SIMPLE_BINARY_FUNCTION(Descale, _Descale)
#define SIMPLE_BINARY_FUNCTION_VOID(funcName, _funcName, operationId) \
void funcName(const XTensor &a, XTensor &b, float num, bool requireLink) \
{ \
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) { \
InitTensor(&b, &a); \
} \
_funcName(&a, &b, num); \
if (b.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
} \
_SIMPLE_BINARY_FUNCTION(_Shift, _CudaShift, shift)
SIMPLE_BINARY_FUNCTION_ME(Shift, _Shift)
SIMPLE_BINARY_FUNCTION(Shift, _Shift)
_SIMPLE_BINARY_FUNCTION_INT(_Scale, _CudaScale, scale)
SIMPLE_BINARY_FUNCTION_ME_INT(_ScaleMe, _Scale)
SIMPLE_BINARY_FUNCTION_INT(Scale, _Scale)
_SIMPLE_BINARY_FUNCTION(_Scale, _CudaScaleFloat, scale)
SIMPLE_BINARY_FUNCTION_ME(_ScaleMe, _Scale)
SIMPLE_BINARY_FUNCTION(Scale, _Scale, MATH_SCALE)
SIMPLE_BINARY_FUNCTION_VOID(Scale, _Scale, MATH_SCALE)
_SIMPLE_BINARY_FUNCTION_INT(_Descale, _CudaDescale, descale)
SIMPLE_BINARY_FUNCTION_ME_INT(_DescaleMe, _Descale)
SIMPLE_BINARY_FUNCTION_INT(Descale, _Descale)
_SIMPLE_BINARY_FUNCTION(_Descale, _CudaDescaleFloat, descale)
SIMPLE_BINARY_FUNCTION_ME(_DescaleMe, _Descale)
SIMPLE_BINARY_FUNCTION(Descale, _Descale, MATH_DESCALE)
SIMPLE_BINARY_FUNCTION_VOID(Descale, _Descale, MATH_DESCALE)
_SIMPLE_BINARY_FUNCTION_INT(_Shift, _CudaShift, shift)
SIMPLE_BINARY_FUNCTION_ME_INT(_ShiftMe, _Shift)
SIMPLE_BINARY_FUNCTION_INT(Shift, _Shift)
_SIMPLE_BINARY_FUNCTION(_Shift, _CudaShiftFloat, shift)
SIMPLE_BINARY_FUNCTION_ME(_ShiftMe, _Shift)
SIMPLE_BINARY_FUNCTION(Shift, _Shift, MATH_SHIFT)
SIMPLE_BINARY_FUNCTION_VOID(Shift, _Shift, MATH_SHIFT)
_SIMPLE_BINARY_FUNCTION(_Mod, _CudaMod, mod)
SIMPLE_BINARY_FUNCTION_ME(Mod, _Mod)
SIMPLE_BINARY_FUNCTION(Mod, _Mod)
_SIMPLE_BINARY_FUNCTION_INT(_Mod, _CudaMod, mod)
SIMPLE_BINARY_FUNCTION_ME_INT(_ModMe, _Mod)
SIMPLE_BINARY_FUNCTION_INT(Mod, _Mod)
#else
/* define three marco separately, specify the respective function names (CPU mode) */
#define _SIMPLE_BINARY_FUNCTION(_funcName, origFunc) \
#define _SIMPLE_BINARY_FUNCTION_INT(_funcName, _cudaFuncName, origFunc) \
void _funcName(const XTensor * a, XTensor * b, int num) \
{ \
/* run it on GPUs */ \
if (a->devID >= 0) { \
_cudaFuncName(a, b, num); \
return; \
} \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same data type!"); \
CheckNTErrors((a->dataType == X_INT), "TODO!"); \
CheckNTErrors((a->dataType == X_INT&&b->dataType == X_INT), "TODO!"); \
int * d = (int*)a->data; \
int * db = (int*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (int)origFunc(d[i], num); \
} \
#define _SIMPLE_BINARY_FUNCTION(_funcName, _cudaFuncName, origFunc) \
void _funcName(const XTensor * a, XTensor * b, float num) \
{ \
/* run it on GPUs */ \
if (a->devID >= 0) { \
_cudaFuncName(a, b, num); \
return; \
} \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same data type!"); \
CheckNTErrors((a->dataType == X_FLOAT&&b->dataType == X_FLOAT), "TODO!");\
float * d = (float*)a->data; \
float * db = (float*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (float)origFunc(d[i], num); \
}
#define SIMPLE_BINARY_FUNCTION_ME_INT(funcName, _funcName) \
void funcName(XTensor &a, int num) \
{ \
_funcName(&a, &a, num); \
} \
#define SIMPLE_BINARY_FUNCTION_ME(funcName, _funcName) \
void funcName(XTensor & a, int num) \
void funcName(XTensor &a, float num) \
{ \
_funcName(&a, &a, num); \
}
} \
#define SIMPLE_BINARY_FUNCTION_INT(funcName, _funcName) \
void funcName(const XTensor &a, XTensor &b, int num) \
{ \
_funcName(&a, &b, num); \
} \
#define SIMPLE_BINARY_FUNCTION(funcName, _funcName) \
void funcName(const XTensor & a, XTensor &b, int num) \
void funcName(const XTensor &a, XTensor &b, float num) \
{ \
_funcName(&a, &b, num); \
}
} \
_SIMPLE_BINARY_FUNCTION_INT(_Scale, _CudaScale, scale)
SIMPLE_BINARY_FUNCTION_ME_INT(Scale, _Scale)
SIMPLE_BINARY_FUNCTION_INT(Scale, _Scale)
_SIMPLE_BINARY_FUNCTION(_Scale, scale)
_SIMPLE_BINARY_FUNCTION(_Scale, _CudaScaleFloat, scale)
SIMPLE_BINARY_FUNCTION_ME(Scale, _Scale)
SIMPLE_BINARY_FUNCTION(Scale, _Scale)
_SIMPLE_BINARY_FUNCTION(_Descale, descale)
_SIMPLE_BINARY_FUNCTION_INT(_Descale, _CudaDescale, descale)
SIMPLE_BINARY_FUNCTION_ME_INT(Descale, _Descale)
SIMPLE_BINARY_FUNCTION_INT(Descale, _Descale)
_SIMPLE_BINARY_FUNCTION(_Descale, _CudaDescaleFloat, descale)
SIMPLE_BINARY_FUNCTION_ME(Descale, _Descale)
SIMPLE_BINARY_FUNCTION(Descale, _Descale)
_SIMPLE_BINARY_FUNCTION(_Shift, shift)
_SIMPLE_BINARY_FUNCTION_INT(_Shift, _CudaShift, shift)
SIMPLE_BINARY_FUNCTION_ME_INT(Shift, _Shift)
SIMPLE_BINARY_FUNCTION_INT(Shift, _Shift)
_SIMPLE_BINARY_FUNCTION(_Shift, _CudaShiftFloat, shift)
SIMPLE_BINARY_FUNCTION_ME(Shift, _Shift)
SIMPLE_BINARY_FUNCTION(Shift, _Shift)
_SIMPLE_BINARY_FUNCTION(_Mod, mod)
SIMPLE_BINARY_FUNCTION_ME(Mod, _Mod)
SIMPLE_BINARY_FUNCTION(Mod, _Mod)
_SIMPLE_BINARY_FUNCTION_INT(_Mod, _CudaMod, mod)
SIMPLE_BINARY_FUNCTION_ME_INT(Mod, _Mod)
SIMPLE_BINARY_FUNCTION_INT(Mod, _Mod)
#endif
......
......@@ -36,18 +36,36 @@ int cudascale(int x, int scale)
}
__device__
float cudascale(float x, float scale)
{
return x * scale;
}
__device__
int cudadescale(int x, int descale)
{
return x / descale;
}
__device__
float cudadescale(float x, float descale)
{
return x / descale;
}
__device__
int cudashift(int x, int shift)
{
return x + shift;
}
__device__
float cudashift(float x, float descale)
{
return x + descale;
}
__device__
int cudamod(int x, int mod)
{
return x % mod;
......@@ -92,9 +110,51 @@ void _Cuda##funcName(const XTensor * a, XTensor * b, int num) \
BacktoCudaDev(a->devID, devIDBackup); \
} \
#define SIMPLE_BINARY_FUNCTION_FLOAT_GPU(funcName, origFunc) \
__global__ \
void Kernel##funcName(float * a, float * b, int size, float num) \
{ \
int i = blockDim.x * blockIdx.x + threadIdx.x; \
\
if (i < size) \
b[i] = (float)origFunc(a[i], num); \
} \
\
\
void _Cuda##funcName(const XTensor * a, XTensor * b, float num) \
{ \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \
CheckNTErrors((a->isSparse == false), "TODO!"); \
\
int gridSize[3]; \
int blockSize[3]; \
\
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize); \
\
dim3 blocks(gridSize[0]); \
dim3 threads(blockSize[0]); \
\
int devIDBackup; \
ProtectCudaDev(a->devID, devIDBackup); \
\
if (a->dataType == X_FLOAT) { \
Kernel##funcName<<<blocks, threads>>> \
((float*)a->data, (float*)b->data, a->unitNum, num);\
} \
else { \
ShowNTErrors("TODO!"); \
} \
\
BacktoCudaDev(a->devID, devIDBackup); \
}
SIMPLE_BINARY_FUNCTION_GPU(Scale, cudascale)
SIMPLE_BINARY_FUNCTION_FLOAT_GPU(ScaleFloat, cudascale)
SIMPLE_BINARY_FUNCTION_GPU(Descale, cudadescale)
SIMPLE_BINARY_FUNCTION_FLOAT_GPU(DescaleFloat, cudadescale)
SIMPLE_BINARY_FUNCTION_GPU(Shift, cudashift)
SIMPLE_BINARY_FUNCTION_FLOAT_GPU(ShiftFloat, cudashift)
SIMPLE_BINARY_FUNCTION_GPU(Mod, cudamod)
#endif // USE_CUDA
......
......@@ -32,20 +32,29 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* scale each entry (CUDA Kernel) */
__global__
void KernelScale(int * a, int * b, int size, int scale);
__global__
void KernelScale(int * a, int * b, int size, float scale);
/* scale each entry */
void _CudaScale(const XTensor * a, XTensor * b, int scale);
void _CudaScaleFloat(const XTensor * a, XTensor * b, float scale);
/* descale each entry (CUDA Kernel) */
__global__
void KernelDescale(int * a, int * b, int size, int scale);
__global__
void KernelDescale(int * a, int * b, int size, float scale);
/* descale each entry */
void _CudaDescale(const XTensor * a, XTensor * b, int scale);
void _CudaDescaleFloat(const XTensor * a, XTensor * b, float scale);
/* shift each entry (CUDA Kernel) */
__global__
void KernelShift(int * a, int * b, int size, int shift);
__global__
void KernelShift(int * a, int * b, int size, float shift);
/* shift each entry */
void _CudaShift(const XTensor * a, XTensor * b, int shift);
void _CudaShiftFloat(const XTensor * a, XTensor * b, float shift);
/* mod each entry (CUDA Kernel) */
__global__
......
......@@ -37,51 +37,76 @@ void _Scale(const XTensor * a, XTensor * b, float scale);
scale up tensor entires (on site)
b = a * scale
*/
void Scale(XTensor & a, int scale);
void Scale(XTensor & a, float scale);
void _ScaleMe(XTensor & a, int scale);
void _ScaleMe(XTensor & a, float scale);
/*
scale up tensor entires
b = a * scale
*/
void Scale(const XTensor & a, XTensor &b, int scale);
void Scale(const XTensor & a, XTensor &b, float scale);
void Scale(const XTensor & a, XTensor &b, float scale, bool requireLink = false);
/*
scale up tensor entires (return an XTensor structure)
b = a * scale
*/
XTensor Scale(const XTensor & a, float scale);
/*
descale tensor entires
b = a / scale
*/
void _Descale(const XTensor * a, XTensor * b, int scale);
void _Descale(const XTensor * a, XTensor * b, float scale);
/*
descale tensor entires (on site)
b = a / scale
*/
void Descale(XTensor & a, int scale);
void _DescaleMe(XTensor & a, int scale);
void _DescaleMe(XTensor & a, float scale);
/*
descale tensor entires
b = a / scale
*/
void Descale(const XTensor & a, XTensor & b, int scale);
void Descale(const XTensor & a, XTensor & b, float scale, bool requireLink = false);
/*
descale tensor entires (return an XTensor structure)
b = a / scale
*/
XTensor Descale(const XTensor & a, float scale);
/*
shift tensor entires
b = a + shift
*/
void _Shift(const XTensor * a, XTensor * b, int shift);
void _Shift(const XTensor * a, XTensor * b, float shift);
/*
shift tensor entires (on site)
b = a + shift
*/
void Shift(XTensor & a, int shift);
void _ShiftMe(XTensor & a, int shift);
void _ShiftMe(XTensor & a, float shift);
/*
shift tensor entires
b = a + shift
*/
void Shift(const XTensor & a, XTensor & b, int shift);
void Shift(const XTensor & a, XTensor & b, float shift, bool requireLink = false);
/*
shift tensor entires (return an XTensor structure)
b = a + shift
*/
XTensor Shift(const XTensor & a, float shift);
/*
mod tensor entires
......@@ -93,7 +118,7 @@ void _Mod(const XTensor * a, XTensor * b, int base);
mod tensor entires (on site)
b = a % mod
*/
void Mod(XTensor & a, int base);
void _ModMe(XTensor & a, int base);
/*
mod tensor entires
......
......@@ -94,6 +94,23 @@ XTensor Clip(const XTensor & a, DTYPE lower, DTYPE upper)
return b;
}
void Clip(const XTensor & a, XTensor & b, DTYPE lower, DTYPE upper, bool requireLink)
{
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) {
InitTensor(&b, &a);
}
/* call _Clip function */
_Clip(&a, &b, lower, upper);
if (b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_CLIP);
XLink::AddParamToHead(&b, lower);
XLink::AddParamToHead(&b, upper);
}
}
/*
backward computation
......
......@@ -37,6 +37,8 @@ void _ClipMe(XTensor * a, DTYPE lower, DTYPE upper);
make a new tensor to keep the result and return it */
XTensor Clip(const XTensor & a, DTYPE lower, DTYPE upper);
void Clip(const XTensor & a, XTensor & b, DTYPE lower, DTYPE upper, bool requireLink = false);
/*
backward of Clip function
*/
......
......@@ -102,4 +102,27 @@ XTensor Power(const XTensor & a, DTYPE p)
return b;
}
/*
get the power(a, p)
>> a - input tensor
>> b - output tensor
>> p - parameter
>> requireLink - if add operation to network
*/
void Power(const XTensor & a, XTensor & b, DTYPE p, bool requireLink)
{
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) {
InitTensor(&b, &a);
}
/* call _Power function */
_Power(&a, &b, p);
if (b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_POWER);
XLink::AddParamToHead(&b, p);
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -41,6 +41,9 @@ make a new tensor to keep the result and return it
*/
XTensor Power(const XTensor & a, DTYPE p);
/* get the power(x, y) */
void Power(const XTensor & a, XTensor & b, DTYPE p, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __POWER_H__
......@@ -118,4 +118,33 @@ XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift)
return b;
}
/*
scale and shift all tensor entires
b = a * scale + shift
>> a - the input tensor
>> b - the output tensor
>> scale - the scaler factor
>> shift - the shift factor
>> requireLink - if add operation to network
*/
void ScaleAndShift(const XTensor & a, XTensor & b, DTYPE scale, DTYPE shift, bool requireLink)
{
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) {
InitTensor(&b, &a);
}
/* call _ScaleAndShift function */
_ScaleAndShift(&a, &b, scale, shift);
if (b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT);
XLink::AddParamToHead(&b, scale);
XLink::AddParamToHead(&b, shift);
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -50,6 +50,12 @@ b = a * scale + shift
*/
XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift = 0);
/*
scale and shift all tensor entires
b = a * scale + shift
*/
void ScaleAndShift(const XTensor &a, XTensor &b, DTYPE scale, DTYPE shift = 0, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __SCALEANDSHIFT_H__
\ No newline at end of file
......@@ -82,58 +82,82 @@ XTensor funcName(const XTensor &a) \
return b; \
}
#define SIMPLE_UNARY_FUNCTION_VOID(funcName, _funcName, operationId) \
void funcName(const XTensor &a, XTensor &b, bool requireLink) \
{ \
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) { \
InitTensor(&b, &a); \
} \
_funcName(&a, &b); \
if (b.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
}
_SIMPLE_UNARY_FUNCTION(_Absolute, _CudaAbsolute, fabs)
_SIMPLE_UNARY_FUNCTION_ME(_AbsoluteMe, _Absolute)
SIMPLE_UNARY_FUNCTION(Absolute, _Absolute, MATH_ABSOLUTE)
SIMPLE_UNARY_FUNCTION_VOID(Absolute, _Absolute, MATH_ABSOLUTE)
_SIMPLE_UNARY_FUNCTION(_Ceil, _CudaCeil, ceil)
_SIMPLE_UNARY_FUNCTION_ME(_CeilMe, _Ceil)
SIMPLE_UNARY_FUNCTION(Ceil, _Ceil, MATH_CEIL)
SIMPLE_UNARY_FUNCTION_VOID(Ceil, _Ceil, MATH_CEIL)
_SIMPLE_UNARY_FUNCTION(_Exp, _CudaExp, exp)
_SIMPLE_UNARY_FUNCTION_ME(_ExpMe, _Exp)
SIMPLE_UNARY_FUNCTION(Exp, _Exp, MATH_EXP)
SIMPLE_UNARY_FUNCTION_VOID(Exp, _Exp, MATH_EXP)
_SIMPLE_UNARY_FUNCTION(_Floor, _CudaFloor, floor)
_SIMPLE_UNARY_FUNCTION_ME(_FloorMe, _Floor)
SIMPLE_UNARY_FUNCTION(Floor, _Floor, MATH_FLOOR)
SIMPLE_UNARY_FUNCTION_VOID(Floor, _Floor, MATH_FLOOR)
_SIMPLE_UNARY_FUNCTION(_IsNonZero, _CudaIsNonZero, isnonzero)
_SIMPLE_UNARY_FUNCTION_ME(_IsNonZeroMe, _IsNonZero)
SIMPLE_UNARY_FUNCTION(IsNonZero, _IsNonZero, MATH_ISNONZERO)
SIMPLE_UNARY_FUNCTION_VOID(IsNonZero, _IsNonZero, MATH_ISNONZERO)
_SIMPLE_UNARY_FUNCTION(_IsZero, _CudaIsZero, iszero)
_SIMPLE_UNARY_FUNCTION_ME(_IsZeroMe, _IsZero)
SIMPLE_UNARY_FUNCTION(IsZero, _IsZero, MATH_ISZERO)
SIMPLE_UNARY_FUNCTION_VOID(IsZero, _IsZero, MATH_ISZERO)
_SIMPLE_UNARY_FUNCTION(_Log, _CudaLog, log)
_SIMPLE_UNARY_FUNCTION_ME(_LogMe, _Log)
SIMPLE_UNARY_FUNCTION(Log, _Log, MATH_LOG)
SIMPLE_UNARY_FUNCTION_VOID(Log, _Log, MATH_LOG)
_SIMPLE_UNARY_FUNCTION(_Round, _CudaRound, round)
_SIMPLE_UNARY_FUNCTION_ME(_RoundMe, _Round)
SIMPLE_UNARY_FUNCTION(Round, _Round, MATH_ROUND)
SIMPLE_UNARY_FUNCTION_VOID(Round, _Round, MATH_ROUND)
_SIMPLE_UNARY_FUNCTION(_Sqrt, _CudaSqrt, sqrt)
_SIMPLE_UNARY_FUNCTION_ME(_SqrtMe, _Sqrt)
SIMPLE_UNARY_FUNCTION(Sqrt, _Sqrt, MATH_SQRT)
SIMPLE_UNARY_FUNCTION_VOID(Sqrt, _Sqrt, MATH_SQRT)
_SIMPLE_UNARY_FUNCTION(_Square, _CudaSquare, square)
_SIMPLE_UNARY_FUNCTION_ME(_SquareMe, _Square)
SIMPLE_UNARY_FUNCTION(Square, _Square, MATH_SQUARE)
SIMPLE_UNARY_FUNCTION_VOID(Square, _Square, MATH_SQUARE)
_SIMPLE_UNARY_FUNCTION(_Sin, _CudaSin, sin)
_SIMPLE_UNARY_FUNCTION_ME(_SinMe, _Sin)
SIMPLE_UNARY_FUNCTION(Sin, _Sin, MATH_SIN)
SIMPLE_UNARY_FUNCTION_VOID(Sin, _Sin, MATH_SIN)
_SIMPLE_UNARY_FUNCTION(_Cos, _CudaCos, cos)
_SIMPLE_UNARY_FUNCTION_ME(_CosMe, _Cos)
SIMPLE_UNARY_FUNCTION(Cos, _Cos, MATH_COS)
SIMPLE_UNARY_FUNCTION_VOID(Cos, _Cos, MATH_COS)
_SIMPLE_UNARY_FUNCTION(_Tan, _CudaTan, tan)
_SIMPLE_UNARY_FUNCTION_ME(_TanMe, _Tan)
SIMPLE_UNARY_FUNCTION(Tan, _Tan, MATH_TAN)
SIMPLE_UNARY_FUNCTION_VOID(Tan, _Tan, MATH_TAN)
#else
/* define three marco separately, specify the respective function names (CPU mode) */
......@@ -164,59 +188,82 @@ XTensor funcName(const XTensor &a) \
XLink::MakeLink(&a, NULL, &b, operationId); \
return b; \
}
#define SIMPLE_UNARY_FUNCTION_VOID(funcName, _funcName, operationId) \
void funcName(const XTensor &a, XTensor &b, bool requireLink) \
{ \
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) { \
InitTensor(&b, &a); \
} \
_funcName(&a, &b); \
if (b.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
}
_SIMPLE_UNARY_FUNCTION(_Absolute, fabs)
_SIMPLE_UNARY_FUNCTION_ME(_AbsoluteMe, _Absolute)
SIMPLE_UNARY_FUNCTION(Absolute, _Absolute, MATH_ABSOLUTE)
SIMPLE_UNARY_FUNCTION_VOID(Absolute, _Absolute, MATH_ABSOLUTE)
_SIMPLE_UNARY_FUNCTION(_Ceil, ceil)
_SIMPLE_UNARY_FUNCTION_ME(_CeilMe, _Ceil)
SIMPLE_UNARY_FUNCTION(Ceil, _Ceil, MATH_CEIL)
SIMPLE_UNARY_FUNCTION_VOID(Ceil, _Ceil, MATH_CEIL)
_SIMPLE_UNARY_FUNCTION(_Exp, exp)
_SIMPLE_UNARY_FUNCTION_ME(_ExpMe, _Exp)
SIMPLE_UNARY_FUNCTION(Exp, _Exp, MATH_EXP)
SIMPLE_UNARY_FUNCTION_VOID(Exp, _Exp, MATH_EXP)
_SIMPLE_UNARY_FUNCTION(_Floor, floor)
_SIMPLE_UNARY_FUNCTION_ME(_FloorMe, _Floor)
SIMPLE_UNARY_FUNCTION(Floor, _Floor, MATH_FLOOR)
SIMPLE_UNARY_FUNCTION_VOID(Floor, _Floor, MATH_FLOOR)
_SIMPLE_UNARY_FUNCTION(_IsNonZero, isnonzero)
_SIMPLE_UNARY_FUNCTION_ME(_IsNonZeroMe, _IsNonZero)
SIMPLE_UNARY_FUNCTION(IsNonZero, _IsNonZero, MATH_ISNONZERO)
SIMPLE_UNARY_FUNCTION_VOID(IsNonZero, _IsNonZero, MATH_ISNONZERO)
_SIMPLE_UNARY_FUNCTION(_IsZero, iszero)
_SIMPLE_UNARY_FUNCTION_ME(_IsZeroMe, _IsZero)
SIMPLE_UNARY_FUNCTION(IsZero, _IsZero, MATH_ISZERO)
SIMPLE_UNARY_FUNCTION_VOID(IsZero, _IsZero, MATH_ISZERO)
_SIMPLE_UNARY_FUNCTION(_Log, log)
_SIMPLE_UNARY_FUNCTION_ME(_LogMe, _Log)
SIMPLE_UNARY_FUNCTION(Log, _Log, MATH_LOG)
SIMPLE_UNARY_FUNCTION_VOID(Log, _Log, MATH_LOG)
_SIMPLE_UNARY_FUNCTION(_Round, round)
_SIMPLE_UNARY_FUNCTION_ME(_RoundMe, _Round)
SIMPLE_UNARY_FUNCTION(Round, _Round, MATH_ROUND)
SIMPLE_UNARY_FUNCTION_VOID(Round, _Round, MATH_ROUND)
_SIMPLE_UNARY_FUNCTION(_Sqrt, sqrt)
_SIMPLE_UNARY_FUNCTION_ME(_SqrtMe, _Sqrt)
SIMPLE_UNARY_FUNCTION(Sqrt, _Sqrt, MATH_SQRT)
SIMPLE_UNARY_FUNCTION_VOID(Sqrt, _Sqrt, MATH_SQRT)
_SIMPLE_UNARY_FUNCTION(_Square, square)
_SIMPLE_UNARY_FUNCTION_ME(_SquareMe, _Square)
SIMPLE_UNARY_FUNCTION(Square, _Square, MATH_SQUARE)
SIMPLE_UNARY_FUNCTION_VOID(Square, _Square, MATH_SQUARE)
_SIMPLE_UNARY_FUNCTION(_Sin, sin)
_SIMPLE_UNARY_FUNCTION_ME(_SinMe, _Sin)
SIMPLE_UNARY_FUNCTION(Sin, _Sin, MATH_SIN)
SIMPLE_UNARY_FUNCTION_VOID(Sin, _Sin, MATH_SIN)
_SIMPLE_UNARY_FUNCTION(_Cos, cos)
_SIMPLE_UNARY_FUNCTION_ME(_CosMe, _Cos)
SIMPLE_UNARY_FUNCTION(Cos, _Cos, MATH_COS)
SIMPLE_UNARY_FUNCTION_VOID(Cos, _Cos, MATH_COS)
_SIMPLE_UNARY_FUNCTION(_Tan, tan)
_SIMPLE_UNARY_FUNCTION_ME(_TanMe, _Tan)
SIMPLE_UNARY_FUNCTION(Tan, _Tan, MATH_TAN)
SIMPLE_UNARY_FUNCTION_VOID(Tan, _Tan, MATH_TAN)
/*_SIMPLE_UNARY_FUNCTION(_Round, round)
_SIMPLE_UNARY_FUNCTION_ME(_RoundMe, _Round)
......
......@@ -34,6 +34,8 @@ void _AbsoluteMe(XTensor * a);
/* set every entry to its absolute value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Absolute(const XTensor & a);
/* set every entry to its absolute value */
void Absolute(const XTensor & a, XTensor & b, bool requireLink = false);
/* set every entry to its ceil value */
void _Ceil(const XTensor * a, XTensor * b);
......@@ -43,6 +45,8 @@ void _CeilMe(XTensor * a);
/* set every entry to its ceil value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Ceil(const XTensor & a);
/* set every entry to its ceil value */
void Ceil(const XTensor & a, XTensor & b, bool requireLink = false);
/* set every entry to its exponent value */
void _Exp(const XTensor * a, XTensor * b);
......@@ -52,6 +56,8 @@ void _ExpMe(XTensor * a);
/* set every entry to its exponent value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Exp(const XTensor & a);
/* set every entry to its exponent value */
void Exp(const XTensor & a, XTensor & b, bool requireLink = false);
/* set every entry to its floor value */
void _Floor(const XTensor * a, XTensor * b);
......@@ -61,6 +67,8 @@ void _FloorMe(XTensor * a);
/* set every entry to its floor value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Floor(const XTensor & a);
/* set every entry to its floor value */
void Floor(const XTensor & a, XTensor & b, bool requireLink = false);
/* if source entry is non-zero, set target entry to be one, otherwise zero */
void _IsNonZero(const XTensor *a, XTensor *b);
......@@ -70,6 +78,8 @@ void _IsNonZeroMe(XTensor *a);
/* if source entry is non-zero, set target entry to be one, otherwise zero (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor IsNonZero(const XTensor &a);
/* if source entry is non-zero, set target entry to be one, otherwise zero */
void IsNonZero(const XTensor &a, XTensor & b, bool requireLink = false);
/* if source entry is zero, set target entry to be one, otherwise zero */
void _IsZero(const XTensor *a, XTensor *b);
......@@ -79,6 +89,8 @@ void _IsZeroMe(XTensor *a);
/* if source entry is zero, set target entry to be one, otherwise zero (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor IsZero(const XTensor &a);
/* if source entry is zero, set target entry to be one, otherwise zero */
void IsZero(const XTensor &a, XTensor & b, bool requireLink = false);
/* set every entry to its logarithm value */
void _Log(const XTensor * a, XTensor * b);
......@@ -88,6 +100,8 @@ void _LogMe(XTensor * a);
/* set every entry to its logarithm value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Log(const XTensor & a);
/* set every entry to its logarithm value */
void Log(const XTensor & a, XTensor & b, bool requireLink = false);
/* set every entry to its round value */
void _Round(const XTensor * a, XTensor * b);
......@@ -97,6 +111,8 @@ void _RoundMe(XTensor * a);
/* set every entry to its round value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Round(const XTensor & a);
/* set every entry to its round value */
void Round(const XTensor & a, XTensor & b, bool requireLink = false);
/* set every entry to its sqrt value */
void _Sqrt(const XTensor * a, XTensor * b);
......@@ -106,6 +122,8 @@ void _SqrtMe(XTensor * a);
/* set every entry to its sqrt value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Sqrt(const XTensor & a);
/* set every entry to its sqrt value */
void Sqrt(const XTensor & a, XTensor & b, bool requireLink = false);
/* set every entry to its square value */
void _Square(const XTensor * a, XTensor * b);
......@@ -115,6 +133,8 @@ void _SquareMe(XTensor * a);
/* set every entry to its square value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Square(const XTensor & a);
/* set every entry to its square value */
void Square(const XTensor & a, XTensor & b, bool requireLink = false);
/* set every entry to its sine value */
......@@ -125,6 +145,8 @@ void _SinMe(XTensor * a);
/* set every entry to its sine value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Sin(const XTensor & a);
/* set every entry to its sine value */
void Sin(const XTensor & a, XTensor & b, bool requireLink = false);
/* set every entry to its cosine value */
void _Cos(const XTensor * a, XTensor * b);
......@@ -134,6 +156,8 @@ void _CosMe(XTensor * a);
/* set every entry to its cosine value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Cos(const XTensor & a);
/* set every entry to its cosine value */
void Cos(const XTensor & a, XTensor & b, bool requireLink = false);
/* set every entry to its tangent value */
void _Tan(const XTensor * a, XTensor * b);
......@@ -143,6 +167,8 @@ void _TanMe(XTensor * a);
/* set every entry to its tangent value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Tan(const XTensor & a);
/* set every entry to its tangent value */
void Tan(const XTensor & a, XTensor & b, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -70,7 +70,7 @@ void _CopyIndexed(const XTensor * s, XTensor * t, int dim,
for (int i = dimRDI; i < t->order; i++)
blockNumTgt *= t->dimSizeRDI[i];
CheckNTErrors((blockSizeSrc == blockSizeTgt), "Unmatched tensors!");
CheckNTErrors(blockSizeSrc == blockSizeTgt, "Unmatched tensors!");
indexOffsetNum = blockNumSrc / s->dimSizeRDI[dimRDI];
int realIndexSize = indexOffsetNum * indexSize * copyNum;
......@@ -87,13 +87,15 @@ void _CopyIndexed(const XTensor * s, XTensor * t, int dim,
for (int k = 0; k < copyNum; k++) {
rsi[k] = baseSrc + srcIndex[j] + k;
rti[k] = baseTgt + tgtIndex[j] + k;
CheckNTErrors(rsi[k] < s->unitNum, "Wrong index!");
CheckNTErrors(rti[k] < t->unitNum, "Wrong index!");
}
}
}
for (int i = 0; i < indexSize; i++) {
CheckNTErrors((srcIndex[i] < blockNumSrc), "Index is out of scope!");
CheckNTErrors((tgtIndex[i] < blockNumTgt), "Index is out of scope!");
CheckNTErrors(srcIndex[i] < blockNumSrc, "Index is out of scope!");
CheckNTErrors(tgtIndex[i] < blockNumTgt, "Index is out of scope!");
}
_CopyBlocks(s->data, blockSizeSrc * s->unitSize, realSrcIndex, realIndexSize, t->data, realTgtIndex, s->mem, s->devID);
......
......@@ -131,4 +131,43 @@ XTensor ReduceMax(const XTensor &input, int dim)
return output;
}
/*
get the max value of the items along a dimension of the tensor
>> input - the input tensor
>> output - the output tensor
>> dim - the dimension where the reduction is performed on
>> requireLink - if add operation to network
*/
void ReduceMax(const XTensor &input, XTensor &output, int dim, bool requireLink)
{
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!");
if (!output.isInit || !XTensor::IsReduceShaped(&input, &output, dim)) {
int order = input.order - 1;
int * dimSize = new int[order];
for (int i = 0; i < order; i++) {
if (i < dim)
dimSize[i] = input.dimSize[i];
else if (i >= dim)
dimSize[i] = input.dimSize[i + 1];
}
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
InitTensor(&output, order, dimSize, input.dataType, dr, input.devID, input.mem);
/* destroy variables */
delete[] dimSize;
}
/* call _ReduceMax function */
_ReduceMax(&input, &output, dim);
if (output.enableGrad) {
/* tensor connections */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX);
XLink::AddParamToHeadInt(&output, dim);
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -35,6 +35,9 @@ make a new tensor to keep the result and return it
*/
XTensor ReduceMax(const XTensor &input, int dim);
/* get the max value of the items along a dimension of the tensor. */
void ReduceMax(const XTensor &input, XTensor &output, int dim, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __REDUCEMAX_H__
......@@ -86,4 +86,45 @@ XTensor ReduceMean(const XTensor &input, int dim)
return output;
}
/*
get the mean value along a dimension of the tensor
For a 1-dimensional data array a, mean = (1/n) * sum_i input_i
>> input - the input tensor
>> output - the output tensor
>> dim - the dimension where the reduction is performed on
>> requireLink - if add operation to network
*/
void ReduceMean(const XTensor &input, XTensor &output, int dim, bool requireLink)
{
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!");
if (!output.isInit || !XTensor::IsReduceShaped(&input, &output, dim)) {
int order = input.order - 1;
int * dimSize = new int[order];
for (int i = 0; i < order; i++) {
if (i < dim)
dimSize[i] = input.dimSize[i];
else if (i >= dim)
dimSize[i] = input.dimSize[i + 1];
}
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
InitTensor(&output, order, dimSize, input.dataType, dr, input.devID, input.mem);
/* destroy variables */
delete[] dimSize;
}
/* call _ReduceMean function */
_ReduceMean(&input, &output, dim);
if (output.enableGrad) {
/* tensor connections */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMEAN);
XLink::AddParamToHeadInt(&output, dim);
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -39,6 +39,12 @@ For a 1-dimensional data array a, mean = (1/n) * sum_i input_i
*/
XTensor ReduceMean(const XTensor &input, int dim);
/*
get the mean value along a dimension of the tensor
For a 1-dimensional data array a, mean = (1/n) * sum_i input_i
*/
void ReduceMean(const XTensor &input, XTensor &output, int dim, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __REDUCEMEAN_H__
......@@ -244,6 +244,39 @@ XTensor ReduceSum(const XTensor &input, int dim, const XTensor &shift, DTYPE pow
return output;
}
void ReduceSum(const XTensor &input, XTensor &output, int dim, const XTensor &shift, DTYPE power, bool isExp, bool requireLink)
{
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!");
if (!output.isInit || !XTensor::IsReduceShaped(&input, &output, dim)) {
int order = input.order - 1;
int * dimSize = new int[order];
for (int i = 0; i < order; i++) {
if (i < dim)
dimSize[i] = input.dimSize[i];
else if (i >= dim)
dimSize[i] = input.dimSize[i + 1];
}
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
InitTensor(&output, order, dimSize, input.dataType, dr, input.devID, input.mem);
/* destroy variables */
delete[] dimSize;
}
/* call _ReduceSum function */
_ReduceSum(&input, &output, dim, &shift, power, isExp);
if (output.enableGrad) {
/* tensor connections */
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUM);
XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHead(&output, power);
XLink::AddParamToHeadBool(&output, isExp);
}
}
/*
sum the items along a dimension of the tensor (return an XTensor structure)
make a new tensor to keep the result and return it
......@@ -290,4 +323,52 @@ XTensor ReduceSum(const XTensor &input, int dim, DTYPE power, bool isExp)
return output;
}
/*
sum the items along a dimension of the tensor
For a 1-dimensional data array a,
sum = \sum_i (a_i - shift)^power if isExp == false
sum = \sum_i exp((a_i - shift)^power) if isExp == true
>> input - the input tensor
>> output - the output tensor
>> dim - the dimension where the reduction is performed on
>> shift - shift the input
>> ieExp - specify if the exp() is performed
>> power - we perform pow(item_i, power) on each item in the array
>> requireLink - if add operation to network
*/
void ReduceSum(const XTensor &input, XTensor &output, int dim, DTYPE power, bool isExp, bool requireLink)
{
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!");
if (!output.isInit || !XTensor::IsReduceShaped(&input, &output, dim)) {
int order = input.order - 1;
int * dimSize = new int[order];
for (int i = 0; i < order; i++) {
if (i < dim)
dimSize[i] = input.dimSize[i];
else if (i >= dim)
dimSize[i] = input.dimSize[i + 1];
}
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
InitTensor(&output, order, dimSize, input.dataType, dr, input.devID, input.mem);
/* destroy variables */
delete[] dimSize;
}
/* call _ReduceSum function */
_ReduceSum(&input, &output, dim, NULL, power, isExp);
if (output.enableGrad) {
/* tensor connections */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCESUM);
XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHead(&output, power);
XLink::AddParamToHeadBool(&output, isExp);
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -44,6 +44,8 @@ sum = \sum_i exp(a_i - shift) if isExp == true
*/
XTensor ReduceSum(const XTensor &input, int dim, const XTensor &shift, DTYPE power = (DTYPE)1.0F, bool isExp = false);
void ReduceSum(const XTensor &input, XTensor &output, int dim, const XTensor &shift, DTYPE power = (DTYPE)1.0F, bool isExp = false, bool requireLink = false);
/*
sum the items along a dimension of the tensor (return an XTensor structure)
make a new tensor to keep the result and return it
......@@ -53,6 +55,14 @@ sum = \sum_i exp(a_i) if isExp == true
*/
XTensor ReduceSum(const XTensor &input, int dim, DTYPE power = (DTYPE)1.0F, bool isExp = false);
/*
sum the items along a dimension of the tensor
For a 1-dimensional data array a,
sum = \sum_i (a_i - shift) if isExp == false
sum = \sum_i exp(a_i - shift) if isExp == true
*/
void ReduceSum(const XTensor &input, XTensor &output, int dim, DTYPE power = (DTYPE)1.0F, bool isExp = false, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __REDUCESUM_H__
......@@ -82,4 +82,46 @@ XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift)
return output;
}
/*
squared sum of the items along a dimension of the tensor
For a 1-dimensional data array a, sum = \sum_i (a_i - shift)^2
>> input - the input tensor
>> output - the output tensor
>> dim - the dimension where the reduction is performed on
>> shift - bias on the input
>> requireLink - if add operation to network
*/
void ReduceSumSquared(const XTensor &input, XTensor &output, int dim, const XTensor &shift, bool requireLink)
{
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!");
if (!output.isInit || !XTensor::IsReduceShaped(&input, &output, dim)) {
int order = input.order - 1;
int * dimSize = new int[order];
for (int i = 0; i < order; i++) {
if (i < dim)
dimSize[i] = input.dimSize[i];
else if (i >= dim)
dimSize[i] = input.dimSize[i + 1];
}
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
InitTensor(&output, order, dimSize, input.dataType, dr, input.devID, input.mem);
/* destroy variables */
delete[] dimSize;
}
/* call _ReduceSumSquared function */
_ReduceSumSquared(&input, &output, dim, &shift);
if (output.enableGrad) {
/* tensor connections */
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUMSQUARED);
XLink::AddParamToHeadInt(&output, dim);
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -40,6 +40,13 @@ For a 1-dimensional data array a, sum = \sum_i (a_i - shift)^2
*/
XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift);
/*
squared sum of the items along a dimension of the tensor
For a 1-dimensional data array a,
sum = \sum_i (a_i - shift)^2
*/
void ReduceSumSquared(const XTensor &input, XTensor &output, int dim, const XTensor &shift, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __REDUCESUMSQUARED_H__
......
......@@ -84,4 +84,47 @@ XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean)
return output;
}
/*
variance of the items along a dimension of the tensor
For a 1-dimensional data array a, variance = 1/n * \sum_i (a_i - mean)^2
>> input - the input tensor
>> output - the output tensor
>> dim - the dimension where the reduction is performed on
>> mean - the mean value
>> requireLink - if add operation to network
*/
void ReduceVariance(const XTensor &input, XTensor &output, int dim, const XTensor &mean, bool requireLink)
{
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!");
if (!output.isInit || !XTensor::IsReduceShaped(&input, &output, dim)) {
int order = input.order - 1;
int * dimSize = new int[order];
for (int i = 0; i < order; i++) {
if (i < dim)
dimSize[i] = input.dimSize[i];
else if (i >= dim)
dimSize[i] = input.dimSize[i + 1];
}
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
InitTensor(&output, order, dimSize, input.dataType, dr, input.devID, input.mem);
/* destroy variables */
delete[] dimSize;
}
/* call _ReduceVariance function */
_ReduceVariance(&input, &output, dim, &mean);
if (output.enableGrad) {
/* tensor connection */
XLink::MakeLink(&input, &mean, &output, REDUCE_REDUCEVARIANCE);
XLink::AddParamToHeadInt(&output, dim);
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -39,6 +39,12 @@ For a 1-dimensional data array a, variance = 1/n * \sum_i (a_i - mean)^2
*/
XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean);
/*
variance of the items along a dimension of the tensor
For a 1-dimensional data array a, variance = 1/n * \sum_i (a_i - mean)^2
*/
void ReduceVariance(const XTensor &input, XTensor &output, int dim, const XTensor &mean, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __REDUCEVARIANCE_H__
......@@ -148,6 +148,39 @@ void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim)
}
}
bool CheckMergeSize(const XTensor * s, const XTensor * t, int whereToMerge, int leadingDim)
{
if (!(s && t))
return false;
if (!(s->dataType == t->dataType))
return false;
if (leadingDim < 0)
leadingDim = 0;
int order = s->order - 1;
int * dimSize = new int[order];
for (int i = 0; i < s->order; i++) {
if (i < leadingDim)
dimSize[i] = s->dimSize[i];
else if (i > leadingDim) {
if (i != whereToMerge)
dimSize[i - 1] = s->dimSize[i];
else
dimSize[i - 1] = s->dimSize[i] * s->dimSize[leadingDim];
}
}
for (int i = 0; i < order; i++) {
if (dimSize[i] != t->dimSize[i])
return false;
}
return true;
}
/*
transform a tensor by merging it along with a dimension (return an XTensor structure)
make a new tensor to keep the result and return it
......@@ -199,6 +232,43 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
return t;
}
void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim, bool requireLink)
{
if (!t.isInit || !CheckMergeSize(&s, &t, whereToMerge, leadingDim)) {
if (leadingDim < 0)
leadingDim = 0;
int order = s.order - 1;
int * dimSize = new int[order];
for (int i = 0; i < s.order; i++) {
if (i < leadingDim)
dimSize[i] = s.dimSize[i];
else if (i > leadingDim) {
if (i != whereToMerge)
dimSize[i - 1] = s.dimSize[i];
else
dimSize[i - 1] = s.dimSize[i] * s.dimSize[leadingDim];
}
}
float dr = (!s.isSparse) ? 1.0F : s.denseRatio;
InitTensor(&t, order, dimSize, s.dataType, dr, s.devID, s.mem);
/* destroy variables */
delete[] dimSize;
}
/* call _Merge function */
_Merge(&s, &t, whereToMerge, leadingDim);
if (t.enableGrad) {
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
XLink::AddParamToHeadInt(&t, whereToMerge);
XLink::AddParamToHeadInt(&t, leadingDim);
}
}
/*
merge small tensors into a big tensor
......
......@@ -33,15 +33,21 @@ void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim = -
e.g., (M, N/3, 3) -> (M, N) */
XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim = -1);
void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim = -1, bool requireLink = false);
/* merge small tensors into a big tensor */
void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge);
/* merge small tensors into a big tensor (return an XTensor structure) */
XTensor Merge(const TensorList &smalls, int whereToMerge);
void Merge(const TensorList &smalls, XTensor &t, int whereToMerge);
/* merge two tensors into a big tensor (return an XTensor structure) */
XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge);
void Merge(const XTensor &smallA, const XTensor &smallB, XTensor &t, int whereToMerge);
} // namespace nts(NiuTrans.Tensor)
#endif // __MERGE_H__
\ No newline at end of file
......@@ -30,7 +30,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* copy a number of blocks (of different sizes) to target positions */
__global__
void KernelCopyBlockLists(DTYPE * sourceList[], int * sourceBlockSizes, int sourceBlockNum, DTYPE * targetList[]);
void KernelCopyBlockLists(DTYPE ** sourceList, int * sourceBlockSizes, int sourceBlockNum, DTYPE ** targetList);
/* merge data by blocks (cuda version) */
void _CudaMergeBlockLists(const StrList* sourceList, int * blockSizes, int blockNum, void * target, XMem * myMem);
......
......@@ -48,4 +48,19 @@ XTensor Reshape(XTensor &s, int order, int * dimSize)
return t;
}
void Reshape(XTensor &s, XTensor &t, int order, int * dimSize, bool requireLink)
{
if (!t.isInit || !XTensor::IsSameShaped(&t, &s)) {
InitTensor(&t, &s);
}
/* call Reshape function */
t.Reshape(order, dimSize);
if (t.enableGrad) {
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE);
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -29,5 +29,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* reshape the tensor */
XTensor Reshape(XTensor &s, int order, int * dimSize);
void Reshape(XTensor &s, XTensor &t, int order, int * dimSize, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __RESHAPE_H__
......@@ -156,6 +156,33 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
}
}
bool CheckSplitSize(const XTensor * s, const XTensor * t, int whereToSplit, int splitNum)
{
if (!(s && t))
return false;
if (!(s->dataType == t->dataType))
return false;
int order = s->order + 1;
int * dimSize = new int[order];
dimSize[0] = splitNum;
for (int i = 0; i < s->order; i++) {
if (i == whereToSplit)
dimSize[i + 1] = s->dimSize[i] / splitNum;
else
dimSize[i + 1] = s->dimSize[i];
}
for (int i = 0; i < order; i++) {
if (dimSize[i] != t->dimSize[i])
return false;
}
return true;
}
/*
transform a tensor by splitting it, e.g., (N, M) -> (N/3, M, 3) (return an XTensor structure)
make a new tensor to keep the result and return it
......@@ -200,6 +227,38 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
return t;
}
void Split(const XTensor &s, XTensor &t, int whereToSplit, int splitNum, bool requireLink)
{
if (!t.isInit || !CheckSplitSize(&s, &t, whereToSplit, splitNum)) {
int order = s.order + 1;
int * dimSize = new int[order];
dimSize[0] = splitNum;
for (int i = 0; i < s.order; i++) {
if (i == whereToSplit)
dimSize[i + 1] = s.dimSize[i] / splitNum;
else
dimSize[i + 1] = s.dimSize[i];
}
float dr = (!s.isSparse) ? 1.0F : s.denseRatio;
InitTensor(&t, order, dimSize, s.dataType, dr, s.devID, s.mem);
/* destroy variables */
delete[] dimSize;
}
/* call _Split function */
_Split(&s, &t, whereToSplit, splitNum);
if (t.enableGrad) {
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT);
XLink::AddParamToHeadInt(&t, whereToSplit);
XLink::AddParamToHeadInt(&t, splitNum);
}
}
/*
split a big tensor into small tensors
......
......@@ -41,6 +41,8 @@ e.g., (M, N) -> (M, N/3, 3)
*/
XTensor Split(const XTensor &s, int whereToSplit, int splitNum);
void Split(const XTensor &s, XTensor &t, int whereToSplit, int splitNum, bool requireLink = false);
/* split a big tensor into small tensors */
void _Split(const XTensor * big, TensorList * smalls, int whereToSplit, int splitNum);
......
......@@ -112,4 +112,19 @@ XTensor Squeeze(XTensor & source, int leadingDim)
return target;
}
void Squeeze(XTensor & source, XTensor & target, int leadingDim, bool requireLink)
{
if (!target.isInit || !XTensor::IsSameShaped(&source, &target)) {
InitTensor(&target, &source);
}
/* call _Squeeze function */
_Squeeze(&source, &target, leadingDim);
if (target.enableGrad) {
/* tensor connections */
XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE);
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -37,6 +37,8 @@ void _SqueezeMe(XTensor * source, int leadingDim = -1);
make a new tensor to keep the result and return it */
XTensor Squeeze(XTensor & source, int leadingDim = -1);
void Squeeze(XTensor & source, XTensor & target, int leadingDim = -1, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __SQUEEZE_H__
\ No newline at end of file
......@@ -96,6 +96,34 @@ void _Unsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
}
}
bool CheckUnsqueezeSize(const XTensor * a, const XTensor * b, int dim, int dSize)
{
if (!(a && b))
return false;
if (!(a->dataType == b->dataType))
return false;
int order = a->order + 1;
int * dimSize = new int[order];
for (int i = 0; i < order; i++) {
if (i < dim)
dimSize[i] = a->dimSize[i];
else if (i == dim)
dimSize[i] = dSize;
else
dimSize[i] = a->dimSize[i - 1];
}
for (int i = 0; i < order; i++) {
if (dimSize[i] != b->dimSize[i])
return false;
}
return true;
}
/*
insert a dimension by copying the blocks for x times
(where x is the size of the inerted dimension) (returna a XTensor structure)
......@@ -138,4 +166,37 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize)
return b;
}
void Unsqueeze(const XTensor &a, XTensor &b, int dim, int dSize, bool requireLink)
{
if (!b.isInit || !CheckUnsqueezeSize(&a, &b, dim, dSize)) {
int order = a.order + 1;
int * dimSize = new int[order];
for (int i = 0; i < order; i++) {
if (i < dim)
dimSize[i] = a.dimSize[i];
else if (i == dim)
dimSize[i] = dSize;
else
dimSize[i] = a.dimSize[i - 1];
}
float dr = (!a.isSparse) ? 1.0F : a.denseRatio;
InitTensor(&b, order, dimSize, a.dataType, dr, a.devID, a.mem);
/* destroy variables */
delete[] dimSize;
}
/* call _Unsqueeze function */
_Unsqueeze(&a, &b, dim, dSize);
if (b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE);
XLink::AddParamToHeadInt(&b, dim);
XLink::AddParamToHeadInt(&b, dSize);
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -35,6 +35,8 @@ void _Unsqueeze(const XTensor * a, XTensor * b, int dim, int dSize);
make a new tensor to keep the result and return it */
XTensor Unsqueeze(const XTensor &a, int dim, int dSize);
void Unsqueeze(const XTensor &a, XTensor &b, int dim, int dSize, bool requireLink = false);
} // namespace nts(NiuTrans.Tensor)
#endif // __UNSQUEEZE_H__
......@@ -114,12 +114,12 @@ void Sort(XTensor & a, XTensor & b, XTensor & index, int dim)
_Sort(&a, &b, &index, dim);
/* tensor connections */
TensorList list(2);
list.Add(&b);
list.Add(&index);
XLink::MakeLink(&a, &list, SORT_SORT);
XLink::AddParamToHeadInt(&b, dim);
XLink::AddParamToHeadInt(&index, dim);
//TensorList list(2);
//list.Add(&b);
//list.Add(&index);
// XLink::MakeLink(&a, &list, SORT_SORT);
// XLink::AddParamToHeadInt(&b, dim);
// XLink::AddParamToHeadInt(&index, dim);
}
} // namespace nts(NiuTrans.Tensor)
......@@ -128,14 +128,14 @@ void TopK(XTensor &a, XTensor &b, XTensor &index, int dim, int k)
_TopK(&a, &b, &index, dim, k);
/* tensor connection */
TensorList list(2);
list.Add(&b);
list.Add(&index);
XLink::MakeLink(&a, &list, SORT_TOPK);
XLink::AddParamToHeadInt(&b, dim);
XLink::AddParamToHeadInt(&index, k);
XLink::AddParamToHeadInt(&b, dim);
XLink::AddParamToHeadInt(&index, k);
//TensorList list(2);
//list.Add(&b);
//list.Add(&index);
//XLink::MakeLink(&a, &list, SORT_TOPK);
//XLink::AddParamToHeadInt(&b, dim);
//XLink::AddParamToHeadInt(&index, k);
//XLink::AddParamToHeadInt(&b, dim);
//XLink::AddParamToHeadInt(&index, k);
}
......
/* NiuTrans.Tensor - an open-source tensor library
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include <stdarg.h>
#include <math.h>
#include "XMatrixSegment.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
segment a 2d tensor (i.e., matrix) into blocks and run jobs in parallel
>> parallelRunner - parallel runner
>> job - the function to run
>> opNum - number of operations
>> rowNum - number of rows
>> colNum - number of columns
>> argNum - number of arguments of the jobs
>> ... - arguments of the jobs
*/
void RunParallel2D(XPRunner * parallelRunner, void * job,
int opNum, int rowNum, int colNum, int argNum, ...)
{
if (rowNum == 0 || colNum == 0)
return;
int jobNum = 1;
if (parallelRunner != NULL && (parallelRunner->method == PRUNNER_SINGLE || parallelRunner->method == PRUNNER_MULTIPLE)) {
if (opNum >= parallelRunner->minimumOPNum * parallelRunner->threadNum)
jobNum = parallelRunner->GetJobNum(rowNum * colNum);
}
CheckNTErrors(jobNum != 0, "TODO!");
/* argument list of the jobs */
XList * jobArgList = new XList(argNum);
va_list ap;
va_start(ap, argNum);
for (int i = 0; i < argNum; i++) {
XTensor* p = va_arg(ap, XTensor*);
jobArgList->Add(p);
}
va_end(ap);
/* prepare the neccesary argument list for parallel processing */
XList * jobs = new XList(jobNum);
XList * args = new XList(jobNum);
int * indexList = new int[jobNum * 4 * 4];
/* segment the matrix into blocks */
int nblock = SegmentTensor2D(rowNum, colNum, jobNum, indexList);
/*
assign jobs
argument rules:
1. block information
2. other arguments
*/
for (int i = 0; i < jobNum; i++) {
XList * blockArgs = new XList(argNum + 4);
int * blockIndex = indexList + i * 4;
blockArgs->Add((XTensor*)blockIndex);
blockArgs->Add((XTensor*)blockIndex + 1);
blockArgs->Add((XTensor*)blockIndex + 2);
blockArgs->Add((XTensor*)blockIndex + 3);
for (int j = 0; j < argNum; j++)
blockArgs->Add(jobArgList->GetItem(j));
args->Add((XTensor*)blockArgs);
jobs->Add((XTensor*)job);
}
args->count = nblock;
jobs->count = nblock;
/* single job */
if (jobNum == 1)
((TFunction)job)((XList*)args->GetItem(0));
/* multiple jobs */
else
parallelRunner->Run(jobs, args);
/* free the memory */
delete[] indexList;
for (int i = 0; i < args->count; i++) {
XList * blockArgs = (XList*)args->GetItem(i);
delete blockArgs;
}
delete args;
delete jobs;
delete jobArgList;
}
/*
segment a block into sub-blocks
>> rowNum - number of rows
>> colNum - number of columns
>> blockNum - number of sub-blocks
>> blockIndex - upper-left and bottom-right corners of each sub-block
<< return - the number of resulting sub-blocks
*/
int SegmentTensor2D(int rowNum, int colNum, int blockNum, int * blockIndex)
{
int total = rowNum * colNum;
int rowSize = (int)ceil(sqrt((float)total / blockNum));
int colSize = rowSize;
/* a narrow matrix */
if (rowSize > colNum * 0.9) {
rowSize = colNum;
colSize = (int)ceil((float)rowNum / blockNum);
}
/* a narrow matrix */
if (colSize > rowNum * 0.9) {
colSize = rowNum;
rowSize = (int)ceil((float)colNum / blockNum);
}
if (blockNum == 1) {
colSize = rowNum;
rowSize = colNum;
}
CheckNTErrors((colSize <= rowNum && rowSize <= colNum),
"Too large block!");
int x1, y1, x2, y2;
int xMax = rowNum - 1;
int yMax = colNum - 1;
int nblock = 0, nitem = 0;
int * indexList = blockIndex;
int xSegNum = int((float)rowNum / colSize);
int ySegNum = int((float)colNum / rowSize);
int marginBlockNum = blockNum - xSegNum * ySegNum;
/*
To maximize the number of resulting sub-block, we have to
make use of the margin block
*/
if (blockNum > 1 && marginBlockNum > 0) {
int margin = 0;
int step = 0;
if (rowNum < colNum) {
margin = int(((float)marginBlockNum / blockNum) * colNum);
step = (int)ceil((float)rowNum / marginBlockNum);
x1 = 0;
y1 = yMax - margin + 1;
x2 = step - 1;
y2 = yMax;
while (x2 <= xMax) {
int * blockIndex = indexList + nblock * 4;
blockIndex[0] = x1; blockIndex[1] = y1;
blockIndex[2] = x2; blockIndex[3] = y2;
nblock++;
nitem += (y2 - y1 + 1) * (x2 - x1 + 1);
if (x2 == xMax)
break;
x1 = x2 + 1;
x2 = x1 + step - 1;
if (x2 > xMax)
x2 = xMax;
}
yMax -= margin;
}
else {
margin = int(((float)marginBlockNum / blockNum) * rowNum);
step = (int)ceil((float)colNum / marginBlockNum);
x1 = xMax - margin + 1;
y1 = 0;
x2 = xMax;
y2 = step - 1;
while (y2 <= yMax) {
int * blockIndex = indexList + nblock * 4;
blockIndex[0] = x1; blockIndex[1] = y1;
blockIndex[2] = x2; blockIndex[3] = y2;
nblock++;
nitem += (y2 - y1 + 1) * (x2 - x1 + 1);
if (y2 == yMax)
break;
y1 = y2 + 1;
y2 = y1 + step - 1;
if (y2 > yMax)
y2 = yMax;
}
xMax -= margin;
}
colSize = (int)ceil((float)(xMax + 1) / xSegNum);
rowSize = (int)ceil((float)(yMax + 1) / ySegNum);
}
x1 = 0;
y1 = 0; // upper-left corner
x2 = colSize - 1;
y2 = rowSize - 1; // bottom-right corner
/* the main body of the matrix (after removing the margin block) */
while (x1 <= xMax) {
y1 = 0;
x2 = x1 + colSize - 1;
y2 = y1 + rowSize - 1;
if (x2 > xMax) {
x2 = xMax;
}
while (y2 <= yMax) {
int * blockIndex = indexList + nblock * 4;
blockIndex[0] = x1; blockIndex[1] = y1;
blockIndex[2] = x2; blockIndex[3] = y2;
nblock++;
nitem += (y2 - y1 + 1) * (x2 - x1 + 1);
if (y2 == yMax)
break;
y1 = y2 + 1;
y2 = y1 + rowSize - 1;
if (y2 > yMax)
y2 = yMax;
CheckNTErrors((nblock <= blockNum),
"Fail to segment the matrix!");
}
x1 = x2 + 1;
}
CheckNTErrors(nitem == rowNum * colNum,
"Fail to segment the matrix!");
return nblock;
}
/*
segment a block into sub-blocks (each block consists of a number of rows)
>> rowNum - number of rows
>> colNum - number of columns
>> blockNum - number of sub-blocks
>> blockIndex - upper-left and bottom-right corners of each sub-block
<< return - the number of resulting sub-blocks
*/
int SegmentTensor2DInRows(int rowNum, int colNum, int blockNum, int * blockIndex)
{
if (rowNum < blockNum) {
blockIndex[0] = 0;
blockIndex[1] = 0;
blockIndex[2] = rowNum - 1;
blockIndex[3] = colNum - 1;
return 1;
}
int segSize = (int)ceil((float)rowNum / blockNum);
int x1 = 0;
int x2 = x1 + segSize - 1;
int y1 = 0;
int y2 = colNum - 1;
int last = rowNum - 1;
int nblock = 0;
while (x1 <= last) {
x2 = x1 + segSize - 1;
if (x2 > last) {
x2 = last;
}
int * blockInfo = blockIndex + 4 * nblock;
blockInfo[0] = x1;
blockInfo[1] = y1;
blockInfo[2] = x2;
blockInfo[3] = y2;
nblock++;
if (x2 == last)
break;
x1 += segSize;
}
return nblock;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -65,10 +65,10 @@ void RunParallel2D(XPRunner * parallelRunner, void * job,
TensorList * jobs = new TensorList(jobNum);
TensorList * args = new TensorList(jobNum);
int * indeTensorList = new int[jobNum * 4 * 4];
int * indexList = new int[jobNum * 4 * 4];
/* segment the matrix into blocks */
int nblock = SegmentTensor2D(rowNum, colNum, jobNum, indeTensorList);
int nblock = SegmentTensor2D(rowNum, colNum, jobNum, indexList);
/*
assign jobs
......@@ -79,7 +79,7 @@ void RunParallel2D(XPRunner * parallelRunner, void * job,
for (int i = 0; i < jobNum; i++) {
IntList* indexArgs = new IntList(4);
TensorList * blockArgs = new TensorList(argNum);
int * blockIndex = indeTensorList + i * 4;
int * blockIndex = indexList + i * 4;
indexArgs->Add(blockIndex[0]);
indexArgs->Add(blockIndex[1]);
......@@ -95,7 +95,7 @@ void RunParallel2D(XPRunner * parallelRunner, void * job,
jobs->Add((XTensor*)job);
}
args->count = jobNum * 2;
jobs->count = nblock;
/* single job */
......@@ -106,7 +106,7 @@ void RunParallel2D(XPRunner * parallelRunner, void * job,
parallelRunner->Run(jobs, args);
/* free the memory */
delete[] indeTensorList;
delete[] indexList;
for (int i = 0; i < args->count; i++) {
TensorList * blockArgs = (TensorList*)args->GetItem(i);
delete blockArgs;
......@@ -154,7 +154,7 @@ int SegmentTensor2D(int rowNum, int colNum, int blockNum, int * blockIndex)
int xMax = rowNum - 1;
int yMax = colNum - 1;
int nblock = 0, nitem = 0;
int * indeTensorList = blockIndex;
int * indexList = blockIndex;
int xSegNum = int((float)rowNum / colSize);
int ySegNum = int((float)colNum / rowSize);
......@@ -175,7 +175,7 @@ int SegmentTensor2D(int rowNum, int colNum, int blockNum, int * blockIndex)
x2 = step - 1;
y2 = yMax;
while (x2 <= xMax) {
int * blockIndex = indeTensorList + nblock * 4;
int * blockIndex = indexList + nblock * 4;
blockIndex[0] = x1; blockIndex[1] = y1;
blockIndex[2] = x2; blockIndex[3] = y2;
nblock++;
......@@ -201,7 +201,7 @@ int SegmentTensor2D(int rowNum, int colNum, int blockNum, int * blockIndex)
x2 = xMax;
y2 = step - 1;
while (y2 <= yMax) {
int * blockIndex = indeTensorList + nblock * 4;
int * blockIndex = indexList + nblock * 4;
blockIndex[0] = x1; blockIndex[1] = y1;
blockIndex[2] = x2; blockIndex[3] = y2;
nblock++;
......@@ -241,7 +241,7 @@ int SegmentTensor2D(int rowNum, int colNum, int blockNum, int * blockIndex)
}
while (y2 <= yMax) {
int * blockIndex = indeTensorList + nblock * 4;
int * blockIndex = indexList + nblock * 4;
blockIndex[0] = x1; blockIndex[1] = y1;
blockIndex[2] = x2; blockIndex[3] = y2;
nblock++;
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-04-05
*/
#include "DataSet.h"
#include "StringUtil.h"
#include <string>
#include <vector>
#include <fstream>
#include <algorithm>
using namespace nts;
const int PAD = 0;
/*
load data from the file to the buffer
the data format:
each line: field1 ||| field2 ||| ... ||| fieldn
i.e. fields are separated by `|||` and tokens are separated by space
this will sort the data in ascending order if `sortBuffer` is True
*/
void DataSet::LoadDataToBuffer()
{
string line;
buffer.clear();
bufferUsed = 0;
const string tokenDelimiter = " ";
const string fieldsDelimiter = "|||";
int counter = bufferSize;
while (getline(*fp, line) && counter) {
auto fields = Split<string>(line, fieldsDelimiter);
for (const auto& field : fields) {
auto elements = Split<int>(field, tokenDelimiter);
buffer.emplace_back(elements);
}
counter--;
}
if (fp->eof()) {
fp->seekg(fp->beg);
// LOG("start a new epoch");
}
//if (sortBuffer) {
// auto myCompare = [](const auto& a, const auto& b) {
// return a.first.size() < b.first.size();
// };
// sort(buffer.begin(), buffer.end(), myCompare);
//}
}
/*
select a field and generate a mini-batch by indices
>>> t - a tensor to store the batch
>>> indices - the indices of data
>>> offset - the number of indices
>>> field - indicates which field is selected
this will combine selected elements into a padded batch
*/
void DataSet::LoadBatch(XTensor& t, const int* indices, size_t batchSize, size_t field)
{
if (bufferUsed == bufferSize || bufferUsed == 0) {
LoadDataToBuffer();
}
/* get the maximum length in a mini-batch */
size_t maxLength = 0;
for (size_t i = 0; i < batchSize; ++i) {
int idx = *(indices + i);
maxLength = max(maxLength, buffer[idx * fieldNum + field].size());
}
int* data = new int[maxLength * batchSize];
memset(data, 0, maxLength * batchSize);
size_t cur = 0;
for (size_t i = 0; i < batchSize; ++i) {
size_t next = cur + maxLength;
int idx = *(indices + i);
for (int v : buffer[idx * fieldNum + field]) {
data[cur++] = v;
}
/* pad zeros */
while (cur < next) {
data[cur++] = 0;
}
}
InitTensor2D(&t, batchSize, maxLength, X_INT);
t.SetData(data, maxLength * batchSize);
bufferUsed += batchSize;
delete[] data;
}
/*
the constructor of DataSet
>>> fname - path of the data file
>>> paraFieldNum - the number of different fields
>>> paraBufferSize - size of each field in the buffer
the real size of buffer is `bufferSize * fieldNum`
*/
DataSet::DataSet(const char* fname, size_t paraFieldNum, size_t paraBufferSize)
{
fp = new ifstream(fname);
fieldNum = paraFieldNum;
bufferSize = paraBufferSize;
bufferUsed = 0;
buffer.reserve(bufferSize * fieldNum);
CheckNTErrors(fp, "unable to open the file: %s", fname);
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-04-03
*/
#ifndef __DATASET_H__
#define __DATASET_H__
#include "../XGlobal.h"
#include "../XTensor.h"
#include <cstdio>
#include <fstream>
#include <unordered_map>
#include <vector>
using namespace std;
using BatchType = vector<vector<int>>;
using BucketType = vector<vector<int>>;
using BufferType = vector<vector<int>>;
namespace nts { // namespace nts(NiuTrans.Tensor)
/* A `DataSet` is associated with a file which contains variable length data.*/
class DataSet {
private:
/* the data buffer */
BufferType buffer;
/* the pointer to file stream */
ifstream* fp;
/* number of different fields */
size_t fieldNum;
/* size of the data buffer */
size_t bufferSize;
/* size of used data in buffer */
size_t bufferUsed;
/* load data from a file to the buffer */
void LoadDataToBuffer();
public:
/* sort data in the buffer */
virtual void Sort(){};
/* modify data in the buffer */
virtual void Process(){};
/* group data to buckets */
virtual BucketType Bucketing() { return BucketType(); };
/* generate a mini-batch */
void LoadBatch(XTensor& t, const int* indices, size_t batchSize, size_t field);
/* constructor */
explicit DataSet(const char* fname, size_t paraFieldNum, size_t paraBufferSize);
};
} // namespace nts(NiuTrans.Tensor)
#endif // __DATASET_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-04-02
*/
#ifndef __EMBEDDING_H__
#define __EMBEDDING_H__
#include "../XGlobal.h"
#include "../XTensor.h"
#include <vector>
#include <cstdio>
#include <fstream>
#include <unordered_map>
using namespace std;
using namespace nts;
/*
* A `Embedding` is associated with an embeddings table(embedNum * embedDim).
* The table can be loaded from a file or constructed from an instance.
*/
class Embedding {
public:
XTensor embeddingTable;
public:
/* save the embeddings table to a file */
void Dump(const char* fname);
/* looks up ids in a list of embedding tensors */
XTensor EmbeddingLookup(const XTensor& ids);
/* load an embeddings table from a file */
explicit Embedding(const char* fname);
/* construct table from an XTensor instance */
explicit Embedding(XTensor& table);
/* construct table from a array instance */
explicit Embedding(const float* p, size_t embedNum, size_t embedDim);
};
#endif // __EMBEDDING_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-03-18
*/
#ifndef __STRING_UTIL_H__
#define __STRING_UTIL_H__
#include <cstdlib>
#include <string>
#include <utility>
#include <vector>
using namespace std;
namespace nts {
/* Splits a string based on the given delimiter string. Each pair in the
* returned vector has the start and past-the-end positions for each of the
* parts of the original string. Empty fields are not represented in the output.
*/
vector<pair<int, int>> SplitToPos(const string& s, const string& delimiter);
/* Splits the given string and converts each part to the given T. */
template <typename T>
vector<T> Split(const string& s, const string& delimiter);
template <>
inline vector<string> Split(const string& s, const string& delimiter)
{
vector<string> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
fields.emplace_back(s.substr(p.first, p.second - p.first));
}
return fields;
}
template <>
inline vector<int> Split(const string& s, const string& delimiter)
{
vector<int> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
fields.emplace_back(strtol(s.data() + p.first, nullptr, 10));
}
return fields;
}
template <>
inline vector<int64_t> Split(const string& s, const string& delimiter)
{
vector<int64_t> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
fields.emplace_back(strtoll(s.data() + p.first, nullptr, 10));
}
return fields;
}
template <>
inline vector<float> Split(const string& s, const string& delimiter)
{
vector<float> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
fields.emplace_back(strtod(s.data() + p.first, nullptr));
}
return fields;
}
template <>
inline vector<uint8_t> Split(const string& s, const string& delimiter)
{
vector<uint8_t> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
fields.emplace_back(strtol(s.data() + p.first, nullptr, 10));
}
return fields;
}
template <>
inline vector<bool> Split(const string& s, const string& delimiter)
{
vector<bool> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
fields.emplace_back(
static_cast<bool>(strtol(s.data() + p.first, nullptr, 10)));
}
return fields;
}
} // namespace nts
#endif // __STRING_UTIL_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-03-18
*/
#include "Vocabulary.h"
#include "StringUtil.h"
#include <algorithm>
#include <fstream>
#include <iterator>
#include <map>
#include <string>
#include <vector>
using namespace nts;
/* initialize the mappings with some special tokens */
void nts::Vocabulary::Init()
{
const std::map<std::string, int> init{
{ "<UNK>", 0 },
{ "<PAD>", 1 },
{ "<SOS>", 2 },
{ "<EOS>", 3 },
};
for (const auto& v : init) {
token2ID[v.first] = v.second;
id2Token[v.second] = v.first;
}
}
/*
load vocabulary from disk
format(each line): token<\n>
*/
nts::Vocabulary::Vocabulary(const char* vocabPath)
{
Init();
std::ifstream file(vocabPath);
string line;
int index = int(token2ID.size());
while (getline(file, line)) {
size_t pos = line.find("\n");
string token = line.substr(0, pos);
token2ID[token] = index;
id2Token[index++] = token;
}
currentSize = int(token2ID.size());
}
/*
build vocabulary from the source file
>> srcPath - the path of source file
>> minFreq - the user specified minimum count for a single token
>> maxSize - the user specified maxmum size of a vocabulary
this will sort tokens by frequency and map the most uncommon tokens to <UNK>
*/
nts::Vocabulary::Vocabulary(const char* srcPath, int minFreq, int maxSize)
{
Init();
using Dict = vector<pair<string, int>>;
/* count word and their frequency */
std::ifstream file(srcPath);
map<string, int> dict;
string token;
while (file >> token) {
dict[token] += 1;
}
/* sort tokens by frequency */
Dict dictFlatten;
copy(dict.begin(), dict.end(), back_inserter<Dict>(dictFlatten));
auto myCompare = [](const auto& a, const auto& b) {
return a.second < b.second;
};
sort(dictFlatten.begin(), dictFlatten.end(), myCompare);
/* skip token whose frequency is out of top-maxSize or less than minFreq */
size_t index = token2ID.size();
size_t validSize = maxSize > dict.size() ? maxSize : dict.size();
for (const auto& tokenID : dictFlatten) {
if (validSize && tokenID.second > minFreq &&
token2ID.find(tokenID.first) == token2ID.end()) {
id2Token[int(index)] = tokenID.first;
token2ID[tokenID.first] = int(index);
index++;
validSize--;
}
}
currentSize = int(token2ID.size());
}
/* insert a token to the vocabulary */
void nts::Vocabulary::Insert(string token)
{
if (token2ID.find(token) != token2ID.end()) {
token2ID[token] = currentSize;
id2Token[currentSize++] = token;
}
}
/* save the vocabulary to a file */
void nts::Vocabulary::Dump(const char* vocabPath)
{
std::ofstream file(vocabPath);
for (const auto& tokenID : token2ID)
file << tokenID.first << "\t" << tokenID.second << "\n";
}
/*
maps tokens to integers
>>> tokens - a list of tokens
<<< indices - a list of corresponding indices
notices that this will map OOV to UNK
*/
vector<int> nts::Vocabulary::Token2ID(vector<string> tokens)
{
vector<int> indices;
indices.reserve(tokens.size());
for (auto str : tokens) {
if (token2ID.find(str) == token2ID.end())
indices.emplace_back(token2ID["<UNK>"]);
else
indices.emplace_back(token2ID[str]);
}
return indices;
}
/*
maps integers to tokens
>>> indices - a list of indices
<<< tokens - a list of corresponding tokens
notices that this will throw a error if the id is not found
*/
vector<string> nts::Vocabulary::ID2Token(vector<int> ids)
{
vector<string> tokens;
tokens.reserve(ids.size());
for (auto id : ids) {
CheckNTErrors(id2Token.find(id) != id2Token.end(), "id not found!");
tokens.emplace_back(id2Token[id]);
}
return tokens;
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-04-03
*/
#ifndef __VOCABULARY_H__
#define __VOCABULARY_H__
#include "../XGlobal.h"
#include "../XTensor.h"
#include <cstdio>
#include <fstream>
#include <unordered_map>
#define LOG(...) \
do { \
fprintf(stderr, "[INFO] "); \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n"); \
fflush(stdout); \
} while (0)
#define ERR(...) \
do { \
fprintf(stderr, "[ERROR] "); \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n"); \
fflush(stdout); \
} while (0)
using namespace std;
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
* A Vocabulary maps strings to integers, allowing for strings to be mapped to an out-of-vocabulary token.
* Vocabularies are fit to a particular dataset, which we use to decide which tokens are in-vocabulary.
* For convinience, a vocabulary can be built from a source file or be loaded from a vocabulary file.
*/
class Vocabulary {
private:
/* current size of the vocabulary */
int currentSize = 0;
/* initialize mappings with some special tokens */
void Init();
public:
/* a dict maps tokens to indices */
unordered_map<string, int> token2ID;
/* a dict maps indices to tokens */
unordered_map<int, string> id2Token;
/* save the vocabulary to a file */
void Dump(const char* vocabPath);
/* append token to the vocabulary */
void Insert(string token);
/* maps integers to tokens */
vector<string> ID2Token(vector<int> ids);
/* maps tokens to integers */
vector<int> Token2ID(vector<string> tokens);
/* constructor */
/* load the vocabulary from a file */
explicit Vocabulary(const char* vocabPath);
/* built the vocabulary from the source file */
explicit Vocabulary(const char* srcPath, int minFreq, int maxSize);
};
} // namespace nts(NiuTrans.Tensor)
#endif // __VOCABULARY_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-03-18
*/
#include "DataReader.h"
#include "StringUtil.h"
#include <algorithm>
#include <fstream>
#include <iterator>
#include <map>
#include <string>
#include <vector>
using namespace nts;
/* initialize the mappings with some special tokens */
void nts::Vocabulary::Init()
{
const std::map<std::string, int> init{
{ "<UNK>", 0 },
{ "<PAD>", 1 },
{ "<SOS>", 2 },
{ "<EOS>", 3 },
};
for (const auto& v : init) {
token2ID[v.first] = v.second;
id2Token[v.second] = v.first;
}
}
/*
load vocabulary from disk
format(each line): token<\n>
notice that there are some pre-defined special tokens
*/
nts::Vocabulary::Vocabulary(const char* vocabPath)
{
Init();
std::ifstream file(vocabPath);
string line;
int index = int(token2ID.size());
while (getline(file, line)) {
size_t pos = line.find("\n");
string token = line.substr(0, pos);
token2ID[token] = index;
id2Token[index++] = token;
}
currentSize = int(token2ID.size());
}
/*
build vocabulary from the source file
>> srcPath - the path of source file
>> minFreq - the user specified minimum count for a single token
>> maxSize - the user specified maxmum size of a vocabulary
we sort tokens by their frequency, the most uncommon tokens will be mapped to <UNK>
if the amount of different tokens is larger than paraMaxSize.
*/
nts::Vocabulary::Vocabulary(const char* srcPath, int minFreq, int maxSize)
{
Init();
/* count word and their frequency */
std::ifstream file(srcPath);
map<string, int> dict;
string token;
while (file >> token) {
dict[token] += 1;
}
/* sort word by their frequency */
vector<pair<string, int>> dictFlatten;
copy(dict.begin(), dict.end(), back_inserter<vector<pair<string, int>>>(dictFlatten));
auto myCompare = [](const auto& a, const auto& b) {
return a.second < b.second;
};
sort(dictFlatten.begin(), dictFlatten.end(), myCompare);
/* skip token whose frequency is out of top-maxSize or less than minFreq */
size_t index = token2ID.size();
size_t validSize = maxSize > dict.size() ? maxSize : dict.size();
for (const auto& tokenID : dictFlatten) {
if (validSize && tokenID.second > minFreq && token2ID.find(tokenID.first) == token2ID.end()) {
id2Token[int(index)] = tokenID.first;
token2ID[tokenID.first] = int(index);
index++;
validSize--;
}
}
currentSize = int(token2ID.size());
}
/* insert a token to the vocabulary */
void nts::Vocabulary::Insert(string token)
{
if (token2ID.find(token) != token2ID.end()) {
token2ID[token] = currentSize;
id2Token[currentSize++] = token;
}
}
/* save the vocabulary to a file */
void nts::Vocabulary::Dump(const char* vocabPath)
{
std::ofstream file(vocabPath);
for (const auto& tokenID : token2ID)
file << tokenID.first << "\t" << tokenID.second << "\n";
}
/*
maps tokens to integers
>>> tokens - a list of tokens
<<< indices - a list of corresponding indices
notices that this will map OOV to UNK
*/
vector<int> nts::Vocabulary::Token2ID(vector<string> tokens)
{
vector<int> indices;
indices.reserve(tokens.size());
for (auto str : tokens) {
if (token2ID.find(str) == token2ID.end())
indices.emplace_back(token2ID["<UNK>"]);
else
indices.emplace_back(token2ID[str]);
}
return indices;
}
/*
maps integers to tokens
>>> indices - a list of indices
<<< tokens - a list of corresponding tokens
notices that this will throw a error if the id is not found
*/
vector<string> nts::Vocabulary::ID2Token(vector<int> ids)
{
vector<string> tokens;
tokens.reserve(ids.size());
for (auto id : ids) {
CheckNTErrors(id2Token.find(id) != id2Token.end(), "id not found!");
tokens.emplace_back(id2Token[id]);
}
return tokens;
}
/*
shuffle lines of the file
>> srcFile - the source file to shuffle
this will create a file "<srcFile>.random"
*/
void nts::Shuffle(const char* srcFile)
{
string command = "shuf ";
string tgtFile = srcFile;
tgtFile += ".random";
command += srcFile;
command += " > ";
command += tgtFile;
#ifndef WIN32
system(command.c_str());
#else
ShowNTErrors("Cannot shuffle the file on WINDOWS systems!");
#endif
}
/*
load data to the buffer
data format(a single line): "s1 s2 ... sn|||t1 t2 ... tm"
where s is the source token and t is the target token
this will transfer raw tokens to indices according to the vocabulary
>>> bufferSize - the number of sentences in the buffer
*/
void nts::DataSet::LoadDataToBuf()
{
dataBuf.clear();
string line;
const string srcTgtDelimiter = "|||";
const string sentDelimiter = " ";
int counter = bufferSize;
while (getline(*fp, line) && counter) {
auto subStrings = Split<string>(line, srcTgtDelimiter);
auto src = Split<string>(subStrings[0], sentDelimiter);
auto tgt = Split<string>(subStrings[1], sentDelimiter);
auto srcIndices = srcVocab->Token2ID(src);
auto tgtIndices = tgtVocab->Token2ID(tgt);
dataBuf.emplace_back(srcIndices, tgtIndices);
counter--;
}
}
/*
splits the data buffer to buckets
elements of the buffer are grouped together by length and batched
some samples may never be selected due to their unsuitable length
>>> batchSize - the maximum number of tokens
<<< buckets - a list of tokens indices in a mini-batch
*/
void nts::DataManager::BucketSeqByLen(int batchSize)
{
dataSet->LoadDataToBuf();
auto myCompare = [](const auto& a, const auto& b) {
return a.first.size() < b.first.size();
};
sort(dataSet->dataBuf.begin(), dataSet->dataBuf.end(), myCompare);
int realBatchSize = 0;
for (int i = 0; i < dataSet->dataBuf.size(); ++i) {
/* the first allocation */
if (!buckets.size()) {
buckets.emplace_back(1);
buckets.back().pop_back();
}
/* append the current sentence to bucket */
size_t seqLen = dataSet->dataBuf[i].first.size();
size_t seqNum = buckets.back().size();
if ((seqNum + 1) * seqLen <= batchSize &&
seqLen <= maxSeqLen && seqLen >= minSeqLen) {
buckets.back().push_back(i);
}
/* allocate a new bucket */
else {
buckets.emplace_back(1);
buckets.back().pop_back();
}
}
/* drop the last bucket if its size is too small */
if (buckets.back().size() * dataSet->dataBuf.back().first.size() < batchSize && buckets.size()>1)
buckets.pop_back();
}
/*
load a batch of data
>> batchEnc - the batch of the input sequences
>> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> devID - device id
>> mem - memory pool
*/
void nts::DataManager::LoadBatch(XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec,
XTensor* label, int devID, XMem* mem)
{
if (bucketIter == buckets.size() - 1 || buckets.size() == 0) {
BucketSeqByLen(batchSize);
bucketIter = 0;
}
auto bucket = buckets[bucketIter];
bucketIter++;
size_t sentenceCount = bucket.size();
size_t maxEnc = dataSet->dataBuf[bucket.back()].first.size();
size_t maxDec = dataSet->dataBuf[bucket.back()].second.size();
InitTensor2D(batchEnc, int(sentenceCount), int(maxEnc), X_INT, devID, mem);
InitTensor2D(paddingEnc, int(sentenceCount), int(maxEnc), X_FLOAT, devID, mem);
InitTensor2D(batchDec, int(sentenceCount), int(maxDec), X_INT, devID, mem);
InitTensor2D(paddingDec, int(sentenceCount), int(maxDec), X_FLOAT, devID, mem);
InitTensor2D(label, int(sentenceCount), int(maxDec), X_INT, devID, mem);
int* labelValues = new int[batchDec->unitNum];
int* batchEncValues = new int[batchEnc->unitNum];
int* batchDecValues = new int[batchDec->unitNum];
float* paddingEncOffsets = new float[batchEnc->unitNum];
float* paddingDecOffsets = new float[batchDec->unitNum];
memset(batchEncValues, 0, sizeof(int) * batchEnc->unitNum);
memset(batchDecValues, 0, sizeof(int) * batchDec->unitNum);
memset(paddingEncOffsets, 0, sizeof(float) * batchEnc->unitNum);
memset(paddingDecOffsets, 0, sizeof(float) * batchDec->unitNum);
memset(labelValues, 0, sizeof(float) * batchDec->unitNum);
size_t srcIter = 0;
size_t tgtIter = 0;
for (const auto& index : bucket) {
size_t srcOffset = srcIter + dataSet->dataBuf[bucket.back()].first.size();
size_t tgtOffset = tgtIter + dataSet->dataBuf[bucket.back()].second.size();
for (const auto& src : dataSet->dataBuf[index].first)
batchEncValues[srcIter++] = src;
for (size_t i = 0; i < dataSet->dataBuf[index].second.size(); ++i) {
if (i != 0)
labelValues[tgtIter - 1] = dataSet->dataBuf[index].second[i];
batchDecValues[tgtIter++] = dataSet->dataBuf[index].second[i];
}
while (srcIter < srcOffset)
paddingEncOffsets[srcIter++] = 1.0F;
while (tgtIter < tgtOffset)
paddingDecOffsets[tgtIter++] = 1.0F;
}
batchEnc->SetData(batchEncValues, batchEnc->unitNum);
batchDec->SetData(batchDecValues, batchDec->unitNum);
paddingEnc->SetData(paddingEncOffsets, batchEnc->unitNum);
paddingDec->SetData(paddingDecOffsets, batchDec->unitNum);
label->SetData(labelValues, label->unitNum);
delete[] batchEncValues;
delete[] batchDecValues;
delete[] labelValues;
delete[] paddingEncOffsets;
delete[] paddingDecOffsets;
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-03-18
*/
#ifndef _DATA_READER_H_
#define _DATA_READER_H_
#include "../XGlobal.h"
#include "../XTensor.h"
#include <cstdio>
#include <fstream>
#include <unordered_map>
#define LOG(...) \
do { \
fprintf(stderr, "[INFO] "); \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n"); \
fflush(stdout); \
} while (0)
#define ERR(...) \
do { \
fprintf(stderr, "[ERROR] "); \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n"); \
fflush(stdout); \
} while (0)
using namespace std;
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
* A Vocabulary maps strings to integers, allowing for strings to be mapped to an out-of-vocabulary token.
* Vocabularies are fit to a particular dataset, which we use to decide which tokens are in-vocabulary.
* For convinience, a vocabulary can be built from a source file or be loaded from a vocabulary file.
*/
class Vocabulary {
private:
/* current size of the vocabulary */
int currentSize = 0;
/* initialize mappings with some special tokens */
void Init();
public:
/* a dict maps tokens to indices */
unordered_map<string, int> token2ID;
/* a dict maps indices to tokens */
unordered_map<int, string> id2Token;
Vocabulary() = delete;
Vocabulary(Vocabulary&&) = delete;
/* load the vocabulary from a file */
explicit Vocabulary(const char* vocabPath);
/* built the vocabulary from the source file */
explicit Vocabulary(const char* srcPath, int minFreq, int maxSize);
/* append token to the vocabulary */
void Insert(string token);
/* save the vocabulary to a file */
void Dump(const char* vocabPath);
/* maps tokens to integers */
vector<int> Token2ID(vector<string> tokens);
/* maps integers to tokens */
vector<string> ID2Token(vector<int> ids);
};
/*
* A DataSet maintains a sequences buffer.
* An instance is tied to a specified file and buffer size.
*/
class DataSet {
public:
/* number of buffered sequences */
int bufferSize;
/* the pointer to file stream */
ifstream* fp;
/* source-side vocabulary */
Vocabulary* srcVocab;
/* target-side vocabulary */
Vocabulary* tgtVocab;
/* sequences buffer */
vector<pair<vector<int>, vector<int>>> dataBuf;
DataSet(const DataSet&) = delete;
DataSet(DataSet&&) = delete;
/* constructor */
explicit DataSet(const char* fname, const char* srcVocabPath, const char* tgtVocabPath, int paraBufferSize, bool isShuffled)
{
bufferSize = paraBufferSize;
srcVocab = new Vocabulary(srcVocabPath);
tgtVocab = new Vocabulary(tgtVocabPath);
string trainFN = fname;
#ifndef WIN32
if (isShuffled) {
Shuffle(trainFN);
trainFN += ".random";
}
#endif
fp = new ifstream();
fp->open(trainFN);
CheckNTErrors(fp, "unable to open %s", fname);
}
/* destructor */
~DataSet()
{
fp->close();
}
/* loading data to the buffer */
void LoadDataToBuf();
};
/*
* A DataManager splits buffer to buckets and generates batches.
* It must be associated with a dataset.
*/
class DataManager {
private:
/* size of a mini-batch */
int batchSize;
/* the maximum length of a sequence */
int maxSeqLen;
/* the minimum length of a sequence */
int minSeqLen;
/* index of current bucket */
int bucketIter = 0;
/* the associated data set */
DataSet* dataSet;
/* buckets positon in buffered sequences */
vector<vector<int>> buckets;
DataManager() = delete;
DataManager(DataManager&&) = delete;
DataManager(const DataManager&) = delete;
/* splits sequences to buckets */
void BucketSeqByLen(int batchSize);
public:
/* constructor */
explicit DataManager(const char* fname,
const char* srcVocab, const char* tgtVocab,
int paraBufferSize, bool isShuffled,
int paraMaxSeqLen, int paraMinSeqLen, int paraBatchSize)
{
batchSize = paraBatchSize;
maxSeqLen = paraMaxSeqLen;
minSeqLen = paraMinSeqLen;
dataSet = new DataSet(fname, srcVocab, tgtVocab, paraBufferSize, isShuffled);
}
/* destructor */
~DataManager()
{
delete dataSet;
}
/* loads a mini-batch */
void LoadBatch(XTensor* batchEnc,
XTensor* paddingEnc, XTensor* batchDec,
XTensor* paddingDec, XTensor* label,
int devID, XMem* mem);
};
/* shuffle a file and pipe it to an output file */
void Shuffle(const char* srcFile);
} // namespace nts(NiuTrans.Tensor)
#endif // _DATA_READER_H_
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-03-18
*/
#include "T2TDataReader.h"
#include "StringUtil.h"
#include <algorithm>
#include <cstdio>
#include <fstream>
void nts::T2TDataSet::Process()
{
}
/* split buffer to buckets */
void nts::T2TDataManager::BucketSeqByLen(int batchSize)
{
t2tDataSet->LoadDataToBuf();
auto myCompare = [](const auto& a, const auto& b) {
return a.size() < b.size();
};
sort(dataSet->dataBuf.begin(), dataSet->dataBuf.end(), myCompare);
int realBatchSize = 0;
for (int i = 0; i < dataSet->dataBuf.size(); ++i) {
/* the first allocation */
if (!bucketBoundary.size()) {
bucketBoundary.emplace_back(1);
}
/* append the current sentence to bucket if the total size is suitable */
int seqLen = dataSet->dataBuf[i].first.size();
int seqNum = bucketBoundary.back().size();
if ((seqNum + 1) * seqLen < batchSize && seqLen <= maxSeqLen && seqLen >= minSeqLen) {
bucketBoundary.back().push_back(i);
}
/* allocate a new bucket */
else {
bucketBoundary.emplace_back(1);
}
}
/* drop the last bucket if its size is too small */
if (bucketBoundary.back().size() * dataSet->dataBuf.back().first.size() < batchSize / 2)
bucketBoundary.pop_back();
}
/*
load a batch of sequences (for MT)
>> batchEnc - the batch of the input sequences
>> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> devID - device id
>> mem - memory pool
>> isTraining - indicates whether we are training the model
*/
void nts::T2TDataManager::LoadBatchMT(XTensor* batchEnc,
XTensor* paddingEnc, XTensor* batchDec,
XTensor* paddingDec, XTensor* label,
int devID, XMem* mem, bool isTraining)
{
if (bucketIter == bucketBoundary.size() - 1) {
BucketSeqByLen(batchSize);
bucketIter = 0;
}
auto bucket = bucketBoundary[bucketIter++];
int sentenceCount = bucket.size();
int maxEnc = dataSet->dataBuf[bucket.back()].first.size();
int maxDec = dataSet->dataBuf[bucket.back()].second.size();
InitTensor2D(batchEnc, sentenceCount, maxEnc, X_INT, devID, mem);
InitTensor2D(paddingEnc, sentenceCount, maxEnc, X_FLOAT, devID, mem);
InitTensor2D(batchDec, sentenceCount, maxDec, X_INT, devID, mem);
InitTensor2D(paddingDec, sentenceCount, maxDec, X_FLOAT, devID, mem);
InitTensor2D(label, sentenceCount, maxDec, X_INT, devID, mem);
batchEnc->SetZeroAll();
paddingEnc->SetZeroAll();
batchDec->SetZeroAll();
paddingDec->SetZeroAll();
label->SetZeroAll();
int* labelValues = new int[batchDec->unitNum];
int* batchEncValues = new int[batchEnc->unitNum];
int* batchDecValues = new int[batchDec->unitNum];
MTYPE* paddingEncOffsets = new MTYPE[batchEnc->unitNum];
MTYPE* paddingDecOffsets = new MTYPE[batchDec->unitNum];
memset(batchEncValues, 0, sizeof(int) * batchEnc->unitNum);
memset(batchDecValues, 0, sizeof(int) * batchDec->unitNum);
memset(paddingEncOffsets, 0, sizeof(int) * batchEnc->unitNum);
memset(paddingDecOffsets, 0, sizeof(int) * batchDec->unitNum);
memset(labelValues, 0, sizeof(int) * batchDec->unitNum);
int srcIter = 0;
int tgtIter = 0;
for (const auto& index : bucket) {
int srcOffset = srcIter + dataSet->dataBuf[bucket.back()].first.size();
int tgtOffset = tgtIter + dataSet->dataBuf[bucket.back()].second.size();
for (const auto& src : dataSet->dataBuf[index].first)
batchEncValues[srcIter++] = src;
for (const auto& src : dataSet->dataBuf[index].second)
batchDecValues[tgtIter++] = src;
while (srcIter < srcOffset)
paddingEncOffsets[srcIter++] = 1.0F;
while (tgtIter < tgtOffset)
paddingDecOffsets[tgtIter++] = 1.0F;
}
batchEnc->SetData(batchDecValues, batchEnc->unitNum);
batchDec->SetData(batchDecValues, batchDec->unitNum);
paddingEnc->SetData(paddingEncOffsets, batchEnc->unitNum);
paddingDec->SetData(paddingDecOffsets, batchDec->unitNum);
label->SetData(labelValues, label->unitNum);
delete[] batchEncValues;
delete[] batchDecValues;
delete[] labelValues;
delete[] paddingEncOffsets;
delete[] paddingDecOffsets;
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-03-18
*/
#ifndef _T2T_DATA_READER_H_
#define _T2T_DATA_READER_H_
#include "../XTensor.h"
#include "DataReader.h"
#define LOG(...) \
do { \
fprintf(stderr, "[INFO] "); \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n"); \
fflush(stdout); \
} while (0)
#define ERR(...) \
do { \
fprintf(stderr, "[ERROR] "); \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n"); \
fflush(stdout); \
} while (0)
using namespace std;
namespace nts { // namespace nts(NiuTrans.Tensor)
class T2TDataSet : public DataSet {
public:
/* source-side vocabulary */
Vocabulary* srcVocab;
/* target-side vocabulary */
Vocabulary* tgtVocab;
/* sequences buffer */
vector<pair<vector<int>, vector<int>>> t2tDataBuf;
/* constructor */
explicit T2TDataSet(const char* fname, int paraBufferSize,
const char* srcVocabPath, const char* tgtVocabPath, bool isShuffled)
: DataSet(fname, paraBufferSize, isShuffled)
{
srcVocab = new Vocabulary(srcVocabPath);
tgtVocab = new Vocabulary(tgtVocabPath);
}
/* destructor */
~T2TDataSet()
{
delete srcVocab;
delete tgtVocab;
}
/* loading setences to the buffer */
void LoadDataToBuf() override;
/* post processing of sequences after loading */
void Process() override;
};
class T2TDataManager : public DataManager {
public:
/* the associated data set */
T2TDataSet* t2tDataSet;
/* constructor */
explicit T2TDataManager(bool isShuffled, const char* fname,
int paraBatchSize, int paraMaxSeqLen,
int paraMinSeqLen, int paraBufferSize,
const char* srcVocabPath, const char* tgtVocabPath)
: DataManager(fname, paraBufferSize, isShuffled, paraMaxSeqLen, paraMinSeqLen, paraBatchSize)
{
maxSeqLen = paraMaxSeqLen;
minSeqLen = paraMinSeqLen;
t2tDataSet = new T2TDataSet(fname, paraBufferSize, srcVocabPath, tgtVocabPath, isShuffled);
}
/* */
void LoadBatchMT(XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec,
XTensor* label, int devID, XMem* mem,
bool isTraining);
};
} // namespace nts(NiuTrans.Tensor)
#endif // _T2T_DATA_READER_H_
我 你 他 <EOS> ||| <SOS> I You He <EOS>
我 你 <EOS> ||| <SOS> I You <EOS>
\ No newline at end of file
\ No newline at end of file
I
You
He
\ No newline at end of file
......@@ -21,12 +21,14 @@
#include "../XName.h"
#include <time.h>
#include <math.h>
#include "Dropout.h"
#include "Dropout.cuh"
#include "../core/arithmetic/Multiply.h"
#include "../core/arithmetic/MultiplyDim.h"
#include "../core/math/ScaleAndShift.h"
#include "../core/getandset/SetData.h"
#include "DropoutWithIndex.h"
namespace nts{ // namespace nts(NiuTrans.Tensor
......@@ -147,6 +149,7 @@ XTensor Dropout(const XTensor &x, DTYPE dropProb, int leadingDim, int leadingDim
CheckNTErrors(dropProb >= 0.0 && dropProb <= 1.0, "The probability must be 0-1!");
XTensor mask;
int * maskArrayInt = NULL;
DTYPE * maskArray = NULL;
DTYPE scaleFactor = (DTYPE)1.0 / ((DTYPE)1.0 - dropProb);
......@@ -157,6 +160,23 @@ XTensor Dropout(const XTensor &x, DTYPE dropProb, int leadingDim, int leadingDim
_SetDataRandP(&mask, 0, 1.0F, dropProb, scaleFactor);
return Multiply(x, mask);
/* dropout with index */
/*int unitNum = floor(x.unitNum*dropProb);
maskArrayInt = new int[unitNum];
for (int i = 0; i < unitNum; i++)
maskArrayInt[i] = rand() % x.unitNum;
XTensor maskindex;
InitTensor1D(&maskindex, unitNum, X_INT, x.devID, x.mem);
maskindex.SetData(maskArrayInt, unitNum);
delete[] maskArrayInt;
return DropoutWithIndex(x, maskindex, scaleFactor);*/
}
else if(leadingDim2 < 0){
int n = leadingDim;
......@@ -209,7 +229,6 @@ XTensor Dropout(const XTensor &x, DTYPE dropProb, int leadingDim, int leadingDim
return MultiplyBroadcast(x, mask);
}
}
/*
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Jiang Yufan (email: jiangyufan2018@outlook.com) 2019-03-20
*/
#include "DropoutWithIndex.h"
#include "DropoutWithIndex.cuh"
#include "../core/CHeader.h"
#include "../XName.h"
#include "Identity.h"
namespace nts {
/*
This is a special implementation of "dropout" to reduce memory with maskIndex.
>> x - input tensor
>> maskIndex - mask index tensor
>> c - output tensor
*/
void _DropoutWithIndex(const XTensor * x, XTensor * maskIndex, XTensor * c)
{
CheckNTErrors(maskIndex->order == 1, "Illegal tensor order!");
#ifdef USE_CUDA
if (maskIndex->devID >= 0 || x->devID >= 0 || c->devID >= 0) {
_CudaDropoutWithIndex(x, maskIndex, c);
return;
}
#endif
// TODO!!
ShowNTErrors("TODO!");
}
/*
This is a special implementation of "dropout" to reduce memory with maskIndex.
>> x - input tensor
>> maskIndex - mask index tensor
>> c - output tensor
>> scale - scale factor
*/
XTensor DropoutWithIndex(const XTensor &x, XTensor &maskIndex, DTYPE scale)
{
XTensor c;
int order = x.order;
int * dimSize = new int[order];
for (int i = 0; i < order; i++) {
dimSize[i] = x.dimSize[i];
}
InitTensor1D(&c, x.unitNum, x.dataType, x.devID, x.mem);
_SetDataFixedFloat(&c, 1.0F);
_DropoutWithIndex(&x, &maskIndex, &c);
c.Reshape(order, dimSize);
_MultiplyMe(&c, &x);
_ScaleAndShiftMe(&c, scale);
/* tensor connections */
XLink::MakeLink(&x, &maskIndex, &c, MOVEMENT_DROPOUTWITHINDEX);
XLink::AddParamToHead(&c, scale);
return c;
}
}// namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Jiang Yufan (email: jiangyufan2018@outlook.com) 2019-03-20
*/
#include "DropoutWithIndex.cuh"
#include "../XDevice.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
__global__
/*
This is a special implementation of "dropout" to reduce memory with maskIndex.
>> tData - the data pointer of the target tensor
>> sIndex - mask index
>> size - the size of the sIndex
*/
void KernelDropoutWithIndex1D(DTYPE * tData, int * sIndex, int size)
{
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
DTYPE * t = tData;
if (i < size) {
int id = sIndex[i];
t[id] = DTYPE(0.0F);
}
}
/*
This is a special implementation of "dropout" to reduce memory with maskIndex.
>> x - input tensor
>> maskIndex - mask index tensor
>> c - output tensor
*/
void _CudaDropoutWithIndex(const XTensor * x, XTensor * maskIndex, XTensor * c)
{
int devID = c->devID;
int blockNum = maskIndex->unitNum;
int cudaGrids[3];
int cudaBlocks[3];
int devIDBackup;
ProtectCudaDev(devID, devIDBackup);
GDevs.GetCudaThread(devID, blockNum, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0]);
dim3 threads(cudaBlocks[0]);
DTYPE * tData = (DTYPE*)c->data;
int * sIndex = NULL;
sIndex = (int *)maskIndex->data;
KernelDropoutWithIndex1D <<<blocks, threads >>>(tData, sIndex, blockNum);
BacktoCudaDev(devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Jiang Yufan (email: jiangyufan2018@outlook.com) 2019-03-20
*/
#ifndef __DROPOUTWITHINDEX_CUH__
#define __DROPOUTWITHINDEX_CUH__
#include "../XTensor.h"
#include "DropoutWithIndex.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* dropout with index (cuda version) */
void _CudaDropoutWithIndex(const XTensor * x, XTensor * maskIndex, XTensor * c);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __DROPOUTWITHINDEX_CUH__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-03-18
* $Created by: Jiang Yufan (email: jiangyufan2018@outlook.com) 2019-03-20
*/
#include "StringUtil.h"
#ifndef __DROPOUTWITHINDEX_H__
#define __DROPOUTWITHINDEX_H__
#include "../XTensor.h"
namespace nts {
/* split string by delimiter, this will return indices of all sub-strings */
vector<pair<int, int>> SplitToPos(const string& s, const string& delimiter)
{
vector<pair<int, int>> fields;
if (delimiter.length() == 0) {
fields.emplace_back(0, s.length());
return fields;
}
int pos = 0;
int start = 0;
while ((pos = s.find(delimiter, start)) != string::npos) {
if (pos != start) {
fields.emplace_back(start, pos);
}
start = pos + delimiter.length();
}
if (start != s.length()) {
fields.emplace_back(start, s.length());
}
return fields;
}
}
\ No newline at end of file
void _DropoutWithIndex(const XTensor * x, XTensor * maskIndex, XTensor * c);
XTensor DropoutWithIndex(const XTensor &x, XTensor &mask, DTYPE scale);
} // namespace nts(NiuTrans.Tensor)
#endif // !__DROPOUTWITHINDEX_H__
......@@ -26,7 +26,6 @@
#include "../XTensor.h"
#include "CrossEntropy.h"
#include "Dropout.h"
#include "HardTanH.h"
#include "Identity.h"
......
......@@ -23,7 +23,7 @@
#include "../XName.h"
#include "HardTanH.h"
#include "HardTanH.cuh"
#include "CrossEntropy.h"
#include "../loss/LHeader.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -84,6 +84,21 @@ XTensor HardTanH(const XTensor &x)
return y;
}
void HardTanH(const XTensor &x, XTensor &y, bool requireLink)
{
if (!y.isInit || !XTensor::IsSameShaped(&y, &x)) {
InitTensor(&y, &x);
}
/* call _HardTanH function */
_HardTanH(&x, &y);
if (y.enableGrad) {
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_HARDTANH);
}
}
/*
backward computation
......
......@@ -22,7 +22,7 @@
#include "HardTanH.h"
#include "HardTanH.cuh"
#include "Loss.cuh"
#include "CrossEntropy.cuh"
#include "../loss/CrossEntropy.cuh"
#include "../XDevice.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......
......@@ -40,6 +40,8 @@ void _HardTanH(const XTensor * x, XTensor * y);
/* hard tanh function (return an XTensor structure) */
XTensor HardTanH(const XTensor &x);
void HardTanH(const XTensor &x, XTensor &y, bool requireLink = false);
/* de/dx */
void _HardTanHBackward(XTensor * gold, XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx,
......
......@@ -21,7 +21,7 @@
#include "../XName.h"
#include "Identity.h"
#include "CrossEntropy.h"
#include "../loss/LHeader.h"
#include "../XUtility.h"
#include "../core/movement/CopyValues.h"
......@@ -57,6 +57,22 @@ XTensor Identity(const XTensor &x)
return y;
}
void Identity(const XTensor &x, XTensor &y, bool requireLink)
{
if (!y.isInit || !y.IsSameShaped(&y, &x)) {
InitTensor(&y, &x);
}
/* call _Identity function */
_Identity(&x, &y);
if (y.enableGrad) {
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY);
}
}
/*
backward computation for identity function y = x
......
......@@ -33,6 +33,8 @@ void _Identity(const XTensor * x, XTensor * y);
/* identity function y = x (return an XTensor structure) */
XTensor Identity(const XTensor &x);
void Identity(const XTensor &x, XTensor &y, bool requireLink = false);
/* de/dx */
void _IdentityBackward(XTensor * gold, XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx,
......
......@@ -194,6 +194,25 @@ XTensor LogSoftmax(const XTensor &x, int leadDim)
return y;
}
void LogSoftmax(const XTensor &x, XTensor &y, int leadDim, bool requireLink)
{
int ld = leadDim;
if (ld < 0)
ld = x.order - 1;
if (!y.isInit || !XTensor::IsSameShaped(&y, &x)) {
InitTensor(&y, &x);
}
/* call _LogSoftmax function */
_LogSoftmax(&x, &y, ld);
if (y.enableGrad) {
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
}
}
/*
log scale softmax y = log(e^x / \sum_{i} e^{x_i})
make a new tensor to keep the result and return it
......
......@@ -33,6 +33,8 @@ void _LogSoftmax(const XTensor * x, XTensor * y, int leadDim);
/* log scale softmax y = log(e^x / \sum_{i} e^{x_i}) (return an XTensor structure) */
XTensor LogSoftmax(const XTensor &x, int leadDim);
void LogSoftmax(const XTensor &x, XTensor &y, int leadDim, bool requireLink = false);
/* log scale softmax y = log(e^x / \sum_{i} e^{x_i}) (with both argument of x and y) */
void LogSoftmax(const XTensor &x, XTensor &y, int leadDim);
......
......@@ -22,7 +22,7 @@
#include "../XName.h"
#include "Rectify.h"
#include "Rectify.cuh"
#include "CrossEntropy.h"
#include "../loss/LHeader.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -77,6 +77,20 @@ XTensor Rectify(const XTensor &x)
return y;
}
void Rectify(const XTensor &x, XTensor &y, bool requireLink)
{
if (!y.isInit || !XTensor::IsSameShaped(&y, &x)) {
InitTensor(&y, &x);
}
/* call _Rectify function */
_Rectify(&x, &y);
if (y.enableGrad) {
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_RECTIFY);
}
}
/*
backward computation
......
......@@ -22,7 +22,7 @@
#include "Rectify.h"
#include "Rectify.cuh"
#include "Loss.cuh"
#include "CrossEntropy.cuh"
#include "../loss/CrossEntropy.cuh"
#include "../XDevice.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......
......@@ -33,6 +33,8 @@ void _Rectify(const XTensor * x, XTensor * y);
/* rectify function y = max(0, x) (return an XTensor structure) */
XTensor Rectify(const XTensor &x);
void Rectify(const XTensor &x, XTensor &y, bool requireLink = false);
/* de/dx */
void _RectifyBackward(XTensor * gold, XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx,
......
......@@ -23,7 +23,7 @@
#include <math.h>
#include "Sigmoid.h"
#include "Sigmoid.cuh"
#include "CrossEntropy.h"
#include "../loss/LHeader.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -75,6 +75,21 @@ XTensor Sigmoid(const XTensor &x)
return y;
}
void Sigmoid(const XTensor &x, XTensor &y, bool requireLink)
{
if (!y.isInit || !XTensor::IsSameShaped(&y, &x)) {
InitTensor(&y, &x);
}
/* call _Sigmoid function */
_Sigmoid(&x, &y);
if (y.enableGrad) {
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SIGMOID);
}
}
/*
backward computation
......
......@@ -22,7 +22,7 @@
#include "Sigmoid.h"
#include "Sigmoid.cuh"
#include "Loss.cuh"
#include "CrossEntropy.cuh"
#include "../loss/CrossEntropy.cuh"
#include "../XDevice.h"
#ifdef USE_CUDA
......
......@@ -33,6 +33,8 @@ void _Sigmoid(const XTensor * x, XTensor * y);
/* sigmoid function y = 1/(1+exp(-x)) (return an XTensor structure) */
XTensor Sigmoid(const XTensor &x);
void Sigmoid(const XTensor &x, XTensor &y, bool requireLink = false);
/* de/dx */
void _SigmoidBackward(XTensor * gold, XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx,
......
......@@ -148,6 +148,26 @@ XTensor Softmax(const XTensor &x, int leadDim)
return y;
}
void Softmax(const XTensor &x, XTensor &y, int leadDim, bool requireLink)
{
int ld = leadDim;
if (ld < 0)
ld = x.order - 1;
if (!y.isInit || !XTensor::IsSameShaped(&y, &x)) {
InitTensor(&y, &x);
}
/* call _Softmax function */
_Softmax(&x, &y, ld);
if (y.enableGrad) {
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
}
}
/*
backward computation for dense tensors
......
......@@ -372,27 +372,16 @@ void _CudaSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
int * dimSize = new int[y->order];
for(int i = 0; i < y->order; i++){
if(i < leadDim)
dimSize[i] = -y->dimSize[i];
dimSize[i] = y->dimSize[i];
else if(i > leadDim)
dimSize[i - 1] = -y->dimSize[i];
dimSize[i - 1] = y->dimSize[i];
}
XMem * mem = y->mem;
/* make a matrix of the same size as the y (i.e., y) */
XTensor * ytmp = NewTensor(y, false);
XTensor * ytmp = NewTensor(y);
/* make a matrix to keep \beta */
XTensor * beta = new XTensor(y->order - 1, dimSize, y->dataType, y->denseRatio, y->devID, mem);
if(mem != NULL){
ytmp->data = mem->AllocBuf(mem->devID, y->unitNum * y->unitSize);
beta->data = mem->AllocBuf(mem->devID, beta->unitNum * beta->unitSize);
}
else{
ytmp->data = XMemAlloc(y->devID, y->unitNum * y->unitSize);
beta->data = XMemAlloc(y->devID, beta->unitNum * beta->unitSize);
}
XTensor * beta = NewTensor(y->order - 1, dimSize, y->dataType, y->denseRatio, y->devID, y->mem);
/* \beta = \sum_i (dE/dy_i * y_i) */
_Multiply(dedy, y, ytmp, 0, 0);
......@@ -405,19 +394,6 @@ void _CudaSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
/* dE/ds_j = y_j * ytmp = y_j * (dE/dy_j - \beta) */
_Multiply(y, ytmp, dedx, 0, 0);
if(mem != NULL){
mem->ReleaseBuf(mem->devID, y->unitNum * y->unitSize);
mem->ReleaseBuf(mem->devID, beta->unitNum * beta->unitSize);
}
else{
XMemFree(y->devID, ytmp->data);
XMemFree(y->devID, beta->data);
}
ytmp->data = NULL;
beta->data = NULL;
delete[] dimSize;
delete ytmp;
delete beta;
......
......@@ -33,6 +33,8 @@ void _Softmax(const XTensor * x, XTensor * y, int leadDim);
/* softmax y = e^x / \sum_{i} e^{x_i} (return an XTensor structure) */
XTensor Softmax(const XTensor &x, int leadDim);
void Softmax(const XTensor &x, XTensor &y, int leadDim, bool requireLink = false);
/* de/dx */
void _SoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx,
......
......@@ -22,6 +22,8 @@
#include <math.h>
#include "CrossEntropy.h"
#include "CrossEntropy.cuh"
#include "../XTensor.h"
#include "../XName.h"
#include "../core/arithmetic/MultiplyDim.h"
#include "../core/arithmetic/Multiply.h"
#include "../core/math/Unary.h"
......@@ -61,7 +63,7 @@ void _CrossEntropy(const XTensor * output, const XTensor * gold,
CheckNTErrors(loss->order == output->order - 1, "Wrong loss dimension!");
CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE, "TODO!");
XTensor * interBuf1 = NewTensorBuf(output, output->devID, output->mem);
/*XTensor * interBuf1 = NewTensorBuf(output, output->devID, output->mem);
XTensor * interBuf2 = NewTensorBuf(output, output->devID, output->mem);
_Log(output, interBuf1);
......@@ -76,7 +78,23 @@ void _CrossEntropy(const XTensor * output, const XTensor * gold,
_MultiplyMe(loss, padding);
DelTensorBuf(interBuf2);
DelTensorBuf(interBuf1);
DelTensorBuf(interBuf1);*/
XTensor * inter = NewTensor(output);
_Log(output, inter);
_MultiplyMe(inter, gold);
if(weight != NULL)
_MultiplyDimMe(inter, weight, n);
_NegateMe(inter);
_ReduceSum(inter, loss, n);
if(padding != NULL)
_MultiplyMe(loss, padding);
DelTensor(inter);
}
/*
......@@ -223,6 +241,93 @@ void _CrossEntropyFast(const XTensor * output, const XTensor * gold,
}
/*
*/
XTensor GetReduceTensor(const XTensor & input, int dim)
{
CheckNTErrors(dim >= 0 && dim < input.order, "Illegal dimension to reduce!");
int order = input.order - 1;
int * dimSize = new int[order];
for(int i = 0; i < order; i++){
if(i < dim)
dimSize[i] = input.dimSize[i];
else if(i >= dim)
dimSize[i] = input.dimSize[i + 1];
}
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
XTensor output(order, dimSize, input.dataType, dr, input.devID, input.mem);
output.SetTMPFlag();
return output;
}
/*
compute the cross entropy loss (return an XTensor structure)
make a new tensor to keep the result and return it
loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions
>> output - model prediction
>> gold - gold standard
>> loss - compute loss
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
*/
XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
int leadingDim)
{
int dim = leadingDim < 0 ? output.order - 1 : leadingDim;
XTensor loss;
loss = GetReduceTensor(output, dim);
XTensor * weight = NULL;
XTensor * padding = NULL;
/* call _CrossEntropy function */
_CrossEntropy(&output, &gold, &loss, weight, padding, dim);
/* tensor connection */
TensorList tails(4);
tails.Add((XTensor*)&output);
tails.Add((XTensor*)&gold);
tails.Add(weight);
tails.Add(padding);
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim);
return loss;
}
XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
const XTensor & padding,
int leadingDim)
{
int dim = leadingDim < 0 ? output.order - 1 : leadingDim;
XTensor loss;
loss = GetReduceTensor(output, dim);
XTensor * weight = NULL;
/* call _CrossEntropy function */
_CrossEntropy(&output, &gold, &loss, weight, &padding, dim);
/* tensor connection */
TensorList tails(4);
tails.Add((XTensor*)&output);
tails.Add((XTensor*)&gold);
tails.Add(weight);
tails.Add((XTensor*)&padding);
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim);
return loss;
}
/*
compute the cross entropy loss
loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions
......@@ -579,16 +684,16 @@ void _CrossEntropyBackward(XTensor * dedy, const XTensor * output,
}
}
//if(padding != NULL) {
// XTensor * tmp = NewTensor(padding);
// _IsNonZero(padding, tmp);
// int nonZeroNum = (int)_ReduceSumAll(tmp);
// _ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)nonZeroNum);
// delete tmp;
//}
//else {
// _ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)blockNum);
//}
if(padding != NULL) {
XTensor * tmp = NewTensor(padding);
_IsNonZero(padding, tmp);
int nonZeroNum = (int)_ReduceSumAll(tmp);
_ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)nonZeroNum);
delete tmp;
}
else {
_ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)blockNum);
}
}
} // namespace nts(NiuTrans.Tensor)
......@@ -196,16 +196,17 @@ void _CudaCrossEntropyBackward(XTensor * dedy, const XTensor * output,
delete[] dims;
}
//if(padding != NULL) {
// XTensor * tmp = NewTensor(padding);
// _IsNonZero(padding, tmp);
// int nonZeroNum = (int)_ReduceSumAll(tmp);
// _ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)nonZeroNum);
// delete tmp;
//}
//else {
// _ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)blockNum);
//}
if(padding != NULL) {
XTensor * tmp = NewTensor(padding);
_IsNonZero(padding, tmp);
int nonZeroNum = (int)_ReduceSumAll(tmp);
_ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)nonZeroNum);
delete tmp;
}
else {
int num = dedy->unitNum / dedy->GetDim(n);
_ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)num);
}
}
......
......@@ -41,6 +41,25 @@ void _CrossEntropyFast(const XTensor * output, const XTensor * gold,
XTensor * loss, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* compute the cross entropy loss */
XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
int leadingDim = -1);
/* compute the cross entropy loss with padding */
XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
const XTensor & padding,
int leadingDim = -1);
/* compute the cross entropy loss with weight */
XTensor CrossEntropyWeight(const XTensor & output, const XTensor & gold,
const XTensor & weight,
int leadingDim = -1);
/* compute the cross entropy loss with weight and padding */
XTensor CrossEntropyWeight(const XTensor & output, const XTensor & gold,
const XTensor & padding, const XTensor & weight,
int leadingDim = -1);
/* compute the cross entropy loss (return the loss) */
DTYPE _CrossEntropy(const XTensor * output, const XTensor * gold,
LOSS_COMPUTE_WAY reduceWay, const XTensor * weight = NULL,
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2019-4-22
*/
/* this is a header to include all loss computations in the "loss" workspace */
#ifndef __LHEADER_H__
#define __LHEADER_H__
#include "CrossEntropy.h"
#endif // __LHEADER_H__
\ No newline at end of file
......@@ -22,7 +22,7 @@
#ifndef __TEST_CROSSENTROPY_H__
#define __TEST_CROSSENTROPY_H__
#include "../function/CrossEntropy.h"
#include "../loss/CrossEntropy.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......
......@@ -125,10 +125,10 @@ bool TestRectify2()
{1.0F, 1.0F, 1.0F} };
DTYPE yAnswer[2][3] = { {1.0F, 1.0F, 2.0F},
{2.0F, 4.0F, 5.0F} };
DTYPE dedyAnswer[2][3] = { {-1.0F, -1.0F, -0.5F},
{-0.5F, -0.25F, -0.2F} };
DTYPE dedxAnswer[2][3] = { {-1.0F, -1.0F, -0.5F},
{-0.5F, -0.25F, -0.2F} };
DTYPE dedyAnswer[2][3] = { {-0.5F, -0.5F, -0.25F},
{-0.25F, -0.125F, -0.1F} };
DTYPE dedxAnswer[2][3] = { {-0.5F, -0.5F, -0.25F},
{-0.25F, -0.125F, -0.1F} };
/* CPU test */
bool cpuTest = true;
......
......@@ -35,7 +35,7 @@ bool Test()
wrong = !TestConcatenate() || wrong;
wrong = !TestConcatenateSolely() || wrong;
wrong = !TestCos() || wrong;
wrong = !TestConvertDataType() || wrong;
//wrong = !TestConvertDataType() || wrong;
wrong = !TestCopyIndexed() || wrong;
wrong = !TestCopyValues() || wrong;
wrong = !TestDiv() || wrong;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论