Commit 4388c982 by xuchen

Merge branch 'xiaotong-working'

parents dd6646ed 9d33e210
NiuTrans.Tensor.vcxproj
NiuTrans.Tensor.vcxproj.filters
x64/
vc140.pdb
NiuTrans.Tensor.vcxproj.user
NiuTrans.Tensor.aps
NiuTrans.Tensor张量计算库
\ No newline at end of file
# NiuTrans.Tensor张量计算库
## NiuTrans.Tensor
NiuTrans.Tensor是小牛开源项目所开发的一个工具包,提供了完整的张量定义及计算功能,可以被用于深度学习相关研究及工业系统的开发。NiuTrans.Tensor具有以下特点:
* 简单小巧,易于修改
* c语言编写,代码高度优化
* 同时支持CPU和GPU设备
* 丰富的张量计算接口
* 支持C/C++、Python等调用方式
## 安装方法
在开始创建您的项目并使用NiuTrans.Tensor工具包时,需要注意的是:
* 所创建项目如在CPU上运行,我们的系统支持高性能的数学运算库,推荐安装[MKL](https://software.intel.com/en-us/mkl)[OpenBLAS](http://www.openblas.net/)
* 所创建项目如需在GPU上运行,需安装 [CUDA](https://developer.nvidia.com/cuda-downloads),CUDA版本需求为9.0及以上,CUDA工具为创建高性能GPU加速应用程序提供了开发环境。
小牛开源项目所开发的NiuTrans.Tensor工具包采用源程序编译方法,在Windows和Linux环境下的安装方法如下所示。
### Windows
若在Windows上使用NiuTrans.Tensor工具包:
* 首先需要将NiuTrans.Tensor代码包含在所创建的项目中
* 在所创建项目中需要引用XTensor.h、core里的CHeader.h和function里的FHeader.h这三个头文件:
* 通过XTensor.h可以获取我们需要操作的XTensor类
* 通过core里的CHeader.h可以对Tensor进行一些张量运算
* 通过function里的FHeader.h可以调用一些激活函数
* 在所创建项目中使用命名空间nts
此外,一些必须的环境配置方法请参考 [NiuTrans.Tensor环境配置](http://47.105.50.196/NiuTrans/NiuTrans.Tensor/blob/linye/doc/Configuration.md)
### Linux
若在Linux上使用NiuTrans.Tensor工具包,直接执行make.sh即可在同级目录下生成tensorCPU和tensorGPU,分别对应于NiuTrans.Tensor的CPU以及GPU的可执行文件。以前馈神经网络语言模型为例,输入以下命令即可在GPU上执行提供的测试用例:
>./tensorGPU -test
更多详细使用方法请见[NiuTrans.Tensor开发文档](http://47.104.97.237/niutrans/site/niutensor/index.html)
## 开发团队
NiuTrans.Tensor张量计算库由东北大学自然语言处理实验室、小牛翻译、小牛雅智合作开发,致力于为深度学习相关研究及工业系统的开发提供完整的张量定义及计算功能。
## 更新版本
NiuTrans.Tensor version 0.1.0 - 2018年8月3日
\ No newline at end of file
# NiuTrans.Tensor环境配置
## 注意事项
CUDA最新版本9.2尚且不支持VS2017最新版本,因此建议使用CUDA版本为9.0或9.1,建议使用VS版本为VS2015,或使用VS2017时安装v140工具集。
## CUDA配置
在已安装好VS、CUDA并配置好环境变量后,一些关键的CUDA配置选项如下所示,以下配置选项在 **项目 -> 属性** 中可以找到。
>$(CUDA_PATH)\include
加入到 **VC++目录 -> 包含** 中。
>$(CUDA_PATH)\lib\Win32
加入到 **VC++目录 -> 库** 中。
>cuda.lib;cudadevrt.lib;cudart.lib;cudart_static.lib;nvcuvid.lib;OpenCL.lib;cublas.lib;curand.lib;
加入到 **链接器->输入->附加依赖项** 中。
配置完成后,右键 **工程->项目依赖性** ,选择CUDA9。
在.cu文件上右键属性,在项类型中选择"CUDA C/C++"(最好搜索.cu文件,然后全选设置)。
## 其他配置
**C/C++->常规->SDL检查**,设为否。
**C/C++->预处理器->预处理器定义** 中,添加
>USE_CUDA;USE_BLAS;WIN32;MKL;DEBUG;CRT_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS_
CONSOLE;
**链接器->系统->子系统**,设置为控制台。
**常规->字符集**,使用Unicode字符集。
**调试->命令参数**中设置可执行文件所需要的参数。
......@@ -21,25 +21,31 @@
#include <stdio.h>
#include "XNet.h"
#include "../tensor/XUtility.h"
#include "../tensor/function/FHeader.h"
#include "../tensor/core/CHeader.h"
#include "../sample/fnnlm/FNNLM.h"
#include "../sample/transformer/Transformer.h"
//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>
//#include <crtdbg.h>
using namespace nts;
using namespace samplefnnlm;
void TransposeTest();
void SumDimTest();
using namespace nts;
using namespace fnnlm;
using namespace transformer;
int main( int argc, const char ** argv )
{
if(argc > 1 && !strcmp(argv[1], "-test"))
1;//Test();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
//_CrtSetBreakAlloc(896);
if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
FNNLMMain(argc - 1, argv + 1);
else if(argc > 1 && !strcmp(argv[1], "-t2t"))
TransformerMain(argc - 1, argv + 1);
else{
fprintf(stderr, "Thanks for using NiuTrans.Network! This is a library for building\n");
fprintf(stderr, "neural networks in an easy way. \n\n");
......@@ -47,36 +53,120 @@ int main( int argc, const char ** argv )
fprintf(stderr, "Or run this program with \"-fnnlm\" for sample FNNLM!\n");
}
XNet net;
XTensor a;
XTensor b;
XTensor c;
//_CrtDumpMemoryLeaks();
return 0;
}
InitTensor2D(&a, 2, 2);
InitTensor2D(&b, 2, 4);
InitTensor2D(&c, 2, 4);
void TransposeTest()
{
#ifdef USE_CUDA
XMem mem0(0, UNI_FREE, MILLION * 64, 1024, MILLION * 64);
//XMem mem1(1, UNI_FREE, MILLION * 64, 1024, MILLION * 64);
XTensor x;
XTensor y;
XTensor z;
a.SetZeroAll();
b.SetZeroAll();
c.SetZeroAll();
int loops = 2000;
SetDataFixed(a, 0.1F);
a.Set2D(0.3F, 1, 0);
a.Set2D(0.4F, 1, 1);
int B = 3 * 2 * 4;
int K = 8 * 1;
int N = 50;
int H = 512 * 4;
b = Merge(a, a, 1);
c = HTanH(MMul(a, b));
int nnn = GDevs.nGPU;
a.Dump(stderr, "a:");
b.Dump(stderr, "b:");
c.Dump(stderr, "c:");
XLink::ShowNetwork(stderr, &c);
InitTensor3D(&x, B, N, H, X_FLOAT, 0);
InitTensor4D(&y, K, B, N, H/K, X_FLOAT, 0);
InitTensor3D(&z, B, N, H, X_FLOAT, 0);
net.Backward(c);
cudaEvent_t ctime0;
cudaEvent_t ctime1;
cudaEvent_t ctime2;
cudaEvent_t ctime3;
cudaEvent_t ctime4;
cudaEvent_t ctime5;
net.Dump(stderr);
//_CrtDumpMemoryLeaks();
float elapsedSplit = 0.0;
float elapsedMerge = 0.0;
float elapsedSum = 0.0;
cudaEventCreate(&ctime0);
cudaEventCreate(&ctime1);
cudaEventCreate(&ctime2);
cudaEventCreate(&ctime3);
cudaEventCreate(&ctime4);
cudaEventCreate(&ctime5);
cudaEventRecord(ctime0, 0);
double time0 = GetClock();
for(int i = 0; i < loops; i++)
_Split(&x, &y, 2, K);
double time1 = GetClock();
return 0;
cudaEventRecord(ctime1, 0);
cudaEventSynchronize(ctime1);
cudaEventElapsedTime(&elapsedSplit, ctime0, ctime1);
cudaEventRecord(ctime2, 0);
double time2 = GetClock();
for(int i = 0; i < loops; i++)
_Merge(&y, &x, 3);
double time3 = GetClock();
cudaEventRecord(ctime3, 0);
cudaEventSynchronize(ctime3);
cudaEventElapsedTime(&elapsedMerge, ctime2, ctime3);
cudaEventRecord(ctime4, 0);
double time4 = GetClock();
for(int i = 0; i < loops; i++)
_Sum(&x, &z, &x);
double time5 = GetClock();
cudaEventRecord(ctime5, 0);
cudaEventSynchronize(ctime5);
cudaEventElapsedTime(&elapsedSum, ctime4, ctime5);
fprintf(stderr, "split:%f merge:%f sum:%f\n", time1 - time0, time3 - time2, time5 - time4);
fprintf(stderr, "split:%f merge:%f sum:%f\n", elapsedSplit, elapsedMerge, elapsedSum);
#endif
}
void SumDimTest()
{
XTensor x;
XTensor y;
XTensor z;
int a = 5;
int b = 7;
int c = 3;
InitTensor3D(&x, a, b, c, X_FLOAT, -1);
InitTensor1D(&y, c, X_FLOAT, -1);
InitTensor3D(&z, a, b, c, X_FLOAT, -1);
x.SetZeroAll();
y.SetZeroAll();
z.SetZeroAll();
float * data = new float[x.unitNum];
for(int i = 0; i < x.unitNum; i++)
data[i] = (DTYPE)i;
x.SetData(data, x.unitNum);
for(int i = 0; i < y.unitNum; i++)
data[i] = -(DTYPE)i;
y.SetData(data, y.unitNum);
_SumDim(&x, &y, &z, 2);
z.Dump(stderr, "z:");
delete[] data;
}
......@@ -63,6 +63,8 @@ void XFuncGrad::MakeGrad(XTensor * node)
else{
ShowNTErrors("Wrong activation function type!");
}
node->visitMark = NODE_FINISHED;
}
/* indicates whether the node is for an activation function */
......
......@@ -44,15 +44,107 @@ private:
static
void GradSum(XTensor * node);
/* gradient for multiply (dot production): c = a * b */
/* gradient for sum with one dimension: c = a + b * \beta
where the size of b is equal to that of one dimension of a */
static
void GradSumDim(XTensor * node);
/* gradient for multiply (dot production): c = a * b * \alpha */
static
void GradMultiply(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) */
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static
void GradMatrixMul(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static
void GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE transA,
XTensor * b, XTensor * dedb, MATRIX_TRANS_TYPE transB,
XTensor * dedc, DTYPE alpha);
/* gradient for matrix multiply in batch mode.
for each batch: c_i = matmul(a_i, b_i) * \alpha */
static
void GradMatrixMulBatched(XTensor * node);
/* gradient for log: c = log(a) */
static
void GradLog(XTensor * node);
/* gradient for power */
static
void GradPower(XTensor * node);
/* gradient for negate */
static
void GradNegate(XTensor * node);
/* gradient for ScaleAndShift */
static
void GradScaleAndShift(XTensor * node);
/* gradient for Minus */
static
void GradSub(XTensor * node);
/* gradient for Divide */
static
void GradDiv(XTensor * node);
/* gradient for reduceMean */
static
void GradReduceMean(XTensor * node);
/* gradient for reduceSum */
static
void GradReduceSum(XTensor * node);
/* gradient for reduceSumSquared */
static
void GradReduceSumSquared(XTensor * node);
/* gradient for reduceVariance */
static
void GradReduceVariance(XTensor * node);
/* gradient for sin */
static
void GradSin(XTensor * node);
/* gradient for cos */
static
void GradCos(XTensor * node);
/* gradient for tan */
static
void GradTan(XTensor * node);
/* gradient for exp */
static
void GradExp(XTensor * node);
/* gradient for normalize */
static
void GradNormalize(XTensor * node);
/* gradient for absolute */
static
void GradAbsolute(XTensor * node);
/* gradient for sign */
static
void GradSign(XTensor * node);
/* gradient for clip */
static
void GradClip(XTensor * node);
/* gradient for round */
static
void GradRound(XTensor * node);
};
}
#endif
\ No newline at end of file
#endif
......@@ -43,6 +43,12 @@ void XShapeGrad::MakeGrad(XTensor * node)
GradMergeList(node);
else if(operID == SHAPE_UNSQUEEZE)
GradUnsqueeze(node);
else if(operID == SHAPE_SPLIT)
GradSplit(node);
else if(operID == SHAPE_SPLIT_LIST)
GradSplitList(node);
else if (operID == SHAPE_TRANSPOSE)
GradTranspose(node);
else{
ShowNTErrors("TODO!");
}
......@@ -55,6 +61,13 @@ bool XShapeGrad::IsShapeOP(XTensor * node)
return (income.typeID & DATA_BASE) != 0;
}
/* post processing of a node */
void XShapeGrad::PostProcessing(XTensor * node, int typeID)
{
if(typeID == SHAPE_SPLIT_LIST)
GradSplitListPost(node);
}
/*
gradient for merge
for
......@@ -134,6 +147,8 @@ void XShapeGrad::GradMerge(XTensor * node)
gradInputSmall.data = NULL;
delete[] dims;
node->visitMark = NODE_FINISHED;
}
/*
......@@ -213,6 +228,120 @@ void XShapeGrad::GradMergeList(XTensor * node)
gradSmall.data = NULL;
delete[] dims;
}
node->visitMark = NODE_FINISHED;
}
/*
gradient computation for split:
for
c = split(a)
we have
dE/da = merge(dE/dc)
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradSplit(XTensor * node)
{
XLink &income = node->income;
XTensor * input = income.tails[0];
int whereToSplit = income.GetParamInt(0);
int splitNum = income.GetParamInt(1);
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SPLIT!");
CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!");
CheckNTErrors(splitNum == node->dimSize[0], "Wrong split number!");
XNoder::MakeGrad(input);
/* we can simply merge the gradient tensor
if the input is used in spliting only */
if(input->outgo.tailNum == 1)
_Merge(node->grad, input->grad, whereToSplit + 1, 0);
/* if the tensor is used somewhere else, we need another SUM
for gradient accumulation */
else{
XTensor inputGradTMP(input);
_Merge(node->grad, &inputGradTMP, whereToSplit + 1, 0);
_Sum(input->grad, &inputGradTMP, input->grad);
}
node->visitMark = NODE_FINISHED;
}
/*
gradient computation for spliting
where we return the list of the splits
for
list(c_1, ...) = split(a)
we have
dE/da = merge(dE/c_1, ...)
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradSplitList(XTensor * node)
{
XLink &income = node->income;
XTensor * input = income.tails[0];
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for SPLIT!");
CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!");
node->visitMark = NODE_DOING;
}
/*
gradient computation for spliting. We return
the list of the splits : list(c_1, ...) = split(a).
this method is called only when all nodes of spliting
have been processed. We do this in a post-processing
manner because we can fuze multiple memory copy jobs
one time. This is good for system speed up.
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradSplitListPost(XTensor * node)
{
/* we compute the gradient for current node, rather than for
child node, i.e., we use the outgoing edge here */
XLink &outgo = node->outgo;
XList splits(outgo.tailNum);
int whereToSplit = -1;
int splitNum = 0;
for(int i = 0; i < outgo.tailNum; i++){
XTensor * parent = (XTensor*)outgo.tails[i];
XLink &income = parent->income;
if(income.typeID == SHAPE_SPLIT_LIST){
int w = income.GetParamInt(0);
int splitID = income.GetParamInt(1);
if(whereToSplit < 0)
whereToSplit = w;
splitNum++;
CheckNTErrors(whereToSplit == w, "Wrong dimension for spliting");
CheckNTErrors(income.tailNum == 1, "Something wrong with outgoing edge!");
CheckNTErrors(splitNum - 1 == splitID, "Wrong split id!");
splits.Add(parent);
}
}
/* we can simply merge the gradient tensor
if the node is used in spliting only */
if(outgo.tailNum == splitNum){
_Merge(&splits, node->grad, whereToSplit + 1);
}
/* if the tensor is used as input to other nodes
somewhere else, we need another SUM for gradient
accumulation */
else{
XTensor nodeGradTMP(node);
_Merge(&splits, &nodeGradTMP, whereToSplit + 1);
_Sum(node->grad, &nodeGradTMP, node->grad);
}
}
/*
......@@ -239,6 +368,40 @@ void XShapeGrad::GradUnsqueeze(XTensor * node)
CheckNTErrors(output->unitNum = input->unitNum * dSize, "Wrong tensor size!");
_ReduceSum(output->grad, input->grad, dim);
node->visitMark = NODE_FINISHED;
}
/*
gradient for transposing a tensor
for
c = Transpose(a)
we have
dE/da = Transpose(dE/dc)
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradTranspose(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for TRANSPOSE!");
XTensor * output = node;
XTensor * input = income.tails[0];
XTensor * b = NewTensor(input);
XNoder::MakeGrad(input);
int i = income.GetParamInt(0);
int j = income.GetParamInt(1);
CheckNTErrors(input->order > i && i >= 0, "index of dimension is out of scope!");
CheckNTErrors(input->order > j && j >= 0, "index of dimension is out of scope!");
_Transpose(output->grad, b, i, j);
_Sum(input->grad, b, input->grad);
node->visitMark = NODE_FINISHED;
delete b;
}
}
\ No newline at end of file
......@@ -40,18 +40,41 @@ public:
static
bool IsShapeOP(XTensor * node);
/* post processing of a node */
static
void PostProcessing(XTensor * node, int typeId);
private:
/* gradient for merge: c = merge(a, b, ...) */
/* gradient computation for merge: c = merge(a, b, ...) */
static
void GradMerge(XTensor * node);
/* gradient for merging a list of tensors : c = merge(list(a, b, ...)) */
/* gradient computation for merging a list of tensors : c = merge(list(a, b, ...)) */
static
void GradMergeList(XTensor * node);
/* gradient for unsqueezing a tensor : c = unsqueeze(a) */
/* gradient computation for split: c = split(a) */
static
void GradSplit(XTensor * node);
/* gradient computation for spliting. we return the list of the splits : list(c_1, ...) = split(a) */
static
void GradSplitList(XTensor * node);
/* gradient computation for spliting. we return the list of the splits : list(c_1, ...) = split(a).
this method is called only when all nodes of spliting have been processed. We do this in a post-processing
manner because we can fuze multiple memory copy jobs one time. This is good for system speed up. */
static
void GradSplitListPost(XTensor * node);
/* gradient computation for unsqueezing a tensor : c = unsqueeze(a) */
static
void GradUnsqueeze(XTensor * node);
/* gradient computation for unsqueezing a tensor : c = unsqueeze(a) */
static
void GradTranspose(XTensor * node);
};
}
......
......@@ -46,6 +46,11 @@ unsigned int MakeNetID()
return id;
}
void XNetClearAll()
{
MUTEX_DELE(netMutex);
}
/* constructor */
XNet::XNet()
{
......@@ -143,7 +148,7 @@ void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
/* back-propagation from output to input */
for(int i = nodes.count - 1; i >= 0; i--){
XTensor * node = (XTensor*)nodes.Get(i);
XTensor * node = (XTensor*)nodes.Get(i);;
if(node->visitMark == NODE_FINISHED)
continue;
......@@ -176,6 +181,10 @@ void XNet::BackwardNode(XTensor * node)
return;
if(!XNoder::IsLeaf(node)){
/* post processing for parent nodes */
BackwardNodePost(node);
/* process the current node */
if(XMathGrad::IsMathOP(node))
XMathGrad::MakeGrad(node);
else if(XFuncGrad::IsFunc(node))
......@@ -186,8 +195,24 @@ void XNet::BackwardNode(XTensor * node)
ShowNTErrors("Wrong node type!");
}
}
}
/*
backward computation (in post processing) for a given node
>> node - the node whose parent nodes are not processed yet. So
we do the job at the child node.
*/
void XNet::BackwardNodePost(XTensor * node)
{
bool isSplitList = false;
XLink &outgo = node->outgo;
for(int i = 0; i < outgo.tailNum; i++){
if(outgo.tails[i]->income.typeID == SHAPE_SPLIT_LIST)
isSplitList = true;
}
node->visitMark = NODE_FINISHED;
if(isSplitList)
XShapeGrad::PostProcessing(node, SHAPE_SPLIT_LIST);
}
/*
......@@ -238,10 +263,11 @@ void XNet::TarjanVisit(XTensor * node, XList &orders, const unsigned int code)
if(node == NULL)
return;
//fprintf(stderr, "%d\n", node->id);
if(node->visitMark == code + 1){
ShowNTErrors("There is a circle in the network\n");
}
else if(node->visitMark <= code || node->visitMark >= code + 2){
else if(node->visitMark <= code){
node->visitMark = code + 1;
XLink &income = node->income;
for(int i = 0; i < income.tailNum; i++){
......
......@@ -73,6 +73,9 @@ struct XNet
/* backward computation for a given node */
void BackwardNode(XTensor * node);
/* backward computation (in post processing) for a given node */
void BackwardNodePost(XTensor * node);
/* traverse the net and find the topological order by
depth-first search (Tarjan's algorithm) */
void Traverse(XTensor &root);
......@@ -92,6 +95,7 @@ struct XNet
extern unsigned int netIDGlobal;
extern MUTEX_HANDLE netMutex;
extern unsigned int MakeNetID();
extern void XNetClearAll();
}
......
......@@ -36,7 +36,7 @@
using namespace nts;
namespace samplefnnlm
namespace fnnlm
{
#define _EXIT_(x)// exit(x)
......@@ -126,7 +126,7 @@ struct FNNNet
XTensor output;
};
/* entry of the program */
/* entrance of the program */
int FNNLMMain(int argc, const char ** argv);
};
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "T2TAttention.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TAttention::T2TAttention()
{
nhead = -1;
dk = -1;
dv = -1;
d = -1;
}
/* deconstructor */
T2TAttention::~T2TAttention()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TAttention::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
float minmax = 0;
LoadParamInt(argc, argv, "nhead", &nhead, 8);
LoadParamInt(argc, argv, "d", &dk, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &dv, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_BEDDING_SIZE);
LoadParamFloat(argc, argv, "attminmax", &minmax, 0.08F);
InitTensor2D(&wk, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wq, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wv, d, dv, X_FLOAT, devID, mem);
wk.SetDataRand(-minmax, minmax);
wq.SetDataRand(-minmax, minmax);
wv.SetDataRand(-minmax, minmax);
}
/*
make the network
>> k - keys. It might be of size B * L * H
where B = batch size, L = sequence length,
and H = vector size of each position
>> q - queries
>> v - values
<< return - multi-attention result
*/
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v)
{
XTensor k2;
XTensor q2;
XTensor v2;
/* linear transofmration before self-attention */
k2 = MMul(k, wk);
q2 = MMul(q, wq);
v2 = MMul(v, wv);
XTensor kheads;
XTensor qheads;
XTensor vheads;
/* multi head */
kheads = Split(k2, k2.order - 1, nhead);
qheads = Split(q2, q2.order - 1, nhead);
vheads = Split(v2, v2.order - 1, nhead);
XTensor att;
XTensor scalar;
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */
scalar = Softmax(Linear(BMMul(qheads, X_NOTRANS, kheads, X_TRANS), 1/sqrt((float)dk)), -1);
att = BMMul(scalar, vheads);
/* concatenate the heads */
return Merge(att, att.order - 1);
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TATTENTION_H__
#define __T2TATTENTION_H__
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
/*
multi-head attention
y(Q, K, V) = cat(head_1, head_2, ..., head_n)
where head_i = Attention(Q * w_i^Q, K * w_i^K, V * w_i^V)
attention(Q, K, V) = softmax(Q * K^T/d_k^0.5) V
d_k = dimension size of K
*/
class T2TAttention
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* head number */
int nhead;
/* transformation matrix for K */
XTensor wk;
/* transformation matrix for Q */
XTensor wq;
/* transformation matrix for V */
XTensor wv;
/* size of transformed Q and K */
int dk;
/* size of transformed V */
int dv;
/* size of input Q, K and V */
int d;
public:
/* constructor */
T2TAttention();
/* de-constructor */
~T2TAttention();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TDECODER_H__
#define __T2TDECODER_H__
namespace transformer
{
class T2TDecoder
{
};
class AttDecoder : T2TDecoder
{
public:
/* initialize the model */
void InitModel(int argc, const char ** argv);
};
}
#endif
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-01
*/
#include <math.h>
#include "T2TEmbedding.h"
#include "T2TUtility.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TEmbedder::T2TEmbedder()
{
devID = -1;
mem = NULL;
vSize = -1;
maxLength = -1;
}
/* deconstructor */
T2TEmbedder::~T2TEmbedder()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TEmbedder::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
int d = 0;
LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamInt(argc, argv, "maxlen", &maxLength, 256);
LoadParamInt(argc, argv, "d", &eSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_BEDDING_SIZE);
InitTensor2D(&w, vSize, eSize, X_FLOAT, devID, mem);
w.SetDataRandn(0, sqrt((float)eSize));
/* create the positional embedding matrix */
MakePosEmbedding(eSize, d, maxLength);
}
/*
make positional embeddings (of size eSize * length
eSize - embedding size
length - length of the sequenc
*/
void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
{
InitTensor2D(&posEmbeddingBase, length, eSize, X_FLOAT, devID, mem);
float * data = new float[posEmbeddingBase.unitNum];
for(int pos = 0; pos < length; pos++){
float * dp = data + pos * eSize;
for(int k = 0; k < eSize; k++){
if(k % 2 == 0){
int i = k/2;
dp[k] = sin(pos/pow(10000.0F, 2.0F*i/d));
}
else{
int i = (k - 1)/2;
dp[k] = cos(pos/pow(10000.0F, 2.0F*i/d));
}
}
}
posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum);
delete[] data;
}
/*
make the network
*/
XTensor T2TEmbedder::Make(XTensor &input)
{
CheckNTErrors(input.GetDim(-1) == vSize, "Wrong vocabulary size!");
CheckNTErrors(input.order > 1, "Wrong input tensor size!");
CheckNTErrors(input.dimSize[input.order - 2] < maxLength, "The sequence is too long!");
CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\"");
CheckNTErrors(eSize > 0, "set embedding size by \"-esize\"");
int dims[MAX_TENSOR_DIM_NUM];
memcpy(dims, input.dimSize, input.order * sizeof(int));
dims[input.order - 1] = eSize;
bool match = (posEmbedding.order == input.order);
if(match){
for(int i = 0; i < input.order; i++){
if(dims[i] != posEmbedding.GetDim(i))
match = false;
}
}
/* we make positional embeddings first */
if(!match){
InitTensor(&posEmbedding, input.order, dims, X_FLOAT, 1.0F, devID, mem);
XTensor * posTMP = NewTensorBuf(2, dims + 1, X_FLOAT, 1.0F, devID, mem);
_CopyValues(&posEmbeddingBase, 0, posTMP->unitNum, posTMP, 0);
_Unsqueeze(posTMP, &posEmbedding, 0, dims[0]);
DelTensorBuf(posTMP);
}
XTensor wordEmbedding;
/* then we make word embeddings */
wordEmbedding = MMul(&input, w);
/* we sum over the two embeddings */
return wordEmbedding + posEmbedding;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-01
*/
#ifndef __T2TEMBEDDING_H__
#define __T2TEMBEDDING_H__
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
#define DEFAULT_BEDDING_SIZE 512
/*
embedding (of word at position i):
word embedding + positional embedding
*/
class T2TEmbedder
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* vocabulary size */
int vSize;
/* embedding size */
int eSize;
/* maximum length of the sequence */
int maxLength;
/* word embedding matrix */
XTensor w;
/* predefined positional embeddings. It can speeds up
the embedding processing by re-loading. */
XTensor posEmbeddingBase;
/* positional embeddings */
XTensor posEmbedding;
public:
/* constructor */
T2TEmbedder();
/* de-constructor */
~T2TEmbedder();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make positional embeddings */
void MakePosEmbedding(int eSize, int d, int length);
/* make the network */
XTensor Make(XTensor &input);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "T2TEncoder.h"
#include "T2TLayerNormal.h"
#include "T2TUtility.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
AttEncoder::AttEncoder()
{
}
/* de-constructor */
AttEncoder::~AttEncoder()
{
delete[] attentions;
delete[] fnns;
delete[] layerNorms;
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
LoadParamInt(argc, argv, "nstack", &nlayer, 6);
LoadParamInt(argc, argv, "hsize", &hSize, 512);
LoadParamInt(argc, argv, "esize", &eSize, 512);
LoadParamInt(argc, argv, "vsize", &vSize, -1);
CheckNTErrors(nlayer > 1, "We have one encoding layer at least!");
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsize\"");
/* embedding model */
embedder.InitModel(argc, argv, devID, mem);
attentions = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer];
layerNorms = new T2TLN[nlayer];
/* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){
attentions[i].InitModel(argc, argv, myDevID, myMem);
fnns[i].InitModel(argc, argv, myDevID, myMem);
layerNorms[i].InitModel(argc, argv, myDevID, myMem);
}
}
/*
make the encoding network
>> input - the input tensor of the encoder
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input)
{
XTensor x;
x = embedder.Make(input);
for(int i = 0; i < nlayer; i++){
XTensor att;
XTensor ln;
XTensor fnn;
XTensor res;
/* self attention */
att = attentions[i].Make(x, x, x);
/* residual connection */
res = Sum(att, x);
/* TODO: dropout */
/* layer normalization */
ln = layerNorms[i].Make(res);
/* input of next layer */
x = ln;
/* fnn */
fnn = fnns[i].Make(x);
/* residual connection */
res = Sum(fnn, x);
/* TODO: dropout */
/* layer normalization */
ln = layerNorms[i].Make(res);
/* input of next layer */
x = ln;
}
return x;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TENCODER_H__
#define __T2TENCODER_H__
#include "T2TFNN.h"
#include "T2TAttention.h"
#include "T2TEmbedding.h"
#include "T2TLayerNormal.h"
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
/*
base class of the encoder
*/
class T2TEncoder
{
public:
virtual
XTensor Make(XTensor &input) = 0;
};
/*
the encoder based on RNN
*/
class RNNEncoder : T2TEncoder
{
public:
XTensor Make(XTensor &input);
};
/*
the encoder based on self-attention
*/
class AttEncoder : T2TEncoder
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* layer number */
int nlayer;
/* hidden layer size of the FNN layer */
int hSize;
/* embedding size */
int eSize;
/* vocabulary size */
int vSize;
/* embedding of word at each position */
T2TEmbedder embedder;
/* FNN model of each layer */
T2TFNN * fnns;
/* attention model of each layer */
T2TAttention * attentions;
/* layer normalization */
T2TLN * layerNorms;
/* input tensor of the encoder */
XTensor * input;
/* output tensor of the encoder */
XTensor * output;
public:
/* constructor */
AttEncoder();
/* de-constructor */
~AttEncoder();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */
XTensor Make(XTensor &input);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include "T2TFNN.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
#include "../../tensor/function/FHeader.h"
namespace transformer
{
/* constructor */
T2TFNN::T2TFNN()
{
inSize = -1;
outSize = -1;
hSize = -1;
}
/* deconstructor */
T2TFNN::~T2TFNN()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
float minmax = 0;
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &outSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, DEFAULT_BEDDING_SIZE);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.08F);
InitTensor2D(&w1, inSize, hSize, X_FLOAT, devID, mem);
InitTensor1D(&b1, hSize, X_FLOAT, devID, mem);
InitTensor2D(&w2, hSize, outSize, X_FLOAT, devID, mem);
InitTensor1D(&b2, outSize, X_FLOAT, devID, mem);
w1.SetDataRand(-minmax, minmax);
b1.SetDataRand(-minmax, minmax);
w2.SetDataRand(-minmax, minmax);
b2.SetDataRand(-minmax, minmax);
}
/*
make the network
y = max(0, x * w1 + b1) * w2 + b2
>> input - the input tensor
>> return - the output tensor
*/
XTensor T2TFNN::Make(XTensor &input)
{
XTensor t1;
/* t1 = max(0, x * w1 + b1) */
t1 = Rectify(MMul(input, X_NOTRANS, w1, X_NOTRANS) + b1);
/* result = t1 * w2 + b2 */
return MMul(t1, X_NOTRANS, w2, X_NOTRANS) + b2;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TFNN_H__
#define __T2TFNN_H__
#include "../../tensor/XTensor.h"
using namespace nts;
namespace transformer
{
/* a fnn: y = max(0, x * w1 + b1) * w2 + b2 */
class T2TFNN
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* size of input vector */
int inSize;
/* size of output vector */
int outSize;
/* size of hidden layers */
int hSize;
/* matrix of transformation 1 */
XTensor w1;
/* bias of transformation 1 */
XTensor b1;
/* matrix of transformation 2 */
XTensor w2;
/* bias of transformation 2 */
XTensor b2;
public:
/* constructor */
T2TFNN();
/* deconstructor */
~T2TFNN();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include "T2TLayerNormal.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TLN::T2TLN()
{
devID = -1;
mem = NULL;
}
/* de-constructor */
T2TLN::~T2TLN()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TLN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
}
/*
make the network
for each layer representation x, we have
y =
>> input - the input tensor
>> return - layer normalization output
*/
XTensor T2TLN::Make(XTensor &input)
{
XTensor &x = input;
XTensor mean;
XTensor variance;
XTensor standard;
XTensor meanFilled;
XTensor standardFilled;
/* \mu = (sum_i x_i)/m */
mean = ReduceSum(x, x.order - 1);
/* \sigma = (sum_i (x_i - \mu)^2)/m */
variance = ReduceVariance(x, x.order - 1, mean);
/* standard = sqrt(variance) */
standard = Power(variance, 0.5F);
/* unsqueeze mean and standard deviation to fit them into
the same size of x */
meanFilled = Unsqueeze(mean, x.order - 1, x.GetDim(-1));
standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1));
/* x' = (x - \mu)/standard */
return (x - meanFilled)/standardFilled;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TLAYERNORMAL_H__
#define __T2TLAYERNORMAL_H__
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
class T2TLN
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
public:
/* constructor */
T2TLN();
/* de-constructor */
~T2TLN();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include "T2TModel.h"
#include "T2TUtility.h"
namespace transformer
{
/* constructor */
T2TModel::T2TModel()
{
devID = -1;
mem = NULL;
isLM = false;
isMT = false;
}
/* de-constructor */
T2TModel::~T2TModel()
{
delete mem;
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TModel::InitModel(int argc, const char ** argv)
{
bool useMem = false;
LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamBool(argc, argv, "mem", &useMem, useMem);
LoadParamBool(argc, argv, "lm", &isLM, true);
LoadParamBool(argc, argv, "mt", &isMT, false);
if(useMem){
delete mem;
mem = new XMem(devID);
}
encoder.InitModel(argc, argv, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem);
}
/*
make the encoding network
>> input - input tensor
<< return - encoding result
*/
XTensor T2TModel::MakeEncoding(XTensor &input)
{
return encoder.Make(input);
}
/*
make the entire network (with the output softmax layer)
>> input - input tensor
>> output - output tensor (distribution)
*/
void T2TModel::Make(XTensor &input, XTensor &output)
{
if(isLM){
XTensor encoding;
encoding = MakeEncoding(input);
outputLayer.Make(encoding, output);
}
else{
ShowNTErrors("TODO!");
}
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TMODEL_H__
#define __T2TMODEL_H__
#include "T2TFNN.h"
#include "T2TAttention.h"
#include "T2TEncoder.h"
#include "T2TDecoder.h"
#include "T2TOutput.h"
namespace transformer
{
class T2TModel
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* the encoder */
AttEncoder encoder;
/* the decoder */
AttDecoder decoder;
/* output layer */
T2TOutput outputLayer;
/* indicates whether the model is running for language modeling */
bool isLM;
/* indicates whether the model is running for machine translation */
bool isMT;
public:
/* constructor */
T2TModel();
/* de-constructor */
~T2TModel();
/* initialize the model */
void InitModel(int argc, const char ** argv);
/* make the encoding network */
XTensor MakeEncoding(XTensor &input);
/* make the entire network (with the output softmax layer) */
void Make(XTensor &input, XTensor &output);
};
}
#endif
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include "T2TOutput.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TOutput::T2TOutput()
{
devID = -1;
mem = NULL;
vSize = -1;
inSize = -1;
hSize = -1;
}
/* de-constructor */
T2TOutput::~T2TOutput()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TOutput::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
float minmax = 0;
LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &hSize, DEFAULT_BEDDING_SIZE);
LoadParamFloat(argc, argv, "outputminmax", &minmax, 0.08F);
InitTensor2D(&w, hSize, vSize, X_FLOAT, devID, mem);
w.SetDataRand(-minmax, minmax);
}
/*
make the network
y = softmax(x * w)
>> input - input tensor
<< return - output tensor
*/
XTensor T2TOutput::Make(XTensor &input)
{
XTensor &x = input;
return LogSoftmax(MMul(x, w), -1);
}
/*
make the network (redefined output tensor)
>> input - input tensor
>> output - output tensor
*/
void T2TOutput::Make(XTensor &input, XTensor &output)
{
XTensor &x = input;
output = LogSoftmax(MMul(x, w), -1);
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TOUTPUT_H__
#define __T2TOUTPUT_H__
#include "../../tensor/function/FHeader.h"
using namespace nts;
namespace transformer
{
/* output layer */
class T2TOutput
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* vocabulary size */
int vSize;
/* input vector size */
int inSize;
/* vector size of the linear transformation */
int hSize;
/* transformation matrix */
XTensor w;
public:
/* constructor */
T2TOutput();
/* de-constructor */
~T2TOutput();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
/* make the network (redefined output tensor) */
void Make(XTensor &input, XTensor &output);
};
}
#endif
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02
*/
#include <math.h>
#include "T2TTrainer.h"
#include "T2TUtility.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
T2TTrainer::T2TTrainer()
{
devID = -1;
mem = NULL;
seqLen = NULL;
nseqBuf = 0;
nextSeq = -1;
}
/* de-constructor */
T2TTrainer::~T2TTrainer()
{
delete[] buf;
delete[] seqLen;
delete[] seqOffset;
}
/*
initialization
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
void T2TTrainer::Init(int argc, const char ** argv)
{
LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamFloat(argc, argv, "lrate", &lrate, 0.001F);
LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1);
LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1);
LoadParamInt(argc, argv, "nepoch", &nepoch, 1);
LoadParamInt(argc, argv, "nstep", &nstep, 1);
LoadParamInt(argc, argv, "vsize", &vSize, 1);
LoadParamBool(argc, argv, "sorted", &isLenSorted, false);
LoadParamInt(argc, argv, "bufsize", &bufSize, 50000);
buf = new int[bufSize];
seqLen = new int[bufSize];
seqOffset = new int[bufSize];
}
/*
train the model
>> fn - training data file
>> model - model to train
*/
void T2TTrainer::Train(const char * fn, T2TModel * model)
{
int epoch = 0;
int step = 0;
int wc = 0;
int wordCount = 0;
int wordCountTotal = 0;
bool isEnd = false;
float loss = 0;
XNet net;
double startT = GetClockSec();
for(epoch = 0; epoch < nepoch; epoch++){
FILE * file = fopen(fn, "rb");
CheckNTErrors(file, "cannot open training file!");
wordCount = 0;
/* batch of input sequences */
XTensor batch;
while(LoadBatch(file, &batch, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc)){
/* output probabilities */
XTensor output;
/* make the network */
model->Make(batch, output);
/* back-propagation for obtaining gradients */
net.Backward(output, batch, CROSSENTROPY);
/* update the parameters */
Update(model);
/* get probabilities */
float prob = GetProb(&output, &batch, NULL);
loss += -prob;
wordCount += wc;
wordCountTotal += wc;
if(++step >= nstep){
isEnd = true;
break;
}
if (step % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT5(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, ngram=%d, ppl=%.3f\n",
elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
}
}
fclose(file);
}
double elapsed = GetClockSec() - startT;
XPRINT5(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, ngram=%d, ppl=%.3f\n",
elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
elapsed, step, epoch);
}
char line[MAX_SEQUENCE_LENGTH];
/*
load data to buffer
>> file - where to load data
*/
int T2TTrainer::LoadBuf(FILE * file)
{
int lineCount = 0;
int seqCount = 0;
int wordCount = 0;
while(fgets(line, MAX_SEQUENCE_LENGTH - 1, file)){
int len = (int)strlen(line);
while(line[len - 1] == '\r' || line[len - 1] == '\n'){
line[len - 1] = 0;
len--;
}
len = (int)strlen(line);
if(len == 0)
continue;
/* how many characters are in a word */
int wSize = 0;
/* how many words are in the sentence */
int wNum = 0;
int wNumLocal = 0;
int i = 0;
for(i = 0; i < len; i++){
/* load word (id) seperated by space or tab */
if((line[i] == ' ' || line[i] == '\t') && wSize > 0){
line[i] = 0;
if(wSize == 3 && line[i - 1] == '|' && line[i - 2] == '|' && line[i - 3] == '|'){
seqLen[seqCount] = wNumLocal;
seqOffset[seqCount] = wordCount + wNum - wNumLocal;
seqCount++;
wNumLocal = 0;
}
else{
buf[wordCount + wNum++] = atoi(line + i - wSize);
wNumLocal++;
}
wSize = 0;
}
else
wSize++;
}
if(wSize > 0){
buf[wordCount + wNum++] = atoi(line + i - wSize);
wNumLocal++;
}
seqLen[seqCount] = wNumLocal;
seqOffset[seqCount] = wordCount + wNum - wNumLocal;
seqCount++;
wordCount += wNum;
lineCount++;
if(wordCount >= bufSize - MAX_SEQUENCE_LENGTH)
break;
}
nseqBuf = seqCount;
nextSeq = 0;
return lineCount;
}
/*
load a batch of sequences
>> file - the handle to the data file
>> batch - the batch
>> step - the step we go over when move to the next sequence
>> vs - vocabulary size
>> sBatch - batch size of sequences
>> wBatch - batch size of words
>> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count
*/
int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount)
{
if(nextSeq < 0 || nextSeq >= nseqBuf)
LoadBuf(file);
int seq = MAX(nextSeq, 0);
int wc = 0;
int wn = 0;
int sc = 0;
int max = 0;
while(seq + sc < nseqBuf){
wn = seqLen[seq + sc];
wc += wn;
sc += 1;
if(max < wn)
max = wn;
if(sc >= sBatch && wc >= wBatch)
break;
}
nextSeq = seq + sc;
if(sc > 0){
int dims[MAX_TENSOR_DIM_NUM];
dims[0] = sc;
dims[1] = max;
dims[2] = vs;
if(batch->order != 3 || batch->GetDim(0) != dims[0] ||
batch->GetDim(1) != dims[1] || batch->GetDim(2) != dims[2]){
InitTensor(batch, 3, dims, X_FLOAT, 1.0F, devID, mem);
}
batch->SetZeroAll();
/* this might be slow on GPUs :( */
for(int s = seq; s < seq + sc; s++){
for(int w = 0; w < seqLen[s]; w++){
batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
wCount++;
}
}
}
return sc;
}
/*
get word probabilities for a batch of sequences
>> output - word distribution for each position
>> gold - gold standard
>> wordProbs - word probability for gold prediction
*/
float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs)
{
XTensor probs;
InitTensor(&probs, output);
/* probs[i,j] = output[i,j] * gold[i,j] */
_Multiply(output, gold, &probs);
/* probability of each word */
XTensor wprobs;
InitTensor1D(&wprobs, output->unitNum/output->GetDim(-1), X_FLOAT, output->devID, output->mem);
int dims[2] = {output->unitNum/output->GetDim(-1), output->GetDim(-1)};
probs.Reshape(2, dims);
_ReduceSum(&probs, &wprobs, 1);
if(wordProbs != NULL)
_CopyValues(&wprobs, wordProbs);
/* reshape the tensor to fit it into the reduce procedure
TODO: XTensor supports scalars */
dims[0] = 1;
dims[1] = probs.unitNum;
probs.Reshape(2, dims);
/* probability for the batch */
XTensor result;
InitTensor1D(&result, 1, X_FLOAT, output->devID, output->mem);
_ReduceSum(&probs, &result, 1);
return result.Get1D(0);
}
/*
update the model by delta rule
>> model - the t2t model
*/
void T2TTrainer::Update(T2TModel * model)
{
XList ws(100);
ws.Add(&model->outputLayer.w);
for(int i = 0; i < model->encoder.nlayer; i++){
ws.Add(&model->encoder.fnns[i].w1);
ws.Add(&model->encoder.fnns[i].b1);
ws.Add(&model->encoder.fnns[i].w2);
ws.Add(&model->encoder.fnns[i].b2);
}
ws.Add(&model->encoder.embedder.w);
for(int i = 0; i < ws.count; i++){
XTensor * para = (XTensor*)ws.Get(i);
XTensor * paraGrad = para->grad;
CheckNTErrors(para != NULL, "NULL parameter tensor!");
CheckNTErrors(paraGrad != NULL, "NULL gradient tensor!");
/* the delta rule */
_Sum(para, paraGrad, para, -lrate);
}
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02
*/
#ifndef __T2TTRAINER_H__
#define __T2TTRAINER_H__
#include "T2TModel.h"
#include "../../tensor/function/FHeader.h"
#define MAX_SEQUENCE_LENGTH 1024 * 4
using namespace nts;
namespace transformer
{
/* trainer of the T2T model */
class T2TTrainer
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* buffer for loading words */
int * buf;
/* buffer size */
int bufSize;
/* length of each sequence */
int * seqLen;
/* offset of the first word for each sequence */
int * seqOffset;
/* number of sequences in the buffer */
int nseqBuf;
/* offset for next sequence in the buffer */
int nextSeq;
/* indicates whether the sequence is sorted by length */
bool isLenSorted;
/* vocabulary size of the source side */
int vSize;
/* learning rate */
float lrate;
/* sentence batch size */
int sBatchSize;
/* word batch size */
int wBatchSize;
/* training epoch number */
int nepoch;
/* traing step number */
int nstep;
public:
/* constructor */
T2TTrainer();
/* de-constructor */
~T2TTrainer();
/* initialize the trainer */
void Init(int argc, const char ** argv);
/* train the model */
void Train(const char * fn, T2TModel * model);
/* load data to buffer */
int LoadBuf(FILE * file);
/* load a batch of sequences */
int LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount);
/* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
/* update the model by delta rule */
void Update(T2TModel * model);
};
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
namespace transformer
{
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname) && i + 1 < argc){
strcpy(p, argv[i + 1]);
//fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
hit = true;
}
}
if(!hit)
strcpy(p, defaultP);
}
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname) && i + 1 < argc){
*(int*)p = atoi(argv[i + 1]);
//fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
hit = true;
}
}
if(!hit)
*p = defaultP;
}
void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname)){
*(bool*)p = true;
//fprintf(stderr, " %s=%s\n", name, "true");
hit = true;
}
}
if(!hit)
*p = defaultP;
}
void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP)
{
char vname[128];
vname[0] = '-';
strcpy(vname + 1, name);
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname) && i + 1 < argc){
*p = (float)atof(argv[i + 1]);
//fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
hit = true;
}
}
if(!hit)
*p = defaultP;
}
void ShowParams(int argc, const char ** argv)
{
fprintf(stderr, "args:\n");
for(int i = 0; i < argc; i++){
if(argv[i][0] == '-'){
if(i + 1 < argc && argv[i + 1][0] != '-')
fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]);
else
fprintf(stderr, " %s=yes\n", argv[i]);
}
}
fprintf(stderr, "\n");
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#ifndef __T2TUTILITY_H__
#define __T2TUTILITY_H__
#include <stdio.h>
namespace transformer
{
/* load arguments */
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP);
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP);
void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP);
void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP);
/* show arguments */
void ShowParams(int argc, const char ** argv);
}
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include "Transformer.h"
#include "T2TModel.h"
#include "T2TUtility.h"
#include "T2TTrainer.h"
#include "../../tensor/XDevice.h"
namespace transformer
{
int TransformerMain(int argc, const char ** argv)
{
if(argc == 0)
return 1;
ShowParams(argc, argv);
char * trainFN = new char[MAX_LINE_LENGTH];
LoadParamString(argc, argv, "train", trainFN, "");
T2TModel model;
model.InitModel(argc, argv);
if(strcmp(trainFN, "")){
T2TTrainer trainer;
trainer.Init(argc, argv);
trainer.Train(trainFN, &model);
}
delete[] trainFN;
return 0;
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
*
* An impelementation of the transformer system. See more details
* about FNNLM in
* "Attention Is All You Need" by Vaswani et al.
* https://arxiv.org/pdf/1706.03762.pdf
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
* I start writing the code related to NMT - a long time since my last coding
* work on MT
*/
#ifndef __TRANSFORMER_H__
#define __TRANSFORMER_H__
#include "../../tensor/XGlobal.h"
#include "../../tensor/XTensor.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* entrance of the program */
int TransformerMain(int argc, const char ** argv);
}
#endif
\ No newline at end of file
......@@ -29,6 +29,7 @@
#include "XTensor.h"
#include "XDevice.h"
#include "./test/Test.h"
#include "./core/CHeader.h"
//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>
......@@ -37,6 +38,7 @@
using namespace nts;
void SmallTest();
void TransposeTest();
int main( int argc, const char ** argv )
{
......@@ -92,3 +94,35 @@ void SmallTest()
c.Dump(stderr, "c:");
d.Dump(stderr, "d:");
}
void TransposeTest()
{
XTensor a;
XTensor b;
int I = 2;
int J = 3;
InitTensor4D(&a, 2, 3, 4, 5);
int * dims = new int[a.order];
memcpy(dims, a.dimSize, sizeof(int) * a.order);
dims[I] = a.dimSize[J];
dims[J] = a.dimSize[I];
InitTensor(&b, 4, dims);
a.SetZeroAll();
b.SetZeroAll();
float * data = new float[a.unitNum];
for(int i = 0; i < a.unitNum; i++)
data[i] = (float)i;
a.SetData(data, a.unitNum, 0);
_Transpose(&a, &b, I, J);
b.Dump(stderr, "b:");
delete[] data;
}
......@@ -40,6 +40,7 @@ XDevManager GDevs;
/* constructor */
XDevice::XDevice()
{
stream = NULL;
Clear();
#ifdef USE_CUDA
......@@ -55,6 +56,8 @@ XDevice::~XDevice()
MUTEX_DELE(cublasMutex);
if(isHandleReady)
cublasDestroy(cublasHandle);
if(stream != NULL)
delete stream;
#endif
}
......@@ -118,6 +121,8 @@ void XDevice::Init(int myDevID)
}
else
sprintf(name2, "GPU-%d %s", devID, name);
stream = new XStream(0, devID);
#endif
}
......@@ -161,6 +166,14 @@ cublasHandle_t * XDevice::GetCublasHandle()
return &cublasHandle;
}
/* get the stream of cuda */
cudaStream_t * XDevice::GetCudaStream()
{
CheckNTErrors(stream != NULL, "the stream is not initialized!");
return &stream->stream;
}
#endif // USE_CUDA
/* switch to a device */
......@@ -311,11 +324,19 @@ void XDevManager::Clear()
/* get the handle of GPU */
cublasHandle_t * XDevManager::GetCudaHandle(const int devID)
{
CheckNTErrors((devID < nGPU), "index of GPU is out of range.");
CheckNTErrors(devID < nGPU, "index of GPU is out of range.");
return GPUs[devID].GetCublasHandle();
}
/* get the stream of cuda */
cudaStream_t * XDevManager::GetCudaStream(const int devID)
{
CheckNTErrors(devID < nGPU, "index of GPU is out of range.");
return GPUs[devID].GetCudaStream();
}
#endif
/*
......@@ -384,13 +405,10 @@ int XDevManager::GetCudaThread2D(const int devID, const int n, const int m, int
memset(gridSize, 0, sizeof(int) * 3);
memset(blockSize, 0, sizeof(int) * 3);
if(n <= 0 || m <= 0 || devID >= nGPU)
if(n <= 0 || m <= 0)
return 1;
if(devID < 0){
XPRINT(0, stderr, "WARNING! You are calling the grid and block size computation function on a CPU!");
return 0;
}
CheckNTErrors(devID >= 0 && devID < nGPU, "Invalid GPU device id!");
#ifdef USE_CUDA
......
......@@ -25,6 +25,7 @@
#define __XDEVICE_H__
#include "XThread.h"
#include "XStream.h"
#ifdef USE_CUDA
......@@ -92,6 +93,9 @@ public:
/* specify whether Unified Virtual Address Space (UVA) is supported */
bool isUVASupported;
/* default stream for the device */
XStream * stream;
#ifdef USE_CUDA
/* mutex for handle (GPU cublas) */
......@@ -121,6 +125,9 @@ public:
#ifdef USE_CUDA
/* get cublas handle */
cublasHandle_t * GetCublasHandle();
/* get the stream of cuda */
cudaStream_t * GetCudaStream();
#endif
/* switch to a device */
......@@ -178,6 +185,9 @@ public:
#ifdef USE_CUDA
/* get the handle of GPU */
cublasHandle_t * GetCudaHandle(const int devID);
/* get the stream of cuda */
cudaStream_t * GetCudaStream(const int devID);
#endif
/* get grid and block sizes that max potential */
......
......@@ -167,7 +167,9 @@ void XLink::SetType(int id)
type[0] = 0;
strcpy(type, GetOPName(id));
typeID = id;
CheckNTErrors(strcmp(type, "NULL"), "illegal edge type name!");
if(id != 0){
CheckNTErrors(strcmp(type, "NULL"), "illegal edge type name!");
}
}
/*
......@@ -515,7 +517,7 @@ void XLink::CopyIncoming(const XTensor * reference, XTensor * target)
tails.Add(tail);
}
MakeLink(&tails, target, reference->id);
MakeLink(&tails, target, reference->income.typeID);
int paraNum = reference->income.paramNum;
target->income.paramNum = paraNum;
......
......@@ -208,22 +208,16 @@ void XList::Insert(int pos, void * item)
/* get the item at position i */
void * XList::GetItem(int i) const
{
if( i >= 0 && i < count )
return items[i];
else
return NULL;
CheckNTErrors(i >= 0 && i < count, "Index of a list item is out of scope!");
return items[i];
}
/* get the integer-typed item at position i */
int XList::GetItemInt(int i)
{
CheckNTErrors(isIntList, "An int list is required!");
if( i >= 0 && i < count ){
return *(int*)(items[i]);
}
else
return 0;
CheckNTErrors(i >= 0 && i < count, "Index of a list item is out of scope!");
return *(int*)(items[i]);
}
/* set the item at position i */
......
......@@ -181,7 +181,10 @@ void XMem::Free(int myDevID, void * mem)
else{
#ifdef USE_CUDA
SetDevice(myDevID);
CheckNTErrors(cudaFree((char*)mem) == cudaSuccess, "Cannot free the memory.");
cudaError_t error = cudaFree((char*)mem);
if(error != cudaSuccess){
ShowNTErrors("Cannot free the memory.");
}
#else
ShowNTErrors("Please specify USE_CUDA for compiling this program.");
#endif
......
......@@ -29,6 +29,22 @@ const char * GetOPName(int type)
if ((type & MATH_BASE) != 0){
if (type == MATH_ABSOLUTE)
return "M_ABSOLUTE";
else if (type == MATH_EXP)
return "M_EXP";
else if (type == MATH_LOG)
return "M_LOG";
else if (type == MATH_SIN)
return "M_SIN";
else if (type == MATH_COS)
return "M_COS";
else if (type == MATH_TAN)
return "M_TAN";
else if (type == MATH_ROUND)
return "M_ROUND";
else if (type == MATH_CLIP)
return "M_CLIP";
else if (type == MATH_DIV)
return "M_DIV";
else if (type == MATH_MATRIXMUL)
return "M_MATRIXMUL";
else if (type == MATH_MATRIXMULBATCHED)
......@@ -37,18 +53,20 @@ const char * GetOPName(int type)
return "M_MULTIPLY";
else if (type == MATH_NEGATE)
return "M_NEGATE";
else if (type == MATH_SIGN)
return "M_SIGN";
else if (type == MATH_SUM)
return "M_SUM";
else if (type == MATH_LOG)
return "M_LOG";
else if (type == MATH_NORMALIZE)
return "M_NORMALIZE";
else if (type == MATH_POWER)
return "M_POWER";
else if (type == MATH_SCALEANDSHIFT)
return "M_SCALEANDSHIFT";
else if (type == MATH_SIGN)
return "M_SIGN";
else if (type == MATH_SUM)
return "M_SUM";
else if (type == MATH_SUB)
return "M_SUB";
else if (type == MATH_SUMDIM)
return "M_SUMDIM";
else if (type == REDUCE_REDUCEMAX)
return "R_REDUCEMAX";
else if (type == REDUCE_REDUCEMEAN)
......
......@@ -30,20 +30,30 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* math operations */
#define MATH_BASE 0x00001000
#define MATH_ABSOLUTE MATH_BASE + 1
#define MATH_MATRIXMUL MATH_ABSOLUTE + 1
#define MATH_EXP MATH_ABSOLUTE + 1
#define MATH_LOG MATH_EXP + 1
#define MATH_SIN MATH_LOG + 1
#define MATH_COS MATH_SIN + 1
#define MATH_TAN MATH_COS + 1
#define MATH_ROUND MATH_TAN + 1
#define MATH_CLIP MATH_ROUND + 1
#define MATH_DIV MATH_CLIP + 1
#define MATH_MATRIXMUL MATH_DIV + 1
#define MATH_MATRIXMULBATCHED MATH_MATRIXMUL + 1
#define MATH_MULTIPLY MATH_MATRIXMULBATCHED + 1
#define MATH_NEGATE MATH_MULTIPLY + 1
#define MATH_SIGN MATH_NEGATE + 1
#define MATH_SUM MATH_SIGN + 1
#define MATH_LOG MATH_SUM + 1
#define MATH_NORMALIZE MATH_LOG + 1
#define MATH_NORMALIZE MATH_NEGATE + 1
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
#define MATH_SIGN MATH_SCALEANDSHIFT + 1
#define MATH_SUM MATH_SIGN + 1
#define MATH_SUB MATH_SUM + 1
#define MATH_SUMDIM MATH_SUB + 1
#define REDUCE MATH_SCALEANDSHIFT + 1
#define REDUCE MATH_SUMDIM + 1
#define REDUCE_REDUCEMAX REDUCE + 1
#define REDUCE_REDUCEMEAN REDUCE_REDUCEMAX + 1
#define REDUCE_REDUCESUM REDUCE_REDUCEMEAN + 1
......
......@@ -84,7 +84,7 @@ void XStream::Create(int priority, int myDevID)
XDevice::SetGPUDevice(myDevID);
//cudaStreamCreateWithPriority(&stream, cudaStreamDefault, priority);
CheckNTErrors((cudaStreamCreate(&stream) == cudaSuccess),
"cannot create the cuda stream!");
"cannot create the cuda stream!");
XDevice::SetGPUDevice(backupDevID);
#endif
devID = myDevID;
......
......@@ -42,6 +42,8 @@
#include "core/movement/CopyValues.h"
#include "core/arithmetic/Sum.h"
#include "core/arithmetic/Multiply.h"
#include "core/arithmetic/Sub.h"
#include "core/arithmetic/Div.h"
#include "core/math/ScaleAndShift.h"
#ifdef USE_CUDA
......@@ -354,6 +356,18 @@ XTensor XTensor::operator* (const XTensor& tensor)
return Multiply(*this, tensor);
}
/* overloading of the minus-sign */
XTensor XTensor::operator- (const XTensor& tensor)
{
return Sub(*this, tensor);
}
/* overloading of the division-sign */
XTensor XTensor::operator/ (const XTensor& tensor)
{
return Div(*this, tensor);
}
/*
linear transformation b = a * \scale + \shift
>> scale - the slope
......@@ -426,8 +440,12 @@ get the size of a given dimension
int XTensor::GetDim(const int dim)
{
CheckNTErrors(dim < order, "dimenision is out of range!");
int d = dim;
if(dim < 0)
d = order - 1;
return dimSize[dim];
return dimSize[d];
}
/*
......@@ -454,6 +472,27 @@ void XTensor::Reshape(const int myOrder, const int * myDimSize)
memcpy(dimSizeRDI, dimsRDI, sizeof(int) * order);
}
/*
reshape the tensor to a vector
>> num - number of elements
*/
void XTensor::Reshape(const int num)
{
int dim = num;
Reshape(1, &dim);
}
/*
reshape the tensor to a matrix
>> rowNum - number of rows
>> colNum - number of columns
*/
void XTensor::Reshape(const int rowNum, const int colNum)
{
int dims[2] = {rowNum, colNum};
Reshape(2, dims);
}
/* get the number of items in the data array */
int XTensor::GetSize() const
{
......@@ -560,25 +599,24 @@ set the tensor items by a uniform distribution in range [lower, upper]
void XTensor::SetDataRand(DTYPE lower, DTYPE upper)
{
// TODO: cuda code!!!!!!!
// TODO: replace float with DTYPE
if (data == NULL)
return;
// srand((unsigned)time(0));
DTYPE variance = upper - lower;
void * d = NULL;
if (dataType == X_FLOAT) {
d = new float[unitNum];
for (int i = 0; i < unitNum; i++) {
DTYPE value = lower + (upper - lower) * (float)rand() / RAND_MAX;
DTYPE value = lower + variance * (float)rand() / RAND_MAX;
*((float*)d + i) = value;
}
}
else if (dataType == X_DOUBLE) {
d = new double[unitNum];
for (int i = 0; i < unitNum; i++) {
*((double*)d + i) = lower + (upper - lower) * rand() / RAND_MAX;
*((double*)d + i) = lower + variance * rand() / RAND_MAX;
}
}
else {
......@@ -588,15 +626,15 @@ void XTensor::SetDataRand(DTYPE lower, DTYPE upper)
SetData(d, unitNum);
if (dataType == X_FLOAT) {
delete[](float*)d;
delete[] (float*)d;
}
else {
delete[](double*)d;
delete[] (double*)d;
}
}
/* a gauss distribution */
double GaussRand()
/* a gauss distribution (Box-Muller method) */
double GaussRand(DTYPE mean, DTYPE standardDeviation)
{
// TODO: cuda code!!!!!!!
......@@ -606,8 +644,8 @@ double GaussRand()
double pi = 3.141592654;
if (phase == 0){
u = rand() / (RAND_MAX + 1.0);
v = rand() / (RAND_MAX + 1.0);
u = (rand() + 1.0) / (RAND_MAX + 1.0);
v = (rand() + 1.0) / (RAND_MAX + 1.0);
z = sqrt(-2.0 * log(u))* sin(2.0 * pi * v);
}
else{
......@@ -615,7 +653,7 @@ double GaussRand()
}
phase = 1 - phase;
return z;
return mean + (z * standardDeviation);
}
/*
......@@ -626,7 +664,6 @@ set the tensor items by a normal distribution
void XTensor::SetDataRandn(DTYPE mean, DTYPE standardDeviation)
{
// TODO: cuda code!!!!!!!
// TODO: replace float with DTYPE
if (data == NULL)
return;
......@@ -636,13 +673,13 @@ void XTensor::SetDataRandn(DTYPE mean, DTYPE standardDeviation)
if (dataType == X_FLOAT) {
d = new float[unitNum];
for (int i = 0; i < unitNum; i++) {
*((float*)d + i) = (float)GaussRand();
*((float*)d + i) = (float)GaussRand(mean, standardDeviation);
}
}
else if (dataType == X_DOUBLE) {
d = new double[unitNum];
for (int i = 0; i < unitNum; i++) {
*((double*)d + i) = GaussRand();
*((double*)d + i) = GaussRand(mean, standardDeviation);
}
}
else {
......@@ -652,10 +689,10 @@ void XTensor::SetDataRandn(DTYPE mean, DTYPE standardDeviation)
SetData(d, unitNum);
if (dataType == X_FLOAT) {
delete[](float*)d;
delete[] (float*)d;
}
else {
delete[](double*)d;
delete[] (double*)d;
}
}
......@@ -1003,11 +1040,11 @@ set the value of a cell in a 3d tensor in default type
*/
bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2)
{
CheckNTErrors((order == 3), "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors((d0 >= 0 && d1 < dimSize[0]), "dimension 0 is out of range!");
CheckNTErrors((d2 >= 0 && d2 < dimSize[1]), "dimension 1 is out of range!");
CheckNTErrors((d2 >= 0 && d2 < dimSize[2]), "dimension 1 is out of range!");
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type.");
CheckNTErrors(order == 3, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors(d0 >= 0 && d0 < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors(d1 >= 0 && d1 < dimSize[1], "dimension 1 is out of range!");
CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 1 is out of range!");
CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
int dims[3] = {d0, d1, d1};
......@@ -1439,6 +1476,21 @@ void XTensor::Dump(FILE * file, const char * label, const int n, const int verbo
}
/*
dump data to a file
>> tensor - tensor whose data is dumped
>> file - where to domp the data
>> label - label of the tensor
>> n - number of items to dump
>> verbose - verbose level
*/
void XTensor::Dump(const XTensor * tensor, FILE * file, const char * label, const int n, const int verbose)
{
XTensor a(tensor->order, tensor->dimSize, tensor->dataType, tensor->denseRatio, tensor->devID, tensor->mem);
_CopyValues(tensor, &a);
a.Dump(file, label, n, verbose);
}
/*
read data from a file
>> file - where to load the data
>> label - label of the tensor
......@@ -1687,13 +1739,13 @@ void InitTensor(XTensor * tensor,
dims[0] = -abs(dims[0]);
tensor->Resize(myOrder, dims, myDataType, myDenseRatio);
if(myDevID == CURRENT_GPU)
if (myDevID == CURRENT_GPU)
tensor->devID = XDevice::GetGPUDevice();
else
tensor->devID = myDevID;
tensor->Resize(myOrder, dims, myDataType, myDenseRatio);
if(allocated)
XTensor::AllocateData(tensor);
}
......@@ -1870,28 +1922,47 @@ generate a XTensor which allocates data on the buffer
>> myDimSize - the size of each dimension
>> myMem - memory pool used to allocating the data array.
we actually allocate the data on the buffer associated with
the memory pool.
the memory pool
>> devID - device id
>> myDataType - unit size (e.g., int, float, and double)
>> myDenseRatio - how often an element has non-zero value
*/
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, XMem * myMem,
const TENSOR_DATA_TYPE myDataType, const float myDenseRatio)
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize,
const TENSOR_DATA_TYPE myDataType, const float myDenseRatio,
const int devID, XMem * myMem)
{
CheckNTErrors(myMem != NULL, "No memory pool specified!");
int dims[MAX_TENSOR_DIM_NUM];
memcpy(dims, myDimSize, sizeof(int) * myOrder);
dims[0] = -abs(dims[0]);
XTensor * tensor = NewTensor(myOrder, dims, myDataType, myDenseRatio, -1, myMem);
tensor->data = myMem->AllocBuf(myMem->devID, tensor->unitNum * tensor->unitSize);
XTensor * tensor = NewTensor(myOrder, dims, myDataType, myDenseRatio, devID, myMem);
if(myMem != NULL)
tensor->data = myMem->AllocBuf(myMem->devID, tensor->unitNum * tensor->unitSize);
else
tensor->data = XMemAlloc(devID, tensor->unitNum * tensor->unitSize);
return tensor;
}
/*
generate a XTensor which allocates data on the buffer
>> reference - reference tensor
>> devID - device id
>> myMem - memory pool used to allocating the data array.
we actually allocate the data on the buffer associated with
the memory pool
*/
XTensor * NewTensorBuf(const XTensor * reference, int devID, XMem * myMem)
{
return NewTensorBuf(reference->order, reference->dimSize,
reference->dataType, reference->denseRatio,
devID, myMem);
}
/*
generate a dense vector
>> num - number of entries
>> myDataType - unit size (e.g., int, float, and double)
......@@ -2041,7 +2112,7 @@ XTensor * NewTensor(XTensor * a, bool isFilledData)
free the data space of a given tensor
>> tensor - pointer to the tensor
*/
void DelTensor(const XTensor * tensor)
void DelTensor(XTensor * tensor)
{
delete tensor;
}
......@@ -2050,10 +2121,13 @@ void DelTensor(const XTensor * tensor)
free the data space of a given tensor (on the buffer)
>> tensor - pointer to the tensor
*/
void DelTensorBuf(const XTensor * tensor)
void DelTensorBuf(XTensor * tensor)
{
CheckNTErrors(tensor->mem != NULL, "No memory pool found!");
tensor->mem->ReleaseBuf(tensor->devID, tensor->unitNum * tensor->unitSize);
if(tensor->mem != NULL)
tensor->mem->ReleaseBuf(tensor->devID, tensor->unitNum * tensor->unitSize);
else
XMemFree(tensor->devID, tensor->data);
tensor->data = NULL;
delete tensor;
}
......
......@@ -45,12 +45,13 @@ namespace nts{
struct XLink;
/* define the maximum number of dimensions in a tensor */
#define MAX_TENSOR_DIM_NUM 6
#define MAX_TENSOR_DIM_NUM 8
#define USE_BATCHED_STRIDED_MAT_MUL
#define MIN_TENSOR_SPLIT_NUM 10
#define MIN_TENSOR_SPLIT_NUM 0
#define MIN_TENSOR_SPLIT_LIST_NUM 1024
#define MIN_TENSOR_CAT_NUM 8
/* computation flags */
#define UNSAFE_BUT_FAST_MEM
#define FAST_MATRIX
......@@ -202,6 +203,12 @@ public:
/* overloading of the multiply-sign */
XTensor operator* (const XTensor &tensor);
/* overloading of the minus-sign */
XTensor operator- (const XTensor &tensor);
/* overloading of the division-sign */
XTensor operator/ (const XTensor &tensor);
/* linear transformation */
XTensor Lin(DTYPE scale, DTYPE shift = 0);
......@@ -222,6 +229,12 @@ public:
/* reshape the tensor */
void Reshape(const int order, const int * myDimSize);
/* reshape the tensor to a vector */
void Reshape(const int num);
/* reshape the tensor to a matrix */
void Reshape(const int rowNum, const int colNum);
/* get the number of items in the data array */
int GetSize() const;
......@@ -328,6 +341,10 @@ public:
/* dump data to a file */
void Dump(FILE * file, const char * label = NULL, const int n = -1, const int verbose = 0);
/* dump data to a file */
static
void Dump(const XTensor * tensor, FILE * file, const char * label = NULL, const int n = -1, const int verbose = 0);
/* read data from a file */
void Read(FILE * file, const char * label = NULL);
......@@ -386,8 +403,12 @@ XTensor * NewTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_
const float myDenseRatio = 1.0F, const int myDevID = -1, XMem * myMem = NULL);
/* generate a XTensor which allocates data on the buffer */
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, XMem * myMem,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const float myDenseRatio = 1.0F);
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const float myDenseRatio = 1.0F,
const int myDevID = -1, XMem * myMem = NULL);
/* generate a XTensor which allocates data on the buffer */
XTensor * NewTensorBuf(const XTensor * reference, int devID, XMem * myMem);
/* generate a dense vector */
XTensor * NewTensor1D(const int num, const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1,
......@@ -417,10 +438,10 @@ XTensor * NewTensor5D(const int d0, const int d1, const int d2, const int d3, co
XTensor * NewTensor(XTensor * a, bool isFilledData = true);
/* free the data space of a given tensor */
void DelTensor(const XTensor * tensor);
void DelTensor(XTensor * tensor);
/* free the data space of a given tensor (on the buffer) */
void DelTensorBuf(const XTensor * tensor);
void DelTensorBuf(XTensor * tensor);
} /* end of the nts (NiuTrans.Tensor) namespace */
......
......@@ -175,29 +175,38 @@ void XMemCopy(void * t, int devIDT, const void * s, int devIDS, size_t size)
return;
}
#ifdef USE_CUDA
else if(devIDT >= 0 && devIDS < 0){
cudaError_t error = cudaMemcpy(t, s, size, cudaMemcpyHostToDevice);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyHostToDevice)");
}
}
else if(devIDT < 0 && devIDS >= 0){
cudaError_t error = cudaMemcpy(t, s, size, cudaMemcpyDeviceToHost);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToHost)");
}
}
else{
//if(devIDT == devIDS){
cudaError_t error = cudaMemcpy(t, s, size, cudaMemcpyDeviceToDevice);
int devID = devIDT < 0 ? devIDS : devIDT;
int devIDBackup = 0;
cudaGetDevice(&devIDBackup);
cudaSetDevice(devID);
if(devIDT >= 0 && devIDS < 0){
cudaError_t error = cudaMemcpy(t, s, size, cudaMemcpyHostToDevice);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToDevice)");
ShowNTErrors("cudaMemcpy error (cudaMemcpyHostToDevice)");
}
/*}
}
else if(devIDT < 0 && devIDS >= 0){
cudaError_t error = cudaMemcpy(t, s, size, cudaMemcpyDeviceToHost);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToHost)");
}
}
else{
CheckNTErrors((cudaMemcpyPeer(t, devIDT, s, devIDS, size) == cudaSuccess),
"cudaMemcpy error (cudaMemcpyDeviceToDevice)");
}*/
//if(devIDT == devIDS){
cudaError_t error = cudaMemcpy(t, s, size, cudaMemcpyDeviceToDevice);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToDevice)");
}
/*}
else{
CheckNTErrors((cudaMemcpyPeer(t, devIDT, s, devIDS, size) == cudaSuccess),
"cudaMemcpy error (cudaMemcpyDeviceToDevice)");
}*/
}
cudaSetDevice(devIDBackup);
}
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
......@@ -208,6 +217,9 @@ void XMemCopy(void * t, int devIDT, const void * s, int devIDS, size_t size)
#ifdef USE_CUDA
void XMemCopyAsync(void * t, int devIDT, const void * s, int devIDS, size_t size, cudaStream_t stream, int streamDevID)
{
if(t == s)
return;
int devIDBackup = -1;
if(streamDevID >= 0 && (devIDT >= 0 || devIDS >= 0)){
CheckNTErrors((cudaGetDevice(&devIDBackup) == cudaSuccess), "Cannot get GPU device id!");
......@@ -220,17 +232,23 @@ void XMemCopyAsync(void * t, int devIDT, const void * s, int devIDS, size_t size
return;
}
else if(devIDT >= 0 && devIDS < 0){
CheckNTErrors((cudaMemcpyAsync(t, s, size, cudaMemcpyHostToDevice, stream) == cudaSuccess),
"cudaMemcpyAsync error (cudaMemcpyHostToDevice)");
cudaError_t error = cudaMemcpyAsync(t, s, size, cudaMemcpyHostToDevice, stream);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpyAsync error (cudaMemcpyHostToDevice)");
}
}
else if(devIDT < 0 && devIDS >= 0){
CheckNTErrors((cudaMemcpyAsync(t, s, size, cudaMemcpyDeviceToHost, stream) == cudaSuccess),
"cudaMemcpyAsync error (cudaMemcpyDeviceToHost)");
cudaError_t error = cudaMemcpyAsync(t, s, size, cudaMemcpyDeviceToHost, stream);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpyAsync error (cudaMemcpyDeviceToHost)");
}
}
else{
//if(devIDT == devIDS){
CheckNTErrors((cudaMemcpyAsync(t, s, size, cudaMemcpyDeviceToDevice, stream) == cudaSuccess),
"cudaMemcpyAsync error (cudaMemcpyDeviceToDevice)");
cudaError_t error = cudaMemcpyAsync(t, s, size, cudaMemcpyDeviceToDevice, stream);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpyAsync error (cudaMemcpyDeviceToDevice)");
}
//}
/*else{
CheckNTErrors((cudaMemcpyPeerAsync(t, devIDT, s, devIDS, size, stream) == cudaSuccess),
......@@ -261,18 +279,69 @@ void XMemCopy2D(void * t, size_t tPitch, int devIDT, const void * s, size_t sPit
return;
}
#ifdef USE_CUDA
else if (devIDT >= 0 && devIDS < 0) {
CheckNTErrors((cudaMemcpy2D(t, tPitch, s, sPitch, mSize, n, cudaMemcpyHostToDevice) == cudaSuccess),
"cudaMemcpy2D error (cudaMemcpyHostToDevice)");
else{
int devID = devIDT < 0 ? devIDS : devIDT;
int devIDBackup = 0;
cudaGetDevice(&devIDBackup);
cudaSetDevice(devID);
if (devIDT >= 0 && devIDS < 0) {
cudaError_t error = cudaMemcpy2D(t, tPitch, s, sPitch, mSize, n, cudaMemcpyHostToDevice);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy2D error (cudaMemcpyHostToDevice)");
}
}
else if (devIDT < 0 && devIDS >= 0) {
cudaError_t error = cudaMemcpy2D(t, tPitch, s, sPitch, mSize, n, cudaMemcpyDeviceToHost);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToHost)");
}
}
else {
cudaError_t error = cudaMemcpy2D(t, tPitch, s, sPitch, mSize, n, cudaMemcpyDeviceToDevice);
if (error != cudaSuccess) {
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToDevice)");
}
}
cudaSetDevice(devIDBackup);
}
else if (devIDT < 0 && devIDS >= 0) {
CheckNTErrors((cudaMemcpy2D(t, tPitch, s, sPitch, mSize, n, cudaMemcpyDeviceToHost) == cudaSuccess),
"cudaMemcpy error (cudaMemcpyDeviceToHost)");
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
}
void XMemCopy2DAsync(void * t, size_t tPitch, int devIDT, const void * s, size_t sPitch, int devIDS, size_t mSize, int n, XStream * stream)
{
if (t == s)
return;
if (devIDT < 0 && devIDS < 0) {
for(int i = 0; i < n; i++)
memcpy((char*)t + tPitch * i, (char*)s + sPitch * i, mSize);
return;
}
else {
cudaError_t error = cudaMemcpy2D(t, tPitch, s, sPitch, mSize, n, cudaMemcpyDeviceToDevice);
if (error != cudaSuccess) {
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToDevice)");
#ifdef USE_CUDA
else{
CheckNTErrors(stream != NULL, "No stream found!");
cudaStream_t &cstream = stream->stream;
if (devIDT >= 0 && devIDS < 0) {
cudaError_t error = cudaMemcpy2DAsync(t, tPitch, s, sPitch, mSize, n, cudaMemcpyHostToDevice, cstream);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy2D error (cudaMemcpyHostToDevice)");
}
}
else if (devIDT < 0 && devIDS >= 0) {
cudaError_t error = cudaMemcpy2DAsync(t, tPitch, s, sPitch, mSize, n, cudaMemcpyDeviceToHost, cstream);
if(error != cudaSuccess){
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToHost)");
}
}
else {
cudaError_t error = cudaMemcpy2DAsync(t, tPitch, s, sPitch, mSize, n, cudaMemcpyDeviceToDevice, cstream);
if (error != cudaSuccess) {
ShowNTErrors("cudaMemcpy error (cudaMemcpyDeviceToDevice)");
}
}
}
#else
......
......@@ -23,6 +23,7 @@
#include <stdio.h>
#include "XGlobal.h"
#include "XDevice.h"
#ifndef __XUTILITY_H__
#define __XUTILITY_H__
......@@ -41,6 +42,7 @@ extern void XMemSet(void * p, int value, size_t size);
extern void XMemSet(int devID, void * p, int value, size_t size);
extern void XMemCopy(void * t, int devIDT, const void * s, int devIDS, size_t size);
extern void XMemCopy2D(void * t, size_t tPitch, int devIDT, const void * s, size_t sPitch, int devIDS, size_t mSize, int n);
extern void XMemCopy2DAsync(void * t, size_t tPitch, int devIDT, const void * s, size_t sPitch, int devIDS, size_t mSize, int n, XStream * stream);
extern void * XMemAlloc(int devID, size_t size);
extern void * XMemAllocOnDev(int devID, size_t size);
extern void XMemFree(int devID, void * p);
......
......@@ -26,49 +26,63 @@
#include "../XTensor.h"
#include "shape/Concatenate.h"
#include "shape/ConcatenateSolely.h"
#include "movement/CopyBlocks.h"
#include "movement/CopyBlocksInGrid.h"
#include "movement/CopyBlocksOnSite.h"
#include "movement/CopyData2D.h"
#include "movement/CopyIndexed.h"
#include "movement/CopyInGrid.h"
#include "movement/CopyValues.h"
#include "utilities/FlushToMem.h"
#include "shape/MakeMergeBlockIndex.h"
#include "shape/MakeSplitBlockIndex.h"
#include "arithmetic/Div.h"
#include "arithmetic/MatrixMul.h"
#include "arithmetic/MatrixMul2D.h"
#include "arithmetic/MatrixMul2DMultiTheading.h"
#include "arithmetic/MatrixMul2DParallel.h"
#include "arithmetic/MatrixMulBatched.h"
#include "arithmetic/MatrixMULBatchedCPU.h"
#include "shape/Merge.h"
#include "shape/MergeBlockLists.h"
#include "arithmetic/Multiply.h"
#include "arithmetic/Negate.h"
#include "arithmetic/Sign.h"
#include "arithmetic/Sub.h"
#include "arithmetic/Sum.h"
#include "arithmetic/SumByColumnTV.h"
#include "arithmetic/SumByColumnVT.h"
#include "arithmetic/SumDim.h"
#include "arithmetic/XTensorBLAS.h"
#include "getandset/ConvertDataType.h"
#include "getandset/Select.h"
#include "getandset/SetData.h"
#include "math/Clip.h"
#include "math/Normalize.h"
#include "shape/Permute.h"
#include "math/Power.h"
#include "math/ScaleAndShift.h"
#include "math/Unary.h"
#include "movement/CopyBlocks.h"
#include "movement/CopyBlocksInGrid.h"
#include "movement/CopyBlocksOnSite.h"
#include "movement/CopyData2D.h"
#include "movement/CopyIndexed.h"
#include "movement/CopyInGrid.h"
#include "movement/CopyValues.h"
#include "reduce/ReduceMax.h"
#include "reduce/ReduceMean.h"
#include "reduce/ReduceStandardVariance.h"
#include "reduce/ReduceSum.h"
#include "reduce/ReduceSumSquared.h"
#include "reduce/ReduceVariance.h"
#include "math/ScaleAndShift.h"
#include "getandset/Select.h"
#include "getandset/SetData.h"
#include "sort/Sort.h"
#include "shape/Concatenate.h"
#include "shape/ConcatenateSolely.h"
#include "shape/MakeMergeBlockIndex.h"
#include "shape/MakeSplitBlockIndex.h"
#include "shape/Merge.h"
#include "shape/MergeBlockLists.h"
#include "shape/Permute.h"
#include "shape/Split.h"
#include "arithmetic/Sum.h"
#include "arithmetic/SumByColumnTV.h"
#include "arithmetic/SumByColumnVT.h"
#include "sort/TopK.h"
#include "shape/Transpose.h"
#include "shape/Unsqueeze.h"
#include "sort/Sort.h"
#include "sort/TopK.h"
#include "utilities/XMatrixSegment.h"
#include "arithmetic/XTensorBLAS.h"
#include "utilities/FlushToMem.h"
#endif // __CHEADER_H__
\ No newline at end of file
#endif // __CHEADER_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "Div.h"
#include "Div.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
element-wise division of two tensors
c(i) = a(i)/b(i) + \alpha * c(i)
where i is the index of the item
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
*/
void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim)
{
int leadingDimRDI = a->order - leadingDim - 1;
CheckNTErrors((a->unitNum <= c->unitNum && b->unitNum <= c->unitNum),
"Unmatched tensors in multiplication!");
CheckNTErrors((a->order == b->order && a->order == c->order),
"Unmatched tensors!");
#ifdef USE_CUDA
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
_CudaDiv(a, b, c, alpha, leadingDim);
return;
}
#endif
int stride = 1;
int blockSizeA = 1;
int blockSizeB = 1;
int blockSizeC = 1;
int blockNum = 1;
int dimensionSizeA = a->dimSizeRDI[leadingDimRDI];
int dimensionSizeB = b->dimSizeRDI[leadingDimRDI];
int dimensionSizeC = c->dimSizeRDI[leadingDimRDI];
for (int i = 0; i < a->order; i++) {
if (i != leadingDimRDI) {
CheckNTErrors((a->dimSizeRDI[i] == b->dimSizeRDI[i] && a->dimSizeRDI[i] == c->dimSizeRDI[i]),
"Unmatched tensors!");
}
if (i < leadingDimRDI)
stride *= a->dimSizeRDI[i];
}
blockSizeA = stride * dimensionSizeA;
blockSizeB = stride * dimensionSizeB;
blockSizeC = stride * dimensionSizeC;
blockNum = a->unitNum / blockSizeA;
if (!a->isSparse && !b->isSparse) {
if (a->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE) {
if (a->unitNum == c->unitNum && b->unitNum == c->unitNum) {
int size = a->unitNum;
DTYPE * ap = (DTYPE*)a->data;
DTYPE * bp = (DTYPE*)b->data;
DTYPE * cp = (DTYPE*)c->data;
if (alpha == 0) {
for (int i = 0; i < size; i++)
cp[i] = ap[i] / bp[i];
}
else {
for (int i = 0; i < size; i++)
cp[i] = ap[i] / bp[i] + alpha * cp[i];
}
}
else {
for (int k = 0; k < blockNum; k++) {
for (int ci = 0, ai = 0, bi = 0; ci < dimensionSizeC; ci++, ai++, bi++) {
if (ai >= dimensionSizeA)
ai = 0;
if (bi >= dimensionSizeB)
bi = 0;
DTYPE * ap = (DTYPE*)a->data + k * blockSizeA + ai * stride;
DTYPE * bp = (DTYPE*)b->data + k * blockSizeB + bi * stride;
DTYPE * cp = (DTYPE*)c->data + k * blockSizeC + ci * stride;
for (int j = 0; j < stride; j++)
cp[j] = ap[j] / bp[j] + cp[j] * alpha;
}
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
/*
element-wise division of two tensors (do it on site)
keep the result in the input tensor a and return nothing
a(i) = a(i)*b(i) + \alpha * a(i)
where i is the index of the item
>> a - tensor a (where keep the result)
>> b - tensor b
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
*/
void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim)
{
_Div(a, b, a, alpha, leadingDim);
}
/*
element-wise division of two tensors (return a XTensor structure)
make a new tensor c to keep the result and return it
c(i) = a(i)*b(i)
where i is the index of the item
>> a - tensor a
>> b - tensor b
>> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors
*/
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim)
{
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XTensor c(&a);
c.SetTMP();
/* call _Multiply function */
_Div(&a, &b, &c, 0, leadingDim);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHeadInt(&c, leadingDim);
return c;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "Div.h"
#include "Div.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
division of data arrays in a element-wise manner c(i) = a(i)/b(i)
>> a - data array a
>> b - data array b
>> c - result data array
>> size - size of c
*/
__global__
void KernelDivElementWise(DTYPE * a, DTYPE * b, DTYPE * c, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
c[i] = a[i] / b[i];
}
/*
division of data arrays in a element-wise manner c(i) = a(i)/b(i) + \alpha*c(i)
>> a - data array a
>> b - data array b
>> c - result data array
>> size - size of c
>> alpha - the coefficient
*/
__global__
void KernelDivElementWiseV2(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE alpha)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
c[i] = a[i] / b[i] + alpha * c[i];
}
/*
division of two tensors in a element-wise manner c(i) = a(i)/b(i).
Note that a and b can be of different sizes here, i.e.,
|a_lead| <= |c_lead| and |b_lead| <= |c_lead|
where |a_lead| means the size of the leading dimension of a
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> stride - the number of items we go over when move next along the leading dimension in a block
>> ldSizeA - size of the leading dimension of a
>> ldSizeB - size of the leading dimension of b
>> ldSizeC - size of the leading dimension of c
>> blockNum - number of blocks
*/
template<int nonZeroAlpha> __global__
void KernelDivElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE alpha,
int stride, int ldSizeA, int ldSizeB, int ldSizeC, int blockNum)
{
__shared__ DTYPE* ap[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE* bp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE* cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int i = blockDim.x * blockIdx.x + threadIdx.x;
int j = blockDim.y * blockIdx.y + threadIdx.y;
if (i >= blockNum * stride || j >= ldSizeC)
return;
if (threadIdx.y == 0) {
int block = i / stride;
int size = block * stride;
ap[threadIdx.x] = a + size * ldSizeA;
bp[threadIdx.x] = b + size * ldSizeB;
cp[threadIdx.x] = c + size * ldSizeC;
}
__syncthreads();
int aj = j >= ldSizeA ? j % ldSizeA : j;
int bj = j >= ldSizeB ? j % ldSizeB : j;
int offseti = i % stride;
if (nonZeroAlpha == 0)
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] / bp[threadIdx.x][bj * ldSizeB + offseti];
else
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] / bp[threadIdx.x][bj * ldSizeB + offseti]
+ alpha * cp[threadIdx.x][j * ldSizeC + offseti];
}
/*
element-wise division of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the item index
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> leadingDim - dimension along which we perform broadcasting
*/
void _CudaDiv(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim)
{
int leadingDimRDI = a->order - leadingDim - 1;
CheckNTErrors((a->unitNum <= c->unitNum && b->unitNum <= c->unitNum),
"Unmatched tensors in multiplication!");
CheckNTErrors((a->order == b->order && a->order == c->order), "Unmatched tensors!");
int stride = 1;
int blockSizeA = 1;
int blockNum = 1;
int dimensionSizeA = a->dimSizeRDI[leadingDimRDI];
int dimensionSizeB = b->dimSizeRDI[leadingDimRDI];
int dimensionSizeC = c->dimSizeRDI[leadingDimRDI];
for (int i = 0; i < a->order; i++) {
if (i != leadingDimRDI) {
CheckNTErrors((a->dimSizeRDI[i] == b->dimSizeRDI[i] &&
a->dimSizeRDI[i] == c->dimSizeRDI[i]),
"Unmatched tensors!");
}
if (i < leadingDimRDI)
stride *= a->dimSizeRDI[i];
}
blockSizeA = stride * dimensionSizeA;
blockNum = a->unitNum / blockSizeA;
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
if (!a->isSparse && !b->isSparse) {
if (a->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE) {
int cudaGridSize[3];
int cudaBlockSize[3];
if (a->unitNum == c->unitNum && b->unitNum == c->unitNum) {
GDevs.GetCudaThread(a->devID, c->unitNum, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[0]), threads(cudaBlockSize[0]);
if (alpha == 0)
KernelDivElementWise << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, c->unitNum);
else
KernelDivElementWiseV2 << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, c->unitNum, alpha);
}
else {
GDevs.GetCudaThread2D(c->devID, stride * blockNum, dimensionSizeC, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[0], cudaGridSize[1]), threads(cudaBlockSize[0], cudaBlockSize[1]);
if (alpha == 0) {
KernelDivElementWiseTensorDynamic<0> << <blocks, threads >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, 0,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
}
else {
KernelDivElementWiseTensorDynamic<1> << <blocks, threads >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, alpha,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#ifndef __DIV_CUH__
#define __DIV_CUH__
#include "Div.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* division of two tensors in a element-wise manner c(i) = a(i)/b(i) */
__global__
void KernelDivElementWise(DTYPE * a, DTYPE * b, DTYPE * c, int size);
/* division of two tensors in a element-wise manner c(i) = a(i)/b(i) + \alpha*c(i) */
__global__
void KernelDivElementWiseV2(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE alpha);
/* division of two tensors in a element-wise manner c(i) = a(i)/b(i)+ \alpha*c(i) */
template<int nonZeroAlpha>__global__
void KernelDivElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE alpha, int stride, int ldSizeA, int ldSizeB, int ldSizeC, int blockNum);
/* element-wise division of two tensors */
void _CudaDiv(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __DIV_CUH__
......@@ -16,31 +16,39 @@
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#ifndef __LOG_H__
#define __LOG_H__
#ifndef __DIV_H__
#define __DIV_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its log value */
void _Log(const XTensor * a, XTensor * b);
/*
element-wise division of two tensors:
c(i) = a(i)/b(i) + \alpha * c(i)
where i is the index of the element
*/
void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
/*
set every entry to its log value (do it on site)
element-wise division of two tensors (do it on site)
keep the result in the input tensor a and return nothing
a(i) = a(i)/b(i) + \alpha * a(i)
where i is the index of the element
*/
void _LogMe(XTensor * a);
void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha = 0, int leadingDim = 0);
/*
set every entry to its log value (return a XTensor structure)
element-wise division of two tensors (return a XTensor structure)
make a new tensor to keep the result and return it
c(i) = a(i)/b(i)
where i is the index of the element
*/
XTensor Log(const XTensor & a);
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim = 0);
} // namespace nts(NiuTrans.Tensor)
#endif // __LOG_H__
#endif // __DIV_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XTensor.h"
#include "MatrixMULBatchedCPU.h"
#include "MatrixMul2D.h"
#include "XTensorBLAS.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
matrix multiplication in batch mode (BLAS)
c_i = trans(a_i) * trans(b_i) * \alpha + c_i * \beta for each i in [0,count-1]
>> a - list of input matrices (2d tensors)
>> transposedA - indicate whether the matrix a is transposed
>> b - another list of input matrices (2d tensors)
>> transposedB - indicate whether the matrix b is transposed
>> c - output matrix (2d tensor)
>> alpha - scalar
>> beta - scalar
*/
void _MatrixMULBatchedCPU(const XList * a, MATRIX_TRANS_TYPE transposedA,
const XList * b, MATRIX_TRANS_TYPE transposedB,
XList * c, DTYPE alpha, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty input lists!");
CheckNTErrors(a->count == b->count && a->count == c->count, "Input lists must be of the same size!");
if (a->count == 0)
return;
bool isUniform = true;
for (int i = 1; i < a->count; i++) {
XTensor * aim = (XTensor*)a->GetItem(i - 1);
XTensor * bim = (XTensor*)b->GetItem(i - 1);
XTensor * cim = (XTensor*)c->GetItem(i - 1);
XTensor * ai = (XTensor*)a->GetItem(i);
XTensor * bi = (XTensor*)b->GetItem(i);
XTensor * ci = (XTensor*)c->GetItem(i);
if (!XTensor::IsSameShaped(aim, ai) ||
!XTensor::IsSameShaped(bim, bi) ||
!XTensor::IsSameShaped(cim, ci))
{
isUniform = false;
break;
}
}
for (int i = 0; i < a->count; i++) {
XTensor * ai = (XTensor*)a->GetItem(i);
XTensor * bi = (XTensor*)b->GetItem(i);
XTensor * ci = (XTensor*)c->GetItem(i);
CheckNTErrors((ai->order == 2), "2d tensor (i.e., matrix) is required!");
CheckNTErrors((bi->order == 2), "2d tensor (i.e., matrix) is required!");
CheckNTErrors((ci->order == 2), "2d tensor (i.e., matrix) is required!");
#ifdef USE_BLAS
if (useBLAS)
_MatrixMULCPU(ai, transposedA, bi, transposedB, ci, alpha, beta);
else
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#else
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#endif
}
//}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -24,8 +24,8 @@
#include "../../XName.h"
#include "MatrixMul.h"
#include "MatrixMul2D.h"
#include "MatrixMULBatchedCPU.h"
#include "XTensorBLAS.h"
#include "MatrixMulBatched.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -53,11 +53,29 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
CheckNTErrors(a && b && c, "Empty input tensors!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Input tensors should have the same data type!");
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2),
CheckNTErrors(a->order >= 2 && b->order >= 2 && c->order >= 2,
"Input tensors must have a order >= 2!");
CheckNTErrors(c->order == a->order + b->order - 2, "wrong tensor order")
/* we transform a higher order tensor to a matrix to kill the number
of calls of matrix multiplication */
if(transposedA == X_NOTRANS && a->order > 2 && b->order == 2){
int ncolA = a->dimSize[a->order - 1];
int ncolC = c->dimSize[c->order - 1];
XTensor * a2 = NewTensor2D(a->unitNum/ncolA, -ncolA, a->dataType, a->devID, a->mem);
XTensor * c2 = NewTensor2D(c->unitNum/ncolC, -ncolC, c->dataType, c->devID, c->mem);
a2->data = a->data;
c2->data = c->data;
_MatrixMul2D(a2, transposedA, b, transposedB, c2, alpha, beta, parallelRunner);
a2->data = NULL;
c2->data = NULL;
delete a2;
delete c2;
return;
}
int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1];
int am = transposedA == X_TRANS ? a->dimSizeRDI[1] : a->dimSizeRDI[0];
......@@ -144,10 +162,10 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
cublasHandle_t * handle = a->mem != NULL ? a->mem->GetCublasHandle() : GDevs.GetCudaHandle(a->devID);
_CudaBLASMatrixMULList(handle,
aList, transposedA,
bList, transposedB,
cList, aList->count,
alpha, beta);
aList, transposedA,
bList, transposedB,
cList, aList->count,
alpha, beta);
BacktoCudaDev(a->devID, devIDBackup);
#else
......@@ -156,9 +174,9 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
}
else {
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
_MatrixMULBatchedCPU(aList, transposedA,
bList, transposedB,
cList, alpha, beta);
_MatrixMulBatchedCPU(aList, transposedA,
bList, transposedB,
cList, alpha, beta);
}
for (int i = 0; i < aList->count; i++) {
......@@ -251,9 +269,7 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
/*
matrix multiplication with no transposition c = a * b * alpha
>> a - tensor a
>> transposedA - indicates whether the matrices in a are transposed
>> b - tensor b
>> transposedB - indicates whether teh matrices in b are transposed
>> alpha - a coefficient
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
......
......@@ -26,6 +26,8 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
#define BMMul MatrixMulBatched
/*
matrix multiplication of the two tensors c = trans(a) * trans(b) * alpha + c * beta
......@@ -37,6 +39,28 @@ where trans() returns the transposed matrix if the flag is fired
void _MatrixMulBatched(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0, XPRunner * parallelRunner = NULL);
/*
matrix multiplication of the two tensors c = trans(a) * trans(b) * alpha + c * beta
optimized for GPU
*/
void _MatrixMulBatchedGPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0);
/*
matrix multiplication of the two tensors c = trans(a) * trans(b) * alpha + c * beta
optimized for GPU
*/
void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0);
/*
matrix multiplication of the two tensors c = trans(a) * trans(b) * alpha + c * beta (for list inputs)
optimized for GPU
*/
void _MatrixMulBatchedCPU(const XList * a, MATRIX_TRANS_TYPE transposedA, const XList * b, MATRIX_TRANS_TYPE transposedB,
XList * c, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0);
/*
matrix multiplication of the two tensors (return a XTensor structure) c = trans(a) * trans(b) * alpha
make a new tensor to keep the result and return it
......@@ -49,6 +73,17 @@ where trans() returns the transposed matrix if the flag is fired
XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
/*
matrix multiplication of the two tensors (return a XTensor structure) c = a * b * alpha
make a new tensor to keep the result and return it
for each 2-dimensional data array in a (denoted as ai) and
each 2-dimensional data array in b (denoted as bi), we have
ci = ai * bi * alpha + cm * beta
*/
XTensor MatrixMulBatched(const XTensor &a, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor)
#endif // __MATRIXMULBATCHED_H__
\ No newline at end of file
......@@ -32,9 +32,9 @@ element-wise product of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the item
>> a - matrix a
>> b - matrix b
>> c - result matrix
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
*/
......
......@@ -104,9 +104,9 @@ void KernelMulElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE al
int offseti = i % stride;
if (nonZeroAlpha == 0)
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj* ldSizeA + offseti] * bp[threadIdx.x][bj* ldSizeB + offseti];
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] * bp[threadIdx.x][bj * ldSizeB + offseti];
else
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj* ldSizeA + offseti] * bp[threadIdx.x][bj* ldSizeB + offseti] +
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] * bp[threadIdx.x][bj * ldSizeB + offseti] +
alpha * cp[threadIdx.x][j * ldSizeC + offseti];
}
......
......@@ -76,7 +76,7 @@ XTensor Sign(const XTensor & a)
XTensor b(&a);
b.SetTMP();
/* call _ScaleAndShift function */
/* call _Sign function */
_Sign(&a, &b);
/* tensor connections */
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "Sub.h"
#include "Sub.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor subtraction c = a - b * \beta
>> a - a tensor
>> b - another tensor
>> c - where we put a-b*\beta. we save it in a if c is NULL
>> beta - the scaling factor
*/
void _Sub(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == b->unitNum && a->unitNum == c->unitNum,
"Unmatched tensors in addition!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched tensors in addition!");
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
if (a == c) {
int P2PAccesible = 0;
#ifdef CUDA_UVA
cudaDeviceCanAccessPeer(&P2PAccesible, a->devID, b->devID);
#endif
if ((a->devID < 0 && b->devID >= 0) ||
(a->devID >= 0 && b->devID < 0) ||
(a->devID >= 0 && b->devID >= 0 && a->devID != b->devID && !P2PAccesible))
{
ShowNTErrors("Cannot run this method on multiple devices simultaneously!");
}
else
_CudaSub(a, b, c, beta);
}
else
_CudaSub(a, b, c, beta);
#endif
}
else {
if (!a->isSparse && !b->isSparse) {
CheckNTErrors(!c->isSparse, "Illegal use of sparse tensor in addition!");
if (a->dataType == DEFAULT_DTYPE &&
b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE)
{
DTYPE * ap = (DTYPE*)a->data;
DTYPE * bp = (DTYPE*)b->data;
DTYPE * cp = (DTYPE*)c->data;
/* unrolling */
int num = a->unitNum;
if (num % 4 == 0) {
for (int i = 0; i < num; i += 4) {
cp[i] = ap[i] - bp[i] * beta;
cp[i + 1] = ap[i + 1] - bp[i + 1] * beta;
cp[i + 2] = ap[i + 2] - bp[i + 2] * beta;
cp[i + 3] = ap[i + 3] - bp[i + 3] * beta;
}
}
else if (num % 2 == 0) {
for (int i = 0; i < num; i += 2) {
cp[i] = ap[i] - bp[i] * beta;
cp[i + 1] = ap[i + 1] - bp[i + 1] * beta;
}
}
else {
for (int i = 0; i < num; i++) {
cp[i] = ap[i] - bp[i] * beta;
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
}
/*
tensor subtraction a = a - b * \beta (do it on site)
keep the result in the tensor a and return nothing
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
*/
void _SubMe(XTensor * a, const XTensor * b, DTYPE beta)
{
_Sub(a, b, a, beta);
}
/*
tensor subtraction c = a - b * \beta (return a XTensor structure)
make a new tensor c to keep the result and return it
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
<< return - the result of tensor subtraction
*/
XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
/* call _Sub function */
_Sub(&a, &b, &c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
return c;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#include "../../XDevice.h"
#include "../../XUtility.h"
#include "Sub.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
subtraction of data arrays (CUDA Kernel)
c = a - b * \beta
>> a - A matrix
>> b - another matrix
>> c - where we put a-b
>> size - the size of a/b/c
>> beta - the coefficient
*/
__global__
void KernelSUB(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
c[i] = a[i] - b[i] * beta;
}
/*
tensor subtraction c = a - b * \beta (cuda version)
>> a - a tensor
>> b - another tensor
>> c - where we put a-b*\beta.
>> beta - the scaling factor
*/
void _CudaSub(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors((a->unitNum == b->unitNum && a->unitNum == c->unitNum),
"Unmatched tensors in addition!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
"Unmatched tensors in addition!");
CheckNTErrors((a->devID == b->devID && a->devID == c->devID),
"The tensors must be on the same!");
int devIDBackup = XDevice::GetGPUDevice();
XDevice::SetGPUDevice(a->devID);
if (!a->isSparse && !b->isSparse) {
CheckNTErrors(!c->isSparse, "Illegal use of sparse matrix in addition!");
if (a->dataType == DEFAULT_DTYPE &&
b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE)
{
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
KernelSUB << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, a->unitNum, beta);
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
XDevice::SetGPUDevice(devIDBackup);
}
/* subtraction over arrays
tensor subtraction c = a - b * \beta (cuda version) with an input handle
>> devID - device ID (MUST >= 0)
>> handle - cuda handle
>> a - an array
>> b - another array
>> c - where we put a-b
>> size - size of the array
>> beta - the coefficient
*/
void _CudaSubWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta)
{
if (size == 0)
return;
if (c == NULL)
c = a;
CheckNTErrors((a && b && c), "Empty arrays in addition!");
int devIDBackup;
ProtectCudaDev(devID, devIDBackup);
if (c == a) {
#ifdef DOUBELPRICSION
cublasDaxpy(*handle, size, &beta, b, 1, a, 1);
#else
cublasSaxpy(*handle, size, &beta, b, 1, a, 1);
#endif
}
else {
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(devID, size, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
KernelSUB<<<blocks, threads>>>((DTYPE*)a, (DTYPE*)b, (DTYPE*)c, size, beta);
}
BacktoCudaDev(devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论