Commit 06e95a0a by 张裕浩

update code and fix reduceSum bug

parent fa2ed07c
......@@ -25,23 +25,55 @@
#include "../tensor/core/CHeader.h"
#include "../sample/fnnlm/FNNLM.h"
#include "../tensor/test/Test.h"
#include <cuda_runtime.h>
#include <time.h>
#include <windows.h>
//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>
//#include <crtdbg.h>
using namespace nts;
using namespace samplefnnlm;
using namespace fnnlm;
void SetDataTest()
{
int * dimSize = new int[2];
dimSize[0] = 10000;
dimSize[1] = 1000;
XTensor b1(2, dimSize, X_FLOAT, 1.0F, 0, NULL);
XTensor b2(2, dimSize, X_FLOAT, 1.0F, 0, NULL);
XTensor b3(2, dimSize, X_FLOAT, 1.0F, -1, NULL);
DWORD m_start_time;
DWORD m_end_time;
double time_diff = 0.0;
m_start_time = GetTickCount();
_SetDataRand(&b1, -2.0F, 2.0F);
cudaThreadSynchronize();
m_end_time = GetTickCount();
time_diff = m_end_time - m_start_time;
printf("time %f ms\n", time_diff);
m_start_time = GetTickCount();
_SetDataRand(&b3, -2.0F,2.0F);
cudaThreadSynchronize();
m_end_time = GetTickCount();
time_diff = m_end_time - m_start_time;
printf("time %f ms\n", time_diff);
}
int main( int argc, const char ** argv )
{
//if(argc > 1 && !strcmp(argv[1], "-test"))
if(argc > 1 && !strcmp(argv[1], "-test"))
Test();
/*else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
//SetDataTest();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
FNNLMMain(argc - 1, argv + 1);
else{
/*else{
fprintf(stderr, "Thanks for using NiuTrans.Network! This is a library for building\n");
fprintf(stderr, "neural networks in an easy way. \n\n");
fprintf(stderr, "Run this program with \"-test\" for unit test!\n");
......
......@@ -50,6 +50,7 @@ void XFuncGrad::MakeGrad(XTensor * node)
_IdentityBackward(NULL, output, input, output->grad, input->grad, NOLOSS);
else if(operID == FUNC_LOGSOFTMAX){
int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in logsoftmax!");
_LogSoftmaxBackward(NULL, output, input, output->grad, input->grad, leadDim, NOLOSS);
}
else if(operID == FUNC_RECTIFY)
......@@ -58,11 +59,14 @@ void XFuncGrad::MakeGrad(XTensor * node)
_SigmoidBackward(NULL, output, input, output->grad, input->grad, NOLOSS);
else if(operID == FUNC_SOFTMAX){
int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in softmax!");
_SoftmaxBackward(NULL, output, input, output->grad, input->grad, leadDim, NOLOSS);
}
else{
ShowNTErrors("Wrong activation function type!");
}
node->visitMark = NODE_FINISHED;
}
/* indicates whether the node is for an activation function */
......
......@@ -44,13 +44,105 @@ private:
static
void GradSum(XTensor * node);
/* gradient for multiply (dot production): c = a * b */
/* gradient for sum with one dimension: c = a + b * \beta
where the size of b is equal to that of one dimension of a */
static
void GradSumDim(XTensor * node);
/* gradient for multiply (dot production): c = a * b * \alpha */
static
void GradMultiply(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) */
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static
void GradMatrixMul(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static
void GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE transA,
XTensor * b, XTensor * dedb, MATRIX_TRANS_TYPE transB,
XTensor * dedc, DTYPE alpha);
/* gradient for matrix multiply in batch mode.
for each batch: c_i = matmul(a_i, b_i) * \alpha */
static
void GradMatrixMulBatched(XTensor * node);
/* gradient for log: c = log(a) */
static
void GradLog(XTensor * node);
/* gradient for power */
static
void GradPower(XTensor * node);
/* gradient for negate */
static
void GradNegate(XTensor * node);
/* gradient for ScaleAndShift */
static
void GradScaleAndShift(XTensor * node);
/* gradient for Minus */
static
void GradSub(XTensor * node);
/* gradient for Divide */
static
void GradDiv(XTensor * node);
/* gradient for reduceMean */
static
void GradReduceMean(XTensor * node);
/* gradient for reduceSum */
static
void GradReduceSum(XTensor * node);
/* gradient for reduceSumSquared */
static
void GradReduceSumSquared(XTensor * node);
/* gradient for reduceVariance */
static
void GradReduceVariance(XTensor * node);
/* gradient for sin */
static
void GradSin(XTensor * node);
/* gradient for cos */
static
void GradCos(XTensor * node);
/* gradient for tan */
static
void GradTan(XTensor * node);
/* gradient for exp */
static
void GradExp(XTensor * node);
/* gradient for normalize */
static
void GradNormalize(XTensor * node);
/* gradient for absolute */
static
void GradAbsolute(XTensor * node);
/* gradient for sign */
static
void GradSign(XTensor * node);
/* gradient for clip */
static
void GradClip(XTensor * node);
/* gradient for round */
static
void GradRound(XTensor * node);
};
}
......
......@@ -43,6 +43,12 @@ void XShapeGrad::MakeGrad(XTensor * node)
GradMergeList(node);
else if(operID == SHAPE_UNSQUEEZE)
GradUnsqueeze(node);
else if(operID == SHAPE_SPLIT)
GradSplit(node);
else if(operID == SHAPE_SPLIT_LIST)
GradSplitList(node);
else if (operID == SHAPE_TRANSPOSE)
GradTranspose(node);
else{
ShowNTErrors("TODO!");
}
......@@ -55,6 +61,13 @@ bool XShapeGrad::IsShapeOP(XTensor * node)
return (income.typeID & DATA_BASE) != 0;
}
/* post processing of a node */
void XShapeGrad::PostProcessing(XTensor * node, int typeID)
{
if(typeID == SHAPE_SPLIT_LIST)
GradSplitListPost(node);
}
/*
gradient for merge
for
......@@ -134,6 +147,8 @@ void XShapeGrad::GradMerge(XTensor * node)
gradInputSmall.data = NULL;
delete[] dims;
node->visitMark = NODE_FINISHED;
}
/*
......@@ -213,6 +228,120 @@ void XShapeGrad::GradMergeList(XTensor * node)
gradSmall.data = NULL;
delete[] dims;
}
node->visitMark = NODE_FINISHED;
}
/*
gradient computation for split:
for
c = split(a)
we have
dE/da = merge(dE/dc)
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradSplit(XTensor * node)
{
XLink &income = node->income;
XTensor * input = income.tails[0];
int whereToSplit = income.GetParamInt(0);
int splitNum = income.GetParamInt(1);
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SPLIT!");
CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!");
CheckNTErrors(splitNum == node->dimSize[0], "Wrong split number!");
XNoder::MakeGrad(input);
/* we can simply merge the gradient tensor
if the input is used in spliting only */
if(input->outgo.tailNum == 1)
_Merge(node->grad, input->grad, whereToSplit + 1, 0);
/* if the tensor is used somewhere else, we need another SUM
for gradient accumulation */
else{
XTensor inputGradTMP(input);
_Merge(node->grad, &inputGradTMP, whereToSplit + 1, 0);
_Sum(input->grad, &inputGradTMP, input->grad);
}
node->visitMark = NODE_FINISHED;
}
/*
gradient computation for spliting
where we return the list of the splits
for
list(c_1, ...) = split(a)
we have
dE/da = merge(dE/c_1, ...)
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradSplitList(XTensor * node)
{
XLink &income = node->income;
XTensor * input = income.tails[0];
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SPLIT!");
CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!");
node->visitMark = NODE_DOING;
}
/*
gradient computation for spliting. We return
the list of the splits : list(c_1, ...) = split(a).
this method is called only when all nodes of spliting
have been processed. We do this in a post-processing
manner because we can fuze multiple memory copy jobs
one time. This is good for system speed up.
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradSplitListPost(XTensor * node)
{
/* we compute the gradient for current node, rather than for
child node, i.e., we use the outgoing edge here */
XLink &outgo = node->outgo;
XList splits(outgo.tailNum);
int whereToSplit = -1;
int splitNum = 0;
for(int i = 0; i < outgo.tailNum; i++){
XTensor * parent = (XTensor*)outgo.tails[i];
XLink &income = parent->income;
if(income.typeID == SHAPE_SPLIT_LIST){
int w = income.GetParamInt(0);
int splitID = income.GetParamInt(1);
if(whereToSplit < 0)
whereToSplit = w;
splitNum++;
CheckNTErrors(whereToSplit == w, "Wrong dimension for spliting");
CheckNTErrors(income.tailNum == 1, "Something wrong with outgoing edge!");
CheckNTErrors(splitNum - 1 == splitID, "Wrong split id!");
splits.Add(parent);
}
}
/* we can simply merge the gradient tensor
if the node is used in spliting only */
if(outgo.tailNum == splitNum){
_Merge(&splits, node->grad, whereToSplit + 1);
}
/* if the tensor is used as input to other nodes
somewhere else, we need another SUM for gradient
accumulation */
else{
XTensor nodeGradTMP(node);
_Merge(&splits, &nodeGradTMP, whereToSplit + 1);
_Sum(node->grad, &nodeGradTMP, node->grad);
}
}
/*
......@@ -239,6 +368,40 @@ void XShapeGrad::GradUnsqueeze(XTensor * node)
CheckNTErrors(output->unitNum = input->unitNum * dSize, "Wrong tensor size!");
_ReduceSum(output->grad, input->grad, dim);
node->visitMark = NODE_FINISHED;
}
/*
gradient for transposing a tensor
for
c = Transpose(a)
we have
dE/da = Transpose(dE/dc)
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradTranspose(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for TRANSPOSE!");
XTensor * output = node;
XTensor * input = income.tails[0];
XTensor * b = NewTensor(input);
XNoder::MakeGrad(input);
int i = income.GetParamInt(0);
int j = income.GetParamInt(1);
CheckNTErrors(input->order > i && i >= 0, "index of dimension is out of scope!");
CheckNTErrors(input->order > j && j >= 0, "index of dimension is out of scope!");
_Transpose(output->grad, b, i, j);
_Sum(input->grad, b, input->grad);
node->visitMark = NODE_FINISHED;
delete b;
}
}
\ No newline at end of file
......@@ -40,18 +40,41 @@ public:
static
bool IsShapeOP(XTensor * node);
/* post processing of a node */
static
void PostProcessing(XTensor * node, int typeId);
private:
/* gradient for merge: c = merge(a, b, ...) */
/* gradient computation for merge: c = merge(a, b, ...) */
static
void GradMerge(XTensor * node);
/* gradient for merging a list of tensors : c = merge(list(a, b, ...)) */
/* gradient computation for merging a list of tensors : c = merge(list(a, b, ...)) */
static
void GradMergeList(XTensor * node);
/* gradient for unsqueezing a tensor : c = unsqueeze(a) */
/* gradient computation for split: c = split(a) */
static
void GradSplit(XTensor * node);
/* gradient computation for spliting. we return the list of the splits : list(c_1, ...) = split(a) */
static
void GradSplitList(XTensor * node);
/* gradient computation for spliting. we return the list of the splits : list(c_1, ...) = split(a).
this method is called only when all nodes of spliting have been processed. We do this in a post-processing
manner because we can fuze multiple memory copy jobs one time. This is good for system speed up. */
static
void GradSplitListPost(XTensor * node);
/* gradient computation for unsqueezing a tensor : c = unsqueeze(a) */
static
void GradUnsqueeze(XTensor * node);
/* gradient computation for unsqueezing a tensor : c = unsqueeze(a) */
static
void GradTranspose(XTensor * node);
};
}
......
......@@ -46,6 +46,11 @@ unsigned int MakeNetID()
return id;
}
void XNetClearAll()
{
MUTEX_DELE(netMutex);
}
/* constructor */
XNet::XNet()
{
......@@ -143,7 +148,7 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
/* back-propagation from output to input */
for(int i = nodes.count - 1; i >= 0; i--){
XTensor * node = (XTensor*)nodes.Get(i);
XTensor * node = (XTensor*)nodes.Get(i);;
if(node->visitMark == NODE_FINISHED)
continue;
......@@ -176,6 +181,10 @@ void XNet::BackwardNode(XTensor * node)
return;
if(!XNoder::IsLeaf(node)){
/* post processing for parent nodes */
BackwardNodePost(node);
/* process the current node */
if(XMathGrad::IsMathOP(node))
XMathGrad::MakeGrad(node);
else if(XFuncGrad::IsFunc(node))
......@@ -186,8 +195,24 @@ void XNet::BackwardNode(XTensor * node)
ShowNTErrors("Wrong node type!");
}
}
}
/*
backward computation (in post processing) for a given node
>> node - the node whose parent nodes are not processed yet. So
we do the job at the child node.
*/
void XNet::BackwardNodePost(XTensor * node)
{
bool isSplitList = false;
XLink &outgo = node->outgo;
for(int i = 0; i < outgo.tailNum; i++){
if(outgo.tails[i]->income.typeID == SHAPE_SPLIT_LIST)
isSplitList = true;
}
node->visitMark = NODE_FINISHED;
if(isSplitList)
XShapeGrad::PostProcessing(node, SHAPE_SPLIT_LIST);
}
/*
......@@ -238,10 +263,11 @@ void XNet::TarjanVisit(XTensor * node, XList &orders, const unsigned int code)
if(node == NULL)
return;
//fprintf(stderr, "%d\n", node->id);
if(node->visitMark == code + 1){
ShowNTErrors("There is a circle in the network\n");
}
else if(node->visitMark <= code || node->visitMark >= code + 2){
else if(node->visitMark <= code){
node->visitMark = code + 1;
XLink &income = node->income;
for(int i = 0; i < income.tailNum; i++){
......
......@@ -73,6 +73,9 @@ struct XNet
/* backward computation for a given node */
void BackwardNode(XTensor * node);
/* backward computation (in post processing) for a given node */
void BackwardNodePost(XTensor * node);
/* traverse the net and find the topological order by
depth-first search (Tarjan's algorithm) */
void Traverse(XTensor &root);
......@@ -92,6 +95,7 @@ struct XNet
extern unsigned int netIDGlobal;
extern MUTEX_HANDLE netMutex;
extern unsigned int MakeNetID();
extern void XNetClearAll();
}
......
......@@ -36,7 +36,7 @@
using namespace nts;
namespace samplefnnlm
namespace fnnlm
{
#define _EXIT_(x)// exit(x)
......@@ -126,7 +126,7 @@ struct FNNNet
XTensor output;
};
/* entry of the program */
/* entrance of the program */
int FNNLMMain(int argc, const char ** argv);
};
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "T2TAttention.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TAttention::T2TAttention()
{
nhead = -1;
dk = -1;
dv = -1;
d = -1;
}
/* deconstructor */
T2TAttention::~T2TAttention()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TAttention::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
float minmax = 0;
LoadParamInt(argc, argv, "nhead", &nhead, 8);
LoadParamInt(argc, argv, "d", &dk, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &dv, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
LoadParamFloat(argc, argv, "attminmax", &minmax, 0.1F);
InitTensor2D(&wk, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wq, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wv, d, dv, X_FLOAT, devID, mem);
float scale = 1.0F;
float finfoutk = (float)sqrt(6.0F * scale/(d + dk));
float finfoutv = (float)sqrt(6.0F * scale/(d + dv));
wk.SetDataRand(-finfoutk, finfoutk);
wq.SetDataRand(-finfoutk, finfoutk);
wv.SetDataRand(-finfoutv, finfoutv);
}
/*
make the network
>> k - keys. It might be of size B * L * H
where B = batch size, L = sequence length,
and H = vector size of each position
>> q - queries
>> v - values
<< return - multi-attention result
*/
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v)
{
XTensor k2;
XTensor q2;
XTensor v2;
/* linear transofmration before self-attention */
k2 = MMul(k, wk);
q2 = MMul(q, wq);
v2 = MMul(v, wv);
XTensor kheads;
XTensor qheads;
XTensor vheads;
/* multi head */
kheads = Split(k2, k2.order - 1, nhead);
qheads = Split(q2, q2.order - 1, nhead);
vheads = Split(v2, v2.order - 1, nhead);
XTensor att;
XTensor scalar;
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */
scalar = Softmax(Linear(BMMul(qheads, X_NOTRANS, kheads, X_TRANS), 1/(float)sqrt((float)dk)), -1);
att = BMMul(scalar, vheads);
/* concatenate the heads */
return Merge(att, att.order - 1);
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TATTENTION_H__
#define __T2TATTENTION_H__
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
/*
multi-head attention
y(Q, K, V) = cat(head_1, head_2, ..., head_n)
where head_i = Attention(Q * w_i^Q, K * w_i^K, V * w_i^V)
attention(Q, K, V) = softmax(Q * K^T/d_k^0.5) V
d_k = dimension size of K
*/
class T2TAttention
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* head number */
int nhead;
/* transformation matrix for K */
XTensor wk;
/* transformation matrix for Q */
XTensor wq;
/* transformation matrix for V */
XTensor wv;
/* size of transformed Q and K */
int dk;
/* size of transformed V */
int dv;
/* size of input Q, K and V */
int d;
public:
/* constructor */
T2TAttention();
/* de-constructor */
~T2TAttention();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TDECODER_H__
#define __T2TDECODER_H__
namespace transformer
{
class T2TDecoder
{
};
class AttDecoder : T2TDecoder
{
public:
/* initialize the model */
void InitModel(int argc, const char ** argv);
};
}
#endif
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-01
*/
#include <math.h>
#include "T2TEmbedding.h"
#include "T2TUtility.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TEmbedder::T2TEmbedder()
{
devID = -1;
mem = NULL;
vSize = -1;
maxLength = -1;
}
/* deconstructor */
T2TEmbedder::~T2TEmbedder()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TEmbedder::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamInt(argc, argv, "maxlen", &maxLength, 512);
LoadParamInt(argc, argv, "d", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
InitTensor2D(&w, vSize, eSize, X_FLOAT, devID, mem);
w.SetDataRandn(0, 1.0F/(float)sqrt((float)eSize));
/* create the positional embedding matrix */
MakePosEmbedding(eSize, d, maxLength);
}
/*
make positional embeddings (of size eSize * length
eSize - embedding size
length - length of the sequenc
*/
void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
{
InitTensor2D(&posEmbeddingBase, length, eSize, X_FLOAT, devID, mem);
float * data = new float[posEmbeddingBase.unitNum];
for(int pos = 0; pos < length; pos++){
float * dp = data + pos * eSize;
for(int k = 0; k < eSize; k++){
if(k % 2 == 0){
int i = k/2;
dp[k] = (float)sin(pos/pow(10000.0F, 2.0F*i/d));
}
else{
int i = (k - 1)/2;
dp[k] = (float)cos(pos/pow(10000.0F, 2.0F*i/d));
}
}
}
posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum);
delete[] data;
}
/*
make the network
*/
XTensor T2TEmbedder::Make(XTensor &input)
{
CheckNTErrors(input.GetDim(-1) == vSize, "Wrong vocabulary size!");
CheckNTErrors(input.order > 1, "Wrong input tensor size!");
CheckNTErrors(input.dimSize[input.order - 2] < maxLength, "The sequence is too long!");
CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\"");
CheckNTErrors(eSize > 0, "set embedding size by \"-esize\"");
int dims[MAX_TENSOR_DIM_NUM];
memcpy(dims, input.dimSize, input.order * sizeof(int));
dims[input.order - 1] = eSize;
bool match = (posEmbedding.order == input.order);
if(match){
for(int i = 0; i < input.order; i++){
if(dims[i] != posEmbedding.GetDim(i))
match = false;
}
}
/* we make positional embeddings first */
if(!match){
InitTensor(&posEmbedding, input.order, dims, X_FLOAT, 1.0F, devID, mem);
XTensor * posTMP = NewTensorBuf(2, dims + 1, X_FLOAT, 1.0F, devID, mem);
_CopyValues(&posEmbeddingBase, 0, posTMP->unitNum, posTMP, 0);
_Unsqueeze(posTMP, &posEmbedding, 0, dims[0]);
DelTensorBuf(posTMP);
}
XTensor wordEmbedding;
/* then we make word embeddings */
wordEmbedding = Linear(MMul(input, w), (float)sqrt((float)d));
/* we sum over the two embeddings */
return wordEmbedding +posEmbedding;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-01
*/
#ifndef __T2TEMBEDDING_H__
#define __T2TEMBEDDING_H__
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
#define DEFAULT_EMBEDDING_SIZE 512
/*
embedding (of word at position i):
word embedding + positional embedding
*/
class T2TEmbedder
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* vocabulary size */
int vSize;
/* embedding size */
int eSize;
/* maximum length of the sequence */
int maxLength;
/* dimension size of the hidden layers in the t2t model */
int d;
/* word embedding matrix */
XTensor w;
/* predefined positional embeddings. It can speeds up
the embedding processing by re-loading. */
XTensor posEmbeddingBase;
/* positional embeddings */
XTensor posEmbedding;
public:
/* constructor */
T2TEmbedder();
/* de-constructor */
~T2TEmbedder();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make positional embeddings */
void MakePosEmbedding(int eSize, int d, int length);
/* make the network */
XTensor Make(XTensor &input);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "T2TEncoder.h"
#include "T2TLayerNormal.h"
#include "T2TUtility.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
AttEncoder::AttEncoder()
{
}
/* de-constructor */
AttEncoder::~AttEncoder()
{
delete[] attentions;
delete[] fnns;
delete[] attLayerNorms;
delete[] fnnLayerNorms;
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "esize", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "vsize", &vSize, -1);
CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!");
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsize\"");
/* embedding model */
embedder.InitModel(argc, argv, devID, mem);
attentions = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer];
attLayerNorms = new T2TLN[nlayer];
fnnLayerNorms = new T2TLN[nlayer];
/* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){
attentions[i].InitModel(argc, argv, myDevID, myMem);
fnns[i].InitModel(argc, argv, myDevID, myMem);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
}
}
/*
make the encoding network
>> input - the input tensor of the encoder
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input)
{
XTensor x;
x = embedder.Make(input);
for(int i = 0; i < nlayer; i++){
XTensor att;
XTensor ln;
XTensor fnn;
XTensor res;
/* self attention */
att = attentions[i].Make(x, x, x);
/* residual connection */
res = Sum(att, x);
/* TODO: dropout */
/* layer normalization */
x = attLayerNorms[i].Make(res);
/* fnn */
fnn = fnns[i].Make(x);
/* residual connection */
res = Sum(fnn, x);
/* TODO: dropout */
/* layer normalization */
x = fnnLayerNorms[i].Make(res);
}
return x;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TENCODER_H__
#define __T2TENCODER_H__
#include "T2TFNN.h"
#include "T2TAttention.h"
#include "T2TEmbedding.h"
#include "T2TLayerNormal.h"
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
/*
base class of the encoder
*/
class T2TEncoder
{
public:
virtual
XTensor Make(XTensor &input) = 0;
};
/*
the encoder based on RNN
*/
class RNNEncoder : T2TEncoder
{
public:
XTensor Make(XTensor &input);
};
/*
the encoder based on self-attention
*/
class AttEncoder : T2TEncoder
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* layer number */
int nlayer;
/* hidden layer size of the FNN layer */
int hSize;
/* embedding size */
int eSize;
/* vocabulary size */
int vSize;
/* embedding of word at each position */
T2TEmbedder embedder;
/* FNN model of each layer */
T2TFNN * fnns;
/* attention model of each layer */
T2TAttention * attentions;
/* layer normalization for fnn */
T2TLN * fnnLayerNorms;
/* layer normalization for attention */
T2TLN * attLayerNorms;
/* input tensor of the encoder */
XTensor * input;
/* output tensor of the encoder */
XTensor * output;
public:
/* constructor */
AttEncoder();
/* de-constructor */
~AttEncoder();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */
XTensor Make(XTensor &input);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "T2TFNN.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
#include "../../tensor/function/FHeader.h"
namespace transformer
{
/* constructor */
T2TFNN::T2TFNN()
{
inSize = -1;
outSize = -1;
hSize = -1;
}
/* deconstructor */
T2TFNN::~T2TFNN()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
float minmax = 0;
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &outSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.1F);
InitTensor2D(&w1, inSize, hSize, X_FLOAT, devID, mem);
InitTensor1D(&b1, hSize, X_FLOAT, devID, mem);
InitTensor2D(&w2, hSize, outSize, X_FLOAT, devID, mem);
InitTensor1D(&b2, outSize, X_FLOAT, devID, mem);
float scale = 1.0F;
float finfout1 = (float)sqrt(6.0F * scale/(inSize + hSize));
float finfout2 = (float)sqrt(6.0F * scale/(hSize + outSize));
w1.SetDataRand(-finfout1, finfout1);
b1.SetZeroAll();
w2.SetDataRand(-finfout2, finfout2);
b2.SetZeroAll();
}
/*
make the network
y = max(0, x * w1 + b1) * w2 + b2
>> input - the input tensor
>> return - the output tensor
*/
XTensor T2TFNN::Make(XTensor &input)
{
XTensor t1;
/* t1 = max(0, x * w1 + b1) */
t1 = Rectify(MMul(input, w1) + b1);
/* result = t1 * w2 + b2 */
return MMul(t1, w2) + b2;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TFNN_H__
#define __T2TFNN_H__
#include "../../tensor/XTensor.h"
using namespace nts;
namespace transformer
{
/* a fnn: y = max(0, x * w1 + b1) * w2 + b2 */
class T2TFNN
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* size of input vector */
int inSize;
/* size of output vector */
int outSize;
/* size of hidden layers */
int hSize;
/* matrix of transformation 1 */
XTensor w1;
/* bias of transformation 1 */
XTensor b1;
/* matrix of transformation 2 */
XTensor w2;
/* bias of transformation 2 */
XTensor b2;
public:
/* constructor */
T2TFNN();
/* deconstructor */
~T2TFNN();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include "T2TLayerNormal.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TLN::T2TLN()
{
devID = -1;
mem = NULL;
}
/* de-constructor */
T2TLN::~T2TLN()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TLN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
int d = 0;
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
InitTensor2D(&w, d, d, X_FLOAT, devID, mem);
InitTensor1D(&b, d, X_FLOAT, devID, mem);
float scale = 1.0F;
float finfout = (float)sqrt(6.0F * scale / (d + d));
w.SetDataRand(-finfout, finfout);
b.SetZeroAll();
}
/*
make the network
for each layer representation x, we have
y =
>> input - the input tensor
>> return - layer normalization output
*/
XTensor T2TLN::Make(XTensor &input)
{
XTensor &x = input;
XTensor xn;
XTensor mean;
XTensor variance;
XTensor standard;
XTensor meanFilled;
XTensor standardFilled;
/* \mu = (sum_i x_i)/m */
mean = ReduceMean(x, x.order - 1);
/* \sigma = (sum_i (x_i - \mu)^2)/m */
variance = ReduceVariance(x, x.order - 1, mean);
/* standard = sqrt(variance) */
standard = Power(variance, 0.5F);
/* unsqueeze mean and standard deviation to fit them into
the same shape of x */
meanFilled = Unsqueeze(mean, x.order - 1, x.GetDim(-1));
standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1));
/* x' = (x - \mu)/standard */
xn = (x - meanFilled)/standardFilled ;
/* result = x' * w + b */
return MMul(xn, w) + b;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TLAYERNORMAL_H__
#define __T2TLAYERNORMAL_H__
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
/* layer normalization: y = norm(x) * w + b
where norm(x) = (x - mean)/standardDeviation */
class T2TLN
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* the transformation matrix w */
XTensor w;
/* the bias term b */
XTensor b;
public:
/* constructor */
T2TLN();
/* de-constructor */
~T2TLN();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include "T2TModel.h"
#include "T2TUtility.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TModel::T2TModel()
{
devID = -1;
mem = NULL;
isLM = false;
isMT = false;
}
/* de-constructor */
T2TModel::~T2TModel()
{
delete mem;
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TModel::InitModel(int argc, const char ** argv)
{
bool useMem = false;
LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamBool(argc, argv, "lm", &isLM, true);
LoadParamBool(argc, argv, "mt", &isMT, false);
if(useMem){
delete mem;
mem = new XMem(devID);
}
encoder.InitModel(argc, argv, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem);
}
/*
make the encoding network
>> input - input tensor
<< return - encoding result
*/
XTensor T2TModel::MakeEncoding(XTensor &input)
{
return encoder.Make(input);
}
/*
make the entire network (with the output softmax layer)
>> input - input tensor
>> output - output tensor (distribution)
*/
void T2TModel::Make(XTensor &input, XTensor &output)
{
XTensor encoding;
if(isLM){
encoding = MakeEncoding(input);
outputLayer.Make(encoding, output);
}
else{
ShowNTErrors("TODO!");
}
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TMODEL_H__
#define __T2TMODEL_H__
#include "T2TFNN.h"
#include "T2TAttention.h"
#include "T2TEncoder.h"
#include "T2TDecoder.h"
#include "T2TOutput.h"
namespace transformer
{
class T2TModel
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* the encoder */
AttEncoder encoder;
/* the decoder */
AttDecoder decoder;
/* output layer */
T2TOutput outputLayer;
/* indicates whether the model is running for language modeling */
bool isLM;
/* indicates whether the model is running for machine translation */
bool isMT;
public:
/* constructor */
T2TModel();
/* de-constructor */
~T2TModel();
/* initialize the model */
void InitModel(int argc, const char ** argv);
/* make the encoding network */
XTensor MakeEncoding(XTensor &input);
/* make the entire network (with the output softmax layer) */
void Make(XTensor &input, XTensor &output);
};
}
#endif
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "T2TOutput.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TOutput::T2TOutput()
{
devID = -1;
mem = NULL;
vSize = -1;
inSize = -1;
hSize = -1;
}
/* de-constructor */
T2TOutput::~T2TOutput()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TOutput::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
float minmax = 0;
LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamFloat(argc, argv, "outputminmax", &minmax, 0.08F);
InitTensor2D(&w, hSize, vSize, X_FLOAT, devID, mem);
float scale = 1.0F;
float finfout = (float)sqrt(6.0F * scale/(hSize + vSize));
w.SetDataRand(-finfout, finfout);
}
/*
make the network
y = softmax(x * w)
>> input - input tensor
<< return - output tensor
*/
XTensor T2TOutput::Make(XTensor &input)
{
XTensor &x = input;
return LogSoftmax(MMul(x, w), -1);
}
/*
make the network (redefined output tensor)
>> input - input tensor
>> output - output tensor
*/
void T2TOutput::Make(XTensor &input, XTensor &output)
{
XTensor &x = input;
output = LogSoftmax(MMul(x, w), -1);
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TOUTPUT_H__
#define __T2TOUTPUT_H__
#include "../../tensor/function/FHeader.h"
using namespace nts;
namespace transformer
{
/* output layer */
class T2TOutput
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* vocabulary size */
int vSize;
/* input vector size */
int inSize;
/* vector size of the linear transformation */
int hSize;
/* transformation matrix */
XTensor w;
public:
/* constructor */
T2TOutput();
/* de-constructor */
~T2TOutput();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
/* make the network (redefined output tensor) */
void Make(XTensor &input, XTensor &output);
};
}
#endif
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02
*/
#ifndef __T2TTRAINER_H__
#define __T2TTRAINER_H__
#include "T2TModel.h"
#include "../../tensor/function/FHeader.h"
#define MAX_SEQUENCE_LENGTH 1024 * 4
using namespace nts;
namespace transformer
{
/* trainer of the T2T model */
class T2TTrainer
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* buffer for loading words */
int * buf;
/* buffer size */
int bufSize;
/* length of each sequence */
int * seqLen;
/* offset of the first word for each sequence */
int * seqOffset;
/* number of sequences in the buffer */
int nseqBuf;
/* offset for next sequence in the buffer */
int nextSeq;
/* indicates whether the sequence is sorted by length */
bool isLenSorted;
/* dimension size of each inner layer */
int d;
/* step number of warm-up for training */
int nwarmup;
/* vocabulary size of the source side */
int vSize;
/* learning rate */
float lrate;
/* sentence batch size */
int sBatchSize;
/* word batch size */
int wBatchSize;
/* training epoch number */
int nepoch;
/* traing step number */
int nstep;
public:
/* constructor */
T2TTrainer();
/* de-constructor */
~T2TTrainer();
/* initialize the trainer */
void Init(int argc, const char ** argv);
/* train the model */
void Train(const char * fn, T2TModel * model);
/* load data to buffer */
int LoadBuf(FILE * file);
/* load a batch of sequences */
int LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount);
/* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
/* update the model by delta rule */
void Update(T2TModel * model, const float lr);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
namespace transformer
{
FILE * tmpFILE;
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname) && i + 1 < argc){
strcpy(p, argv[i + 1]);
//fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
hit = true;
}
}
if(!hit)
strcpy(p, defaultP);
}
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname) && i + 1 < argc){
*(int*)p = atoi(argv[i + 1]);
//fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
hit = true;
}
}
if(!hit)
*p = defaultP;
}
void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname)){
*(bool*)p = true;
//fprintf(stderr, " %s=%s\n", name, "true");
hit = true;
}
}
if(!hit)
*p = defaultP;
}
void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname) && i + 1 < argc){
*p = (float)atof(argv[i + 1]);
//fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
hit = true;
}
}
if(!hit)
*p = defaultP;
}
void ShowParams(int argc, const char ** argv)
{
fprintf(stderr, "args:\n");
for(int i = 0; i < argc; i++){
if(argv[i][0] == '-'){
if(i + 1 < argc && argv[i + 1][0] != '-')
fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]);
else
fprintf(stderr, " %s=yes\n", argv[i]);
}
}
fprintf(stderr, "\n");
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TUTILITY_H__
#define __T2TUTILITY_H__
#include <stdio.h>
namespace transformer
{
extern FILE * tmpFILE;
/* load arguments */
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP);
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP);
void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP);
void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP);
/* show arguments */
void ShowParams(int argc, const char ** argv);
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include "Transformer.h"
#include "T2TModel.h"
#include "T2TUtility.h"
#include "T2TTrainer.h"
#include "../../tensor/XDevice.h"
namespace transformer
{
int TransformerMain(int argc, const char ** argv)
{
if(argc == 0)
return 1;
tmpFILE = fopen("tmp.txt", "wb");
ShowParams(argc, argv);
char * trainFN = new char[MAX_LINE_LENGTH];
LoadParamString(argc, argv, "train", trainFN, "");
T2TModel model;
model.InitModel(argc, argv);
if(strcmp(trainFN, "")){
T2TTrainer trainer;
trainer.Init(argc, argv);
trainer.Train(trainFN, &model);
}
delete[] trainFN;
fclose(tmpFILE);
return 0;
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
*
* An impelementation of the transformer system. See more details
* about FNNLM in
* "Attention Is All You Need" by Vaswani et al.
* https://arxiv.org/pdf/1706.03762.pdf
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* I start writing the code related to NMT - a long time since my last coding
* work on MT
*/
#ifndef __TRANSFORMER_H__
#define __TRANSFORMER_H__
#include "../../tensor/XGlobal.h"
#include "../../tensor/XTensor.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* entrance of the program */
int TransformerMain(int argc, const char ** argv);
}
#endif
\ No newline at end of file
......@@ -29,6 +29,7 @@
#include "XTensor.h"
#include "XDevice.h"
#include "./test/Test.h"
#include "./core/CHeader.h"
//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>
......@@ -37,6 +38,7 @@
using namespace nts;
void SmallTest();
void TransposeTest();
int main( int argc, const char ** argv )
{
......@@ -92,3 +94,35 @@ void SmallTest()
c.Dump(stderr, "c:");
d.Dump(stderr, "d:");
}
void TransposeTest()
{
XTensor a;
XTensor b;
int I = 2;
int J = 3;
InitTensor4D(&a, 2, 3, 4, 5);
int * dims = new int[a.order];
memcpy(dims, a.dimSize, sizeof(int) * a.order);
dims[I] = a.dimSize[J];
dims[J] = a.dimSize[I];
InitTensor(&b, 4, dims);
a.SetZeroAll();
b.SetZeroAll();
float * data = new float[a.unitNum];
for(int i = 0; i < a.unitNum; i++)
data[i] = (float)i;
a.SetData(data, a.unitNum, 0);
_Transpose(&a, &b, I, J);
b.Dump(stderr, "b:");
delete[] data;
}
......@@ -40,6 +40,7 @@ XDevManager GDevs;
/* constructor */
XDevice::XDevice()
{
stream = NULL;
Clear();
#ifdef USE_CUDA
......@@ -55,6 +56,8 @@ XDevice::~XDevice()
MUTEX_DELE(cublasMutex);
if(isHandleReady)
cublasDestroy(cublasHandle);
if(stream != NULL)
delete stream;
#endif
}
......@@ -118,6 +121,8 @@ void XDevice::Init(int myDevID)
}
else
sprintf(name2, "GPU-%d %s", devID, name);
stream = new XStream(0, devID);
#endif
}
......@@ -161,6 +166,14 @@ cublasHandle_t * XDevice::GetCublasHandle()
return &cublasHandle;
}
/* get the stream of cuda */
cudaStream_t * XDevice::GetCudaStream()
{
CheckNTErrors(stream != NULL, "the stream is not initialized!");
return &stream->stream;
}
#endif // USE_CUDA
/* switch to a device */
......@@ -311,11 +324,19 @@ void XDevManager::Clear()
/* get the handle of GPU */
cublasHandle_t * XDevManager::GetCudaHandle(const int devID)
{
CheckNTErrors((devID < nGPU), "index of GPU is out of range.");
CheckNTErrors(devID < nGPU, "index of GPU is out of range.");
return GPUs[devID].GetCublasHandle();
}
/* get the stream of cuda */
cudaStream_t * XDevManager::GetCudaStream(const int devID)
{
CheckNTErrors(devID < nGPU, "index of GPU is out of range.");
return GPUs[devID].GetCudaStream();
}
#endif
/*
......@@ -384,13 +405,10 @@ int XDevManager::GetCudaThread2D(const int devID, const int n, const int m, int
memset(gridSize, 0, sizeof(int) * 3);
memset(blockSize, 0, sizeof(int) * 3);
if(n <= 0 || m <= 0 || devID >= nGPU)
if(n <= 0 || m <= 0)
return 1;
if(devID < 0){
XPRINT(0, stderr, "WARNING! You are calling the grid and block size computation function on a CPU!");
return 0;
}
CheckNTErrors(devID >= 0 && devID < nGPU, "Invalid GPU device id!");
#ifdef USE_CUDA
......
......@@ -25,6 +25,7 @@
#define __XDEVICE_H__
#include "XThread.h"
#include "XStream.h"
#ifdef USE_CUDA
......@@ -92,6 +93,9 @@ public:
/* specify whether Unified Virtual Address Space (UVA) is supported */
bool isUVASupported;
/* default stream for the device */
XStream * stream;
#ifdef USE_CUDA
/* mutex for handle (GPU cublas) */
......@@ -121,6 +125,9 @@ public:
#ifdef USE_CUDA
/* get cublas handle */
cublasHandle_t * GetCublasHandle();
/* get the stream of cuda */
cudaStream_t * GetCudaStream();
#endif
/* switch to a device */
......@@ -178,6 +185,9 @@ public:
#ifdef USE_CUDA
/* get the handle of GPU */
cublasHandle_t * GetCudaHandle(const int devID);
/* get the stream of cuda */
cudaStream_t * GetCudaStream(const int devID);
#endif
/* get grid and block sizes that max potential */
......
......@@ -167,7 +167,9 @@ void XLink::SetType(int id)
type[0] = 0;
strcpy(type, GetOPName(id));
typeID = id;
CheckNTErrors(strcmp(type, "NULL"), "illegal edge type name!");
if(id != 0){
CheckNTErrors(strcmp(type, "NULL"), "illegal edge type name!");
}
}
/*
......@@ -515,7 +517,7 @@ void XLink::CopyIncoming(const XTensor * reference, XTensor * target)
tails.Add(tail);
}
MakeLink(&tails, target, reference->id);
MakeLink(&tails, target, reference->income.typeID);
int paraNum = reference->income.paramNum;
target->income.paramNum = paraNum;
......
......@@ -208,22 +208,16 @@ void XList::Insert(int pos, void * item)
/* get the item at position i */
void * XList::GetItem(int i) const
{
if( i >= 0 && i < count )
return items[i];
else
return NULL;
CheckNTErrors(i >= 0 && i < count, "Index of a list item is out of scope!");
return items[i];
}
/* get the integer-typed item at position i */
int XList::GetItemInt(int i)
{
CheckNTErrors(isIntList, "An int list is required!");
if( i >= 0 && i < count ){
return *(int*)(items[i]);
}
else
return 0;
CheckNTErrors(i >= 0 && i < count, "Index of a list item is out of scope!");
return *(int*)(items[i]);
}
/* set the item at position i */
......
......@@ -181,7 +181,10 @@ void XMem::Free(int myDevID, void * mem)
else{
#ifdef USE_CUDA
SetDevice(myDevID);
CheckNTErrors(cudaFree((char*)mem) == cudaSuccess, "Cannot free the memory.");
cudaError_t error = cudaFree((char*)mem);
if(error != cudaSuccess){
ShowNTErrors("Cannot free the memory.");
}
#else
ShowNTErrors("Please specify USE_CUDA for compiling this program.");
#endif
......
......@@ -29,6 +29,22 @@ const char * GetOPName(int type)
if ((type & MATH_BASE) != 0){
if (type == MATH_ABSOLUTE)
return "M_ABSOLUTE";
else if (type == MATH_EXP)
return "M_EXP";
else if (type == MATH_LOG)
return "M_LOG";
else if (type == MATH_SIN)
return "M_SIN";
else if (type == MATH_COS)
return "M_COS";
else if (type == MATH_TAN)
return "M_TAN";
else if (type == MATH_ROUND)
return "M_ROUND";
else if (type == MATH_CLIP)
return "M_CLIP";
else if (type == MATH_DIV)
return "M_DIV";
else if (type == MATH_MATRIXMUL)
return "M_MATRIXMUL";
else if (type == MATH_MATRIXMULBATCHED)
......@@ -37,18 +53,20 @@ const char * GetOPName(int type)
return "M_MULTIPLY";
else if (type == MATH_NEGATE)
return "M_NEGATE";
else if (type == MATH_SIGN)
return "M_SIGN";
else if (type == MATH_SUM)
return "M_SUM";
else if (type == MATH_LOG)
return "M_LOG";
else if (type == MATH_NORMALIZE)
return "M_NORMALIZE";
else if (type == MATH_POWER)
return "M_POWER";
else if (type == MATH_SCALEANDSHIFT)
return "M_SCALEANDSHIFT";
else if (type == MATH_SIGN)
return "M_SIGN";
else if (type == MATH_SUM)
return "M_SUM";
else if (type == MATH_SUB)
return "M_SUB";
else if (type == MATH_SUMDIM)
return "M_SUMDIM";
else if (type == REDUCE_REDUCEMAX)
return "R_REDUCEMAX";
else if (type == REDUCE_REDUCEMEAN)
......
......@@ -30,20 +30,30 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* math operations */
#define MATH_BASE 0x00001000
#define MATH_ABSOLUTE MATH_BASE + 1
#define MATH_MATRIXMUL MATH_ABSOLUTE + 1
#define MATH_EXP MATH_ABSOLUTE + 1
#define MATH_LOG MATH_EXP + 1
#define MATH_SIN MATH_LOG + 1
#define MATH_COS MATH_SIN + 1
#define MATH_TAN MATH_COS + 1
#define MATH_ROUND MATH_TAN + 1
#define MATH_CLIP MATH_ROUND + 1
#define MATH_DIV MATH_CLIP + 1
#define MATH_MATRIXMUL MATH_DIV + 1
#define MATH_MATRIXMULBATCHED MATH_MATRIXMUL + 1
#define MATH_MULTIPLY MATH_MATRIXMULBATCHED + 1
#define MATH_NEGATE MATH_MULTIPLY + 1
#define MATH_SIGN MATH_NEGATE + 1
#define MATH_SUM MATH_SIGN + 1
#define MATH_LOG MATH_SUM + 1
#define MATH_NORMALIZE MATH_LOG + 1
#define MATH_NORMALIZE MATH_NEGATE + 1
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
#define MATH_SIGN MATH_SCALEANDSHIFT + 1
#define MATH_SUM MATH_SIGN + 1
#define MATH_SUB MATH_SUM + 1
#define MATH_SUMDIM MATH_SUB + 1
#define REDUCE MATH_SCALEANDSHIFT + 1
#define REDUCE MATH_SUMDIM + 1
#define REDUCE_REDUCEMAX REDUCE + 1
#define REDUCE_REDUCEMEAN REDUCE_REDUCEMAX + 1
#define REDUCE_REDUCESUM REDUCE_REDUCEMEAN + 1
......
......@@ -84,7 +84,7 @@ void XStream::Create(int priority, int myDevID)
XDevice::SetGPUDevice(myDevID);
//cudaStreamCreateWithPriority(&stream, cudaStreamDefault, priority);
CheckNTErrors((cudaStreamCreate(&stream) == cudaSuccess),
"cannot create the cuda stream!");
"cannot create the cuda stream!");
XDevice::SetGPUDevice(backupDevID);
#endif
devID = myDevID;
......
......@@ -42,6 +42,8 @@
#include "core/movement/CopyValues.h"
#include "core/arithmetic/Sum.h"
#include "core/arithmetic/Multiply.h"
#include "core/arithmetic/Sub.h"
#include "core/arithmetic/Div.h"
#include "core/math/ScaleAndShift.h"
#ifdef USE_CUDA
......@@ -354,6 +356,18 @@ XTensor XTensor::operator* (const XTensor& tensor)
return Multiply(*this, tensor);
}
/* overloading of the minus-sign */
XTensor XTensor::operator- (const XTensor& tensor)
{
return Sub(*this, tensor);
}
/* overloading of the division-sign */
XTensor XTensor::operator/ (const XTensor& tensor)
{
return Div(*this, tensor);
}
/*
linear transformation b = a * \scale + \shift
>> scale - the slope
......@@ -426,8 +440,12 @@ get the size of a given dimension
int XTensor::GetDim(const int dim)
{
CheckNTErrors(dim < order, "dimenision is out of range!");
int d = dim;
if(dim < 0)
d = order - 1;
return dimSize[dim];
return dimSize[d];
}
/*
......@@ -454,6 +472,27 @@ void XTensor::Reshape(const int myOrder, const int * myDimSize)
memcpy(dimSizeRDI, dimsRDI, sizeof(int) * order);
}
/*
reshape the tensor to a vector
>> num - number of elements
*/
void XTensor::Reshape(const int num)
{
int dim = num;
Reshape(1, &dim);
}
/*
reshape the tensor to a matrix
>> rowNum - number of rows
>> colNum - number of columns
*/
void XTensor::Reshape(const int rowNum, const int colNum)
{
int dims[2] = {rowNum, colNum};
Reshape(2, dims);
}
/* get the number of items in the data array */
int XTensor::GetSize() const
{
......@@ -560,25 +599,24 @@ set the tensor items by a uniform distribution in range [lower, upper]
void XTensor::SetDataRand(DTYPE lower, DTYPE upper)
{
// TODO: cuda code!!!!!!!
// TODO: replace float with DTYPE
if (data == NULL)
return;
// srand((unsigned)time(0));
DTYPE variance = upper - lower;
void * d = NULL;
if (dataType == X_FLOAT) {
d = new float[unitNum];
for (int i = 0; i < unitNum; i++) {
DTYPE value = lower + (upper - lower) * (float)rand() / RAND_MAX;
DTYPE value = lower + variance * (float)rand() / RAND_MAX;
*((float*)d + i) = value;
}
}
else if (dataType == X_DOUBLE) {
d = new double[unitNum];
for (int i = 0; i < unitNum; i++) {
*((double*)d + i) = lower + (upper - lower) * rand() / RAND_MAX;
*((double*)d + i) = lower + variance * rand() / RAND_MAX;
}
}
else {
......@@ -588,15 +626,15 @@ void XTensor::SetDataRand(DTYPE lower, DTYPE upper)
SetData(d, unitNum);
if (dataType == X_FLOAT) {
delete[](float*)d;
delete[] (float*)d;
}
else {
delete[](double*)d;
delete[] (double*)d;
}
}
/* a gauss distribution */
double GaussRand()
/* a gauss distribution (Box-Muller method) */
double GaussRand(DTYPE mean, DTYPE standardDeviation)
{
// TODO: cuda code!!!!!!!
......@@ -606,8 +644,8 @@ double GaussRand()
double pi = 3.141592654;
if (phase == 0){
u = rand() / (RAND_MAX + 1.0);
v = rand() / (RAND_MAX + 1.0);
u = (rand() + 1.0) / (RAND_MAX + 1.0);
v = (rand() + 1.0) / (RAND_MAX + 1.0);
z = sqrt(-2.0 * log(u))* sin(2.0 * pi * v);
}
else{
......@@ -615,7 +653,7 @@ double GaussRand()
}
phase = 1 - phase;
return z;
return mean + (z * standardDeviation);
}
/*
......@@ -626,7 +664,6 @@ set the tensor items by a normal distribution
void XTensor::SetDataRandn(DTYPE mean, DTYPE standardDeviation)
{
// TODO: cuda code!!!!!!!
// TODO: replace float with DTYPE
if (data == NULL)
return;
......@@ -636,13 +673,13 @@ void XTensor::SetDataRandn(DTYPE mean, DTYPE standardDeviation)
if (dataType == X_FLOAT) {
d = new float[unitNum];
for (int i = 0; i < unitNum; i++) {
*((float*)d + i) = (float)GaussRand();
*((float*)d + i) = (float)GaussRand(mean, standardDeviation);
}
}
else if (dataType == X_DOUBLE) {
d = new double[unitNum];
for (int i = 0; i < unitNum; i++) {
*((double*)d + i) = GaussRand();
*((double*)d + i) = GaussRand(mean, standardDeviation);
}
}
else {
......@@ -652,10 +689,10 @@ void XTensor::SetDataRandn(DTYPE mean, DTYPE standardDeviation)
SetData(d, unitNum);
if (dataType == X_FLOAT) {
delete[](float*)d;
delete[] (float*)d;
}
else {
delete[](double*)d;
delete[] (double*)d;
}
}
......@@ -1003,13 +1040,13 @@ set the value of a cell in a 3d tensor in default type
*/
bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2)
{
CheckNTErrors((order == 3), "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors((d0 >= 0 && d1 < dimSize[0]), "dimension 0 is out of range!");
CheckNTErrors((d2 >= 0 && d2 < dimSize[1]), "dimension 1 is out of range!");
CheckNTErrors((d2 >= 0 && d2 < dimSize[2]), "dimension 1 is out of range!");
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type.");
CheckNTErrors(order == 3, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors(d0 >= 0 && d0 < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors(d1 >= 0 && d1 < dimSize[1], "dimension 1 is out of range!");
CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 1 is out of range!");
CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
int dims[3] = {d0, d1, d1};
int dims[3] = {d0, d1, d2};
return SetToDevice(devID, GetCell(dims, 3), value);
}
......@@ -1439,6 +1476,21 @@ void XTensor::Dump(FILE * file, const char * label, const int n, const int verbo
}
/*
dump data to a file
>> tensor - tensor whose data is dumped
>> file - where to domp the data
>> label - label of the tensor
>> n - number of items to dump
>> verbose - verbose level
*/
void XTensor::Dump(const XTensor * tensor, FILE * file, const char * label, const int n, const int verbose)
{
XTensor a(tensor->order, tensor->dimSize, tensor->dataType, tensor->denseRatio, tensor->devID, tensor->mem);
_CopyValues(tensor, &a);
a.Dump(file, label, n, verbose);
}
/*
read data from a file
>> file - where to load the data
>> label - label of the tensor
......@@ -1687,13 +1739,13 @@ void InitTensor(XTensor * tensor,
dims[0] = -abs(dims[0]);
tensor->Resize(myOrder, dims, myDataType, myDenseRatio);
if(myDevID == CURRENT_GPU)
if (myDevID == CURRENT_GPU)
tensor->devID = XDevice::GetGPUDevice();
else
tensor->devID = myDevID;
tensor->Resize(myOrder, dims, myDataType, myDenseRatio);
if(allocated)
XTensor::AllocateData(tensor);
}
......@@ -1870,28 +1922,47 @@ generate a XTensor which allocates data on the buffer
>> myDimSize - the size of each dimension
>> myMem - memory pool used to allocating the data array.
we actually allocate the data on the buffer associated with
the memory pool.
the memory pool
>> devID - device id
>> myDataType - unit size (e.g., int, float, and double)
>> myDenseRatio - how often an element has non-zero value
*/
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, XMem * myMem,
const TENSOR_DATA_TYPE myDataType, const float myDenseRatio)
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize,
const TENSOR_DATA_TYPE myDataType, const float myDenseRatio,
const int devID, XMem * myMem)
{
CheckNTErrors(myMem != NULL, "No memory pool specified!");
int dims[MAX_TENSOR_DIM_NUM];
memcpy(dims, myDimSize, sizeof(int) * myOrder);
dims[0] = -abs(dims[0]);
XTensor * tensor = NewTensor(myOrder, dims, myDataType, myDenseRatio, -1, myMem);
tensor->data = myMem->AllocBuf(myMem->devID, tensor->unitNum * tensor->unitSize);
XTensor * tensor = NewTensor(myOrder, dims, myDataType, myDenseRatio, devID, myMem);
if(myMem != NULL)
tensor->data = myMem->AllocBuf(myMem->devID, tensor->unitNum * tensor->unitSize);
else
tensor->data = XMemAlloc(devID, tensor->unitNum * tensor->unitSize);
return tensor;
}
/*
generate a XTensor which allocates data on the buffer
>> reference - reference tensor
>> devID - device id
>> myMem - memory pool used to allocating the data array.
we actually allocate the data on the buffer associated with
the memory pool
*/
XTensor * NewTensorBuf(const XTensor * reference, int devID, XMem * myMem)
{
return NewTensorBuf(reference->order, reference->dimSize,
reference->dataType, reference->denseRatio,
devID, myMem);
}
/*
generate a dense vector
>> num - number of entries
>> myDataType - unit size (e.g., int, float, and double)
......@@ -2041,7 +2112,7 @@ XTensor * NewTensor(XTensor * a, bool isFilledData)
free the data space of a given tensor
>> tensor - pointer to the tensor
*/
void DelTensor(const XTensor * tensor)
void DelTensor(XTensor * tensor)
{
delete tensor;
}
......@@ -2050,10 +2121,13 @@ void DelTensor(const XTensor * tensor)
free the data space of a given tensor (on the buffer)
>> tensor - pointer to the tensor
*/
void DelTensorBuf(const XTensor * tensor)
void DelTensorBuf(XTensor * tensor)
{
CheckNTErrors(tensor->mem != NULL, "No memory pool found!");
tensor->mem->ReleaseBuf(tensor->devID, tensor->unitNum * tensor->unitSize);
if(tensor->mem != NULL)
tensor->mem->ReleaseBuf(tensor->devID, tensor->unitNum * tensor->unitSize);
else
XMemFree(tensor->devID, tensor->data);
tensor->data = NULL;
delete tensor;
}
......
......@@ -45,12 +45,13 @@ namespace nts{
struct XLink;
/* define the maximum number of dimensions in a tensor */
#define MAX_TENSOR_DIM_NUM 6
#define MAX_TENSOR_DIM_NUM 8
#define USE_BATCHED_STRIDED_MAT_MUL
#define MIN_TENSOR_SPLIT_NUM 10
#define MIN_TENSOR_SPLIT_NUM 0
#define MIN_TENSOR_SPLIT_LIST_NUM 1024
#define MIN_TENSOR_CAT_NUM 8
/* computation flags */
#define UNSAFE_BUT_FAST_MEM
#define FAST_MATRIX
......@@ -202,6 +203,12 @@ public:
/* overloading of the multiply-sign */
XTensor operator* (const XTensor &tensor);
/* overloading of the minus-sign */
XTensor operator- (const XTensor &tensor);
/* overloading of the division-sign */
XTensor operator/ (const XTensor &tensor);
/* linear transformation */
XTensor Lin(DTYPE scale, DTYPE shift = 0);
......@@ -222,6 +229,12 @@ public:
/* reshape the tensor */
void Reshape(const int order, const int * myDimSize);
/* reshape the tensor to a vector */
void Reshape(const int num);
/* reshape the tensor to a matrix */
void Reshape(const int rowNum, const int colNum);
/* get the number of items in the data array */
int GetSize() const;
......@@ -328,6 +341,10 @@ public:
/* dump data to a file */
void Dump(FILE * file, const char * label = NULL, const int n = -1, const int verbose = 0);
/* dump data to a file */
static
void Dump(const XTensor * tensor, FILE * file, const char * label = NULL, const int n = -1, const int verbose = 0);
/* read data from a file */
void Read(FILE * file, const char * label = NULL);
......@@ -386,8 +403,12 @@ XTensor * NewTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_
const float myDenseRatio = 1.0F, const int myDevID = -1, XMem * myMem = NULL);
/* generate a XTensor which allocates data on the buffer */
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, XMem * myMem,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const float myDenseRatio = 1.0F);
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const float myDenseRatio = 1.0F,
const int myDevID = -1, XMem * myMem = NULL);
/* generate a XTensor which allocates data on the buffer */
XTensor * NewTensorBuf(const XTensor * reference, int devID, XMem * myMem);
/* generate a dense vector */
XTensor * NewTensor1D(const int num, const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1,
......@@ -417,10 +438,10 @@ XTensor * NewTensor5D(const int d0, const int d1, const int d2, const int d3, co
XTensor * NewTensor(XTensor * a, bool isFilledData = true);
/* free the data space of a given tensor */
void DelTensor(const XTensor * tensor);
void DelTensor(XTensor * tensor);
/* free the data space of a given tensor (on the buffer) */
void DelTensorBuf(const XTensor * tensor);
void DelTensorBuf(XTensor * tensor);
} /* end of the nts (NiuTrans.Tensor) namespace */
......
......@@ -175,29 +175,38 @@ void XMemCopy(void * t, int devIDT, const void * s, int devIDS, size_t size)
return;
}
#ifdef USE_CUDA
else if(devIDT >= 0 && devIDS < 0){
cudaError_t error = cudaMemcpy(t, s, size, cudaMemcpyHostToDevice);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyHostToDevice)");
}
}
else if(devIDT < 0 && devIDS >= 0){
cudaError_t error = cudaMemcpy(t, s, size, cudaMemcpyDeviceToHost);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToHost)");
}
}
else{
//if(devIDT == devIDS){
cudaError_t error = cudaMemcpy(t, s, size, cudaMemcpyDeviceToDevice);
int devID = devIDT < 0 ? devIDS : devIDT;
int devIDBackup = 0;
cudaGetDevice(&devIDBackup);
cudaSetDevice(devID);
if(devIDT >= 0 && devIDS < 0){
cudaError_t error = cudaMemcpy(t, s, size, cudaMemcpyHostToDevice);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToDevice)");
ShowNTErrors("cudaMemcpy error (cudaMemcpyHostToDevice)");
}
/*}
}
else if(devIDT < 0 && devIDS >= 0){
cudaError_t error = cudaMemcpy(t, s, size, cudaMemcpyDeviceToHost);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToHost)");
}
}
else{
CheckNTErrors((cudaMemcpyPeer(t, devIDT, s, devIDS, size) == cudaSuccess),
"cudaMemcpy error (cudaMemcpyDeviceToDevice)");
}*/
//if(devIDT == devIDS){
cudaError_t error = cudaMemcpy(t, s, size, cudaMemcpyDeviceToDevice);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToDevice)");
}
/*}
else{
CheckNTErrors((cudaMemcpyPeer(t, devIDT, s, devIDS, size) == cudaSuccess),
"cudaMemcpy error (cudaMemcpyDeviceToDevice)");
}*/
}
cudaSetDevice(devIDBackup);
}
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
......@@ -208,6 +217,9 @@ void XMemCopy(void * t, int devIDT, const void * s, int devIDS, size_t size)
#ifdef USE_CUDA
void XMemCopyAsync(void * t, int devIDT, const void * s, int devIDS, size_t size, cudaStream_t stream, int streamDevID)
{
if(t == s)
return;
int devIDBackup = -1;
if(streamDevID >= 0 && (devIDT >= 0 || devIDS >= 0)){
CheckNTErrors((cudaGetDevice(&devIDBackup) == cudaSuccess), "Cannot get GPU device id!");
......@@ -220,17 +232,23 @@ void XMemCopyAsync(void * t, int devIDT, const void * s, int devIDS, size_t size
return;
}
else if(devIDT >= 0 && devIDS < 0){
CheckNTErrors((cudaMemcpyAsync(t, s, size, cudaMemcpyHostToDevice, stream) == cudaSuccess),
"cudaMemcpyAsync error (cudaMemcpyHostToDevice)");
cudaError_t error = cudaMemcpyAsync(t, s, size, cudaMemcpyHostToDevice, stream);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpyAsync error (cudaMemcpyHostToDevice)");
}
}
else if(devIDT < 0 && devIDS >= 0){
CheckNTErrors((cudaMemcpyAsync(t, s, size, cudaMemcpyDeviceToHost, stream) == cudaSuccess),
"cudaMemcpyAsync error (cudaMemcpyDeviceToHost)");
cudaError_t error = cudaMemcpyAsync(t, s, size, cudaMemcpyDeviceToHost, stream);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpyAsync error (cudaMemcpyDeviceToHost)");
}
}
else{
//if(devIDT == devIDS){
CheckNTErrors((cudaMemcpyAsync(t, s, size, cudaMemcpyDeviceToDevice, stream) == cudaSuccess),
"cudaMemcpyAsync error (cudaMemcpyDeviceToDevice)");
cudaError_t error = cudaMemcpyAsync(t, s, size, cudaMemcpyDeviceToDevice, stream);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpyAsync error (cudaMemcpyDeviceToDevice)");
}
//}
/*else{
CheckNTErrors((cudaMemcpyPeerAsync(t, devIDT, s, devIDS, size, stream) == cudaSuccess),
......@@ -261,18 +279,69 @@ void XMemCopy2D(void * t, size_t tPitch, int devIDT, const void * s, size_t sPit
return;
}
#ifdef USE_CUDA
else if (devIDT >= 0 && devIDS < 0) {
CheckNTErrors((cudaMemcpy2D(t, tPitch, s, sPitch, mSize, n, cudaMemcpyHostToDevice) == cudaSuccess),
"cudaMemcpy2D error (cudaMemcpyHostToDevice)");
else{
int devID = devIDT < 0 ? devIDS : devIDT;
int devIDBackup = 0;
cudaGetDevice(&devIDBackup);
cudaSetDevice(devID);
if (devIDT >= 0 && devIDS < 0) {
cudaError_t error = cudaMemcpy2D(t, tPitch, s, sPitch, mSize, n, cudaMemcpyHostToDevice);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy2D error (cudaMemcpyHostToDevice)");
}
}
else if (devIDT < 0 && devIDS >= 0) {
cudaError_t error = cudaMemcpy2D(t, tPitch, s, sPitch, mSize, n, cudaMemcpyDeviceToHost);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToHost)");
}
}
else {
cudaError_t error = cudaMemcpy2D(t, tPitch, s, sPitch, mSize, n, cudaMemcpyDeviceToDevice);
if (error != cudaSuccess) {
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToDevice)");
}
}
cudaSetDevice(devIDBackup);
}
else if (devIDT < 0 && devIDS >= 0) {
CheckNTErrors((cudaMemcpy2D(t, tPitch, s, sPitch, mSize, n, cudaMemcpyDeviceToHost) == cudaSuccess),
"cudaMemcpy error (cudaMemcpyDeviceToHost)");
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
}
void XMemCopy2DAsync(void * t, size_t tPitch, int devIDT, const void * s, size_t sPitch, int devIDS, size_t mSize, int n, XStream * stream)
{
if (t == s)
return;
if (devIDT < 0 && devIDS < 0) {
for(int i = 0; i < n; i++)
memcpy((char*)t + tPitch * i, (char*)s + sPitch * i, mSize);
return;
}
else {
cudaError_t error = cudaMemcpy2D(t, tPitch, s, sPitch, mSize, n, cudaMemcpyDeviceToDevice);
if (error != cudaSuccess) {
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToDevice)");
#ifdef USE_CUDA
else{
CheckNTErrors(stream != NULL, "No stream found!");
cudaStream_t &cstream = stream->stream;
if (devIDT >= 0 && devIDS < 0) {
cudaError_t error = cudaMemcpy2DAsync(t, tPitch, s, sPitch, mSize, n, cudaMemcpyHostToDevice, cstream);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy2D error (cudaMemcpyHostToDevice)");
}
}
else if (devIDT < 0 && devIDS >= 0) {
cudaError_t error = cudaMemcpy2DAsync(t, tPitch, s, sPitch, mSize, n, cudaMemcpyDeviceToHost, cstream);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToHost)");
}
}
else {
cudaError_t error = cudaMemcpy2DAsync(t, tPitch, s, sPitch, mSize, n, cudaMemcpyDeviceToDevice, cstream);
if (error != cudaSuccess) {
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToDevice)");
}
}
}
#else
......
......@@ -23,6 +23,7 @@
#include <stdio.h>
#include "XGlobal.h"
#include "XDevice.h"
#ifndef __XUTILITY_H__
#define __XUTILITY_H__
......@@ -41,6 +42,7 @@ extern void XMemSet(void * p, int value, size_t size);
extern void XMemSet(int devID, void * p, int value, size_t size);
extern void XMemCopy(void * t, int devIDT, const void * s, int devIDS, size_t size);
extern void XMemCopy2D(void * t, size_t tPitch, int devIDT, const void * s, size_t sPitch, int devIDS, size_t mSize, int n);
extern void XMemCopy2DAsync(void * t, size_t tPitch, int devIDT, const void * s, size_t sPitch, int devIDS, size_t mSize, int n, XStream * stream);
extern void * XMemAlloc(int devID, size_t size);
extern void * XMemAllocOnDev(int devID, size_t size);
extern void XMemFree(int devID, void * p);
......
......@@ -26,49 +26,63 @@
#include "../XTensor.h"
#include "shape/Concatenate.h"
#include "shape/ConcatenateSolely.h"
#include "movement/CopyBlocks.h"
#include "movement/CopyBlocksInGrid.h"
#include "movement/CopyBlocksOnSite.h"
#include "movement/CopyData2D.h"
#include "movement/CopyIndexed.h"
#include "movement/CopyInGrid.h"
#include "movement/CopyValues.h"
#include "utilities/FlushToMem.h"
#include "shape/MakeMergeBlockIndex.h"
#include "shape/MakeSplitBlockIndex.h"
#include "arithmetic/Div.h"
#include "arithmetic/MatrixMul.h"
#include "arithmetic/MatrixMul2D.h"
#include "arithmetic/MatrixMul2DMultiTheading.h"
#include "arithmetic/MatrixMul2DParallel.h"
#include "arithmetic/MatrixMulBatched.h"
#include "arithmetic/MatrixMULBatchedCPU.h"
#include "shape/Merge.h"
#include "shape/MergeBlockLists.h"
#include "arithmetic/Multiply.h"
#include "arithmetic/Negate.h"
#include "arithmetic/Sign.h"
#include "arithmetic/Sub.h"
#include "arithmetic/Sum.h"
#include "arithmetic/SumByColumnTV.h"
#include "arithmetic/SumByColumnVT.h"
#include "arithmetic/SumDim.h"
#include "arithmetic/XTensorBLAS.h"
#include "getandset/ConvertDataType.h"
#include "getandset/Select.h"
#include "getandset/SetData.h"
#include "math/Clip.h"
#include "math/Normalize.h"
#include "shape/Permute.h"
#include "math/Power.h"
#include "math/ScaleAndShift.h"
#include "math/Unary.h"
#include "movement/CopyBlocks.h"
#include "movement/CopyBlocksInGrid.h"
#include "movement/CopyBlocksOnSite.h"
#include "movement/CopyData2D.h"
#include "movement/CopyIndexed.h"
#include "movement/CopyInGrid.h"
#include "movement/CopyValues.h"
#include "reduce/ReduceMax.h"
#include "reduce/ReduceMean.h"
#include "reduce/ReduceStandardVariance.h"
#include "reduce/ReduceSum.h"
#include "reduce/ReduceSumSquared.h"
#include "reduce/ReduceVariance.h"
#include "math/ScaleAndShift.h"
#include "getandset/Select.h"
#include "getandset/SetData.h"
#include "sort/Sort.h"
#include "shape/Concatenate.h"
#include "shape/ConcatenateSolely.h"
#include "shape/MakeMergeBlockIndex.h"
#include "shape/MakeSplitBlockIndex.h"
#include "shape/Merge.h"
#include "shape/MergeBlockLists.h"
#include "shape/Permute.h"
#include "shape/Split.h"
#include "arithmetic/Sum.h"
#include "arithmetic/SumByColumnTV.h"
#include "arithmetic/SumByColumnVT.h"
#include "sort/TopK.h"
#include "shape/Transpose.h"
#include "shape/Unsqueeze.h"
#include "sort/Sort.h"
#include "sort/TopK.h"
#include "utilities/XMatrixSegment.h"
#include "arithmetic/XTensorBLAS.h"
#include "utilities/FlushToMem.h"
#endif // __CHEADER_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "Div.h"
#include "Div.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
element-wise division of two tensors
c(i) = a(i)/b(i) + \alpha * c(i)
where i is the index of the item
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
*/
void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim)
{
int leadingDimRDI = a->order - leadingDim - 1;
CheckNTErrors((a->unitNum <= c->unitNum && b->unitNum <= c->unitNum),
"Unmatched tensors in multiplication!");
CheckNTErrors((a->order == b->order && a->order == c->order),
"Unmatched tensors!");
#ifdef USE_CUDA
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
_CudaDiv(a, b, c, alpha, leadingDim);
return;
}
#endif
int stride = 1;
int blockSizeA = 1;
int blockSizeB = 1;
int blockSizeC = 1;
int blockNum = 1;
int dimensionSizeA = a->dimSizeRDI[leadingDimRDI];
int dimensionSizeB = b->dimSizeRDI[leadingDimRDI];
int dimensionSizeC = c->dimSizeRDI[leadingDimRDI];
for (int i = 0; i < a->order; i++) {
if (i != leadingDimRDI) {
CheckNTErrors((a->dimSizeRDI[i] == b->dimSizeRDI[i] && a->dimSizeRDI[i] == c->dimSizeRDI[i]),
"Unmatched tensors!");
}
if (i < leadingDimRDI)
stride *= a->dimSizeRDI[i];
}
blockSizeA = stride * dimensionSizeA;
blockSizeB = stride * dimensionSizeB;
blockSizeC = stride * dimensionSizeC;
blockNum = a->unitNum / blockSizeA;
if (!a->isSparse && !b->isSparse) {
if (a->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE) {
if (a->unitNum == c->unitNum && b->unitNum == c->unitNum) {
int size = a->unitNum;
DTYPE * ap = (DTYPE*)a->data;
DTYPE * bp = (DTYPE*)b->data;
DTYPE * cp = (DTYPE*)c->data;
if (alpha == 0) {
for (int i = 0; i < size; i++)
cp[i] = ap[i] / bp[i];
}
else {
for (int i = 0; i < size; i++)
cp[i] = ap[i] / bp[i] + alpha * cp[i];
}
}
else {
for (int k = 0; k < blockNum; k++) {
for (int ci = 0, ai = 0, bi = 0; ci < dimensionSizeC; ci++, ai++, bi++) {
if (ai >= dimensionSizeA)
ai = 0;
if (bi >= dimensionSizeB)
bi = 0;
DTYPE * ap = (DTYPE*)a->data + k * blockSizeA + ai * stride;
DTYPE * bp = (DTYPE*)b->data + k * blockSizeB + bi * stride;
DTYPE * cp = (DTYPE*)c->data + k * blockSizeC + ci * stride;
for (int j = 0; j < stride; j++)
cp[j] = ap[j] / bp[j] + cp[j] * alpha;
}
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
/*
element-wise division of two tensors (do it on site)
keep the result in the input tensor a and return nothing
a(i) = a(i)*b(i) + \alpha * a(i)
where i is the index of the item
>> a - tensor a (where keep the result)
>> b - tensor b
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
*/
void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim)
{
_Div(a, b, a, alpha, leadingDim);
}
/*
element-wise division of two tensors (return a XTensor structure)
make a new tensor c to keep the result and return it
c(i) = a(i)*b(i)
where i is the index of the item
>> a - tensor a
>> b - tensor b
>> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors
*/
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim)
{
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XTensor c(&a);
c.SetTMP();
/* call _Multiply function */
_Div(&a, &b, &c, 0, leadingDim);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHeadInt(&c, leadingDim);
return c;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "Div.h"
#include "Div.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
division of data arrays in a element-wise manner c(i) = a(i)/b(i)
>> a - data array a
>> b - data array b
>> c - result data array
>> size - size of c
*/
__global__
void KernelDivElementWise(DTYPE * a, DTYPE * b, DTYPE * c, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
c[i] = a[i] / b[i];
}
/*
division of data arrays in a element-wise manner c(i) = a(i)/b(i) + \alpha*c(i)
>> a - data array a
>> b - data array b
>> c - result data array
>> size - size of c
>> alpha - the coefficient
*/
__global__
void KernelDivElementWiseV2(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE alpha)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
c[i] = a[i] / b[i] + alpha * c[i];
}
/*
division of two tensors in a element-wise manner c(i) = a(i)/b(i).
Note that a and b can be of different sizes here, i.e.,
|a_lead| <= |c_lead| and |b_lead| <= |c_lead|
where |a_lead| means the size of the leading dimension of a
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> stride - the number of items we go over when move next along the leading dimension in a block
>> ldSizeA - size of the leading dimension of a
>> ldSizeB - size of the leading dimension of b
>> ldSizeC - size of the leading dimension of c
>> blockNum - number of blocks
*/
template<int nonZeroAlpha> __global__
void KernelDivElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE alpha,
int stride, int ldSizeA, int ldSizeB, int ldSizeC, int blockNum)
{
__shared__ DTYPE* ap[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE* bp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE* cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int i = blockDim.x * blockIdx.x + threadIdx.x;
int j = blockDim.y * blockIdx.y + threadIdx.y;
if (i >= blockNum * stride || j >= ldSizeC)
return;
if (threadIdx.y == 0) {
int block = i / stride;
int size = block * stride;
ap[threadIdx.x] = a + size * ldSizeA;
bp[threadIdx.x] = b + size * ldSizeB;
cp[threadIdx.x] = c + size * ldSizeC;
}
__syncthreads();
int aj = j >= ldSizeA ? j % ldSizeA : j;
int bj = j >= ldSizeB ? j % ldSizeB : j;
int offseti = i % stride;
if (nonZeroAlpha == 0)
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] / bp[threadIdx.x][bj * ldSizeB + offseti];
else
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] / bp[threadIdx.x][bj * ldSizeB + offseti]
+ alpha * cp[threadIdx.x][j * ldSizeC + offseti];
}
/*
element-wise division of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the item index
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> leadingDim - dimension along which we perform broadcasting
*/
void _CudaDiv(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim)
{
int leadingDimRDI = a->order - leadingDim - 1;
CheckNTErrors((a->unitNum <= c->unitNum && b->unitNum <= c->unitNum),
"Unmatched tensors in multiplication!");
CheckNTErrors((a->order == b->order && a->order == c->order), "Unmatched tensors!");
int stride = 1;
int blockSizeA = 1;
int blockNum = 1;
int dimensionSizeA = a->dimSizeRDI[leadingDimRDI];
int dimensionSizeB = b->dimSizeRDI[leadingDimRDI];
int dimensionSizeC = c->dimSizeRDI[leadingDimRDI];
for (int i = 0; i < a->order; i++) {
if (i != leadingDimRDI) {
CheckNTErrors((a->dimSizeRDI[i] == b->dimSizeRDI[i] &&
a->dimSizeRDI[i] == c->dimSizeRDI[i]),
"Unmatched tensors!");
}
if (i < leadingDimRDI)
stride *= a->dimSizeRDI[i];
}
blockSizeA = stride * dimensionSizeA;
blockNum = a->unitNum / blockSizeA;
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
if (!a->isSparse && !b->isSparse) {
if (a->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE) {
int cudaGridSize[3];
int cudaBlockSize[3];
if (a->unitNum == c->unitNum && b->unitNum == c->unitNum) {
GDevs.GetCudaThread(a->devID, c->unitNum, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[0]), threads(cudaBlockSize[0]);
if (alpha == 0)
KernelDivElementWise << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, c->unitNum);
else
KernelDivElementWiseV2 << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, c->unitNum, alpha);
}
else {
GDevs.GetCudaThread2D(c->devID, stride * blockNum, dimensionSizeC, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[0], cudaGridSize[1]), threads(cudaBlockSize[0], cudaBlockSize[1]);
if (alpha == 0) {
KernelDivElementWiseTensorDynamic<0> << <blocks, threads >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, 0,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
}
else {
KernelDivElementWiseTensorDynamic<1> << <blocks, threads >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, alpha,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#ifndef __DIV_CUH__
#define __DIV_CUH__
#include "Div.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* division of two tensors in a element-wise manner c(i) = a(i)/b(i) */
__global__
void KernelDivElementWise(DTYPE * a, DTYPE * b, DTYPE * c, int size);
/* division of two tensors in a element-wise manner c(i) = a(i)/b(i) + \alpha*c(i) */
__global__
void KernelDivElementWiseV2(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE alpha);
/* division of two tensors in a element-wise manner c(i) = a(i)/b(i)+ \alpha*c(i) */
template<int nonZeroAlpha>__global__
void KernelDivElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE alpha, int stride, int ldSizeA, int ldSizeB, int ldSizeC, int blockNum);
/* element-wise division of two tensors */
void _CudaDiv(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __DIV_CUH__
......@@ -16,31 +16,39 @@
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#ifndef __LOG_H__
#define __LOG_H__
#ifndef __DIV_H__
#define __DIV_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its log value */
void _Log(const XTensor * a, XTensor * b);
/*
element-wise division of two tensors:
c(i) = a(i)/b(i) + \alpha * c(i)
where i is the index of the element
*/
void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
/*
set every entry to its log value (do it on site)
element-wise division of two tensors (do it on site)
keep the result in the input tensor a and return nothing
a(i) = a(i)/b(i) + \alpha * a(i)
where i is the index of the element
*/
void _LogMe(XTensor * a);
void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha = 0, int leadingDim = 0);
/*
set every entry to its log value (return a XTensor structure)
element-wise division of two tensors (return a XTensor structure)
make a new tensor to keep the result and return it
c(i) = a(i)/b(i)
where i is the index of the element
*/
XTensor Log(const XTensor & a);
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim = 0);
} // namespace nts(NiuTrans.Tensor)
#endif // __LOG_H__
#endif // __DIV_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XTensor.h"
#include "MatrixMULBatchedCPU.h"
#include "MatrixMul2D.h"
#include "XTensorBLAS.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
matrix multiplication in batch mode (BLAS)
c_i = trans(a_i) * trans(b_i) * \alpha + c_i * \beta for each i in [0,count-1]
>> a - list of input matrices (2d tensors)
>> transposedA - indicate whether the matrix a is transposed
>> b - another list of input matrices (2d tensors)
>> transposedB - indicate whether the matrix b is transposed
>> c - output matrix (2d tensor)
>> alpha - scalar
>> beta - scalar
*/
void _MatrixMULBatchedCPU(const XList * a, MATRIX_TRANS_TYPE transposedA,
const XList * b, MATRIX_TRANS_TYPE transposedB,
XList * c, DTYPE alpha, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty input lists!");
CheckNTErrors(a->count == b->count && a->count == c->count, "Input lists must be of the same size!");
if (a->count == 0)
return;
bool isUniform = true;
for (int i = 1; i < a->count; i++) {
XTensor * aim = (XTensor*)a->GetItem(i - 1);
XTensor * bim = (XTensor*)b->GetItem(i - 1);
XTensor * cim = (XTensor*)c->GetItem(i - 1);
XTensor * ai = (XTensor*)a->GetItem(i);
XTensor * bi = (XTensor*)b->GetItem(i);
XTensor * ci = (XTensor*)c->GetItem(i);
if (!XTensor::IsSameShaped(aim, ai) ||
!XTensor::IsSameShaped(bim, bi) ||
!XTensor::IsSameShaped(cim, ci))
{
isUniform = false;
break;
}
}
for (int i = 0; i < a->count; i++) {
XTensor * ai = (XTensor*)a->GetItem(i);
XTensor * bi = (XTensor*)b->GetItem(i);
XTensor * ci = (XTensor*)c->GetItem(i);
CheckNTErrors((ai->order == 2), "2d tensor (i.e., matrix) is required!");
CheckNTErrors((bi->order == 2), "2d tensor (i.e., matrix) is required!");
CheckNTErrors((ci->order == 2), "2d tensor (i.e., matrix) is required!");
#ifdef USE_BLAS
if (useBLAS)
_MatrixMULCPU(ai, transposedA, bi, transposedB, ci, alpha, beta);
else
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#else
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#endif
}
//}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -24,8 +24,8 @@
#include "../../XName.h"
#include "MatrixMul.h"
#include "MatrixMul2D.h"
#include "MatrixMULBatchedCPU.h"
#include "XTensorBLAS.h"
#include "MatrixMulBatched.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -53,11 +53,29 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
CheckNTErrors(a && b && c, "Empty input tensors!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Input tensors should have the same data type!");
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2),
CheckNTErrors(a->order >= 2 && b->order >= 2 && c->order >= 2,
"Input tensors must have a order >= 2!");
CheckNTErrors(c->order == a->order + b->order - 2, "wrong tensor order")
/* we transform a higher order tensor to a matrix to kill the number
of calls of matrix multiplication */
if(transposedA == X_NOTRANS && a->order > 2 && b->order == 2){
int ncolA = a->dimSize[a->order - 1];
int ncolC = c->dimSize[c->order - 1];
XTensor * a2 = NewTensor2D(a->unitNum/ncolA, -ncolA, a->dataType, a->devID, a->mem);
XTensor * c2 = NewTensor2D(c->unitNum/ncolC, -ncolC, c->dataType, c->devID, c->mem);
a2->data = a->data;
c2->data = c->data;
_MatrixMul2D(a2, transposedA, b, transposedB, c2, alpha, beta, parallelRunner);
a2->data = NULL;
c2->data = NULL;
delete a2;
delete c2;
return;
}
int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1];
int am = transposedA == X_TRANS ? a->dimSizeRDI[1] : a->dimSizeRDI[0];
......@@ -144,10 +162,10 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
cublasHandle_t * handle = a->mem != NULL ? a->mem->GetCublasHandle() : GDevs.GetCudaHandle(a->devID);
_CudaBLASMatrixMULList(handle,
aList, transposedA,
bList, transposedB,
cList, aList->count,
alpha, beta);
aList, transposedA,
bList, transposedB,
cList, aList->count,
alpha, beta);
BacktoCudaDev(a->devID, devIDBackup);
#else
......@@ -156,9 +174,9 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
}
else {
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
_MatrixMULBatchedCPU(aList, transposedA,
bList, transposedB,
cList, alpha, beta);
_MatrixMulBatchedCPU(aList, transposedA,
bList, transposedB,
cList, alpha, beta);
}
for (int i = 0; i < aList->count; i++) {
......@@ -251,9 +269,7 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
/*
matrix multiplication with no transposition c = a * b * alpha
>> a - tensor a
>> transposedA - indicates whether the matrices in a are transposed
>> b - tensor b
>> transposedB - indicates whether teh matrices in b are transposed
>> alpha - a coefficient
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
......
......@@ -26,6 +26,8 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
#define BMMul MatrixMulBatched
/*
matrix multiplication of the two tensors c = trans(a) * trans(b) * alpha + c * beta
......@@ -37,6 +39,28 @@ where trans() returns the transposed matrix if the flag is fired
void _MatrixMulBatched(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0, XPRunner * parallelRunner = NULL);
/*
matrix multiplication of the two tensors c = trans(a) * trans(b) * alpha + c * beta
optimized for GPU
*/
void _MatrixMulBatchedGPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0);
/*
matrix multiplication of the two tensors c = trans(a) * trans(b) * alpha + c * beta
optimized for GPU
*/
void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0);
/*
matrix multiplication of the two tensors c = trans(a) * trans(b) * alpha + c * beta (for list inputs)
optimized for GPU
*/
void _MatrixMulBatchedCPU(const XList * a, MATRIX_TRANS_TYPE transposedA, const XList * b, MATRIX_TRANS_TYPE transposedB,
XList * c, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0);
/*
matrix multiplication of the two tensors (return a XTensor structure) c = trans(a) * trans(b) * alpha
make a new tensor to keep the result and return it
......@@ -49,6 +73,17 @@ where trans() returns the transposed matrix if the flag is fired
XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
/*
matrix multiplication of the two tensors (return a XTensor structure) c = a * b * alpha
make a new tensor to keep the result and return it
for each 2-dimensional data array in a (denoted as ai) and
each 2-dimensional data array in b (denoted as bi), we have
ci = ai * bi * alpha + cm * beta
*/
XTensor MatrixMulBatched(const XTensor &a, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor)
#endif // __MATRIXMULBATCHED_H__
\ No newline at end of file
......@@ -32,9 +32,9 @@ element-wise product of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the item
>> a - matrix a
>> b - matrix b
>> c - result matrix
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
*/
......
......@@ -104,9 +104,9 @@ void KernelMulElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE al
int offseti = i % stride;
if (nonZeroAlpha == 0)
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj* ldSizeA + offseti] * bp[threadIdx.x][bj* ldSizeB + offseti];
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] * bp[threadIdx.x][bj * ldSizeB + offseti];
else
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj* ldSizeA + offseti] * bp[threadIdx.x][bj* ldSizeB + offseti] +
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] * bp[threadIdx.x][bj * ldSizeB + offseti] +
alpha * cp[threadIdx.x][j * ldSizeC + offseti];
}
......
......@@ -76,7 +76,7 @@ XTensor Sign(const XTensor & a)
XTensor b(&a);
b.SetTMP();
/* call _ScaleAndShift function */
/* call _Sign function */
_Sign(&a, &b);
/* tensor connections */
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "Sub.h"
#include "Sub.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor subtraction c = a - b * \beta
>> a - a tensor
>> b - another tensor
>> c - where we put a-b*\beta. we save it in a if c is NULL
>> beta - the scaling factor
*/
void _Sub(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == b->unitNum && a->unitNum == c->unitNum,
"Unmatched tensors in addition!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched tensors in addition!");
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
if (a == c) {
int P2PAccesible = 0;
#ifdef CUDA_UVA
cudaDeviceCanAccessPeer(&P2PAccesible, a->devID, b->devID);
#endif
if ((a->devID < 0 && b->devID >= 0) ||
(a->devID >= 0 && b->devID < 0) ||
(a->devID >= 0 && b->devID >= 0 && a->devID != b->devID && !P2PAccesible))
{
ShowNTErrors("Cannot run this method on multiple devices simultaneously!");
}
else
_CudaSub(a, b, c, beta);
}
else
_CudaSub(a, b, c, beta);
#endif
}
else {
if (!a->isSparse && !b->isSparse) {
CheckNTErrors(!c->isSparse, "Illegal use of sparse tensor in addition!");
if (a->dataType == DEFAULT_DTYPE &&
b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE)
{
DTYPE * ap = (DTYPE*)a->data;
DTYPE * bp = (DTYPE*)b->data;
DTYPE * cp = (DTYPE*)c->data;
/* unrolling */
int num = a->unitNum;
if (num % 4 == 0) {
for (int i = 0; i < num; i += 4) {
cp[i] = ap[i] - bp[i] * beta;
cp[i + 1] = ap[i + 1] - bp[i + 1] * beta;
cp[i + 2] = ap[i + 2] - bp[i + 2] * beta;
cp[i + 3] = ap[i + 3] - bp[i + 3] * beta;
}
}
else if (num % 2 == 0) {
for (int i = 0; i < num; i += 2) {
cp[i] = ap[i] - bp[i] * beta;
cp[i + 1] = ap[i + 1] - bp[i + 1] * beta;
}
}
else {
for (int i = 0; i < num; i++) {
cp[i] = ap[i] - bp[i] * beta;
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
}
/*
tensor subtraction a = a - b * \beta (do it on site)
keep the result in the tensor a and return nothing
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
*/
void _SubMe(XTensor * a, const XTensor * b, DTYPE beta)
{
_Sub(a, b, a, beta);
}
/*
tensor subtraction c = a - b * \beta (return a XTensor structure)
make a new tensor c to keep the result and return it
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
<< return - the result of tensor subtraction
*/
XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
/* call _Sub function */
_Sub(&a, &b, &c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
return c;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#include "../../XDevice.h"
#include "../../XUtility.h"
#include "Sub.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
subtraction of data arrays (CUDA Kernel)
c = a - b * \beta
>> a - A matrix
>> b - another matrix
>> c - where we put a-b
>> size - the size of a/b/c
>> beta - the coefficient
*/
__global__
void KernelSUB(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
c[i] = a[i] - b[i] * beta;
}
/*
tensor subtraction c = a - b * \beta (cuda version)
>> a - a tensor
>> b - another tensor
>> c - where we put a-b*\beta.
>> beta - the scaling factor
*/
void _CudaSub(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors((a->unitNum == b->unitNum && a->unitNum == c->unitNum),
"Unmatched tensors in addition!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
"Unmatched tensors in addition!");
CheckNTErrors((a->devID == b->devID && a->devID == c->devID),
"The tensors must be on the same!");
int devIDBackup = XDevice::GetGPUDevice();
XDevice::SetGPUDevice(a->devID);
if (!a->isSparse && !b->isSparse) {
CheckNTErrors(!c->isSparse, "Illegal use of sparse matrix in addition!");
if (a->dataType == DEFAULT_DTYPE &&
b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE)
{
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
KernelSUB << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, a->unitNum, beta);
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
XDevice::SetGPUDevice(devIDBackup);
}
/* subtraction over arrays
tensor subtraction c = a - b * \beta (cuda version) with an input handle
>> devID - device ID (MUST >= 0)
>> handle - cuda handle
>> a - an array
>> b - another array
>> c - where we put a-b
>> size - size of the array
>> beta - the coefficient
*/
void _CudaSubWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta)
{
if (size == 0)
return;
if (c == NULL)
c = a;
CheckNTErrors((a && b && c), "Empty arrays in addition!");
int devIDBackup;
ProtectCudaDev(devID, devIDBackup);
if (c == a) {
#ifdef DOUBELPRICSION
cublasDaxpy(*handle, size, &beta, b, 1, a, 1);
#else
cublasSaxpy(*handle, size, &beta, b, 1, a, 1);
#endif
}
else {
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(devID, size, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
KernelSUB<<<blocks, threads>>>((DTYPE*)a, (DTYPE*)b, (DTYPE*)c, size, beta);
}
BacktoCudaDev(devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#ifndef __SUB_CUH__
#define __SUB_CUH__
#include "Sub.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* subtraction of data arrays (CUDA Kernel) */
__global__
void KernelSUB(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1.0);
/* tensor subtraction c = a - b * \beta (cuda version) */
void _CudaSub(const XTensor * a, const XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
/* tensor subtraction c = a - b * \beta (cuda version) with an input handle */
void _CudaSubWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1.0);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __SUB_CUH__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
* Today is the first day of August. It's still very hot.
*/
#ifndef __ABSOLUTE_H__
#define __ABSOLUTE_H__
#ifndef __SUB_H__
#define __SUB_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its absolute value */
void _Absolute(const XTensor * a, XTensor * b);
/* tensor subtraction c = a - b * \beta */
void _Sub(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
/*
set every entry to its absolute value (do it on site)
/*
tensor subtraction a = a - b * \beta
keep the result in the input tensor a and return nothing
*/
void _AbsoluteMe(XTensor * a);
/*
set every entry to its absolute value (return a XTensor structure)
make a new tensor to keep the result and return it
void _SubMe(XTensor * a, const XTensor * b, DTYPE beta = (DTYPE)1.0);
/*
tensor subtraction c = a - b * \beta
make a new tensor c to keep the result and return it
*/
XTensor Absolute(const XTensor & a);
XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __ABSOLUTE_H__
#endif // __SUB_H__
......@@ -22,8 +22,10 @@
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "../movement/CopyValues.h"
#include "Sum.h"
#include "Sum.cuh"
#include "SumDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -43,8 +45,12 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched tensors in addition!");
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
if(beta == 0){
_CopyValues(a, c);
return;
}
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
if (a == c) {
int P2PAccesible = 0;
......@@ -67,7 +73,7 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
}
else {
if (!a->isSparse && !b->isSparse) {
CheckNTErrors(!c->isSparse, "Illegal use of sparse matrix in addition!");
CheckNTErrors(!c->isSparse, "Illegal use of sparse tensor in addition!");
if (a->dataType == DEFAULT_DTYPE &&
b->dataType == DEFAULT_DTYPE &&
......@@ -123,6 +129,33 @@ void _SumMe(XTensor * a, const XTensor * b, DTYPE beta)
{
_Sum(a, b, a, beta);
}
/*
return a dimension if the sum is performed as SumDim (in more details in SumDim.h
>> a - a tensor
>> b - another tensor for sum
*/
int GetSumDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/*
tensor summation c = a + b * \beta (return a XTensor structure)
......@@ -137,13 +170,29 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
int n = GetSumDimIndex(a, b);
if(n == -1){
/* call _Sum function */
_Sum(&a, &b, &c, beta);
/* call _Sum function */
_Sum(&a, &b, &c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta);
}
else if(n >= 0 && n < a.order){
/* call _Sum function */
_SumDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
else{
ShowNTErrors("Something is wrong!");
}
return c;
}
......
......@@ -20,6 +20,7 @@
*/
#include "../../XDevice.h"
#include "../../XUtility.h"
#include "Sum.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-29
*/
#include "Sum.h"
#include "SumDim.h"
#include "SumDim.cuh"
#include "../../XName.h"
#include "../movement/CopyValues.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor summation
c = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a+b*\beta. we save it in a if c is NULL
>> n - the dimension index
>> beta - the scaling factor
*/
void _SumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in addition!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in addition!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in addition!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
if(beta == 0){
_CopyValues(a, c);
return;
}
if(XTensor::IsSameShaped(a, b)){
_Sum(a, b, c, beta);
return;
}
if(a->devID >= 0 || b->devID >= 0 || c->devID >= 0){
#ifdef USE_CUDA
_CudaSumDim(a, b, c, n, beta);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
}
else{
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for(int i = a->order - 1; i >= 0; i--){
if(i > n)
stride *= a->dimSize[i];
else if(i < n)
blockNum *= a->dimSize[i];
}
if (a->dataType == DEFAULT_DTYPE){
int num = a->unitNum;
if(stride > 1){
for(int i = 0, j = 0; i < num; i += stride, j++){
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE bv = *((DTYPE*)b->data + j % blockSize) * beta;
DTYPE * cp = (DTYPE*)c->data + i;
for(int k = 0; k < stride; k++)
cp[k] = ap[k] + bv;
}
}
else if(stride == 1){
DTYPE * bp = (DTYPE*)b->data;
for(int i = 0; i < num; i += blockSize){
DTYPE * ap = (DTYPE*)a->data + i;
DTYPE * cp = (DTYPE*)c->data + i;
if(beta == 1.0F){
for(int j = 0; j < blockSize; j++)
cp[j] = ap[j] + bp[j];
}
else{
for(int j = 0; j < blockSize; j++)
cp[j] = ap[j] + bp[j] * beta;
}
}
}
else{
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
}
}
/*
tensor summation (do it on site)
keep the result in the input tensor and return nothing
a = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> beta - the scaling factor
*/
void _SumDim(XTensor * a, const XTensor * b, int n, DTYPE beta)
{
_SumDim(a, b, a, n, beta);
}
/*
tensor summation (return a XTensor structure and make tensor connections)
make a new tensor to keep the result and return it
c = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> n - the dimension index
>> beta - the scaling factor
<< return - the result tensor by tensor summation
*/
XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
/* call _Sum function */
_SumDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
return c;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-29
*/
#include "SumDim.cuh"
#include "../../XDevice.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
tensor summation of a tensor and a row vector
c = a + b * \beta
where a is a tensor and b is a row vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c
>> colNum - number of columns of a and c (i.e., the size of b)
>> beta - the scaling factor
*/
template <class T, bool betaFired>
__global__
void KernelAddWithRow(T * a, T * b, T * c, int rowNum, int colNum, T beta)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int col = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
if(col >= colNum || row >= rowNum)
return;
if(threadIdx.y == 0)
bv[threadIdx.x] = b[col];
__syncthreads();
int offset = colNum * row + col;
if(betaFired)
c[offset] = a[offset] + bv[threadIdx.x] * beta;
else
c[offset] = a[offset] + bv[threadIdx.x];
}
/*
tensor summation of a tensor and a colum vector
c = a + b * \beta
where a is a tensor and b is a colum vector
>> a - pointer to the data array of a
>> b - pointer to the data array of b
>> c - pointer to the data array of c
>> rowNum - number of rows of a and c (i.e., the size of b)
>> colNum - number of columns of a and c
>> blockNum - size of a block (matrix), i.e., rowNum * colNum
>> blockNum - number of matrics
>> beta - the scaling factor
*/
template <class T, bool betaFired>
__global__
void KernelAddWithCol(T * a, T * b, T * c, int rowNum, int colNum, int blockSize, int blockNum, T beta)
{
__shared__ T bv[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int colIndex = blockDim.x * blockIdx.x + threadIdx.x;
int row = blockDim.y * blockIdx.y + threadIdx.y;
int col = colIndex % colNum;
int block = colIndex / colNum;
if(row >= rowNum || block >= blockNum)
return;
if(threadIdx.x == 0)
bv[threadIdx.y] = b[row];
__syncthreads();
int offset = block * blockSize + row * colNum + col;
if(betaFired)
c[offset] = a[offset] + bv[threadIdx.y] * beta;
else
c[offset] = a[offset] + bv[threadIdx.y];
}
/*
tensor summation (cuda version)
c = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a+b*\beta. we save it in a if c is NULL
>> n - the dimension index
>> beta - the scaling factor
*/
void _CudaSumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == c->unitNum, "Unmatched tensors in addition!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched data types in addition!");
CheckNTErrors(a->order == c->order, "The input tensors do not have the same order in addition!");
CheckNTErrors(!a->isSparse && !b->isSparse && !c->isSparse, "Dense tensors are required!");
CheckNTErrors(a->dimSize[n] == b->unitNum, "Wrong tensor size!");
int stride = 1;
int blockSize = a->dimSize[n];
int blockNum = 1;
for(int i = a->order - 1; i >= 0; i--){
if(i > n)
stride *= a->dimSize[i];
else if(i < n)
blockNum *= a->dimSize[i];
}
int cudaGrids[3];
int cudaBlocks[3];
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE){
if(stride > 1){
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
if(beta == (DTYPE)1.0F)
KernelAddWithCol<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, beta);
else
KernelAddWithCol<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockSize, stride, blockSize * stride, blockNum, beta);
}
else if(stride == 1){
GDevs.GetCudaThread2D(a->devID, blockSize, blockNum, MAX_INT, cudaGrids, cudaBlocks);
if(beta == (DTYPE)1.0F)
KernelAddWithRow<DTYPE, false> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, beta);
else
KernelAddWithRow<DTYPE, true> <<<dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1])>>>
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data,
blockNum, blockSize, beta);
}
else{
ShowNTErrors("Something is wrong!");
}
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif
} // namespace nts(NiuTrans.Tensor)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论