Commit d992a0e9 by xuchen

Merge branch 'xuchen' into xiaotong-working

parents 7df7bc1d ceb5b101
......@@ -31,8 +31,6 @@ namespace nts{
/* compute dE/dx of a node */
void XFuncGrad::MakeGrad(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
int operID = income.typeID;
......
......@@ -26,6 +26,7 @@
#include "XBackwardShape.h"
#include "../tensor/XName.h"
#include "../tensor/core/CHeader.h"
#include "../tensor/core/getandset/SetData.h"
namespace nts{
......@@ -37,18 +38,22 @@ void XShapeGrad::MakeGrad(XTensor * node, bool isEfficent)
XLink &income = node->income;
int operID = income.typeID;
if(operID == SHAPE_MERGE)
if(operID == MOVEMENT_COPYINDEXED)
GradCopyIndexed(node, isEfficent);
else if(operID == SHAPE_MERGE)
GradMerge(node, isEfficent);
else if(operID == SHAPE_MERGE_LIST)
GradMergeList(node, isEfficent);
else if(operID == SHAPE_UNSQUEEZE)
GradUnsqueeze(node, isEfficent);
else if(operID == SHAPE_RESHAPE)
GradReshape(node, isEfficent);
else if(operID == SHAPE_SPLIT)
GradSplit(node, isEfficent);
else if(operID == SHAPE_SPLIT_LIST)
GradSplitList(node, isEfficent);
else if (operID == SHAPE_TRANSPOSE)
GradTranspose(node, isEfficent);
else if(operID == SHAPE_UNSQUEEZE)
GradUnsqueeze(node, isEfficent);
else{
ShowNTErrors("TODO!");
}
......@@ -69,6 +74,50 @@ void XShapeGrad::PostProcessing(XTensor * node, int typeID, bool isEfficent)
}
/*
gradient computation for copying indexed sub-tensors
for
b = copyindexed(a)
we have
dE/da = spread(b)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XShapeGrad::GradCopyIndexed(XTensor * node, bool isEfficent)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum > 0, "Wrong input tensor number for CopyIndexed!");
int dim = income.GetParamInt(0);
int * srcIndex = (int *)income.GetParamPointer(1);
int indexSize = income.GetParamInt(2);
int * tgtIndex = (int *)income.GetParamPointer(3);
int copyNum = income.GetParamInt(4);
int realIndexSize = indexSize * copyNum;
int * realSrcIndex = new int[realIndexSize];
int * realTgtIndex = new int[realIndexSize];
for(int i = 0; i < indexSize; i++) {
for(int j = 0; j < copyNum; j++) {
realSrcIndex[i * copyNum + j] = srcIndex[i] + j;
realTgtIndex[i * copyNum + j] = tgtIndex[i] + j;
}
}
XTensor * input = income.tails[0];
XNoder::MakeGrad(input);
_SpreadForGather(input->grad, node->grad, dim, realSrcIndex, realIndexSize, realTgtIndex);
delete[] realSrcIndex;
delete[] realTgtIndex;
delete[] srcIndex;
delete[] tgtIndex;
node->visitMark = NODE_FINISHED;
}
/*
gradient for merge
for
c = merge(a_0, a_1, ...)
......@@ -89,7 +138,7 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficent)
XTensor * input = income.tails[0];
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for MERGE!");
CheckNTErrors(node->order == input->order - 1, "wrong tensor orders!");
CheckNTErrors(node->order == input->order - 1, "Wrong tensor orders!");
int whereToMerge = income.GetParamInt(0);
int leadDim = income.GetParamInt(1);
......@@ -237,6 +286,36 @@ void XShapeGrad::GradMergeList(XTensor * node, bool isEfficient)
}
/*
gradient computation for reshaping a tensor
for
b = reshape(a)
we have
dE/da = reshape(dE/db)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XShapeGrad::GradReshape(XTensor * node, bool isEfficent)
{
XLink &income = node->income;
XTensor * input = income.tails[0];
XNoder::MakeGrad(input);
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for MERGE!");
int order = income.GetParamInt(0);
int * dimSize = (int *)income.GetParamPointer(1);
node->grad->Reshape(order, dimSize);
_CopyValues(node->grad, input->grad);
delete[] dimSize;
node->visitMark = NODE_FINISHED;
}
/*
gradient computation for split:
for
c = split(a)
......@@ -358,75 +437,75 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient)
}
}
/*
gradient for unsqueezing a tensor
/*
gradient for transposing a tensor
for
c = unsqueeze(a)
c = Transpose(a)
we have
dE/da = reduecesum(dE/dc)
dE/da = Transpose(dE/dc)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XShapeGrad::GradUnsqueeze(XTensor * node, bool isEfficient)
void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for UNSQUEEZE!");
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for TRANSPOSE!");
XTensor * output = node;
XTensor * input = income.tails[0];
XTensor * b = NewTensorBuf(input, input->devID, input->mem);
XNoder::MakeGrad(input);
int dim = income.GetParamInt(0);
int dSize = income.GetParamInt(1);
int i = income.GetParamInt(0);
int j = income.GetParamInt(1);
CheckNTErrors(dSize == output->GetDim(dim), "Wrong dim size for UNSQUEEZE!");
CheckNTErrors(output->unitNum = input->unitNum * dSize, "Wrong tensor size!");
XTensor * g = NewTensorBuf(input->grad, input->devID, input->mem);
_ReduceSum(output->grad, g, dim);
_Sum(input->grad, g, input->grad);
CheckNTErrors(input->order > i && i >= 0, "index of dimension is out of scope!");
CheckNTErrors(input->order > j && j >= 0, "index of dimension is out of scope!");
_Transpose(output->grad, b, i, j);
_Sum(input->grad, b, input->grad);
DelTensorBuf(g);
DelTensorBuf(b);
node->visitMark = NODE_FINISHED;
delete b;
}
/*
gradient for transposing a tensor
/*
gradient for unsqueezing a tensor
for
c = Transpose(a)
c = unsqueeze(a)
we have
dE/da = Transpose(dE/dc)
dE/da = reduecesum(dE/dc)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XShapeGrad::GradTranspose(XTensor * node, bool isEfficient)
void XShapeGrad::GradUnsqueeze(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for TRANSPOSE!");
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for UNSQUEEZE!");
XTensor * output = node;
XTensor * input = income.tails[0];
XTensor * b = NewTensorBuf(input, input->devID, input->mem);
XNoder::MakeGrad(input);
int i = income.GetParamInt(0);
int j = income.GetParamInt(1);
CheckNTErrors(input->order > i && i >= 0, "index of dimension is out of scope!");
CheckNTErrors(input->order > j && j >= 0, "index of dimension is out of scope!");
int dim = income.GetParamInt(0);
int dSize = income.GetParamInt(1);
_Transpose(output->grad, b, i, j);
_Sum(input->grad, b, input->grad);
CheckNTErrors(dSize == output->GetDim(dim), "Wrong dim size for UNSQUEEZE!");
CheckNTErrors(output->unitNum = input->unitNum * dSize, "Wrong tensor size!");
DelTensorBuf(b);
XTensor * g = NewTensorBuf(input->grad, input->devID, input->mem);
_ReduceSum(output->grad, g, dim);
_Sum(input->grad, g, input->grad);
DelTensorBuf(g);
node->visitMark = NODE_FINISHED;
delete b;
}
}
......@@ -45,6 +45,11 @@ public:
void PostProcessing(XTensor * node, int typeId, bool isEfficent);
private:
/* gradient computation for copying indexed sub-tensors: b = copyindexed(a, srcIndex, indexSize, tgtIndex, copyNum) */
static
void GradCopyIndexed(XTensor * node, bool isEfficent);
/* gradient computation for merge: c = merge(a, b, ...) */
static
void GradMerge(XTensor * node, bool isEfficent);
......@@ -52,6 +57,14 @@ private:
/* gradient computation for merging a list of tensors : c = merge(list(a, b, ...)) */
static
void GradMergeList(XTensor * node, bool isEfficent);
/* gradient computation for transposing a tensor : b = transpose(a) */
static
void GradTranspose(XTensor * node, bool isEfficent);
/* gradient computation for reshaping a tensor: c = reshape(a) */
static
void GradReshape(XTensor * node, bool isEfficent);
/* gradient computation for split: c = split(a) */
static
......@@ -71,10 +84,6 @@ private:
static
void GradUnsqueeze(XTensor * node, bool isEfficent);
/* gradient computation for unsqueezing a tensor : c = unsqueeze(a) */
static
void GradTranspose(XTensor * node, bool isEfficent);
};
}
......
......@@ -74,6 +74,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net);
void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NAME loss,
FNNModel &model, FNNModel &grad, FNNNet &net);
void ForwardAutoDiff(XTensor inputs[], XTensor &output, FNNModel &model);
void ForwardAutoDiff(NGram * ngrams, int batch, XTensor &output, FNNModel &model);
/*
entry of the program
......@@ -476,7 +477,12 @@ void Train(const char * train, bool isShuffled, FNNModel &model)
Clear(model, true);
/* forward + backward process */
ForwardAutoDiff(inputs, output, model);
/* this is implemented by gather function */
ForwardAutoDiff(ngrams, ngramNum, output, model);
/* this is implemented by multiply function */
//ForwardAutoDiff(inputs, output, model);
/* automatic differentiation */
autoDiffer.Backward(output, gold, CROSSENTROPY);
......@@ -975,7 +981,55 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
}
/*
forward process (with tensor connections)
forward process (with tensor connections) (this is implemented by gather function)
>> ngrams - the loaded ngrams
>> batch - the tensor encoding a batch of words
>> output - output probability
>> model - the fnn model
*/
void ForwardAutoDiff(NGram * ngrams, int batch, XTensor &output, FNNModel &model)
{
int n = model.n;
int depth = model.hDepth;
XTensor words;
XTensor embeddingBig;
XTensor hidden;
XTensor b;
int size = batch * (n-1);
int * index = new int[size];
for(int i = 0; i < batch; i++){
for (int j = 0; j < n-1; j++){
int a = i * (n - 1) + j;
index[a] = ngrams[i].words[j];
}
}
XTensor embedding;
embedding = Gather(model.embeddingW, 0, index, size);
delete[] index;
int dimSize[2];
dimSize[0] = embedding.GetDim(0) / (n - 1);
dimSize[1] = embedding.GetDim(1) * (n - 1);
hidden = Reshape(embedding, embedding.order, dimSize);
/* hidden layers */
for(int i = 0; i < depth; i++)
hidden = MMul(hidden, model.hiddenW[i]) + model.hiddenB[i];
/* output layer */
output = LogSoftmax(MMul(hidden, model.outputW) + model.outputB, 1);
//XLink::ShowNetwork(stderr, &output);
}
/*
forward process (with tensor connections) (this is implemented by multiply function)
>> inputs - input word representations
>> output - output probability
>> model - the fnn model
......@@ -1122,8 +1176,12 @@ void Test(const char * test, const char * result, FNNModel &model)
/* forward computation */
Forward(inputs, output, model, net);
}
else {
ForwardAutoDiff(inputs, output, model);
else {
/* this is implemented by gather function */
ForwardAutoDiff(ngrams, ngramNum, output, model);
/* this is implemented by multiply function */
//ForwardAutoDiff(inputs, output, model);
}
/* prediction probabilities */
......
......@@ -263,6 +263,18 @@ int XLink::GetParamInt(int i)
char * p = (char*)params + i * paramSize;
return *(int*)p;
}
/*
get a paramter in integer
>> i - id of the parameter
<< return - the parameter in integer
*/
void * XLink::GetParamPointer(int i)
{
CheckNTErrors(params != NULL, "parameter array cannot be empty!");
char * p = (char*)params + i * paramSize;
return *(int **)p;
}
/*
get a parameter in MATRIX_TRANS_TYPE
......@@ -401,8 +413,7 @@ add a boolean parameter
*/
void XLink::AddParamToHeadBool(XTensor * h, bool param)
{
if(h != NULL)
return;
CheckNTErrors(h != NULL, "head tensor cannot be empty!");
h->income.AddParam(&param, sizeof(bool));
}
......@@ -413,8 +424,7 @@ add a pointer parameter
*/
void XLink::AddParamToHeadPointer(XTensor * h, void * param)
{
if(h != NULL)
return;
CheckNTErrors(h != NULL, "head tensor cannot be empty!");
h->income.AddParam(&param, sizeof(param));
}
......
......@@ -127,6 +127,9 @@ struct XLink
/* get a paramter in integer */
int GetParamInt(int i);
/* get a paramter in pointer */
void * GetParamPointer(int i);
/* get a parameter in MATRIX_TRANS_TYPE */
MATRIX_TRANS_TYPE GetParamTrans(int i);
......
......@@ -35,6 +35,8 @@ const char * GetOPName(int type)
return "M_EXP";
else if (type == MATH_FLOOR)
return "M_FLOOR";
else if (type == MATH_ISZERO)
return "M_ISZERO";
else if (type == MATH_LOG)
return "M_LOG";
else if (type == MATH_SQRT)
......@@ -107,10 +109,14 @@ const char * GetOPName(int type)
return "S_MERGE_LIST";
else if (type == SHAPE_PERMUTE)
return "S_PERMUTE";
else if (type == SHAPE_RESHAPE)
return "S_RESHAPE";
else if (type == SHAPE_SPLIT)
return "S_SPLIT";
else if (type == SHAPE_SPLIT_LIST)
return "S_SPLIT_LIST";
else if (type == SHAPE_SQUEEZE)
return "S_SQUEEZE";
else if (type == SHAPE_TRANSPOSE)
return "S_TRANSPOSE";
else if (type == SHAPE_UNSQUEEZE)
......
......@@ -35,7 +35,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_CEIL MATH_ABSOLUTE + 1
#define MATH_EXP MATH_CEIL + 1
#define MATH_FLOOR MATH_EXP + 1
#define MATH_LOG MATH_FLOOR + 1
#define MATH_ISZERO MATH_FLOOR + 1
#define MATH_LOG MATH_ISZERO + 1
#define MATH_SQRT MATH_LOG + 1
#define MATH_SQUARE MATH_SQRT + 1
#define MATH_SIN MATH_SQUARE + 1
......@@ -81,9 +82,11 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define SHAPE_MERGE SHAPE_CONCATENATE + 1
#define SHAPE_MERGE_LIST SHAPE_MERGE + 1
#define SHAPE_PERMUTE SHAPE_MERGE_LIST + 1
#define SHAPE_SPLIT SHAPE_PERMUTE + 1
#define SHAPE_RESHAPE SHAPE_PERMUTE + 1
#define SHAPE_SPLIT SHAPE_RESHAPE + 1
#define SHAPE_SPLIT_LIST SHAPE_SPLIT + 1
#define SHAPE_TRANSPOSE SHAPE_SPLIT_LIST + 1
#define SHAPE_SQUEEZE SHAPE_SPLIT_LIST + 1
#define SHAPE_TRANSPOSE SHAPE_SQUEEZE + 1
#define SHAPE_UNSQUEEZE SHAPE_TRANSPOSE + 1
#define SORT SHAPE_UNSQUEEZE + 1
......
......@@ -373,24 +373,48 @@ XTensor XTensor::operator+ (const XTensor& tensor)
return Sum(*this, tensor);
}
/* overloading of the plus-sign */
XTensor XTensor::operator+ (const DTYPE shift)
{
return ScaleAndShift(*this, 1, shift);
}
/* overloading of the multiply-sign */
XTensor XTensor::operator* (const XTensor& tensor)
{
return Multiply(*this, tensor);
}
/* overloading of the multiply-sign */
XTensor XTensor::operator* (const DTYPE scale)
{
return ScaleAndShift(*this, scale, 0);
}
/* overloading of the minus-sign */
XTensor XTensor::operator- (const XTensor& tensor)
{
return Sub(*this, tensor);
}
/* overloading of the minus-sign */
XTensor XTensor::operator- (const DTYPE shift)
{
return ScaleAndShift(*this, 1, -shift);
}
/* overloading of the division-sign */
XTensor XTensor::operator/ (const XTensor& tensor)
{
return Div(*this, tensor);
}
/* overloading of the division-sign */
XTensor XTensor::operator/ (const DTYPE scale)
{
return ScaleAndShift(*this, (DTYPE)1/scale, 0);
}
/*
linear transformation b = a * \scale + \shift
>> scale - the slope
......@@ -439,7 +463,7 @@ judge whether the three matrices are in the same type and size
>> c - a tensor again
<< return - whether the two input tensors are identical
*/
bool XTensor::IsSameShaped(XTensor * a, XTensor * b, XTensor * c)
bool XTensor::IsSameShaped(const XTensor * a, const XTensor * b, const XTensor * c)
{
return IsSameShaped(a, b) && IsSameShaped(a, c);
}
......@@ -460,7 +484,7 @@ void XTensor::SetDim(int * myDimSize)
get the size of a given dimension
>> dim - the given dim we are looking at
*/
int XTensor::GetDim(const int dim)
int XTensor::GetDim(const int dim) const
{
CheckNTErrors(dim < order, "dimenision is out of range!");
......@@ -766,6 +790,20 @@ void XTensor::SetDataPointer()
dataP = &data;
}
/* compare two number */
bool IsFloatEqual(DTYPE a, DTYPE b, float absError, float relError)
{
if(a == b)
return true;
if(fabs(a - b) < absError)
return true;
if(fabs(a) < fabs(b))
return (fabs(a - b) / b < relError) ? true : false;
else
return (fabs(a - b) / a < relError) ? true : false;
}
/* check whether the data array is the same as the answer */
bool XTensor::CheckData(const void * d, int num, float tolerance, int beg)
{
if (data == NULL || d == NULL)
......@@ -779,7 +817,7 @@ bool XTensor::CheckData(const void * d, int num, float tolerance, int beg)
DTYPE * answerPrt = (DTYPE*)d;
for (int i = beg; i < num; i++) {
value = ToCPU(devID, valuePrt);
if (fabs(value - *answerPrt) > tolerance)
if(IsFloatEqual(value, *answerPrt, tolerance, 1e-4F) == false)
return false;
valuePrt++;
answerPrt++;
......@@ -1446,9 +1484,18 @@ void XTensor::Dump(FILE * file, const char * label, const int n, const int beg,
}
}
else {
ShowNTErrors("TODO!");
else if(dataType == X_INT) {
int end = MIN(n > 0 ? beg + n : beg + unitNum, unitNum);
for(int i = beg; i < end; i++){
int f = ((int*)d)[i];
if(i == beg)
fprintf(file, "%d", f);
else
fprintf(file, " %d", f);
}
}
else
ShowNTErrors("TODO!");
}
else {
int num = this->unitNumNonZero > 0 ? *(int*)d : 0;
......
......@@ -204,15 +204,27 @@ public:
/* overloading of the plus-sign */
XTensor operator+ (const XTensor &tensor);
/* overloading of the plus-sign */
XTensor operator+ (const DTYPE shift);
/* overloading of the multiply-sign */
XTensor operator* (const XTensor &tensor);
/* overloading of the multiply-sign */
XTensor operator* (const DTYPE scale);
/* overloading of the minus-sign */
XTensor operator- (const XTensor &tensor);
/* overloading of the minus-sign */
XTensor operator- (const DTYPE shift);
/* overloading of the division-sign */
XTensor operator/ (const XTensor &tensor);
/* overloading of the division-sign */
XTensor operator/ (const DTYPE scale);
/* linear transformation */
XTensor Lin(DTYPE scale, DTYPE shift = 0);
......@@ -223,13 +235,13 @@ public:
/* judge whether the three matrices are in the same type and size */
static
bool IsSameShaped(XTensor * a, XTensor * b, XTensor * c);
bool IsSameShaped(const XTensor * a, const XTensor * b, const XTensor * c);
/* set the size of each dimension */
void SetDim(int * myDimSize);
/* get the size of a given dimension */
int GetDim(const int dim);
int GetDim(const int dim) const;
/* reshape the tensor */
void Reshape(const int order, const int * myDimSize);
......
......@@ -63,11 +63,14 @@
#include "movement/CopyIndexed.h"
#include "movement/CopyInGrid.h"
#include "movement/CopyValues.h"
#include "movement/Gather.h"
#include "movement/Spread.h"
#include "reduce/ReduceMax.h"
#include "reduce/ReduceMean.h"
#include "reduce/ReduceStandardVariance.h"
#include "reduce/ReduceSum.h"
#include "reduce/ReduceSumAll.h"
#include "reduce/ReduceSumSquared.h"
#include "reduce/ReduceVariance.h"
......@@ -77,8 +80,10 @@
#include "shape/MakeSplitBlockIndex.h"
#include "shape/Merge.h"
#include "shape/MergeBlockLists.h"
#include "shape/Reshape.h"
#include "shape/Permute.h"
#include "shape/Split.h"
#include "shape/Squeeze.h"
#include "shape/Transpose.h"
#include "shape/Unsqueeze.h"
......
......@@ -234,7 +234,7 @@ void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim > 0, "Illegal dimension!");
CheckNTErrors(dim < n && dim >= 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
......@@ -264,11 +264,78 @@ void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
}
/*
modify data items along with a given index and dimension (and keep the remaining items unchanged)
>> source - the tensor whose data array would be modified
>> modify - the tensor whose data array would be used to modify the source tensor
>> dim - the dimension along which we modify the tensor
>> index - index of the given dimension
e.g., given a source tensor (3, 3)
1 2 3
4 5 6
7 8 9
given a modified tensor (3)
1 2 3
when dim = 0, index = 1, we have
1 2 3
1 2 3
7 8 9
i.e., we set entries of row 1 to {1, 2, 3}
*/
void _SetDataIndexed(XTensor * source, XTensor * modify, int dim, int index)
{
int order = source->order;
int size = source->GetDim(dim);
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
CheckNTErrors(index >= 0 && index < size, "Illegal index!");
for(int i = 0; i < order - 1; i++){
if(i < dim){
CheckNTErrors(modify->GetDim(i) == source->GetDim(i), "Illegal dimension!");
}
else if(i >= dim){
CheckNTErrors(modify->GetDim(i) == source->GetDim(i+1), "Illegal dimension!");
}
}
if(source->devID < 0 && modify->devID < 0){
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for(int i = order - 1; i > dim; i--){
stride *= source->GetDim(i);
}
blockSize = stride * source->GetDim(dim);
blockNum = source->unitNum / blockSize;
for(int i = 0; i < blockNum; i++){
DTYPE * d = (DTYPE*)source->data + blockSize * i + index * stride;
DTYPE * p = (DTYPE*)modify->data + stride * i;
for(int j = 0; j < stride; j++)
d[j] = p[j];
}
}
else if(source->devID >= 0 && modify->devID >= 0) {
#ifdef USE_CUDA
_CudaSetDataIndexed(source, modify, dim, index);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
}
else{
ShowNTErrors("TODO!");
}
}
/*
generate data as lower triangular matrics for last two dimensions
>> tensor - the tensor whose data to be set
>> p - the value for each entry of the lower triangular matrics
>> shift - the offset from diagonal
e.g., for a 3* 3 tensor,
e.g., for a 3 * 3 tensor,
when p = 1 ans shift = 0, we have
1 0 0
1 1 0
......@@ -363,7 +430,6 @@ void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
}
}
/*
generate data items with a normal distribution with specified mean and standard deviation
>> mean - mean or expectation of the distribution
......
......@@ -231,7 +231,7 @@ void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
int n = tensor->order;
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim < n && dim > 0, "Illegal dimension!");
CheckNTErrors(dim < n && dim >= 0, "Illegal dimension!");
CheckNTErrors(beg >= 0 && beg < tensor->GetDim(dim), "Illegal beginning position!");
CheckNTErrors(beg + len >= 0 && beg + len < tensor->GetDim(dim), "Illegal length!");
......@@ -255,12 +255,95 @@ void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p)
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
KernelSetDataDim<<<blocks, threads >>>((DTYPE*)tensor->data, beg * stride, len * stride, blockSize, blockNum, p);
KernelSetDataDim<<<blocks, threads >>>((DTYPE*)tensor->data, beg * stride,
len * stride, blockSize, blockNum, p);
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
modify data items along with a given index and dimension
(and keep the remaining items unchanged) - kernel version
>> s - the pointer whose data would be modified
>> m - the pointer whose data would be used to modify the data pointed by s
>> blockNum - number of data blocks
>> blockSize - size of a data block
>> stride - stride of a data block
*/
__global__
void KernelSetDataIndexed(DTYPE * s, DTYPE * m, int blockNum, int blockSize, int stride)
{
/* offset in each block */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* block id */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= stride || j >= blockNum)
return;
int x = blockSize * j + i;
int y = stride * j + i;
s[x] = m[y];
}
/*
modify data items along with a given index and dimension (and keep the remaining items unchanged)
>> source - the tensor whose data array would be modified
>> modify - the tensor whose data array would be used to modify the source tensor
>> dim - the dimension along which we modify the tensor
>> index - index of the given dimension
e.g., given a source tensor (3, 3)
1 2 3
4 5 6
7 8 9
given a modified tensor (3)
1 2 3
when dim = 0, index = 1, we have
1 2 3
1 2 3
7 8 9
i.e., we set entries of row 1 to {1, 2, 3}
*/
void _CudaSetDataIndexed(XTensor * source, XTensor * modify, int dim, int index)
{
int order = source->order;
int size = source->GetDim(dim);
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
CheckNTErrors(index >= 0 && index < size, "Illegal index!");
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for(int i = order - 1; i > dim; i--){
stride *= source->GetDim(i);
}
blockSize = stride * source->GetDim(dim);
blockNum = source->unitNum / blockSize;
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(source->devID, stride, blockNum, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup);
KernelSetDataIndexed<<<blocks, threads >>>((DTYPE*)source->data + index * stride, (DTYPE*)modify->data,
blockNum, blockSize, stride);
BacktoCudaDev(source->devID, devIDBackup);
}
/*
set lower triangular matrics for each block
>> d - pointer to the data array
>> l - row number (or column number) of each block, i.e,
......
......@@ -40,6 +40,9 @@ void _CudaSetDataFixedDouble(XTensor * tensor, double p);
/* set data items along with a given dimension (and keep the remaining items unchanged) */
void _CudaSetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p);
/* modify data items along with a given index and dimension (and keep the remaining items unchanged) */
void _CudaSetDataIndexed(XTensor * source, XTensor * modify, int dim, int index);
/* generate data as lower triangular matrics for last two dimensions (cuda version) */
void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift);
......
......@@ -48,6 +48,9 @@ void _SetDataFixedDouble(XTensor * tensor, double p);
/* set data items along with a given dimension (and keep the remaining items unchanged) */
void _SetDataDim(XTensor * tensor, int beg, int len, int dim, DTYPE p);
/* modify data items along with a given index and dimension (and keep the remaining items unchanged) */
void _SetDataIndexed(XTensor * source, XTensor * modify, int dim, int index);
/* generate data as lower triangular matrics for last two dimensions */
void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-03
*/
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-03
*/
#ifndef __CLIP_H__
#define __CLIP_H__
......@@ -29,16 +30,12 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its clip value */
void _Clip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper);
/*
set every entry to its clip value (do it on site)
keep the result in the input tensor a and return nothing
*/
/* set every entry to its clip value (do it on site)
keep the result in the input tensor a and return nothing */
void _ClipMe(XTensor * a, DTYPE lower, DTYPE upper);
/*
set every entry to its clip value (return a XTensor structure)
make a new tensor to keep the result and return it
*/
/* set every entry to its clip value (return a XTensor structure)
make a new tensor to keep the result and return it */
XTensor Clip(const XTensor & a, DTYPE lower, DTYPE upper);
/*
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#include <math.h>
#include "../../XName.h"
......@@ -36,6 +37,11 @@ DTYPE round(DTYPE r)
return (r > 0.0) ? (DTYPE)floor(r + 0.5) : (DTYPE)ceil(r - 0.5);
}
DTYPE iszero(DTYPE r)
{
return (r == 0.0) ? (DTYPE)1.0 : (DTYPE)0.0;
}
#ifdef USE_CUDA
/* define three marco separately, specify the respective function names (GPU mode) */
#define _SIMPLE_UNARY_FUNCTION(_funcName, _cudaFuncName, origFunc) \
......@@ -65,7 +71,7 @@ void _funcNameMe(XTensor * a) \
XTensor funcName(const XTensor &a) \
{ \
XTensor b(&a); \
b.SetTMPFlag(); \
b.SetTMPFlag(); \
_funcName(&a, &b); \
XLink::MakeLink(&a, NULL, &b, operationId); \
return b; \
......@@ -87,6 +93,10 @@ _SIMPLE_UNARY_FUNCTION(_Floor, _CudaFloor, floor)
_SIMPLE_UNARY_FUNCTION_ME(_FloorMe, _Floor)
SIMPLE_UNARY_FUNCTION(Floor, _Floor, MATH_FLOOR)
_SIMPLE_UNARY_FUNCTION(_IsZero, _CudaIsZero, iszero)
_SIMPLE_UNARY_FUNCTION_ME(_IsZeroMe, _IsZero)
SIMPLE_UNARY_FUNCTION(IsZero, _IsZero, MATH_ISZERO)
_SIMPLE_UNARY_FUNCTION(_Log, _CudaLog, log)
_SIMPLE_UNARY_FUNCTION_ME(_LogMe, _Log)
SIMPLE_UNARY_FUNCTION(Log, _Log, MATH_LOG)
......@@ -140,7 +150,7 @@ void _funcNameMe(XTensor * a) \
XTensor funcName(const XTensor &a) \
{ \
XTensor b(&a); \
b.SetTMPFlag(); \
b.SetTMPFlag(); \
_funcName(&a, &b); \
XLink::MakeLink(&a, NULL, &b, operationId); \
return b; \
......@@ -163,6 +173,10 @@ _SIMPLE_UNARY_FUNCTION(_Floor, floor)
_SIMPLE_UNARY_FUNCTION_ME(_FloorMe, _Floor)
SIMPLE_UNARY_FUNCTION(Floor, _Floor, MATH_FLOOR)
_SIMPLE_UNARY_FUNCTION(_IsZero, iszero)
_SIMPLE_UNARY_FUNCTION_ME(_IsZeroMe, _IsZero)
SIMPLE_UNARY_FUNCTION(IsZero, _IsZero, MATH_ISZERO)
_SIMPLE_UNARY_FUNCTION(_Log, log)
_SIMPLE_UNARY_FUNCTION_ME(_LogMe, _Log)
SIMPLE_UNARY_FUNCTION(Log, _Log, MATH_LOG)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#include <math.h>
#include "../../XDevice.h"
......@@ -28,17 +29,23 @@
namespace nts {
__device__
DTYPE CudaSquare(DTYPE x)
DTYPE cudasquare(DTYPE x)
{
return x * x;
}
__device__
DTYPE CudaRound(DTYPE r)
DTYPE cudaround(DTYPE r)
{
return (r > 0.0) ? (DTYPE)floor(r + 0.5) : (DTYPE)ceil(r - 0.5);
}
__device__
DTYPE cudaiszero(DTYPE r)
{
return (r == 0.0) ? (DTYPE)1.0 : (DTYPE)0.0;
}
#define SIMPLE_UNARY_FUNCTION_GPU(funcName, origFunc) \
__global__ \
void Kernel##funcName(DTYPE * a, DTYPE * b, int size) \
......@@ -89,10 +96,11 @@ SIMPLE_UNARY_FUNCTION_GPU(Absolute, fabs)
SIMPLE_UNARY_FUNCTION_GPU(Ceil, ceil)
SIMPLE_UNARY_FUNCTION_GPU(Exp, exp)
SIMPLE_UNARY_FUNCTION_GPU(Floor, floor)
SIMPLE_UNARY_FUNCTION_GPU(IsZero, cudaiszero)
SIMPLE_UNARY_FUNCTION_GPU(Log, log)
SIMPLE_UNARY_FUNCTION_GPU(Round, CudaRound)
SIMPLE_UNARY_FUNCTION_GPU(Round, cudaround)
SIMPLE_UNARY_FUNCTION_GPU(Sqrt, sqrt)
SIMPLE_UNARY_FUNCTION_GPU(Square, CudaSquare)
SIMPLE_UNARY_FUNCTION_GPU(Square, cudasquare)
SIMPLE_UNARY_FUNCTION_GPU(Sin, sin)
SIMPLE_UNARY_FUNCTION_GPU(Cos, cos)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#ifndef __UNARY_CUH__
#define __UNARY_CUH__
......@@ -65,6 +66,15 @@ void KernelFloor(__half * a, __half * b, int size);
/* set each entry to its floor value */
void _CudaFloor(const XTensor * a, XTensor * b);
/* if source entry is zero, set target entry to be one, otherwise zero (CUDA Kernel) */
__global__
void KernelIsZero(DTYPE * a, DTYPE * b, int size);
/* if source entry is zero, set target entry to be one, otherwise zero (CUDA Kernel) with float16 data type*/
__global__
void KernelIsZero(__half * a, __half * b, int size);
/* if source entry is zero, set target entry to be one, otherwise zero */
void _CudaIsZero(const XTensor * a, XTensor * b);
/* set each entry to its logarithm value (CUDA Kernel) */
__global__
void KernelLog(DTYPE * a, DTYPE * b, int size);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#ifndef __UNARY_H__
#define __UNARY_H__
......@@ -62,6 +63,15 @@ void _FloorMe(XTensor * a);
make a new tensor to keep the result and return it */
XTensor Floor(const XTensor & a);
/* if source entry is zero, set target entry to be one, otherwise zero */
void _IsZero(const XTensor *a, XTensor *b);
/* if source entry is zero, set target entry to be one, otherwise zero (do it on site)
keep the result in the input tensor a and return nothing */
void _IsZeroMe(XTensor *a);
/* if source entry is zero, set target entry to be one, otherwise zero (return a XTensor structure)
make a new tensor to keep the result and return it */
XTensor IsZero(const XTensor &a);
/* set every entry to its logarithm value */
void _Log(const XTensor * a, XTensor * b);
/* set every entry to its logarithm value (do it on site)
......
......@@ -32,7 +32,7 @@ copy indexed sub-tensors
>> t - the target tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3,2)
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
>> tgtIndex - index of the target sub-tensors
......@@ -135,17 +135,25 @@ XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, in
/* call _CopyIndexed function */
_CopyIndexed(&s, &t, dim, srcIndex, indexSize, tgtIndex, copyNum);
/* care: we must malloc a new array for save index,
because the source indexs may be freed. */
int * saveSrcIndex = new int[indexSize];
memcpy(saveSrcIndex, srcIndex, indexSize * sizeof(int));
int * saveTgtIndex = new int[indexSize];
memcpy(saveTgtIndex, tgtIndex, indexSize * sizeof(int));
/* tensor connection */
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYINDEXED);
XLink::AddParamToHeadInt(&t, dim);
XLink::AddParamToHeadPointer(&t, srcIndex);
XLink::AddParamToHeadPointer(&t, saveSrcIndex);
XLink::AddParamToHeadInt(&t, indexSize);
XLink::AddParamToHeadPointer(&t, tgtIndex);
XLink::AddParamToHeadPointer(&t, saveTgtIndex);
XLink::AddParamToHeadInt(&t, copyNum);
/* destroy variables */
delete[] dimSize;
return t;
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __COPYINDEXED_H__
#define __COPYINDEXED_H__
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-18
*/
#include "Gather.h"
#include "CopyIndexed.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
gather indexed sub-tensors
>> s - the source tensor
>> t - the target tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
*/
void _Gather(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize)
{
int * tgtIndex = new int[indexSize];
for(int i = 0; i < indexSize; i++)
tgtIndex[i] = i;
_CopyIndexed(s, t, dim, srcIndex, indexSize, tgtIndex, 1);
delete[] tgtIndex;
}
/*
gather indexed sub-tensors (return a XTensor structure)
make a new tensor to keep the result and return it
>> s - the source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
<< return - the result of copying indexed sub-tensors
Notice: the index must be on the CPU!!!
*/
XTensor Gather(const XTensor &s, int dim, int * srcIndex, int indexSize)
{
int * tgtIndex = new int[indexSize];
for(int i = 0; i < indexSize; i++)
tgtIndex[i] = i;
/* call CopyIndexed function */
XTensor result;
result = CopyIndexed(s, dim, srcIndex, indexSize, tgtIndex, 1);
delete[] tgtIndex;
return result;
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-18
*/
#ifndef __GATHER_H__
#define __GATHER_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* gather selected sub-tensors */
void _Gather(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize);
/* gather selected sub-tensors (return a XTensor structure)
make a new tensor to keep the result and return it */
XTensor Gather(const XTensor &s, int dim, int * srcIndex, int indexSize);
} // namespace nts(NiuTrans.Tensor)
#endif // __GATHER_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#include "Spread.h"
#include "Spread.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
This is core assignment for spread function.
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - number of data blocks
>> blockSizeSrc - size of source data block
>> blockSizeColl - size of source data block
>> stride - stride of a data block
*/
void _Assignment(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride)
{
for (int i = 0; i < blockNum; i++) {
DTYPE * s = sData + blockSizeSrc * i;
DTYPE * c = cData + blockSizeColl * i;
for(int j = 0; j < stride; j++)
s[j] = c[j];
}
}
/*
spread a collection tensor to source tensor.
This is a inverse operation compared to gather.
>> source - the source tensor whose data would be modified
>> collection - the collection whose data would be spread to source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and collIndex)
>> collIndex - index of the gathered sub-tensors
*/
void _Spread(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex)
{
int order = source->order;
int size = source->GetDim(dim);
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
for(int i = 0; i < order; i++){
if(i < dim){
CheckNTErrors(collection->GetDim(i) == source->GetDim(i), "Illegal dimension!");
}
else if(i > dim){
CheckNTErrors(collection->GetDim(i) == source->GetDim(i), "Illegal dimension!");
}
else{
CheckNTErrors(collection->GetDim(i) == indexSize, "Illegal dimension!");
}
}
#ifdef USE_CUDA
if(source->devID >= 0 && collection->devID >= 0) {
_CudaSpread(source, collection, dim, srcIndex, indexSize, collIndex);
return;
}
#endif
int blockSizeSrc = 1;
int blockSizeColl = 1;
int blockNum = 1;
int stride = 1;
for (int i = dim + 1; i < order; i++) {
stride *= source->GetDim(i);
}
blockSizeSrc = stride * source->GetDim(dim);
blockSizeColl = stride * collection->GetDim(dim);
blockNum = source->unitNum / blockSizeSrc;
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++){
int src = srcIndex[i];
int tgt = collIndex[i];
DTYPE * s = sData + src * stride;
DTYPE * c = cData + tgt * stride;
_Assignment(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
}
}
/*
This is core assignment for backward computation of gather function.
Care of the operator "+=" instead of "=".
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - number of data blocks
>> blockSizeSrc - size of source data block
>> blockSizeColl - size of source data block
>> stride - stride of a data block
*/
void _AssignmentForGather(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride)
{
for (int i = 0; i < blockNum; i++) {
DTYPE * s = sData + blockSizeSrc * i;
DTYPE * c = cData + blockSizeColl * i;
for(int j = 0; j < stride; j++)
s[j] += c[j];
}
}
/*
spread a collection tensor to source tensor.
And this is a special spread function for backward computation of gather function.
>> source - the source tensor whose data would be modified
>> collection - the collection whose data would be spread to source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and collIndex)
>> collIndex - index of the gathered sub-tensors
*/
void _SpreadForGather(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex)
{
int order = source->order;
int size = source->GetDim(dim);
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
for(int i = 0; i < order; i++){
if(i < dim){
CheckNTErrors(collection->GetDim(i) == source->GetDim(i), "Illegal dimension!");
}
else if(i > dim){
CheckNTErrors(collection->GetDim(i) == source->GetDim(i), "Illegal dimension!");
}
else{
CheckNTErrors(collection->GetDim(i) == indexSize, "Illegal dimension!");
}
}
#ifdef USE_CUDA
if(source->devID >= 0 && collection->devID >= 0) {
_CudaSpreadForGather(source, collection, dim, srcIndex, indexSize, collIndex);
return;
}
#endif
int blockSizeSrc = 1;
int blockSizeColl = 1;
int blockNum = 1;
int stride = 1;
for (int i = dim + 1; i < order; i++) {
stride *= source->GetDim(i);
}
blockSizeSrc = stride * source->GetDim(dim);
blockSizeColl = stride * collection->GetDim(dim);
blockNum = source->unitNum / blockSizeSrc;
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++){
int src = srcIndex[i];
int tgt = collIndex[i];
DTYPE * s = sData + src * stride;
DTYPE * c = cData + tgt * stride;
_AssignmentForGather(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#ifndef __SPREAD_CUH__
#define __SPREAD_CUH__
#include "../../XTensor.h"
#include "../../XDevice.h"
#include "Spread.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
This is core assignment for spread function.
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - the number of data blocks
>> blockSizeSrc - the size of source data block
>> blockSizeColl - the size of source data block
>> stride - the stride of a data block
*/
__global__
void KernelSpread(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride)
{
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* offset in each block */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= blockNum || j >= stride)
return;
DTYPE * s = sData + blockSizeSrc * i;
DTYPE * c = cData + blockSizeColl * i;
s[j] = c[j];
}
/*
spread a collection tensor to source tensor (cuda version).
This is a inverse operation compared to gather.
>> source - the source tensor whose data would be modified
>> collection - the collection whose data would be spread to source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and collIndex)
>> collIndex - index of the gathered sub-tensors
*/
void _CudaSpread(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex)
{
int order = source->order;
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
int blockSizeSrc = 1;
int blockSizeColl = 1;
int blockNum = 1;
int stride = 1;
for (int i = dim + 1; i < order; i++) {
stride *= source->GetDim(i);
}
blockSizeSrc = stride * source->GetDim(dim);
blockSizeColl = stride * collection->GetDim(dim);
blockNum = source->unitNum / blockSizeSrc;
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(source->devID, blockNum, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup);
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++) {
int src = srcIndex[i];
int tgt = collIndex[i];
DTYPE * s = sData + src * stride;
DTYPE * c = cData + tgt * stride;
KernelSpread<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
}
BacktoCudaDev(source->devID, devIDBackup);
}
/*
This is core assignment for backward computation of gather function.
Care of the operator "+=" instead of "=".
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - number of data blocks
>> blockSizeSrc - size of source data block
>> blockSizeColl - size of source data block
>> stride - stride of a data block
*/
__global__
void KernelSpreadForGather(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride)
{
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* offset in each block */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= blockNum || j >= stride)
return;
DTYPE * s = sData + blockSizeSrc * i;
DTYPE * c = cData + blockSizeColl * i;
s[j] += c[j];
}
/*
spread a collection tensor to source tensor (cuda version).
And this is a special spread function for backward computation of gather function.
>> source - the source tensor whose data would be modified
>> collection - the collection whose data would be spread to source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and collIndex)
>> collIndex - index of the gathered sub-tensors
*/
void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex)
{
int order = source->order;
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(dim >= 0 && dim < order, "Illegal dimension!");
int blockSizeSrc = 1;
int blockSizeColl = 1;
int blockNum = 1;
int stride = 1;
for (int i = dim + 1; i < order; i++) {
stride *= source->GetDim(i);
}
blockSizeSrc = stride * source->GetDim(dim);
blockSizeColl = stride * collection->GetDim(dim);
blockNum = source->unitNum / blockSizeSrc;
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(source->devID, blockNum, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup);
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++) {
int src = srcIndex[i];
int tgt = collIndex[i];
DTYPE * s = sData + src * stride;
DTYPE * c = cData + tgt * stride;
KernelSpreadForGather<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
}
BacktoCudaDev(source->devID, devIDBackup);
}
} // namespace nts(NiuTrans.Tensor)
#endif // __SPREAD_CUH__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#ifndef __SPREAD_CUH__
#define __SPREAD_CUH__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* spread a collection tensor to source tensor (cuda version) */
void _CudaSpread(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex);
/* special spread function for backward computation of gather function (cuda version) */
void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex);
} // namespace nts(NiuTrans.Tensor)
#endif // __SPREAD_CUH__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#ifndef __SPREAD_H__
#define __SPREAD_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* spread a collection tensor to source tensor */
void _Spread(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex);
/* spread a collection tensor to source tensor (return a XTensor structure)
make a new tensor to keep the result and return it */
void Spread(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex);
/* special spread function for backward computation of gather function */
void _SpreadForGather(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex);
} // namespace nts(NiuTrans.Tensor)
#endif // __SPREAD_H__
\ No newline at end of file
......@@ -33,7 +33,7 @@ sum = \sum_i (a_i - shift) if isExp == false
sum = \sum_i exp(a_i - shift) if isExp == true
*/
void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor * shift = NULL,
DTYPE power = (DTYPE)1.0F, bool isExp = false);
DTYPE power = (DTYPE)1.0F, bool isExp = false);
/*
sum the items along a dimension of the tensor (return a XTensor structure)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-27
*/
#include "ReduceSumAll.h"
#include "ReduceSum.h"
#include "../movement/CopyValues.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
int * getDimSize(const XTensor * tensor, int n)
{
int order = tensor->order;
int * dimSize = new int[order - 1];
for (int i = 0; i < order; i++) {
if(i < n)
dimSize[i] = tensor->dimSize[i];
else if(i > n)
dimSize[i - 1] = tensor->dimSize[i];
}
return dimSize;
}
/*
sum all the items of the tensor (It should be optimized!)
>> source - the inpute tensor
<< return - the total summation
*/
DTYPE _ReduceSumAll(XTensor * source)
{
int order = source->order;
DTYPE summation;
XTensor * big = NewTensor(source);
_CopyValues(source, big);
for(int i = 0; i < order; i++) {
if(i == order - 1)
big->Reshape(big->unitNum, 1);
int * dimSize;
dimSize = getDimSize(big, 0);
XTensor * little = NewTensor(big->order - 1, dimSize, source->dataType, source->denseRatio, source->devID, source->mem);
_ReduceSum(big, little, 0);
delete big;
delete dimSize;
big = NewTensor(little);
_CopyValues(little, big);
delete little;
}
summation = big->Get1D(0);
delete big;
return summation;
}
/*
sum all the items of the tensor
>> source - the inpute tensor
<< return - the total summation
*/
DTYPE ReduceSumAll(XTensor & source)
{
return _ReduceSumAll(&source);
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-27
*/
#ifndef __REDUCESUMALL_H__
#define __REDUCESUMALL_H__
#include "../../XTensor.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/* sum all the items of the tensor */
DTYPE _ReduceSumAll(XTensor * source);
/* sum all the items of the tensor */
DTYPE ReduceSumAll(XTensor & source);
} // namespace nts(NiuTrans.Tensor)
#endif // __REDUCESUMALL_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "../movement/CopyValues.h"
#include "Reshape.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
reshape the tensor
>> s - the input tensor
>> order - order of the tensor
>> dimSize - the size of each dimension
<< return - the output tensor
*/
XTensor Reshape(XTensor &s, int order, int * dimSize)
{
XTensor t(&s);
t.SetTMPFlag();
_CopyValues(&s, &t);
int oriOrder = s.order;
int * oriDimSize = new int[order];
memcpy(oriDimSize, s.dimSize, sizeof(int) * order);
/* call Reshape function */
t.Reshape(order, dimSize);
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE);
XLink::AddParamToHeadInt(&t, oriOrder);
XLink::AddParamToHeadPointer(&t, oriDimSize);
return t;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#ifndef __RESHAPE_H__
#define __RESHAPE_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* reshape the tensor */
XTensor Reshape(XTensor &s, int order, int * dimSize);
} // namespace nts(NiuTrans.Tensor)
#endif // __RESHAPE_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-27
*/
#include "Squeeze.h"
#include "../movement/CopyValues.h"
#include "../../XName.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
squeeze the tensor along the specified dimension
>> source - the input tensor
>> target - the output tensor
>> leadingDim - the dimension that we would squeeze
if leadingDim = -1, squeeze all dimensions that are 1
else, squeeze the specified dimension
*/
void _Squeeze(XTensor * source, XTensor * target, int leadingDim)
{
int order = target->order;
CheckNTErrors(XTensor::IsSameShaped(source, target),
"The source and target tensor must be of the same size!");
CheckNTErrors(leadingDim >= -1 && leadingDim < order,
"Wrong leading dimension");
_CopyValues(source, target);
if(leadingDim < 0) {
int * newDimSize = new int[order];
int newOrder = 0;
for(int i = 0; i < order; i++) {
int dim = source->GetDim(i);
if(dim > 1) {
newDimSize[newOrder] = dim;
newOrder += 1;
}
}
target->Reshape(newOrder, newDimSize);
delete[] newDimSize;
}
else {
if(source->GetDim(leadingDim) > 1)
return;
int newOrder = order - 1;
int * newDimSize = new int[newOrder];
for(int i = 0; i < order; i++)
if(i < leadingDim)
newDimSize[i] = source->GetDim(i);
else if(i > leadingDim)
newDimSize[i - 1] = source->GetDim(i);
target->Reshape(newOrder, newDimSize);
delete[] newDimSize;
}
}
/*
squeeze the tensor along the specified dimension (do it on site)
keep the result in the input tensor a and return nothing
>> source - the input tensor
>> leadingDim - the dimension that we would squeeze
if leadingDim = -1, squeeze all dimensions that are 1
else, squeeze the specified dimension
*/
void _SqueezeMe(XTensor * source, int leadingDim)
{
_Squeeze(source, source, leadingDim);
}
/*
squeeze the tensor along the specified dimension (return a XTensor structure)
make a new tensor to keep the result and return it
>> source - the input tensor
>> leadingDim - the dimension that we would squeeze
if leadingDim = -1, squeeze all dimensions that are 1
else, squeeze the specified dimension
<< return - the output tensor after squeeze operation
*/
XTensor Squeeze(XTensor & source, int leadingDim)
{
XTensor target(&source);
target.SetTMPFlag();
/* call _Squeeze function */
_Squeeze(&source, &target, leadingDim);
/* tensor connections */
XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE);
return target;
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-27
*/
#ifndef __SQUEEZE_H__
#define __SQUEEZE_H__
#include "../../XTensor.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/* squeeze the tensor along the specified dimension */
void _Squeeze(XTensor * source, XTensor * target, int leadingDim = -1);
/* squeeze the tensor along the specified dimension (do it on site)
keep the result in the input tensor a and return nothing */
void _SqueezeMe(XTensor * source, int leadingDim = -1);
/* squeeze the tensor along the specified dimension (return a XTensor structure)
make a new tensor to keep the result and return it */
XTensor Squeeze(XTensor & source, int leadingDim = -1);
} // namespace nts(NiuTrans.Tensor)
#endif // __SQUEEZE_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __UNSQUEEZE_H__
#define __UNSQUEEZE_H__
......@@ -26,14 +26,13 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* insert a dimension by copying the blocks for x times (where x is the size of the inerted dimension) */
/* insert a dimension by copying the blocks for x times
(where x is the size of the inerted dimension) */
void _Unsqueeze(const XTensor * a, XTensor * b, int dim, int dSize);
/*
insert a dimension by copying the blocks for x times
(where x is the size of the inerted dimension) (return a XTensor structure)
make a new tensor to keep the result and return it
*/
/* insert a dimension by copying the blocks for x times
(where x is the size of the inerted dimension) (return a XTensor structure)
make a new tensor to keep the result and return it */
XTensor Unsqueeze(const XTensor &a, int dim, int dSize);
} // namespace nts(NiuTrans.Tensor)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-17
*/
#include <math.h>
#include "CrossEntropy.h"
#include "CrossEntropy.cuh"
#include "../core/arithmetic/MultiplyDim.h"
#include "../core/arithmetic/Multiply.h"
#include "../core/math/Unary.h"
#include "../core/math/ScaleAndShift.h"
#include "../core/arithmetic/Negate.h"
#include "../core/reduce/ReduceSum.h"
#include "../core/reduce/ReduceSumAll.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
compute the cross entropy loss
loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions
>> output - model prediction
>> gold - gold standard
>> loss - compute loss
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
*/
void _CrossEntropy(const XTensor * output, const XTensor * gold,
XTensor * loss, const XTensor * weight,
const XTensor * padding, int leadingDim)
{
int n = leadingDim < 0 ? output->order - 1 : leadingDim;
CheckNTErrors(n >= 0 && n < output->order, "Wrong leadingDim!");
int unitNum = output->dimSize[n];
CheckNTErrors(XTensor::IsSameShaped(output, gold),
"The output tensor and gold tensor must be of the same size!");
CheckNTErrors(weight == NULL || weight->unitNum == unitNum, "Wrong weight tensor!");
CheckNTErrors(padding == NULL || XTensor::IsSameShaped(padding, loss), "The loss tensor and padding tensor must be same shape!");
CheckNTErrors(loss->order == output->order - 1, "Wrong loss dimension!");
CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE, "TODO!");
XTensor * logInter = NewTensorBuf(output, output->devID, output->mem);
XTensor * mulInter = NewTensorBuf(output, output->devID, output->mem);
XTensor * negInter = NewTensorBuf(output, output->devID, output->mem);
XTensor * logBuf = NewTensorBuf(output, output->devID, output->mem);
XTensor * mulBuf = NewTensorBuf(output, output->devID, output->mem);
XTensor * negBuf = NewTensorBuf(output, output->devID, output->mem);
/* l = log(output) */
_Log(output, logBuf);
if(weight != NULL){
XTensor * weightBuf = NewTensorBuf(output, output->devID, output->mem);
/* multiply gold and weight by broadcast wg = mulDim(g * w) */
_MultiplyDim(gold, weight, weightBuf, n, 0);
/* multiply weighted gold and log(output) wgl = mul(wg, l) */
_Multiply(weightBuf, logBuf, mulBuf, 0);
DelTensorBuf(weightBuf);
}
else{
/* multiply gold and log(output) gl = mul(g, l) */
_Multiply(gold, logBuf, mulBuf, 0);
}
/* negate multiply result n = negate(mul) */
_NegateMe(mulBuf);
_ReduceSum(mulBuf, loss, n);
DelTensorBuf(negInter);
DelTensorBuf(mulInter);
DelTensorBuf(logInter);
}
/*
compute the cross entropy loss (implementation manually)
loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions
>> output - model prediction
>> gold - gold standard
>> loss - compute loss
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
*/
void _CrossEntropyManual(const XTensor * output, const XTensor * gold,
XTensor * loss, const XTensor * weight,
const XTensor * padding, int leadingDim)
{
#ifdef USE_CUDA
if(output->devID >= 0) {
_CudaCrossEntropyManual(output, gold, loss, weight, padding, leadingDim);
return;
}
#endif
int order = output->order;
int n = leadingDim < 0 ? output->order - 1 : leadingDim;
int leadingDimSize = output->GetDim(n);
CheckNTErrors(n >= 0 && n < output->order,
"Wrong leadingDim!");
CheckNTErrors(XTensor::IsSameShaped(output, gold),
"The output tensor and gold tensor must be of the same size!");
CheckNTErrors(weight == NULL || weight->unitNum == leadingDimSize,
"Wrong weight tensor!");
CheckNTErrors(padding == NULL || XTensor::IsSameShaped(padding, loss),
"The loss tensor and padding tensor must be same shape!");
CheckNTErrors(loss->order == output->order - 1,
"Wrong loss dimension!");
CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE,
"TODO!");
int blockNum = 1;
int blockSize = 1;
int stride = 1;
for(int i = n + 1; i < order; i++)
stride *= output->GetDim(i);
blockSize = stride * leadingDimSize;
blockNum = output->unitNum / blockSize;
DTYPE * outputData = (DTYPE*)output->data;
DTYPE * goldData = (DTYPE*)gold->data;
DTYPE * lossData = (DTYPE*)loss->data;
DTYPE tmpLoss;
if(weight == NULL) {
if(padding == NULL) {
for(int i = 0; i < blockNum; i++) {
int beg = i * blockSize;
tmpLoss = 0;
for(int j = 0; j < blockSize; j++)
tmpLoss += -(*(goldData + beg + j)) *
(DTYPE)log(*(outputData + beg + j));
*(lossData + i) = tmpLoss;
}
}
else {
DTYPE * paddingData = (DTYPE*)padding->data;
for(int i = 0; i < blockNum; i++) {
int beg = i * blockSize;
if(*(paddingData + i) == 0)
*(lossData + i) = 0;
else{
tmpLoss = 0;
for(int j = 0; j < blockSize; j++)
tmpLoss += -(*(goldData + beg + j)) *
(DTYPE)log(*(outputData + beg + j));
*(lossData + i) = tmpLoss;
}
}
}
}
else {
DTYPE * weightData = (DTYPE*)weight->data;
if(padding == NULL) {
for(int i = 0; i < blockNum; i++) {
int beg = i * blockSize;
tmpLoss = 0;
for(int j = 0; j < blockSize; j++)
tmpLoss += -(*(goldData + beg + j)) *
(DTYPE)log(*(outputData + beg + j)) *
(*(weightData + j));
*(lossData + i) = tmpLoss;
}
}
else {
DTYPE * paddingData = (DTYPE*)padding->data;
for(int i = 0; i < blockNum; i++) {
int beg = i * blockSize;
if(*(paddingData + i) == 0)
*(lossData + i) = 0;
else{
tmpLoss = 0;
for(int j = 0; j < blockSize; j++)
tmpLoss += -(*(goldData + beg + j)) *
(DTYPE)log(*(outputData + beg + j)) *
(*(weightData + j));
*(lossData + i) = tmpLoss;
}
}
}
}
}
/*
get the dimSize after reduce operation
>> tensor - a tensor to be reduced
>> n - the reduce dimension
<< return - the pointer of dimSize
*/
int * reduceDimSize(const XTensor * tensor, int n)
{
int order = tensor->order;
int * dimSize = new int[order - 1];
for (int i = 0; i < order; i++) {
if(i < n)
dimSize[i] = tensor->dimSize[i];
else if(i > n)
dimSize[i - 1] = tensor->dimSize[i];
}
return dimSize;
}
/*
compute the cross entropy loss
loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions
>> output - model prediction
>> gold - gold standard
>> reduce - loss compute way, sum or mean
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
*/
DTYPE _CrossEntropy(const XTensor * output, const XTensor * gold,
LOSS_COMPUTE_WAY reduceWay, const XTensor * weight,
const XTensor * padding, int leadingDim)
{
int n = leadingDim < 0 ? output->order - 1 : leadingDim;
CheckNTErrors(n >= 0 && n < output->order, "Wrong leadingDim!");
int unitNum = output->dimSize[n];
CheckNTErrors(XTensor::IsSameShaped(output, gold),
"The output tensor and gold tensor must be of the same size!");
CheckNTErrors(weight == NULL || weight->unitNum == unitNum, "Wrong weight tensor!");
CheckNTErrors(padding == NULL || padding->order == output->order - 1, "The loss tensor and padding tensor must be same shape!");
CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE, "TODO!");
XTensor * logBuf = NewTensorBuf(output, output->devID, output->mem);
XTensor * mulBuf = NewTensorBuf(output, output->devID, output->mem);
XTensor * negBuf = NewTensorBuf(output, output->devID, output->mem);
/* l = log(output) */
_Log(output, logBuf);
if(weight != NULL){
XTensor * weightBuf = NewTensorBuf(output, output->devID, output->mem);
/* multiply gold and weight by broadcast wg = mulDim(g * w) */
_MultiplyDim(gold, weight, weightBuf, n, 0);
/* multiply weighted gold and log(output) wgl = mul(wg, l) */
_Multiply(weightBuf, logBuf, mulBuf, 0);
DelTensorBuf(weightBuf);
}
else{
/* multiply gold and log(output) gl = mul(g, l) */
_Multiply(gold, logBuf, mulBuf, 0);
}
/* negate multiply result n = negate(mul) */
_NegateMe(mulBuf);
int * dimSize;
dimSize = reduceDimSize(output, n);
XTensor * lossInter = NewTensor(output->order - 1, dimSize, output->dataType, output->denseRatio, output->devID, output->mem);
/* reduce sum all classes */
_ReduceSum(mulBuf, lossInter, n);
DelTensorBuf(negBuf);
DelTensorBuf(mulBuf);
DelTensorBuf(logBuf);
DTYPE loss;
/* compute the total loss */
if(padding != NULL) {
XTensor * temp(lossInter);
_Multiply(lossInter, padding, temp);
loss = _ReduceSumAll(temp);
delete temp;
}
else
loss = _ReduceSumAll(lossInter);
if(reduceWay == REDUCE_MEAN) {
if(padding != NULL) {
XTensor * zeroIndicator = NewTensorBuf(padding, padding->devID, padding->mem);
_IsZero(padding, zeroIndicator);
int reduceSize = (int)_ReduceSumAll(zeroIndicator);
loss = loss / (DTYPE)(padding->unitNum - reduceSize);
DelTensorBuf(zeroIndicator);
}
else
loss = loss / (DTYPE)lossInter->unitNum;
}
else if(reduceWay == REDUCE_SUM) {
/* don't need to do anything */
}
else {
ShowNTErrors("TODO");
}
delete[] dimSize;
delete lossInter;
return loss;
}
/*
compute the cross entropy loss (implementation manually)
loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions
>> output - model prediction
>> gold - gold standard
>> reduceWay - loss compute way, sum or mean
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
<< return - the cross entropy loss that is a scalar
*/
DTYPE _CrossEntropyManual(const XTensor * output, const XTensor * gold,
LOSS_COMPUTE_WAY reduceWay, const XTensor * weight,
const XTensor * padding, int leadingDim)
{
#ifdef USE_CUDA
if(output->devID >= 0) {
return _CudaCrossEntropyManual(output, gold, reduceWay, weight, padding, leadingDim);
}
#endif
int order = output->order;
int n = leadingDim < 0 ? output->order - 1 : leadingDim;
int leadingDimSize = output->GetDim(n);
CheckNTErrors(n >= 0 && n < output->order,
"Wrong leadingDim!");
CheckNTErrors(XTensor::IsSameShaped(output, gold),
"The output tensor and gold tensor must be of the same size!");
CheckNTErrors(weight == NULL || weight->unitNum == leadingDimSize,
"Wrong weight tensor!");
CheckNTErrors(padding == NULL || padding->order == output->order - 1,
"Wrong padding tensor!");
CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE,
"TODO!");
int blockNum = 1;
int blockSize = 1;
int stride = 1;
for(int i = n + 1; i < order; i++)
stride *= output->GetDim(i);
blockSize = stride * leadingDimSize;
blockNum = output->unitNum / blockSize;
DTYPE * outputData = (DTYPE*)output->data;
DTYPE * goldData = (DTYPE*)gold->data;
DTYPE loss = 0;
int nonZeroNum = 0;
if(weight == NULL) {
if(padding == NULL) {
nonZeroNum = blockNum;
for(int i = 0; i < blockNum; i++) {
int beg = i * blockSize;
for(int j = 0; j < blockSize; j++)
loss += -(*(goldData + beg + j)) *
(DTYPE)log(*(outputData + beg + j));
}
}
else {
DTYPE * paddingData = (DTYPE*)padding->data;
for(int i = 0; i < blockNum; i++) {
if(*(paddingData + i) == 0)
continue;
else{
nonZeroNum += 1;
int beg = i * blockSize;
for(int j = 0; j < blockSize; j++)
loss += -(*(goldData + beg + j)) *
(DTYPE)log(*(outputData + beg + j));
}
}
}
}
else {
DTYPE * weightData = (DTYPE*)weight->data;
if(padding == NULL) {
nonZeroNum = blockNum;
for(int i = 0; i < blockNum; i++) {
int beg = i * blockSize;
for(int j = 0; j < blockSize; j++)
loss += -(*(goldData + beg + j)) *
(DTYPE)log(*(outputData + beg + j)) *
(*(weightData + j));
}
}
else {
DTYPE * paddingData = (DTYPE*)padding->data;
for(int i = 0; i < blockNum; i++) {
if(*(paddingData + i) == 0)
continue;
else{
nonZeroNum += 1;
int beg = i * blockSize;
for(int j = 0; j < blockSize; j++)
loss += -(*(goldData + beg + j)) *
(DTYPE)log(*(outputData + beg + j)) *
(*(weightData + j));
}
}
}
}
if(reduceWay == REDUCE_MEAN) {
loss = loss / (DTYPE)nonZeroNum;
}
else if(reduceWay == REDUCE_SUM) {
/* don't need to do anything */
}
else {
ShowNTErrors("TODO");
}
return loss;
}
/*
backward compuation for cross entropy function (tensor version)
loss = sum_{i} (-t_i * log(y_i))
dE/dy_i = -t_i / y_i
where E is the error(loss) function that measure the errors in y
with respect to gold standard, and y this the model output
>> dedy - dE/dy (for return)
>> output - model prediction
>> gold - gold standard
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
*/
void _CrossEntropyBackward(XTensor * dedy, const XTensor * output, const XTensor * gold,
const XTensor * weight, XTensor * padding,
int leadingDim)
{
#ifdef USE_CUDA
if(output->devID >= 0) {
_CudaCrossEntropyBackward(dedy, output, gold, weight, padding, leadingDim);
return;
}
#endif
int order = output->order;
int n = leadingDim < 0 ? output->order - 1 : leadingDim;
int leadingDimSize = output->GetDim(n);
int unitSize = dedy->unitSize;
CheckNTErrors(n >= 0 && n < output->order,
"Wrong leading dimension!");
CheckNTErrors(XTensor::IsSameShaped(dedy, output, gold),
"The output tensor and gold tensor must be of the same size!");
CheckNTErrors(weight == NULL || weight->unitNum == leadingDimSize,
"Wrong weight tensor!");
CheckNTErrors(padding == NULL || padding->order == output->order - 1,
"Wrong padding tensor!");
CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE,
"TODO!");
int blockNum = 1;
int blockSize = 1;
int stride = 1;
for(int i = n + 1; i < order; i++)
stride *= output->GetDim(i);
blockSize = stride * leadingDimSize;
blockNum = output->unitNum / blockSize;
DTYPE * dedyData = (DTYPE*)dedy->data;
DTYPE * outputData = (DTYPE*)output->data;
DTYPE * goldData = (DTYPE*)gold->data;
if(weight == NULL) {
if(padding == NULL) {
for(int i = 0; i < blockNum; i++) {
int beg = i * blockSize;
for(int j = 0; j < blockSize; j++)
*(dedyData + beg + j) = -(*(goldData + beg + j)) /
(*(outputData + beg + j));
}
}
else {
DTYPE * paddingData = (DTYPE*)padding->data;
for(int i = 0; i < blockNum; i++) {
int beg = i * blockSize;
if(*(paddingData + i) == 0)
memset(dedyData + beg, 0, blockSize * unitSize);
else
for(int j = 0; j < blockSize; j++)
*(dedyData + beg + j) = -(*(goldData + beg + j)) /
(*(outputData + beg + j));
}
}
}
else {
DTYPE * weightData = (DTYPE*)weight->data;
if(padding == NULL) {
for(int i = 0; i < blockNum; i++) {
int beg = i * blockSize;
for(int j = 0; j < blockSize; j++)
*(dedyData + beg + j) = -(*(weightData + j)) *
(*(goldData + beg + j)) /
(*(outputData + beg + j));
}
}
else {
DTYPE * paddingData = (DTYPE*)padding->data;
for(int i = 0; i < blockNum; i++) {
int beg = i * blockSize;
if(*(paddingData + i) == 0)
memset(dedyData + beg, 0, blockSize * unitSize);
else
for(int j = 0; j < blockSize; j++) {
*(dedyData + beg + j) = -(*(weightData + j)) *
(*(goldData + beg + j)) /
(*(outputData + beg + j));
}
}
}
}
if(padding != NULL) {
XTensor * tmp(padding);
_IsZero(padding, tmp);
int nonZeroNum = (int)_ReduceSumAll(tmp);
_ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)nonZeroNum);
delete tmp;
}
else {
_ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)blockNum);
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-17
*/
#ifndef __CROSSENTROPY_CUH__
#define __CROSSENTROPY_CUH__
#include "../XTensor.h"
#include "../XDevice.h"
#include "CrossEntropy.cuh"
#include "CrossEntropy.h"
#include "../core/reduce/ReduceSumAll.h"
#include "../core/math/Unary.h"
#include "../core/math/ScaleAndShift.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
compute the cross entropy loss (cuda kernel)
>> outputData - the data pointer of output tensor
>> goldData - the data pointer of gold tensor
>> lossData - the data pointer of loss tensor
>> weightData - the data pointer of weight tensor
>> paddingData - the data pointer of padding tensor
>> blockNum - the number of data blocks
>> stride - the size of a data block
*/
__global__
void KernelCrossEntropy(DTYPE * outputData, DTYPE * goldData,
DTYPE * lossData, DTYPE * weightData,
DTYPE * paddingData, int blockNum, int blockSize)
{
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
if(i >= blockNum)
return;
int beg = i * blockSize;
DTYPE tmpLoss = 0;
if(weightData == NULL) {
if(paddingData == NULL) {
tmpLoss = 0;
for(int j = 0; j < blockSize; j++)
tmpLoss += -(*(goldData + beg + j)) *
(DTYPE)log(*(outputData + beg + j));
*(lossData + i) = tmpLoss;
}
else {
if(*(paddingData + i) == 0)
*(lossData + i) = tmpLoss;
else{
for(int j = 0; j < blockSize; j++)
tmpLoss += -(*(goldData + beg + j)) *
(DTYPE)log(*(outputData + beg + j));
*(lossData + i) = tmpLoss;
}
}
}
else {
if(paddingData == NULL) {
for(int j = 0; j < blockSize; j++)
tmpLoss += -(*(goldData + beg + j)) *
(DTYPE)log(*(outputData + beg + j)) *
(*(weightData + j));
*(lossData + i) = tmpLoss;
}
else {
if(*(paddingData + i) == 0)
*(lossData + i) = tmpLoss;
else{
tmpLoss = 0;
for(int j = 0; j < blockSize; j++)
tmpLoss += -(*(goldData + beg + j)) *
(DTYPE)log(*(outputData + beg + j)) *
(*(weightData + j));
*(lossData + i) = tmpLoss;
}
}
}
}
/*
compute the cross entropy loss (cuda version)
loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions
>> output - model prediction
>> gold - gold standard
>> loss - compute loss
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
*/
void _CudaCrossEntropyManual(const XTensor * output, const XTensor * gold,
XTensor * loss, const XTensor * weight,
const XTensor * padding, int leadingDim)
{
int order = output->order;
int n = leadingDim < 0 ? output->order - 1 : leadingDim;
int leadingDimSize = output->GetDim(n);
CheckNTErrors(n >= 0 && n < output->order,
"Wrong leadingDim!");
CheckNTErrors(XTensor::IsSameShaped(output, gold),
"The output tensor and gold tensor must be of the same size!");
CheckNTErrors(weight == NULL || weight->unitNum == leadingDimSize,
"Wrong weight tensor!");
CheckNTErrors(padding == NULL || XTensor::IsSameShaped(padding, loss),
"The loss tensor and padding tensor must be same shape!");
CheckNTErrors(loss->order == output->order - 1,
"Wrong loss dimension!");
CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE,
"TODO!");
int blockNum = 1;
int blockSize = 1;
int stride = 1;
for(int i = n + 1; i < order; i++)
stride *= output->GetDim(i);
blockSize = stride * leadingDimSize;
blockNum = output->unitNum / blockSize;
int cudaGrids[3];
int cudaBlocks[3];
//GDevs.GetCudaThread2D(output->devID, blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
GDevs.GetCudaThread(output->devID, blockNum, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(output->devID, devIDBackup);
DTYPE * outputData = (DTYPE*)output->data;
DTYPE * goldData = (DTYPE*)gold->data;
DTYPE * lossData = (DTYPE*)loss->data;
if(weight == NULL) {
if(padding == NULL)
KernelCrossEntropy<<<dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >>>
(outputData, goldData, lossData,
NULL, NULL,
blockNum, blockSize);
else
KernelCrossEntropy<<<dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >>>
(outputData, goldData, lossData,
NULL, (DTYPE*)padding->data,
blockNum, blockSize);
}
else {
if(padding == NULL)
KernelCrossEntropy<<<dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >>>
(outputData, goldData, lossData,
(DTYPE*)weight->data, NULL,
blockNum, blockSize);
else
KernelCrossEntropy<<<dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >>>
(outputData, goldData, lossData,
(DTYPE*)weight->data, (DTYPE*)padding->data,
blockNum, blockSize);
}
BacktoCudaDev(output->devID, devIDBackup);
}
/*
compute the cross entropy loss (scalar version)
loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions
>> output - model prediction
>> gold - gold standard
>> reduceWay - loss compute way, sum or mean
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
<< return - the cross entropy loss that is a scalar
*/
DTYPE _CudaCrossEntropyManual(const XTensor * output, const XTensor * gold,
LOSS_COMPUTE_WAY reduceWay, const XTensor * weight,
const XTensor * padding, int leadingDim)
{
DTYPE loss = 0;
int order = output->order;
int n = leadingDim < 0 ? output->order - 1 : leadingDim;
int leadingDimSize = output->GetDim(n);
CheckNTErrors(n >= 0 && n < output->order,
"Wrong leadingDim!");
CheckNTErrors(XTensor::IsSameShaped(output, gold),
"The output tensor and gold tensor must be of the same size!");
CheckNTErrors(weight == NULL || weight->unitNum == leadingDimSize,
"Wrong weight tensor!");
CheckNTErrors(padding == NULL || padding->order == output->order - 1,
"Wrong padding tensor!");
CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE,
"TODO!");
int * dimSize = new int[output->order - 1];
for (int i = 0; i < order; i++) {
if(i < n)
dimSize[i] = output->dimSize[i];
else if(i > n)
dimSize[i - 1] = output->dimSize[i];
}
XTensor * lossInter = NewTensor(output->order - 1, dimSize, output->dataType, output->denseRatio, output->devID, output->mem);
_CudaCrossEntropyManual(output, gold, lossInter, weight, padding, leadingDim);
loss = _ReduceSumAll(lossInter);
if(reduceWay == REDUCE_MEAN) {
int totalNum;
if(padding == NULL) {
totalNum = lossInter->unitNum;
}
else {
XTensor * zeroIndicator = NewTensorBuf(output, output->devID, output->mem);
_IsZero(padding, zeroIndicator);
totalNum = lossInter->unitNum - (int)_ReduceSumAll(zeroIndicator);
DelTensorBuf(zeroIndicator);
}
loss = loss / (DTYPE)totalNum;
}
return loss;
}
/*
backward computation of cross entropy function (kernel version)
>> dedyData - the data pointer of dedy tensor
>> outputData - the data pointer of output tensor
>> goldData - the data pointer of gold tensor
>> weightData - the data pointer of weight tensor
>> paddingData - the data pointer of padding tensor
>> blockNum - the number of data blocks
>> blockSize - the size of a data block
*/
__global__
void KernelCrossEntropyBackward(DTYPE * dedyData, DTYPE * outputData, DTYPE * goldData,
DTYPE * weightData, DTYPE * paddingData,
int blockNum, int blockSize)
{
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
if(i >= blockNum)
return;
int beg = i * blockSize;
if(weightData == NULL) {
if(paddingData == NULL) {
for(int j = 0; j < blockSize; j++)
*(dedyData + beg + j) = -(*(goldData + beg + j)) /
(*(outputData + beg + j));
}
else {
if(*(paddingData + i) == 0)
memset(dedyData + beg, 0, blockSize * sizeof(DTYPE));
else
for(int j = 0; j < blockSize; j++)
*(dedyData + beg + j) = -(*(goldData + beg + j)) /
(*(outputData + beg + j));
}
}
else {
if(paddingData == NULL) {
for(int j = 0; j < blockSize; j++)
*(dedyData + beg + j) = -(*(weightData + j)) *
(*(goldData + beg + j)) /
(*(outputData + beg + j));
}
else {
if(*(paddingData + i) == 0)
memset(dedyData + beg, 0, blockSize * sizeof(DTYPE));
else
for(int j = 0; j < blockSize; j++) {
*(dedyData + beg + j) = -(*(weightData + j)) *
(*(goldData + beg + j)) /
(*(outputData + beg + j));
}
}
}
}
/*
backward computation of cross entropy function
loss = sum_{i} (-t_i * log(y_i))
dE/dy_i = -t_i / y_i
where E is the error(loss) function that measure the errors in y
with respect to gold standard, and y this the model output
>> dedy - dE/dy (for return)
>> output - model prediction
>> gold - gold standard
>> weight - a rescaling weight given to each class
>> padding - specify a target value that is ignored and does not contribute to the loss computation
>> leadingDim - the leading dimension for the output
*/
void _CudaCrossEntropyBackward(XTensor * dedy, const XTensor * output, const XTensor * gold,
const XTensor * weight, XTensor * padding,
int leadingDim)
{
int order = output->order;
int n = leadingDim < 0 ? output->order - 1 : leadingDim;
int leadingDimSize = output->GetDim(n);
CheckNTErrors(n >= 0 && n < output->order,
"Wrong leading dimension!");
CheckNTErrors(XTensor::IsSameShaped(dedy, output, gold),
"The output tensor and gold tensor must be of the same size!");
CheckNTErrors(weight == NULL || weight->unitNum == leadingDimSize,
"Wrong weight tensor!");
CheckNTErrors(padding == NULL || padding->order == output->order - 1,
"Wrong padding tensor!");
CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE,
"TODO!");
int blockNum = 1;
int blockSize = 1;
int stride = 1;
for(int i = n + 1; i < order; i++)
stride *= output->GetDim(i);
blockSize = stride * leadingDimSize;
blockNum = output->unitNum / blockSize;
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread(output->devID, blockNum, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int devIDBackup;
ProtectCudaDev(output->devID, devIDBackup);
DTYPE * dedyData = (DTYPE*)dedy->data;
DTYPE * outputData = (DTYPE*)output->data;
DTYPE * goldData = (DTYPE*)gold->data;
if(weight == NULL) {
if(padding == NULL)
KernelCrossEntropyBackward<<<dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >>>
(dedyData, outputData, goldData,
NULL, NULL,
blockNum, blockSize);
else
KernelCrossEntropyBackward<<<dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >>>
(dedyData, outputData, goldData,
NULL, (DTYPE*)padding->data,
blockNum, blockSize);
}
else {
if(padding == NULL)
KernelCrossEntropyBackward<<<dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >>>
(dedyData, outputData, goldData,
(DTYPE*)weight->data, NULL,
blockNum, blockSize);
else
KernelCrossEntropyBackward<<<dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >>>
(dedyData, outputData, goldData,
(DTYPE*)weight->data, (DTYPE*)padding->data,
blockNum, blockSize);
}
if(padding != NULL) {
XTensor * tmp(padding);
_IsZero(padding, tmp);
int nonZeroNum = (int)_ReduceSumAll(tmp);
_ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)nonZeroNum);
delete tmp;
}
else {
_ScaleAndShiftMe(dedy, (DTYPE)1.0/(DTYPE)blockNum);
}
BacktoCudaDev(output->devID, devIDBackup);
}
} // namespace nts(NiuTrans.Tensor)
#endif // __CROSSENTROPY_CUH__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-17
*/
#ifndef __CROSSENTROPY_CUH__
#define __CROSSENTROPY_CUH__
#include "../XTensor.h"
#include "CrossEntropy.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/* compute the cross entropy loss (tensor version) */
void _CudaCrossEntropyManual(const XTensor * output, const XTensor * gold,
XTensor * loss, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* compute the cross entropy loss (scalar version) */
DTYPE _CudaCrossEntropyManual(const XTensor * output, const XTensor * gold,
LOSS_COMPUTE_WAY reduceWay, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* backward computation of cross entropy function */
void _CudaCrossEntropyBackward(XTensor * dedy, const XTensor * output, const XTensor * gold,
const XTensor * weight = NULL, XTensor * padding = NULL,
int leadingDim = -1);
} // namespace nts(NiuTrans.Tensor)
#endif // __CROSSENTROPY_CUH__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-17
*/
#ifndef __CROSSENTROPY_H__
#define __CROSSENTROPY_H__
#include "../XTensor.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
enum LOSS_COMPUTE_WAY{
REDUCE_SUM,
REDUCE_MEAN
};
/* compute the cross entropy loss (tensor version) */
void _CrossEntropy(const XTensor * output, const XTensor * gold,
XTensor * loss, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* compute the cross entropy loss (tensor version) */
void _CrossEntropyManual(const XTensor * output, const XTensor * gold,
XTensor * loss, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* compute the cross entropy loss (scalar version) */
DTYPE _CrossEntropy(const XTensor * output, const XTensor * gold,
LOSS_COMPUTE_WAY reduceWay, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* compute the cross entropy loss (scalar version) */
DTYPE _CrossEntropyManual(const XTensor * output, const XTensor * gold,
LOSS_COMPUTE_WAY reduceWay = REDUCE_MEAN, const XTensor * weight = NULL,
const XTensor * padding = NULL, int leadingDim = -1);
/* backward computation of cross entropy function */
void _CrossEntropyBackward(XTensor * dedy, const XTensor * output, const XTensor * gold,
const XTensor * weight = NULL, XTensor * padding = NULL,
int leadingDim = -1);
} // namespace nts(NiuTrans.Tensor)
#endif // __CROSSENTROPY_H__
\ No newline at end of file
......@@ -20,7 +20,6 @@
*/
#include "../XName.h"
#include <math.h>
#include <time.h>
#include "Dropout.h"
#include "Dropout.cuh"
......
......@@ -23,7 +23,6 @@
#define __DROPOUT_H__
#include "../XTensor.h"
#include "Loss.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......
......@@ -26,6 +26,7 @@
#include "../XTensor.h"
#include "CrossEntropy.h"
#include "Dropout.h"
#include "HardTanH.h"
#include "Identity.h"
......
......@@ -23,6 +23,7 @@
#include "../XName.h"
#include "HardTanH.h"
#include "HardTanH.cuh"
#include "CrossEntropy.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -118,7 +119,9 @@ void _HardTanHBackward(XTensor * gold, XTensor * y, XTensor * x,
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
/* calculate dE/dy */
if(lossName != NOLOSS)
if(lossName == CROSSENTROPY)
_CrossEntropyBackward(dedy, y, gold);
else if(lossName != NOLOSS)
_LossBackward(dedy, gold, y, lossName);
DTYPE * dedyp = (DTYPE*)dedy->data;
......
......@@ -22,6 +22,7 @@
#include "HardTanH.h"
#include "HardTanH.cuh"
#include "Loss.cuh"
#include "CrossEntropy.cuh"
#include "../XDevice.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -136,8 +137,10 @@ void _CudaHardTanHBackward(XTensor * gold, XTensor * y, XTensor * x,
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
/* calculate dE/dy */
if(lossName != NOLOSS)
_LossBackward(dedy, gold, y, lossName);
if(lossName == CROSSENTROPY)
_CudaCrossEntropyBackward(dedy, y, gold);
else if(lossName != NOLOSS)
_CudaLossBackward(dedy, gold, y, lossName);
int gridSize[3], blockSize[3];
......
......@@ -21,6 +21,7 @@
#include "../XName.h"
#include "Identity.h"
#include "CrossEntropy.h"
#include "../XUtility.h"
#include "../core/movement/CopyValues.h"
......@@ -78,7 +79,9 @@ void _IdentityBackward(XTensor * gold, XTensor * y, XTensor * x,
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE)
{
/* calculate dE/dy */
if(lossName != NOLOSS)
if(lossName == CROSSENTROPY)
_CrossEntropyBackward(dedy, y, gold);
else if(lossName != NOLOSS)
_LossBackward(dedy, gold, y, lossName);
if(dedy->data != dedx->data)
......
......@@ -16,8 +16,8 @@
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include <math.h>
#include "Loss.h"
......
......@@ -22,6 +22,7 @@
#include "../XName.h"
#include "Rectify.h"
#include "Rectify.cuh"
#include "CrossEntropy.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -116,7 +117,9 @@ void _RectifyBackward(XTensor * gold, XTensor * y, XTensor * x,
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE)
{
/* calculate dE/dy */
if(lossName != NOLOSS)
if(lossName == CROSSENTROPY)
_CrossEntropyBackward(dedy, y, gold);
else if(lossName != NOLOSS)
_LossBackward(dedy, gold, y, lossName);
DTYPE * dedyp = (DTYPE*)dedy->data;
......
......@@ -22,6 +22,7 @@
#include "Rectify.h"
#include "Rectify.cuh"
#include "Loss.cuh"
#include "CrossEntropy.cuh"
#include "../XDevice.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -133,7 +134,9 @@ void _CudaRectifyBackward(XTensor * gold, XTensor * y, XTensor * x,
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
/* calculate dE/dy */
if(lossName != NOLOSS)
if(lossName == CROSSENTROPY)
_CudaCrossEntropyBackward(dedy, y, gold);
else if(lossName != NOLOSS)
_CudaLossBackward(dedy, gold, y, lossName);
int gridSize[3], blockSize[3];
......
......@@ -23,6 +23,7 @@
#include <math.h>
#include "Sigmoid.h"
#include "Sigmoid.cuh"
#include "CrossEntropy.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -107,7 +108,9 @@ void _SigmoidBackward(XTensor * gold, XTensor * y, XTensor * x,
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE)
{
/* calculate dE/dy */
if(lossName != NOLOSS)
if(lossName == CROSSENTROPY)
_CrossEntropyBackward(dedy, y, gold);
else if(lossName != NOLOSS)
_LossBackward(dedy, gold, y, lossName);
DTYPE * dedyp = (DTYPE*)dedy->data;
......
......@@ -22,6 +22,7 @@
#include "Sigmoid.h"
#include "Sigmoid.cuh"
#include "Loss.cuh"
#include "CrossEntropy.cuh"
#include "../XDevice.h"
#ifdef USE_CUDA
......@@ -128,7 +129,9 @@ void _CudaSigmoidBackward(XTensor * gold, XTensor * y, XTensor * x,
{
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
/* calculate dE/dy */
if(lossName != NOLOSS)
if(lossName == CROSSENTROPY)
_CudaCrossEntropyBackward(dedy, y, gold);
else if(lossName != NOLOSS)
_LossBackward(dedy, gold, y, lossName);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-06-27
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-06-27
*/
#include "TCopyIndexed.h"
......@@ -344,6 +344,235 @@ bool TestCopyIndexed3()
#endif // USE_CUDA
}
/*
case 4: copy indexed sub-tensors
In this case, (3, 2, 3) -> (3, 2, 2), dim = 2, indexSize = 2,
srcIndex = [0, 2], tgtIndex = [0, 1], copyNum = 1.
*/
bool TestCopyIndexed4()
{
/* a input tensor of size (3, 2, 3) */
int sOrder = 3;
int * sDimSize = new int[sOrder];
sDimSize[0] = 3;
sDimSize[1] = 2;
sDimSize[2] = 3;
int sUnitNum = 1;
for (int i = 0; i < sOrder; i++)
sUnitNum *= sDimSize[i];
/* a output tensor of size (3, 2, 2) */
int tOrder = 3;
int * tDimSize = new int[tOrder];
tDimSize[0] = 3;
tDimSize[1] = 2;
tDimSize[2] = 2;
int tUnitNum = 1;
for (int i = 0; i < tOrder; i++)
tUnitNum *= tDimSize[i];
/* a index tensor of size(2) */
int iOrder = 3;
int * iDimSize = new int[iOrder];
iDimSize[0] = 3;
iDimSize[1] = 2;
iDimSize[2] = 2;
int iUnitNum = 1;
for (int i = 0; i < iOrder; i++)
iUnitNum *= iDimSize[i];
DTYPE sData[3][2][3] = { { {0.0F, -1.0F, 2.0F},
{2.0F, 1.0F, 3.0F} },
{ {1.0F, 2.0F, 4.0F},
{3.0F, 1.0F, 2.0F}},
{ {-1.0F, 3.0F, 2.0F},
{1.0F, -1.0F, 0.0F} } };
DTYPE answer[3][2][2] = { { {0.0F, 2.0F},
{2.0F, 3.0F} },
{ {1.0F, 4.0F},
{3.0F, 2.0F}},
{ {-1.0F, 2.0F},
{1.0F, 0.0F} } };
int dim = 2;
int indexSize = 2;
int srcIndex[2] = {0, 2};
int tgtIndex[2] = {0, 1};
int copyNum = 1;
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * t = NewTensor(tOrder, tDimSize);
XTensor * index = NewTensor(tOrder, tDimSize, X_INT);
XTensor tUser;
/* initialize variables */
s->SetData(sData, sUnitNum);
t->SetZeroAll();
index->SetData(srcIndex, iUnitNum);
/* call CopyIndexed function */
_CopyIndexed(s, t, dim, (int*)index->data, indexSize, tgtIndex, copyNum);
tUser = CopyIndexed(*s, dim, (int*)index->data, indexSize, tgtIndex, copyNum);
/* check results */
cpuTest = t->CheckData(answer, tUnitNum) && tUser.CheckData(answer, tUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
XTensor * tGPU = NewTensor(sOrder, tDimSize, X_FLOAT, 1.0F, 0);
XTensor tUserGPU;
/* initialize variables */
sGPU->SetData(sData, sUnitNum);
tGPU->SetZeroAll();
/* call CopyIndexed function */
_CopyIndexed(sGPU, tGPU, dim, (int*)index->data, indexSize, tgtIndex, copyNum);
tUserGPU = CopyIndexed(*sGPU, dim, srcIndex, indexSize, tgtIndex, copyNum);
/* check results */
gpuTest = tGPU->CheckData(answer, tUnitNum) && tUserGPU.CheckData(answer, tUnitNum);
/* destroy variables */
delete s;
delete t;
delete index;
delete sGPU;
delete tGPU;
delete[] sDimSize;
delete[] tDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete t;
delete[] sDimSize;
delete[] tDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 5: copy indexed sub-tensors
In this case, (3, 2, 3) -> (3, 2, 2), dim = 2, indexSize = 1,
srcIndex = [0, 1], tgtIndex = [0, 2], copyNum = 2.
*/
bool TestCopyIndexed5()
{
/* a input tensor of size (3, 2, 3) */
int sOrder = 3;
int * sDimSize = new int[sOrder];
sDimSize[0] = 3;
sDimSize[1] = 2;
sDimSize[2] = 3;
int sUnitNum = 1;
for (int i = 0; i < sOrder; i++)
sUnitNum *= sDimSize[i];
/* a output tensor of size (3, 2, 2) */
int tOrder = 3;
int * tDimSize = new int[tOrder];
tDimSize[0] = 3;
tDimSize[1] = 2;
tDimSize[2] = 4;
int tUnitNum = 1;
for (int i = 0; i < tOrder; i++)
tUnitNum *= tDimSize[i];
DTYPE sData[3][2][3] = { { {0.0F, -1.0F, 2.0F},
{2.0F, 1.0F, 3.0F} },
{ {1.0F, 2.0F, 4.0F},
{3.0F, 1.0F, 2.0F}},
{ {-1.0F, 3.0F, 2.0F},
{1.0F, -1.0F, 0.0F} } };
DTYPE answer[3][2][4] = { { {0.0F, -1.0F, -1.0F, 2.0F},
{2.0F, 1.0F, 1.0F, 3.0F} },
{ {1.0F, 2.0F, 2.0F, 4.0F},
{3.0F, 1.0F, 1.0F, 2.0F}},
{ {-1.0F, 3.0F, 3.0F, 2.0F},
{1.0F, -1.0F, -1.0F, 0.0F} } };
int dim = 2;
int indexSize = 2;
int srcIndex[2] = {0, 1};
int tgtIndex[2] = {0, 2};
int copyNum = 2;
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * t = NewTensor(tOrder, tDimSize);
XTensor tUser;
/* initialize variables */
s->SetData(sData, sUnitNum);
t->SetZeroAll();
/* call CopyIndexed function */
_CopyIndexed(s, t, dim, srcIndex, indexSize, tgtIndex, copyNum);
tUser = CopyIndexed(*s, dim, srcIndex, indexSize, tgtIndex, copyNum);
/* check results */
cpuTest = t->CheckData(answer, tUnitNum) && tUser.CheckData(answer, tUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
XTensor * tGPU = NewTensor(sOrder, tDimSize, X_FLOAT, 1.0F, 0);
XTensor tUserGPU;
/* initialize variables */
sGPU->SetData(sData, sUnitNum);
tGPU->SetZeroAll();
/* call CopyIndexed function */
_CopyIndexed(sGPU, tGPU, dim, srcIndex, indexSize, tgtIndex, copyNum);
tUserGPU = CopyIndexed(*sGPU, dim, srcIndex, indexSize, tgtIndex, copyNum);
/* check results */
gpuTest = tGPU->CheckData(answer, tUnitNum) && tUserGPU.CheckData(answer, tUnitNum);
/* destroy variables */
delete s;
delete t;
delete sGPU;
delete tGPU;
delete[] sDimSize;
delete[] tDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete t;
delete[] sDimSize;
delete[] tDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
......@@ -381,6 +610,24 @@ bool TestCopyIndexed()
}
else
XPRINT(0, stdout, ">> case 3 passed!\n");
/* case 4 test */
caseFlag = TestCopyIndexed4();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 4 failed!\n");
}
else
XPRINT(0, stdout, ">> case 4 passed!\n");
/* case 5 test */
caseFlag = TestCopyIndexed5();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 5 failed!\n");
}
else
XPRINT(0, stdout, ">> case 5 passed!\n");
/* other cases test */
/*
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-06-27
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-06-27
*/
#ifndef __TEST_COPYINDEXED_H__
#define __TEST_COPYINDEXED_H__
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-18
*/
#include "../core/math/Unary.h"
#include "TCos.h"
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-17
*/
#include <math.h>
#include "TCrossEntropy.h"
#include "../core/math/ScaleAndShift.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: test CrossEntropy function.
loss = sum_{i} (-t_i * log(y_i))
where t_i is the gold standard and y_i is the model output.
*/
bool TestCrossEntropy1()
{
/* a tensor of size (1, 4) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 1;
dimSize[1] = 4;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE outputData[4] = {0.25F, 0.25F, 0.25F, 0.25F};
DTYPE goldData[4] = {0.5F, 0.5F, 0.0F, 0.0F};
DTYPE answer = 1.3863F;
DTYPE error1;
DTYPE error2;
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * output = NewTensor(order, dimSize);
XTensor * gold = NewTensor(order, dimSize);
XTensor * loss = NewTensor1D(1);
/* initialize variables */
output->SetData(outputData, unitNum);
gold->SetData(goldData, unitNum);
/* call CrossEntropy function */
_CrossEntropyManual(output, gold, loss);
error2 = _CrossEntropy(output, gold, REDUCE_SUM);
error1 = loss->Get1D(0);
/* check results */
cpuTest = (fabs(error1 - answer) < 1e-4F &&
fabs(error2 - answer) < 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * outputGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * goldGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * lossGPU = NewTensor1D(1, X_FLOAT, 0);
/* Initialize variables */
outputGPU->SetData(outputData, unitNum);
goldGPU->SetData(goldData, unitNum);
/* call CrossEntropy function */
_CrossEntropyManual(outputGPU, goldGPU, lossGPU);
error1 = lossGPU->Get1D(0);
error2 = _CrossEntropy(outputGPU, goldGPU, REDUCE_SUM);
/* check results */
gpuTest = (fabs(error1 - answer) < 1e-4F &&
fabs(error2 - answer) < 1e-4F);
/* destroy variables */
delete output;
delete gold;
delete loss;
delete outputGPU;
delete goldGPU;
delete lossGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete output;
delete gold;
delete loss;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 2: test CrossEntropy function.
loss = sum_{i} (-t_i * log(y_i))
where t_i is the gold standard and y_i is the model output.
*/
bool TestCrossEntropy2()
{
/* a tensor of size (4, 10) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 4;
dimSize[1] = 10;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE outputData[4][10] = { {0.5F, 2.6F, 0.3F, 1.7F, 0.6F,
0.1F, 0.7F, 1.3F, 0.4F, 0.6F},
{0.5F, 1.6F, 0.2F, 1.1F, 0.3F,
0.8F, 2.2F, 0.1F, 0.1F, 0.8F},
{0.2F, 0.5F, 1.1F, 1.2F, 0.6F,
0.1F, 0.2F, 0.7F, 0.5F, 0.7F},
{0.2F, 1.7F, 0.6F, 1.5F, 0.8F,
0.1F, 0.8F, 0.1F, 0.6F, 0.2F} };
DTYPE answer1 = 4.3275F;
DTYPE answer2 = 1.0818F;
DTYPE error1;
DTYPE error2;
DTYPE error3;
DTYPE error4;
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * output = NewTensor(order, dimSize);
XTensor * gold = NewTensor(order, dimSize);
/* initialize variables */
output->SetData(outputData, unitNum);
gold->SetZeroAll();
gold->Set2D(1.0F, 0, 9);
gold->Set2D(1.0F, 1, 7);
gold->Set2D(1.0F, 2, 2);
gold->Set2D(1.0F, 3, 9);
/* call CrossEntropy function */
error1 = _CrossEntropy(output, gold, REDUCE_SUM);
error2 = _CrossEntropy(output, gold, REDUCE_MEAN);
error3 = _CrossEntropyManual(output, gold, REDUCE_SUM);
error4 = _CrossEntropyManual(output, gold, REDUCE_MEAN);
/* check results */
cpuTest = (fabs(error1 - answer1) < 1e-4F &&
fabs(error2 - answer2) < 1e-4F &&
fabs(error3 - answer1) < 1e-4F &&
fabs(error4 - answer2) < 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * outputGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * goldGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
/* Initialize variables */
outputGPU->SetData(outputData, unitNum);
goldGPU->SetZeroAll();
goldGPU->Set2D(1.0F, 0, 9);
goldGPU->Set2D(1.0F, 1, 7);
goldGPU->Set2D(1.0F, 2, 2);
goldGPU->Set2D(1.0F, 3, 9);
/* call CrossEntropy function */
error1 = _CrossEntropy(outputGPU, goldGPU, REDUCE_SUM);
error2 = _CrossEntropy(outputGPU, goldGPU, REDUCE_MEAN);
error3 = _CrossEntropyManual(outputGPU, goldGPU, REDUCE_SUM);
error4 = _CrossEntropyManual(outputGPU, goldGPU, REDUCE_MEAN);
/* check results */
gpuTest = (fabs(error1 - answer1) < 1e-4F &&
fabs(error2 - answer2) < 1e-4F &&
fabs(error3 - answer1) < 1e-4F &&
fabs(error4 - answer2) < 1e-4F);
/* destroy variables */
delete output;
delete gold;
delete outputGPU;
delete goldGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete output;
delete gold;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 3: test CrossEntropy function.
loss = sum_{i} (-t_i * log(y_i))
where t_i is the gold standard and y_i is the model output.
In this case, I compute the cross entropy with weight.
*/
bool TestCrossEntropy3()
{
/* a output tensor of size (4, 4) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 4;
dimSize[1] = 4;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
/* a weight tensor of size (4) */
int wOrder = 1;
int * wDimSize = new int[wOrder];
wDimSize[0] = 4;
int wUnitNum = 1;
for (int i = 0; i < wOrder; i++)
wUnitNum *= wDimSize[i];
DTYPE outputData[4][4] = { {0.3F, 0.2F, 0.3F, 0.2F},
{0.1F, 0.4F, 0.2F, 0.3F},
{0.7F, 0.1F, 0.1F, 0.1F},
{0.5F, 0.1F, 0.2F, 0.2F} };
DTYPE weightData[4] = {2.0F, 1.0F, 5.0F, 0.0F};
DTYPE answer[4] = {2.4079F, 0.9163F, 11.5129F, 0.0F};
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * output = NewTensor(order, dimSize);
XTensor * gold = NewTensor(order, dimSize);
XTensor * loss = NewTensor1D(4);
XTensor * weight = NewTensor(wOrder, wDimSize);
/* initialize variables */
output->SetData(outputData, unitNum);
weight->SetData(weightData, wUnitNum);
gold->SetZeroAll();
gold->Set2D(1.0F, 0, 0);
gold->Set2D(1.0F, 1, 1);
gold->Set2D(1.0F, 2, 2);
gold->Set2D(1.0F, 3, 3);
/* call CrossEntropy function */
_CrossEntropyManual(output, gold, loss, weight);
/* check results */
cpuTest = loss->CheckData(answer, 4, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * outputGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * goldGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * lossGPU = NewTensor1D(4, X_FLOAT, 0);
XTensor * weightGPU = NewTensor(wOrder, wDimSize, X_FLOAT, 1.0F, 0);
/* Initialize variables */
outputGPU->SetData(outputData, unitNum);
weightGPU->SetData(weightData, wUnitNum);
goldGPU->SetZeroAll();
goldGPU->Set2D(1.0F, 0, 0);
goldGPU->Set2D(1.0F, 1, 1);
goldGPU->Set2D(1.0F, 2, 2);
goldGPU->Set2D(1.0F, 3, 3);
/* call CrossEntropy function */
_CrossEntropyManual(outputGPU, goldGPU, lossGPU, weightGPU);
/* check results */
gpuTest = lossGPU->CheckData(answer, 4, 1e-4F);
/* destroy variables */
delete output;
delete gold;
delete loss;
delete weight;
delete outputGPU;
delete goldGPU;
delete lossGPU;
delete weightGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete output;
delete gold;
delete loss;
delete weight;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 4: test CrossEntropy function.
loss = sum_{i} (-t_i * log(y_i))
where t_i is the gold standard and y_i is the model output.
*/
bool TestCrossEntropy4()
{
/* a tensor of size (10, 1) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 10;
dimSize[1] = 1;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
/* CPU test */
bool cpuTest = true;
DTYPE answer = 0.0F;
DTYPE error;
/* create tensors */
XTensor * output = NewTensor(order, dimSize);
XTensor * gold = NewTensor(order, dimSize);
/* initialize variables */
output->SetZeroAll();
gold->SetZeroAll();
_ScaleAndShiftMe(output, 1, 1);
_ScaleAndShiftMe(gold, 1, 2);
/* call CrossEntropy function */
error = _CrossEntropyManual(output, gold);
/* check results */
cpuTest = (fabs(error - answer) < 1e-4);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * outputGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * goldGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
/* Initialize variables */
outputGPU->SetZeroAll();
goldGPU->SetZeroAll();
_ScaleAndShiftMe(outputGPU, 1, 1);
_ScaleAndShiftMe(goldGPU, 1, 2);
/* call CrossEntropy function */
error = _CrossEntropyManual(outputGPU, goldGPU);
/* check results */
gpuTest = (fabs(error - answer) < 1e-4);
/* destroy variables */
delete output;
delete gold;
delete outputGPU;
delete goldGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete output;
delete gold;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for CrossEntropy Function */
bool TestCrossEntropy()
{
XPRINT(0, stdout, "[TEST CrossEntropy] compute the cross entropy loss and backward gradient \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestCrossEntropy1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestCrossEntropy2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* case 3 test */
caseFlag = TestCrossEntropy3();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 3 failed!\n");
}
else
XPRINT(0, stdout, ">> case 3 passed!\n");
/* case 4 test */
caseFlag = TestCrossEntropy4();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 4 failed!\n");
}
else
XPRINT(0, stdout, ">> case 4 passed!\n");
///* other cases test */
///*
//TODO!!
//*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-17
*/
#ifndef __TEST_CROSSENTROPY_H__
#define __TEST_CROSSENTROPY_H__
#include "../function/CrossEntropy.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for CrossEntropy Function */
bool TestCrossEntropy();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_CROSSENTROPY_H__
......@@ -259,7 +259,7 @@ bool TestDivDim2()
/* test for DivDim Function */
bool TestDivDim()
{
XPRINT(0, stdout, "[TEST DIVDIM] tensor division c(i) = a/b + \alpha * c by broadcasting\n");
XPRINT(0, stdout, "[TEST DIVDIM] tensor division c(i) = a/b + \\alpha * c by broadcasting\n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-18
*/
#include "TGather.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: gather indexed sub-tensors
In this case, (3, 2, 3) -> (3, 2, 2), dim = 2,
srcIndex = [0, 2], indexSize = 2
*/
bool TestGather1()
{
/* a input tensor of size (3, 2, 3) */
int sOrder = 3;
int * sDimSize = new int[sOrder];
sDimSize[0] = 3;
sDimSize[1] = 2;
sDimSize[2] = 3;
int sUnitNum = 1;
for (int i = 0; i < sOrder; i++)
sUnitNum *= sDimSize[i];
/* a output tensor of size (3, 2, 2) */
int tOrder = 3;
int * tDimSize = new int[tOrder];
tDimSize[0] = 3;
tDimSize[1] = 2;
tDimSize[2] = 2;
int tUnitNum = 1;
for (int i = 0; i < tOrder; i++)
tUnitNum *= tDimSize[i];
DTYPE sData[3][2][3] = { { {0.0F, -1.0F, 2.0F},
{2.0F, 1.0F, 3.0F} },
{ {1.0F, 2.0F, 4.0F},
{3.0F, 1.0F, 2.0F}},
{ {-1.0F, 3.0F, 2.0F},
{1.0F, -1.0F, 0.0F} } };
DTYPE answer[3][2][2] = { { {0.0F, 2.0F},
{2.0F, 3.0F} },
{ {1.0F, 4.0F},
{3.0F, 2.0F}},
{ {-1.0F, 2.0F},
{1.0F, 0.0F} } };
int dim = 2;
int indexSize = 2;
int srcIndex[2] = {0, 2};
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * t = NewTensor(tOrder, tDimSize);
XTensor tUser;
/* initialize variables */
s->SetData(sData, sUnitNum);
t->SetZeroAll();
/* call Gather function */
_Gather(s, t, dim, srcIndex, indexSize);
tUser = Gather(*s, dim, srcIndex, indexSize);
/* check results */
cpuTest = t->CheckData(answer, tUnitNum) && tUser.CheckData(answer, tUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
XTensor * tGPU = NewTensor(sOrder, tDimSize, X_FLOAT, 1.0F, 0);
XTensor tUserGPU;
/* initialize variables */
sGPU->SetData(sData, sUnitNum);
tGPU->SetZeroAll();
/* call Gather function */
_Gather(sGPU, tGPU, dim, srcIndex, indexSize);
tUserGPU = Gather(*sGPU, dim, srcIndex, indexSize);
/* check results */
gpuTest = tGPU->CheckData(answer, tUnitNum) && tUserGPU.CheckData(answer, tUnitNum);
/* destroy variables */
delete s;
delete t;
delete sGPU;
delete tGPU;
delete[] sDimSize;
delete[] tDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete t;
delete[] sDimSize;
delete[] tDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 2: gather indexed sub-tensors
In this case, (3, 2, 3) -> (3, 1, 3), dim = 1,
srcIndex = [0], indexSize = 1
*/
bool TestGather2()
{
/* a input tensor of size (3, 2, 3) */
int sOrder = 3;
int * sDimSize = new int[sOrder];
sDimSize[0] = 3;
sDimSize[1] = 2;
sDimSize[2] = 3;
int sUnitNum = 1;
for (int i = 0; i < sOrder; i++)
sUnitNum *= sDimSize[i];
/* a output tensor of size (3, 1, 3) */
int tOrder = 3;
int * tDimSize = new int[tOrder];
tDimSize[0] = 3;
tDimSize[1] = 1;
tDimSize[2] = 3;
int tUnitNum = 1;
for (int i = 0; i < tOrder; i++)
tUnitNum *= tDimSize[i];
DTYPE sData[3][2][3] = { { {0.0F, -1.0F, 2.0F},
{2.0F, 1.0F, 3.0F} },
{ {1.0F, 2.0F, 4.0F},
{3.0F, 1.0F, 2.0F}},
{ {-1.0F, 3.0F, 2.0F},
{1.0F, -1.0F, 0.0F} } };
DTYPE answer[3][1][3] = { { {0.0F, -1.0F, 2.0F} },
{ {1.0F, 2.0F, 4.0F} } ,
{ {-1.0F, 3.0F, 2.0F} } };
int dim = 1;
int indexSize = 1;
int srcIndex[2] = {0};
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * t = NewTensor(tOrder, tDimSize);
XTensor tUser;
/* initialize variables */
s->SetData(sData, sUnitNum);
t->SetZeroAll();
/* call Gather function */
_Gather(s, t, dim, srcIndex, indexSize);
tUser = Gather(*s, dim, srcIndex, indexSize);
/* check results */
cpuTest = t->CheckData(answer, tUnitNum) && tUser.CheckData(answer, tUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
XTensor * tGPU = NewTensor(sOrder, tDimSize, X_FLOAT, 1.0F, 0);
XTensor tUserGPU;
/* initialize variables */
sGPU->SetData(sData, sUnitNum);
tGPU->SetZeroAll();
/* call Gather function */
_Gather(sGPU, tGPU, dim, srcIndex, indexSize);
tUserGPU = Gather(*sGPU, dim, srcIndex, indexSize);
/* check results */
gpuTest = tGPU->CheckData(answer, tUnitNum) && tUserGPU.CheckData(answer, tUnitNum);
/* destroy variables */
delete s;
delete t;
delete sGPU;
delete tGPU;
delete[] sDimSize;
delete[] tDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete t;
delete[] sDimSize;
delete[] tDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for Gather Function */
bool TestGather()
{
XPRINT(0, stdout, "[TEST Gather] gather indexed sub-tensors \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestGather1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestGather2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-18
*/
#ifndef __TEST_GATHER_H__
#define __TEST_GATHER_H__
#include "../core/movement/Gather.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for Gather Function */
bool TestGather();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_GATHER_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-04-30
*/
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-04-30
*/
#ifndef __TEST_LOSS_H__
#define __TEST_LOSS_H__
......
......@@ -251,7 +251,7 @@ bool TestMultiplyDim2()
/* test for MultiplyDim Function */
bool TestMultiplyDim()
{
XPRINT(0, stdout, "[TEST MULTIPLYDIM] tensor multiplication c = a * b + \alpha * c by broadcasting\n");
XPRINT(0, stdout, "[TEST MULTIPLYDIM] tensor multiplication c = a * b + \\alpha * c by broadcasting\n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-27
*/
#include "TReduceSumAll.h"
#include <math.h>
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: test ReduceSumAll function
sum all the items of the tensor
*/
bool TestReduceSumAll1()
{
/* a tensor of size (2, 4) */
int sOrder = 2;
int * sDimSize = new int[sOrder];
sDimSize[0] = 2;
sDimSize[1] = 4;
int sUnitNum = 1;
for (int i = 0; i < sOrder; i++)
sUnitNum *= sDimSize[i];
DTYPE sData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE summation;
DTYPE answer = 28.0F;
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(sOrder, sDimSize);
/* initialize variables */
s->SetData(sData, sUnitNum);
/* call ReduceSumAll function */
summation = _ReduceSumAll(s);
/* check results */
cpuTest = (fabs(answer - summation) < 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
/* initialize variables */
sGPU->SetData(sData, sUnitNum);
/* call ReduceSumAll function */
summation = _ReduceSumAll(sGPU);
/* check results */
gpuTest = (fabs(answer - summation) < 1e-4F);
/* destroy variables */
delete s;
delete sGPU;
delete[] sDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete[] sDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for ReduceSumAll Function */
bool TestReduceSumAll()
{
XPRINT(0, stdout, "[TEST ReduceSumAll] sum the items along a dimension of the tensor.\n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestReduceSumAll1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-27
*/
#ifndef __TEST_REDUCESUMALL_H__
#define __TEST_REDUCESUMALL_H__
#include "../core/reduce/ReduceSumAll.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for ReduceSumAll Function */
bool TestReduceSumAll();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_REDUCESUMALL_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-06
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-06
*/
#include "TSetData.h"
#include "../core/getandset/SetData.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -80,6 +81,327 @@ bool TestSetData1()
#endif // USE_CUDA
}
/*
case 2: test SetDataIndexed function.
modify data items along with a given dimension.
*/
bool TestSetData2()
{
/* a input tensor of size (2, 4) */
int sOrder = 2;
int * sDimSize = new int[sOrder];
sDimSize[0] = 2;
sDimSize[1] = 4;
int sUnitNum = 1;
for (int i = 0; i < sOrder; i++)
sUnitNum *= sDimSize[i];
/* a data tensor of size (4) for GPU test */
int dataOrder = 1;
int * dataDimSize = new int[dataOrder];
dataDimSize[0] = 4;
int dataUnitNum = 1;
for (int i = 0; i < dataOrder; i++)
dataUnitNum *= dataDimSize[i];
DTYPE data[4] = {0.0F, 1.0F, 2.0F, 3.0F};
DTYPE answer[2][4] = { {1.0F, 1.0F, 1.0F, 1.0F},
{0.0F, 1.0F, 2.0F, 3.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * modify = NewTensor(dataOrder, dataDimSize);
/* Initialize variables */
_SetDataFixedFloat(s, 1.0F);
modify->SetData(data, dataUnitNum);
/* call SetDataIndexed function */
_SetDataIndexed(s, modify, 0, 1);
/* check results */
cpuTest = s->CheckData(answer, sUnitNum, 1e-5F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
XTensor * modifyGPU = NewTensor(dataOrder, dataDimSize, X_FLOAT, 1.0F, 0);
/* Initialize variables */
_SetDataFixedFloat(sGPU, 1.0F);
modifyGPU->SetData(data, dataUnitNum);
/* call SetDataIndexed function */
_SetDataIndexed(sGPU, modifyGPU, 0, 1);
gpuTest = sGPU->CheckData(answer, sUnitNum, 1e-5F);
/* destroy variables */
delete s;
delete modify;
delete sGPU;
delete modifyGPU;
delete[] sDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete modify;
delete[] sDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 3: test SetDataIndexed function.
modify data items along with a given dimension.
*/
bool TestSetData3()
{
/* a input tensor of size (2, 4, 3) */
int sOrder = 3;
int * sDimSize = new int[sOrder];
sDimSize[0] = 2;
sDimSize[1] = 4;
sDimSize[2] = 3;
int sUnitNum = 1;
for (int i = 0; i < sOrder; i++)
sUnitNum *= sDimSize[i];
/* a data tensor of size (2, 3) for GPU test */
int dataOrder = 2;
int * dataDimSize = new int[dataOrder];
dataDimSize[0] = 2;
dataDimSize[1] = 3;
int dataUnitNum = 1;
for (int i = 0; i < dataOrder; i++)
dataUnitNum *= dataDimSize[i];
DTYPE data[2][3] = { {0.0F, 1.0F, 2.0F},
{3.0F, 4.0F, 5.0F} };
DTYPE answer[2][4][3] = { { {1.0F, 1.0F, 1.0F},
{0.0F, 1.0F, 2.0F},
{1.0F, 1.0F, 1.0F},
{1.0F, 1.0F, 1.0F} },
{ {1.0F, 1.0F, 1.0F},
{3.0F, 4.0F, 5.0F},
{1.0F, 1.0F, 1.0F},
{1.0F, 1.0F, 1.0F} } };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * modify = NewTensor(dataOrder, dataDimSize);
/* Initialize variables */
_SetDataFixedFloat(s, 1.0F);
modify->SetData(data, dataUnitNum);
/* call SetDataIndexed function */
_SetDataFixedFloat(s, 1.0F);
_SetDataIndexed(s, modify, 1, 1);
/* check results */
cpuTest = s->CheckData(answer, sUnitNum, 1e-5F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
XTensor * modifyGPU = NewTensor(dataOrder, dataDimSize, X_FLOAT, 1.0F, 0);
/* Initialize variables */
_SetDataFixedFloat(sGPU, 1.0F);
modifyGPU->SetData(data, dataUnitNum);
/* call SetDataIndexed function */
_SetDataIndexed(sGPU, modifyGPU, 1, 1);
gpuTest = sGPU->CheckData(answer, sUnitNum, 1e-5F);
/* destroy variables */
delete s;
delete modify;
delete sGPU;
delete modifyGPU;
delete[] sDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete modify;
delete[] sDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 4: test SetDataDim function.
set data items along with a given dimension (and keep the remaining items unchanged)
*/
bool TestSetData4()
{
/* a input tensor of size (3, 3) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 3;
dimSize[1] = 3;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE sData[3][3] = { {1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F},
{7.0F, 8.0F, 9.0F} };
DTYPE answer[3][3] = { {1.0F, 2.0F, 3.0F},
{0.0F, 0.0F, 0.0F},
{7.0F, 8.0F, 9.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(order, dimSize);
/* initialize variables */
s->SetData(sData, unitNum);
/* call _SetDataDim function */
_SetDataDim(s, 1, 1, 0, 0);
/* check results */
cpuTest = s->CheckData(answer, unitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
/* initialize variables */
sGPU->SetData(sData, unitNum);
/* call _SetDataDim function */
_SetDataDim(sGPU, 1, 1, 0, 0);
gpuTest = sGPU->CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete s;
delete sGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 5: test SetDataDim function.
set data items along with a given dimension (and keep the remaining items unchanged)
*/
bool TestSetData5()
{
/* a input tensor of size (2, 4, 3) */
int order = 3;
int * dimSize = new int[order];
dimSize[0] = 2;
dimSize[1] = 4;
dimSize[2] = 3;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE data[2][4][3] = { { {1.0F, 1.0F, 1.0F},
{0.0F, 1.0F, 2.0F},
{1.0F, 1.0F, 1.0F},
{1.0F, 1.0F, 1.0F} },
{ {1.0F, 1.0F, 1.0F},
{3.0F, 4.0F, 5.0F},
{1.0F, 1.0F, 1.0F},
{1.0F, 1.0F, 1.0F} } };
DTYPE answer[2][4][3] = { { {1.0F, 1.0F, 1.0F},
{0.0F, 1.0F, 2.0F},
{5.0F, 5.0F, 5.0F},
{1.0F, 1.0F, 1.0F} },
{ {1.0F, 1.0F, 1.0F},
{3.0F, 4.0F, 5.0F},
{5.0F, 5.0F, 5.0F},
{1.0F, 1.0F, 1.0F} } };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(order, dimSize);
/* initialize variables */
s->SetData(data, unitNum);
/* call _SetDataDim function */
_SetDataDim(s, 2, 1, 1, 5.0F);
/* check results */
cpuTest = s->CheckData(answer, unitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
/* initialize variables */
sGPU->SetData(data, unitNum);
/* call _SetDataDim function */
_SetDataDim(sGPU, 2, 1, 1, 5.0F);
gpuTest = sGPU->CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete s;
delete sGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
......@@ -100,6 +422,42 @@ bool TestSetData()
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestSetData2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* case 3 test */
caseFlag = TestSetData3();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 3 failed!\n");
}
else
XPRINT(0, stdout, ">> case 3 passed!\n");
/* case 4 test */
caseFlag = TestSetData4();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 4 failed!\n");
}
else
XPRINT(0, stdout, ">> case 4 passed!\n");
/* case 5 test */
caseFlag = TestSetData5();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 5 failed!\n");
}
else
XPRINT(0, stdout, ">> case 5 passed!\n");
/* other cases test */
/*
TODO!!
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#include "TSpread.h"
#include "../core/getandset/SetData.h"
#include "../core/movement/Spread.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: test _Spread function.
spread a collection tensor to source tensor.
*/
bool TestSpread1()
{
/* a input tensor of size (2, 4, 3) */
int sOrder = 3;
int * sDimSize = new int[sOrder];
sDimSize[0] = 4;
sDimSize[1] = 4;
sDimSize[2] = 3;
int sUnitNum = 1;
for (int i = 0; i < sOrder; i++)
sUnitNum *= sDimSize[i];
/* a data tensor of size (2, 4, 3) */
int dataOrder = 3;
int * dataDimSize = new int[dataOrder];
dataDimSize[0] = 2;
dataDimSize[1] = 4;
dataDimSize[2] = 3;
int dataUnitNum = 1;
for (int i = 0; i < dataOrder; i++)
dataUnitNum *= dataDimSize[i];
int srcIndex[2] = {0, 1};
int tgtIndex[2] = {0, 1};
DTYPE data[2][4][3] = { { {1.0F, 1.0F, 1.0F},
{0.0F, 1.0F, 2.0F},
{1.0F, 1.0F, 1.0F},
{1.0F, 1.0F, 1.0F} },
{ {1.0F, 1.0F, 1.0F},
{3.0F, 4.0F, 5.0F},
{1.0F, 1.0F, 1.0F},
{1.0F, 1.0F, 1.0F} } };
DTYPE answer[4][4][3] = { { {1.0F, 1.0F, 1.0F},
{0.0F, 1.0F, 2.0F},
{1.0F, 1.0F, 1.0F},
{1.0F, 1.0F, 1.0F} },
{ {1.0F, 1.0F, 1.0F},
{3.0F, 4.0F, 5.0F},
{1.0F, 1.0F, 1.0F},
{1.0F, 1.0F, 1.0F} },
{ {0.0F, 0.0F, 0.0F},
{0.0F, 0.0F, 0.0F},
{0.0F, 0.0F, 0.0F} },
{ {0.0F, 0.0F, 0.0F},
{0.0F, 0.0F, 0.0F},
{0.0F, 0.0F, 0.0F} },
};
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * modify = NewTensor(dataOrder, dataDimSize);
/* Initialize variables */
_SetDataFixedFloat(s, 0.0F);
modify->SetData(data, dataUnitNum);
/* call _Spread function */
_Spread(s, modify, 0, srcIndex, 2, tgtIndex);
/* check results */
cpuTest = s->CheckData(answer, sUnitNum, 1e-5F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
XTensor * modifyGPU = NewTensor(dataOrder, dataDimSize, X_FLOAT, 1.0F, 0);
/* Initialize variables */
_SetDataFixedFloat(sGPU, 0.0F);
modifyGPU->SetData(data, dataUnitNum);
/* call _Spread function */
_Spread(sGPU, modifyGPU, 0, srcIndex, 2, tgtIndex);
gpuTest = sGPU->CheckData(answer, sUnitNum, 1e-5F);
/* destroy variables */
delete s;
delete modify;
delete sGPU;
delete modifyGPU;
delete[] sDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete[] sDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for Spread Function */
bool TestSpread()
{
XPRINT(0, stdout, "[TEST Spread] spread a collection tensor to source tensor \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestSpread1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-09-25
*/
#ifndef __TEST_SPREAD_H__
#define __TEST_SPREAD_H__
#include "../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for Spread Function */
bool TestSpread();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_SPREAD_H__
\ No newline at end of file
......@@ -28,7 +28,7 @@ bool Test()
{
bool wrong = false;
XPRINT(0, stdout, "Testing the XTensor utilites ... \n\n");
wrong = !TestAbsolute() || wrong;
wrong = !TestClip() || wrong;
wrong = !TestConcatenate() || wrong;
......@@ -40,6 +40,7 @@ bool Test()
wrong = !TestDiv() || wrong;
wrong = !TestDivDim() || wrong;
wrong = !TestExp() || wrong;
wrong = !TestGather() || wrong;
wrong = !TestLog() || wrong;
wrong = !TestMatrixMul() || wrong;
wrong = !TestMatrixMul2D() || wrong;
......@@ -54,6 +55,7 @@ bool Test()
wrong = !TestReduceMax() || wrong;
wrong = !TestReduceMean() || wrong;
wrong = !TestReduceSum() || wrong;
wrong = !TestReduceSumAll() || wrong;
wrong = !TestReduceSumSquared() || wrong;
wrong = !TestReduceVariance() || wrong;
wrong = !TestRound() || wrong;
......@@ -75,7 +77,8 @@ bool Test()
//wrong = !TestTopK() || wrong;
wrong = !TestUnsqueeze() || wrong;
wrong = !TestXMem() || wrong;
wrong = !TestCrossEntropy() || wrong;
wrong = !TestDropout() || wrong;
wrong = !TestHardTanH() || wrong;
wrong = !TestIdentity() || wrong;
......
......@@ -33,6 +33,7 @@
#include "TDiv.h"
#include "TDivDim.h"
#include "TExp.h"
#include "TGather.h"
#include "TLog.h"
#include "TMatrixMul.h"
#include "TMatrixMul2D.h"
......@@ -47,6 +48,7 @@
#include "TReduceMax.h"
#include "TReduceMean.h"
#include "TReduceSum.h"
#include "TReduceSumAll.h"
#include "TReduceSumSquared.h"
#include "TReduceVariance.h"
#include "TRound.h"
......@@ -69,6 +71,7 @@
#include "TUnsqueeze.h"
#include "TXMem.h"
#include "TCrossEntropy.h"
#include "TDropout.h"
#include "THardTanH.h"
#include "TIdentity.h"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论