Commit d294ac15 by xuchen

NiuTrans.Tensor - version 0.1.0!

parent d664c0a0
NiuTrans.Tensor张量计算库
\ No newline at end of file
# NiuTrans.Tensor张量计算库
## NiuTrans.Tensor
NiuTrans.Tensor是小牛开源项目所开发的一个工具包,提供了完整的张量定义及计算功能,可以被用于深度学习相关研究及工业系统的开发。NiuTrans.Tensor具有以下特点:
* 简单小巧,易于修改
* c语言编写,代码高度优化
* 同时支持CPU和GPU设备
* 丰富的张量计算接口
* 支持C/C++、Python等调用方式
## 安装方法
在开始创建您的项目并使用NiuTrans.Tensor工具包时,需要注意的是:
* 所创建项目如在CPU上运行,我们的系统支持高性能的数学运算库,推荐安装[MKL](https://software.intel.com/en-us/mkl)[OpenBLAS](http://www.openblas.net/)
* 所创建项目如需在GPU上运行,需安装 [CUDA](https://developer.nvidia.com/cuda-downloads),CUDA版本需求为9.0及以上,CUDA工具为创建高性能GPU加速应用程序提供了开发环境。
小牛开源项目所开发的NiuTrans.Tensor工具包采用源程序编译方法,在Windows和Linux环境下的安装方法如下所示。
### Windows
若在Windows上使用NiuTrans.Tensor工具包:
* 首先需要将NiuTrans.Tensor代码包含在所创建的项目中
* 在所创建项目中需要引用XTensor.h、core里的CHeader.h和function里的FHeader.h这三个头文件:
* 通过XTensor.h可以获取我们需要操作的XTensor类
* 通过core里的CHeader.h可以对Tensor进行一些张量运算
* 通过function里的FHeader.h可以调用一些激活函数
* 在所创建项目中使用命名空间nts
此外,一些必须的环境配置方法请参考 [NiuTrans.Tensor环境配置](http://47.105.50.196/NiuTrans/NiuTrans.Tensor/blob/linye/doc/Configuration.md)
### Linux
若在Linux上使用NiuTrans.Tensor工具包,直接执行make.sh即可在同级目录下生成tensorCPU和tensorGPU,分别对应于NiuTrans.Tensor的CPU以及GPU的可执行文件。以前馈神经网络语言模型为例,输入以下命令即可在GPU上执行提供的测试用例:
>./tensorGPU -test
更多详细使用方法请见[NiuTrans.Tensor开发文档](http://47.104.97.237/niutrans/site/niutensor/index.html)
## 开发团队
NiuTrans.Tensor张量计算库由东北大学自然语言处理实验室、小牛翻译、小牛雅智合作开发,致力于为深度学习相关研究及工业系统的开发提供完整的张量定义及计算功能。
## 更新版本
NiuTrans.Tensor version 0.1.0 - 2018年8月3日
\ No newline at end of file
# NiuTrans.Tensor环境配置
## 注意事项
CUDA最新版本9.2尚且不支持VS2017最新版本,因此建议使用CUDA版本为9.0或9.1,建议使用VS版本为VS2015,或使用VS2017时安装v140工具集。
## CUDA配置
在已安装好VS、CUDA并配置好环境变量后,一些关键的CUDA配置选项如下所示,以下配置选项在 **项目 -> 属性** 中可以找到。
>$(CUDA_PATH)\include
加入到 **VC++目录 -> 包含** 中。
>$(CUDA_PATH)\lib\Win32
加入到 **VC++目录 -> 库** 中。
>cuda.lib;cudadevrt.lib;cudart.lib;cudart_static.lib;nvcuvid.lib;OpenCL.lib;cublas.lib;curand.lib;
加入到 **链接器->输入->附加依赖项** 中。
配置完成后,右键 **工程->项目依赖性** ,选择CUDA9。
在.cu文件上右键属性,在项类型中选择"CUDA C/C++"(最好搜索.cu文件,然后全选设置)。
## 其他配置
**C/C++->常规->SDL检查**,设为否。
**C/C++->预处理器->预处理器定义** 中,添加
>USE_CUDA;USE_BLAS;WIN32;MKL;DEBUG;CRT_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS_
CONSOLE;
**链接器->系统->子系统**,设置为控制台。
**常规->字符集**,使用Unicode字符集。
**调试->命令参数**中设置可执行文件所需要的参数。
......@@ -25,6 +25,7 @@
#include "../tensor/function/FHeader.h"
#include "../tensor/core/CHeader.h"
#include "../sample/fnnlm/FNNLM.h"
#include "../sample/transformer/Transformer.h"
//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>
......@@ -35,19 +36,16 @@ void SumDimTest();
using namespace nts;
using namespace fnnlm;
using namespace transformer;
int main( int argc, const char ** argv )
{
//TransposeTest();
//return 0;
//_CrtSetBreakAlloc(896);
//SumDimTest();
//return 0;
if(argc > 1 && !strcmp(argv[1], "-test"))
1;//Test();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
FNNLMMain(argc - 1, argv + 1);
else if(argc > 1 && !strcmp(argv[1], "-t2t"))
TransformerMain(argc - 1, argv + 1);
else{
fprintf(stderr, "Thanks for using NiuTrans.Network! This is a library for building\n");
fprintf(stderr, "neural networks in an easy way. \n\n");
......@@ -55,37 +53,6 @@ int main( int argc, const char ** argv )
fprintf(stderr, "Or run this program with \"-fnnlm\" for sample FNNLM!\n");
}
return 0;
XNet net;
XTensor a;
XTensor b;
XTensor c;
InitTensor2D(&a, 2, 2);
InitTensor2D(&b, 2, 4);
InitTensor2D(&c, 2, 4);
a.SetZeroAll();
b.SetZeroAll();
c.SetZeroAll();
SetDataFixed(a, 0.1F);
a.Set2D(0.3F, 1, 0);
a.Set2D(0.4F, 1, 1);
b = Merge(a, a, 1);
c = HTanH(MMul(a, b));
a.Dump(stderr, "a:");
b.Dump(stderr, "b:");
c.Dump(stderr, "c:");
XLink::ShowNetwork(stderr, &c);
net.Backward(c);
net.Dump(stderr);
//_CrtDumpMemoryLeaks();
return 0;
......
......@@ -49,13 +49,24 @@ private:
static
void GradSumDim(XTensor * node);
/* gradient for multiply (dot production): c = a * b */
/* gradient for multiply (dot production): c = a * b * \alpha */
static
void GradMultiply(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) */
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static
void GradMatrixMul(XTensor * node);
/* gradient for matrix multiply: c = matmul(a, b) * \alpha */
static
void GradMatrixMul(XTensor * a, XTensor * deda, MATRIX_TRANS_TYPE transA,
XTensor * b, XTensor * dedb, MATRIX_TRANS_TYPE transB,
XTensor * dedc, DTYPE alpha);
/* gradient for matrix multiply in batch mode.
for each batch: c_i = matmul(a_i, b_i) * \alpha */
static
void GradMatrixMulBatched(XTensor * node);
/* gradient for log: c = log(a) */
static
......@@ -124,6 +135,14 @@ private:
/* gradient for sign */
static
void GradSign(XTensor * node);
/* gradient for clip */
static
void GradClip(XTensor * node);
/* gradient for round */
static
void GradRound(XTensor * node);
};
}
......
......@@ -46,6 +46,11 @@ unsigned int MakeNetID()
return id;
}
void XNetClearAll()
{
MUTEX_DELE(netMutex);
}
/* constructor */
XNet::XNet()
{
......@@ -258,10 +263,11 @@ void XNet::TarjanVisit(XTensor * node, XList &orders, const unsigned int code)
if(node == NULL)
return;
//fprintf(stderr, "%d\n", node->id);
if(node->visitMark == code + 1){
ShowNTErrors("There is a circle in the network\n");
}
else if(node->visitMark <= code || node->visitMark >= code + 2){
else if(node->visitMark <= code){
node->visitMark = code + 1;
XLink &income = node->income;
for(int i = 0; i < income.tailNum; i++){
......
......@@ -95,6 +95,7 @@ struct XNet
extern unsigned int netIDGlobal;
extern MUTEX_HANDLE netMutex;
extern unsigned int MakeNetID();
extern void XNetClearAll();
}
......
......@@ -240,6 +240,7 @@ void Check(FNNModel &model)
{
CheckErrors(model.n > 0 && model.n <= MAX_N_GRAM, "The LM order is out of range (use -n)!");
CheckErrors(model.vSize > 0, "no vocabulary size found (use -vsize)!");
CheckErrors(model.eSize > 0, "no embedding size found (use -esize)!");
}
/* make a hard copy of the fnn model */
......@@ -580,7 +581,7 @@ void Update(FNNModel &model, FNNModel &grad, float epsilon, bool isNodeGrad)
get prediction probabilites of the gold words
>> output - output probabilities
>> gold - gold standard
>>
>> wordPobs - probability of each word
<< return - probability of the batch
*/
float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs)
......@@ -632,8 +633,10 @@ int LoadNGrams(FILE * file, int n, NGram * ngrams, int sentNum, int wordNum)
if(pin <= 0){
int len = (int)strlen(lineBuf);
if(lineBuf[len - 1] == '\r')
while(lineBuf[len - 1] == '\r' || lineBuf[len - 1] == '\n'){
lineBuf[len - 1] = 0;
len--;
}
len = (int)strlen(lineBuf);
if(len == 0)
......@@ -644,10 +647,11 @@ int LoadNGrams(FILE * file, int n, NGram * ngrams, int sentNum, int wordNum)
/* how many words are in the sentence */
int wNum = 0;
int i = 0;
for(int i = pin; i < len; i++){
for(i = pin; i < len; i++){
/* load word (id) seperated by space or tab */
if((lineBuf[i] == ' ' || lineBuf[i] == '\t' || i == len - 1) && wSize > 0){
if((lineBuf[i] == ' ' || lineBuf[i] == '\t') && wSize > 0){
lineBuf[i] = 0;
wordBuf[wNum++] = atoi(lineBuf + i - wSize);
wSize = 0;
......@@ -656,6 +660,9 @@ int LoadNGrams(FILE * file, int n, NGram * ngrams, int sentNum, int wordNum)
wSize++;
}
if(wSize > 0)
wordBuf[wNum++] = atoi(lineBuf + i - wSize);
wordBufCount = wNum;
lineNum++;
}
......
......@@ -22,6 +22,7 @@
#include <math.h>
#include "T2TAttention.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
......@@ -56,9 +57,9 @@ void T2TAttention::InitModel(int argc, const char ** argv, int myDevID, XMem * m
float minmax = 0;
LoadParamInt(argc, argv, "nhead", &nhead, 8);
LoadParamInt(argc, argv, "dk", &dk, 512);
LoadParamInt(argc, argv, "dv", &dv, 512);
LoadParamInt(argc, argv, "d", &d, 512);
LoadParamInt(argc, argv, "d", &dk, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &dv, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_BEDDING_SIZE);
LoadParamFloat(argc, argv, "attminmax", &minmax, 0.08F);
InitTensor2D(&wk, d, dk, X_FLOAT, devID, mem);
......@@ -79,16 +80,16 @@ make the network
>> v - values
<< return - multi-attention result
*/
XTensor * T2TAttention::Make(XTensor * k, XTensor * q, XTensor * v)
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v)
{
XTensor k2;
XTensor q2;
XTensor v2;
/* linear transofmration before self-attention */
k2 = MMul(*k, wk);
q2 = MMul(*q, wq);
v2 = MMul(*v, wv);
k2 = MMul(k, wk);
q2 = MMul(q, wq);
v2 = MMul(v, wv);
XTensor kheads;
XTensor qheads;
......@@ -104,14 +105,10 @@ XTensor * T2TAttention::Make(XTensor * k, XTensor * q, XTensor * v)
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */
scalar = Softmax(Linear(BMMul(qheads, X_NOTRANS, kheads, X_TRANS), 1/sqrt((float)dk)), -1);
att = MMul(scalar, vheads);
XTensor * result = new XTensor();
att = BMMul(scalar, vheads);
/* concatenate the heads */
*result = Merge(att, -1);
return result;
return Merge(att, att.order - 1);
}
}
......@@ -77,7 +77,7 @@ public:
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor * Make(XTensor * k, XTensor * q, XTensor * v);
XTensor Make(XTensor &k, XTensor &q, XTensor &v);
};
}
......
......@@ -57,7 +57,8 @@ void T2TEmbedder::InitModel(int argc, const char ** argv, int myDevID, XMem * my
LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamInt(argc, argv, "maxlen", &maxLength, 256);
LoadParamInt(argc, argv, "d", &d, 256);
LoadParamInt(argc, argv, "d", &eSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_BEDDING_SIZE);
InitTensor2D(&w, vSize, eSize, X_FLOAT, devID, mem);
......@@ -74,9 +75,9 @@ length - length of the sequenc
*/
void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
{
InitTensor2D(&posEmbedding, length, eSize, X_FLOAT, devID, mem);
InitTensor2D(&posEmbeddingBase, length, eSize, X_FLOAT, devID, mem);
float * data = new float[posEmbedding.unitNum];
float * data = new float[posEmbeddingBase.unitNum];
for(int pos = 0; pos < length; pos++){
float * dp = data + pos * eSize;
......@@ -92,7 +93,7 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
}
}
posEmbedding.SetData(data, posEmbedding.unitNum);
posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum);
delete[] data;
}
......@@ -100,20 +101,21 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
/*
make the network
*/
XTensor * T2TEmbedder::Make(XTensor * input)
XTensor T2TEmbedder::Make(XTensor &input)
{
CheckNTErrors(input->GetDim(-1) == vSize, "Wrong vocabulary size!");
CheckNTErrors(input->order > 1, "Wrong input tensor size!");
CheckNTErrors(input->dimSize[input->order - 2] < maxLength, "The sequence is too long!");
CheckNTErrors(input.GetDim(-1) == vSize, "Wrong vocabulary size!");
CheckNTErrors(input.order > 1, "Wrong input tensor size!");
CheckNTErrors(input.dimSize[input.order - 2] < maxLength, "The sequence is too long!");
CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\"");
CheckNTErrors(eSize > 0, "set embedding size by \"-esize\"");
int dims[MAX_TENSOR_DIM_NUM];
memcpy(dims, input->dimSize, input->order);
dims[0] = eSize;
memcpy(dims, input.dimSize, input.order * sizeof(int));
dims[input.order - 1] = eSize;
bool match = (posEmbedding.order == input->order);
bool match = (posEmbedding.order == input.order);
if(match){
for(int i = 0; i < input->order; i++){
for(int i = 0; i < input.order; i++){
if(dims[i] != posEmbedding.GetDim(i))
match = false;
}
......@@ -121,18 +123,11 @@ XTensor * T2TEmbedder::Make(XTensor * input)
/* we make positional embeddings first */
if(!match){
InitTensor(&posEmbedding, input->order, dims, X_FLOAT, 1.0F, devID, mem);
XTensor * posTMP = NewTensorBuf(2, dims, X_FLOAT, 1.0F, devID, mem);
_CopyValues(&posEmbeddingBase, 0, posTMP->unitNum, posTMP, 0);
int dims2[MAX_TENSOR_DIM_NUM];
dims2[0] = dims[0];
dims2[1] = dims[1];
dims2[2] = posEmbedding.unitNum / (dims[0] * dims[1]);
posEmbedding.Reshape(3, dims2);
InitTensor(&posEmbedding, input.order, dims, X_FLOAT, 1.0F, devID, mem);
XTensor * posTMP = NewTensorBuf(2, dims + 1, X_FLOAT, 1.0F, devID, mem);
_Unsqueeze(posTMP, &posEmbedding, 0, dims2[2]);
posEmbedding.Reshape(input->order, dims);
_CopyValues(&posEmbeddingBase, 0, posTMP->unitNum, posTMP, 0);
_Unsqueeze(posTMP, &posEmbedding, 0, dims[0]);
DelTensorBuf(posTMP);
}
......@@ -140,14 +135,10 @@ XTensor * T2TEmbedder::Make(XTensor * input)
XTensor wordEmbedding;
/* then we make word embeddings */
wordEmbedding = MMul(*input, w);
XTensor * result = new XTensor();
wordEmbedding = MMul(&input, w);
/* we sum over the two embeddings */
*result = wordEmbedding + posEmbedding;
return result;
return wordEmbedding + posEmbedding;
}
}
......@@ -29,6 +29,8 @@ using namespace nts;
namespace transformer
{
#define DEFAULT_BEDDING_SIZE 512
/*
embedding (of word at position i):
word embedding + positional embedding
......@@ -75,7 +77,7 @@ public:
void MakePosEmbedding(int eSize, int d, int length);
/* make the network */
XTensor * Make(XTensor * input);
XTensor Make(XTensor &input);
};
}
......
......@@ -82,26 +82,28 @@ make the encoding network
>> input - the input tensor of the encoder
<< return - the output tensor of the encoder
*/
XTensor * AttEncoder::Make(XTensor * input)
XTensor AttEncoder::Make(XTensor &input)
{
XTensor * x = embedder.Make(input);
XTensor x;
x = embedder.Make(input);
for(int i = 0; i < nlayer; i++){
XTensor * att;
XTensor * ln;
XTensor * fnn;
XTensor att;
XTensor ln;
XTensor fnn;
XTensor res;
/* self attention */
att = attentions[i].Make(x, x, x);
/* residual connection */
res = Sum(*att, *x);
res = Sum(att, x);
/* TODO: dropout */
/* layer normalization */
ln = layerNorms[i].Make(&res);
ln = layerNorms[i].Make(res);
/* input of next layer */
x = ln;
......@@ -110,12 +112,12 @@ XTensor * AttEncoder::Make(XTensor * input)
fnn = fnns[i].Make(x);
/* residual connection */
res = Sum(*fnn, *x);
res = Sum(fnn, x);
/* TODO: dropout */
/* layer normalization */
ln = layerNorms[i].Make(&res);
ln = layerNorms[i].Make(res);
/* input of next layer */
x = ln;
......
......@@ -40,7 +40,7 @@ class T2TEncoder
{
public:
virtual
XTensor * Make(XTensor * input) = 0;
XTensor Make(XTensor &input) = 0;
};
/*
......@@ -49,7 +49,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder
{
public:
XTensor * Make(XTensor * input);
XTensor Make(XTensor &input);
};
......@@ -106,7 +106,7 @@ public:
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */
XTensor * Make(XTensor * input);
XTensor Make(XTensor &input);
};
......
......@@ -21,6 +21,7 @@
#include "T2TFNN.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
#include "../../tensor/function/FHeader.h"
......@@ -54,9 +55,9 @@ void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
float minmax = 0;
LoadParamInt(argc, argv, "d", &inSize, 512);
LoadParamInt(argc, argv, "d", &outSize, 512);
LoadParamInt(argc, argv, "fnnh", &hSize, 512);
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &outSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, DEFAULT_BEDDING_SIZE);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.08F);
InitTensor2D(&w1, inSize, hSize, X_FLOAT, devID, mem);
......@@ -77,18 +78,15 @@ y = max(0, x * w1 + b1) * w2 + b2
>> input - the input tensor
>> return - the output tensor
*/
XTensor * T2TFNN::Make(XTensor * input)
XTensor T2TFNN::Make(XTensor &input)
{
XTensor t1;
XTensor * result = new XTensor();
/* t1 = max(0, x * w1 + b1) */
t1 = Rectify(MMul(*input, X_NOTRANS, w1, X_NOTRANS) + b1);
t1 = Rectify(MMul(input, X_NOTRANS, w1, X_NOTRANS) + b1);
/* result = t1 * w2 + b2 */
*result = MMul(t1, X_NOTRANS, w2, X_NOTRANS) + b2;
return result;
return MMul(t1, X_NOTRANS, w2, X_NOTRANS) + b2;
}
......
......@@ -72,7 +72,7 @@ public:
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor * Make(XTensor * input);
XTensor Make(XTensor &input);
};
......
......@@ -20,6 +20,7 @@
*/
#include "T2TLayerNormal.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
......@@ -56,9 +57,31 @@ y =
>> input - the input tensor
>> return - layer normalization output
*/
XTensor * T2TLN::Make(XTensor * input)
XTensor T2TLN::Make(XTensor &input)
{
return NULL;
XTensor &x = input;
XTensor mean;
XTensor variance;
XTensor standard;
XTensor meanFilled;
XTensor standardFilled;
/* \mu = (sum_i x_i)/m */
mean = ReduceSum(x, x.order - 1);
/* \sigma = (sum_i (x_i - \mu)^2)/m */
variance = ReduceVariance(x, x.order - 1, mean);
/* standard = sqrt(variance) */
standard = Power(variance, 0.5F);
/* unsqueeze mean and standard deviation to fit them into
the same size of x */
meanFilled = Unsqueeze(mean, x.order - 1, x.GetDim(-1));
standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1));
/* x' = (x - \mu)/standard */
return (x - meanFilled)/standardFilled;
}
}
......@@ -49,7 +49,7 @@ public:
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor * Make(XTensor * input);
XTensor Make(XTensor &input);
};
}
......
......@@ -69,7 +69,7 @@ make the encoding network
>> input - input tensor
<< return - encoding result
*/
XTensor * T2TModel::MakeEncoding(XTensor * input)
XTensor T2TModel::MakeEncoding(XTensor &input)
{
return encoder.Make(input);
}
......@@ -79,10 +79,12 @@ make the entire network (with the output softmax layer)
>> input - input tensor
>> output - output tensor (distribution)
*/
void T2TModel::Make(XTensor * input, XTensor * output)
void T2TModel::Make(XTensor &input, XTensor &output)
{
if(isLM){
XTensor * encoding = MakeEncoding(input);
XTensor encoding;
encoding = MakeEncoding(input);
outputLayer.Make(encoding, output);
}
else{
......
......@@ -66,10 +66,10 @@ public:
void InitModel(int argc, const char ** argv);
/* make the encoding network */
XTensor * MakeEncoding(XTensor * input);
XTensor MakeEncoding(XTensor &input);
/* make the entire network (with the output softmax layer) */
void Make(XTensor * input, XTensor * output);
void Make(XTensor &input, XTensor &output);
};
}
......
......@@ -21,6 +21,7 @@
#include "T2TOutput.h"
#include "T2TUtility.h"
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
......@@ -52,11 +53,16 @@ void T2TOutput::InitModel(int argc, const char ** argv, int myDevID, XMem * myMe
devID = myDevID;
mem = myMem;
float minmax = 0;
LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamInt(argc, argv, "hsize", &inSize, 512);
LoadParamInt(argc, argv, "hsize", &hSize, 512);
}
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_BEDDING_SIZE);
LoadParamInt(argc, argv, "d", &hSize, DEFAULT_BEDDING_SIZE);
LoadParamFloat(argc, argv, "outputminmax", &minmax, 0.08F);
InitTensor2D(&w, hSize, vSize, X_FLOAT, devID, mem);
w.SetDataRand(-minmax, minmax);
}
/*
make the network
......@@ -64,14 +70,11 @@ y = softmax(x * w)
>> input - input tensor
<< return - output tensor
*/
XTensor * T2TOutput::Make(XTensor * input)
XTensor T2TOutput::Make(XTensor &input)
{
XTensor &x = *input;
XTensor * result = new XTensor();
*result = LogSoftmax(MMul(x, w), -1);
XTensor &x = input;
return result;
return LogSoftmax(MMul(x, w), -1);
}
/*
......@@ -79,11 +82,11 @@ make the network (redefined output tensor)
>> input - input tensor
>> output - output tensor
*/
void T2TOutput::Make(XTensor * input, XTensor * output)
void T2TOutput::Make(XTensor &input, XTensor &output)
{
XTensor &x = *input;
XTensor &x = input;
*output = LogSoftmax(MMul(x, w), -1);
output = LogSoftmax(MMul(x, w), -1);
}
}
\ No newline at end of file
......@@ -62,10 +62,10 @@ public:
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor * Make(XTensor * input);
XTensor Make(XTensor &input);
/* make the network (redefined output tensor) */
void Make(XTensor * input, XTensor * output);
void Make(XTensor &input, XTensor &output);
};
......
......@@ -19,8 +19,11 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02
*/
#include <math.h>
#include "T2TTrainer.h"
#include "T2TUtility.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
......@@ -28,6 +31,8 @@ namespace transformer
/* constructor */
T2TTrainer::T2TTrainer()
{
devID = -1;
mem = NULL;
seqLen = NULL;
nseqBuf = 0;
nextSeq = -1;
......@@ -38,6 +43,7 @@ T2TTrainer::~T2TTrainer()
{
delete[] buf;
delete[] seqLen;
delete[] seqOffset;
}
/*
......@@ -47,17 +53,19 @@ initialization
*/
void T2TTrainer::Init(int argc, const char ** argv)
{
LoadParamInt(argc, argv, "dev", &devID, -1);
LoadParamFloat(argc, argv, "lrate", &lrate, 0.001F);
LoadParamInt(argc, argv, "sbatch", &sBatchSize, 1);
LoadParamInt(argc, argv, "wbatch", &wBatchSize, 1);
LoadParamInt(argc, argv, "nepoch", &nepoch, 1);
LoadParamInt(argc, argv, "nstep", &nstep, 1);
LoadParamInt(argc, argv, "vsize", &vSize, 1);
LoadParamBool(argc, argv, "sorted", &isLenSorted, false);
LoadParamInt(argc, argv, "bufsize", &bufSize, 50000);
int maxUnitInBuf;
LoadParamInt(argc, argv, "bufsize", &maxUnitInBuf, 20000);
buf = new int[maxUnitInBuf];
seqLen = new int[maxUnitInBuf];
seqOffset = new int[maxUnitInBuf];
buf = new int[bufSize];
seqLen = new int[bufSize];
seqOffset = new int[bufSize];
}
/*
......@@ -67,6 +75,70 @@ train the model
*/
void T2TTrainer::Train(const char * fn, T2TModel * model)
{
int epoch = 0;
int step = 0;
int wc = 0;
int wordCount = 0;
int wordCountTotal = 0;
bool isEnd = false;
float loss = 0;
XNet net;
double startT = GetClockSec();
for(epoch = 0; epoch < nepoch; epoch++){
FILE * file = fopen(fn, "rb");
CheckNTErrors(file, "cannot open training file!");
wordCount = 0;
/* batch of input sequences */
XTensor batch;
while(LoadBatch(file, &batch, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc)){
/* output probabilities */
XTensor output;
/* make the network */
model->Make(batch, output);
/* back-propagation for obtaining gradients */
net.Backward(output, batch, CROSSENTROPY);
/* update the parameters */
Update(model);
/* get probabilities */
float prob = GetProb(&output, &batch, NULL);
loss += -prob;
wordCount += wc;
wordCountTotal += wc;
if(++step >= nstep){
isEnd = true;
break;
}
if (step % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT5(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, ngram=%d, ppl=%.3f\n",
elapsed, step, epoch + 1, wordCountTotal, exp(loss / wordCount));
}
}
fclose(file);
}
double elapsed = GetClockSec() - startT;
XPRINT5(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, ngram=%d, ppl=%.3f\n",
elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
elapsed, step, epoch);
}
char line[MAX_SEQUENCE_LENGTH];
......@@ -83,8 +155,10 @@ int T2TTrainer::LoadBuf(FILE * file)
while(fgets(line, MAX_SEQUENCE_LENGTH - 1, file)){
int len = (int)strlen(line);
if(line[len - 1] == '\r')
while(line[len - 1] == '\r' || line[len - 1] == '\n'){
line[len - 1] = 0;
len--;
}
len = (int)strlen(line);
if(len == 0)
......@@ -96,10 +170,11 @@ int T2TTrainer::LoadBuf(FILE * file)
/* how many words are in the sentence */
int wNum = 0;
int wNumLocal = 0;
int i = 0;
for(int i = 0; i < len; i++){
for(i = 0; i < len; i++){
/* load word (id) seperated by space or tab */
if((line[i] == ' ' || line[i] == '\t' || i == len - 1) && wSize > 0){
if((line[i] == ' ' || line[i] == '\t') && wSize > 0){
line[i] = 0;
if(wSize == 3 && line[i - 1] == '|' && line[i - 2] == '|' && line[i - 3] == '|'){
......@@ -109,7 +184,7 @@ int T2TTrainer::LoadBuf(FILE * file)
wNumLocal = 0;
}
else{
buf[wNum++] = atoi(line + i - wSize);
buf[wordCount + wNum++] = atoi(line + i - wSize);
wNumLocal++;
}
......@@ -119,6 +194,11 @@ int T2TTrainer::LoadBuf(FILE * file)
wSize++;
}
if(wSize > 0){
buf[wordCount + wNum++] = atoi(line + i - wSize);
wNumLocal++;
}
seqLen[seqCount] = wNumLocal;
seqOffset[seqCount] = wordCount + wNum - wNumLocal;
seqCount++;
......@@ -126,10 +206,7 @@ int T2TTrainer::LoadBuf(FILE * file)
wordCount += wNum;
lineCount++;
if(wordCount >= wBatchSize)
break;
if(lineCount >= sBatchSize)
if(wordCount >= bufSize - MAX_SEQUENCE_LENGTH)
break;
}
......@@ -148,27 +225,32 @@ load a batch of sequences
>> sBatch - batch size of sequences
>> wBatch - batch size of words
>> isSorted - indicates whether the sequences are sorted by length
>> wCount - word count
*/
int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted)
int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount)
{
if(nextSeq >= nseqBuf)
if(nextSeq < 0 || nextSeq >= nseqBuf)
LoadBuf(file);
int seq = nextSeq;
int seq = MAX(nextSeq, 0);
int wc = 0;
int wn = 0;
int sc = 0;
int max = 0;
while(seq < nseqBuf){
wc += seqLen[seq];
while(seq + sc < nseqBuf){
wn = seqLen[seq + sc];
wc += wn;
sc += 1;
if(max < wc)
max = wc;
if(max < wn)
max = wn;
if(sc >= sBatch && wc >= wBatch)
break;
}
nextSeq = seq + sc;
if(sc > 0){
int dims[MAX_TENSOR_DIM_NUM];
dims[0] = sc;
......@@ -182,14 +264,86 @@ int T2TTrainer::LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sB
batch->SetZeroAll();
/* this might be slow on GPUs :( */
for(int s = seq; s < seq + sc; s++){
for(int w = 0; w < seqLen[s]; w++){
batch->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
wCount++;
}
}
}
return sc;
}
/*
get word probabilities for a batch of sequences
>> output - word distribution for each position
>> gold - gold standard
>> wordProbs - word probability for gold prediction
*/
float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs)
{
XTensor probs;
InitTensor(&probs, output);
/* probs[i,j] = output[i,j] * gold[i,j] */
_Multiply(output, gold, &probs);
/* probability of each word */
XTensor wprobs;
InitTensor1D(&wprobs, output->unitNum/output->GetDim(-1), X_FLOAT, output->devID, output->mem);
int dims[2] = {output->unitNum/output->GetDim(-1), output->GetDim(-1)};
probs.Reshape(2, dims);
_ReduceSum(&probs, &wprobs, 1);
if(wordProbs != NULL)
_CopyValues(&wprobs, wordProbs);
/* reshape the tensor to fit it into the reduce procedure
TODO: XTensor supports scalars */
dims[0] = 1;
dims[1] = probs.unitNum;
probs.Reshape(2, dims);
/* probability for the batch */
XTensor result;
InitTensor1D(&result, 1, X_FLOAT, output->devID, output->mem);
_ReduceSum(&probs, &result, 1);
return result.Get1D(0);
}
/*
update the model by delta rule
>> model - the t2t model
*/
void T2TTrainer::Update(T2TModel * model)
{
XList ws(100);
ws.Add(&model->outputLayer.w);
for(int i = 0; i < model->encoder.nlayer; i++){
ws.Add(&model->encoder.fnns[i].w1);
ws.Add(&model->encoder.fnns[i].b1);
ws.Add(&model->encoder.fnns[i].w2);
ws.Add(&model->encoder.fnns[i].b2);
}
ws.Add(&model->encoder.embedder.w);
for(int i = 0; i < ws.count; i++){
XTensor * para = (XTensor*)ws.Get(i);
XTensor * paraGrad = para->grad;
CheckNTErrors(para != NULL, "NULL parameter tensor!");
CheckNTErrors(paraGrad != NULL, "NULL gradient tensor!");
/* the delta rule */
_Sum(para, paraGrad, para, -lrate);
}
}
}
\ No newline at end of file
}
......@@ -26,7 +26,7 @@
#include "../../tensor/function/FHeader.h"
#define MAX_SEQUENCE_LENGTH 1024 * 64
#define MAX_SEQUENCE_LENGTH 1024 * 4
using namespace nts;
......@@ -46,6 +46,9 @@ public:
/* buffer for loading words */
int * buf;
/* buffer size */
int bufSize;
/* length of each sequence */
int * seqLen;
......@@ -57,6 +60,9 @@ public:
/* offset for next sequence in the buffer */
int nextSeq;
/* indicates whether the sequence is sorted by length */
bool isLenSorted;
/* vocabulary size of the source side */
int vSize;
......@@ -93,10 +99,16 @@ public:
int LoadBuf(FILE * file);
/* load a batch of sequences */
int LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted);
int LoadBatch(FILE * file, XTensor * batch, int step, int vs, int sBatch, int wBatch, bool isSorted, int &wCount);
/* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
/* update the model by delta rule */
void Update(T2TModel * model);
};
}
#endif
\ No newline at end of file
#endif
......@@ -26,7 +26,7 @@
namespace transformer
{
void LoadParamString(int argc, const char ** argv, const char * name, char * p, char * defaultP)
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP)
{
char vname[128];
vname[0] = '-';
......@@ -34,8 +34,8 @@ void LoadParamString(int argc, const char ** argv, const char * name, char * p,
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname) && i + 1 < argc){
*(int*)p = atoi(argv[i + 1]);
fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
strcpy(p, argv[i + 1]);
//fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
hit = true;
}
}
......@@ -52,7 +52,7 @@ void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname) && i + 1 < argc){
*(int*)p = atoi(argv[i + 1]);
fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
//fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
hit = true;
}
}
......@@ -69,7 +69,8 @@ void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bo
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname)){
*(bool*)p = true;
fprintf(stderr, " %s=%s\n", name, "true");
//fprintf(stderr, " %s=%s\n", name, "true");
hit = true;
}
}
if(!hit)
......@@ -84,12 +85,27 @@ void LoadParamFloat(int argc, const char ** argv, const char * name, float * p,
bool hit = false;
for(int i = 0; i < argc; i++){
if(!strcmp(argv[i], vname) && i + 1 < argc){
strcpy((char*)p, argv[i + 1]);
fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
*p = (float)atof(argv[i + 1]);
//fprintf(stderr, " %s=%s\n", name, argv[i + 1]);
hit = true;
}
}
if(!hit)
*p = defaultP;
}
}
\ No newline at end of file
void ShowParams(int argc, const char ** argv)
{
fprintf(stderr, "args:\n");
for(int i = 0; i < argc; i++){
if(argv[i][0] == '-'){
if(i + 1 < argc && argv[i + 1][0] != '-')
fprintf(stderr, " %s=%s\n", argv[i], argv[i + 1]);
else
fprintf(stderr, " %s=yes\n", argv[i]);
}
}
fprintf(stderr, "\n");
}
}
......@@ -27,12 +27,15 @@
namespace transformer
{
/* load model parameters */
void LoadParamString(int argc, const char ** argv, const char * name, char * p, char * defaultP);
/* load arguments */
void LoadParamString(int argc, const char ** argv, const char * name, char * p, const char * defaultP);
void LoadParamInt(int argc, const char ** argv, const char * name, int * p, int defaultP);
void LoadParamBool(int argc, const char ** argv, const char * name, bool * p, bool defaultP);
void LoadParamFloat(int argc, const char ** argv, const char * name, float * p, float defaultP);
/* show arguments */
void ShowParams(int argc, const char ** argv);
}
#endif
\ No newline at end of file
#endif
......@@ -20,12 +20,37 @@
*/
#include "Transformer.h"
#include "T2TModel.h"
#include "T2TUtility.h"
#include "T2TTrainer.h"
#include "../../tensor/XDevice.h"
namespace transformer
{
int TransformerMain(int argc, const char ** argv)
{
if(argc == 0)
return 1;
ShowParams(argc, argv);
char * trainFN = new char[MAX_LINE_LENGTH];
LoadParamString(argc, argv, "train", trainFN, "");
T2TModel model;
model.InitModel(argc, argv);
if(strcmp(trainFN, "")){
T2TTrainer trainer;
trainer.Init(argc, argv);
trainer.Train(trainFN, &model);
}
delete[] trainFN;
return 0;
}
......
......@@ -38,7 +38,7 @@ namespace transformer
{
/* entrance of the program */
int TransformerMMain(int argc, const char ** argv);
int TransformerMain(int argc, const char ** argv);
}
......
......@@ -37,7 +37,6 @@
using namespace nts;
void SetDataTest();
void SmallTest();
void TransposeTest();
......
......@@ -39,16 +39,26 @@ const char * GetOPName(int type)
return "M_COS";
else if (type == MATH_TAN)
return "M_TAN";
else if (type == MATH_ROUND)
return "M_ROUND";
else if (type == MATH_CLIP)
return "M_CLIP";
else if (type == MATH_DIV)
return "M_DIV";
else if (type == MATH_MATRIXMUL)
return "M_MATRIXMUL";
else if (type == MATH_MATRIXMULBATCHED)
return "M_MATRIXMULBATCHED";
else if (type == MATH_MULTIPLY)
return "M_MULTIPLY";
else if (type == MATH_DIV)
return "M_DIV";
else if (type == MATH_NEGATE)
return "M_NEGATE";
else if (type == MATH_NORMALIZE)
return "M_NORMALIZE";
else if (type == MATH_POWER)
return "M_POWER";
else if (type == MATH_SCALEANDSHIFT)
return "M_SCALEANDSHIFT";
else if (type == MATH_SIGN)
return "M_SIGN";
else if (type == MATH_SUM)
......@@ -57,12 +67,6 @@ const char * GetOPName(int type)
return "M_SUB";
else if (type == MATH_SUMDIM)
return "M_SUMDIM";
else if (type == MATH_NORMALIZE)
return "M_NORMALIZE";
else if (type == MATH_POWER)
return "M_POWER";
else if (type == MATH_SCALEANDSHIFT)
return "M_SCALEANDSHIFT";
else if (type == REDUCE_REDUCEMAX)
return "R_REDUCEMAX";
else if (type == REDUCE_REDUCEMEAN)
......
......@@ -30,28 +30,30 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* math operations */
#define MATH_BASE 0x00001000
#define MATH_ABSOLUTE MATH_BASE + 1
#define MATH_EXP MATH_ABSOLUTE + 1
#define MATH_LOG MATH_EXP + 1
#define MATH_SIN MATH_LOG + 1
#define MATH_COS MATH_SIN + 1
#define MATH_TAN MATH_COS + 1
#define MATH_ROUND MATH_TAN + 1
#define MATH_NEGATE MATH_TAN + 1
#define MATH_MATRIXMUL MATH_TAN + 1
#define MATH_CLIP MATH_ROUND + 1
#define MATH_DIV MATH_CLIP + 1
#define MATH_MATRIXMUL MATH_DIV + 1
#define MATH_MATRIXMULBATCHED MATH_MATRIXMUL + 1
#define MATH_MULTIPLY MATH_MATRIXMULBATCHED + 1
#define MATH_DIV MATH_MULTIPLY + 1
#define MATH_SIGN MATH_DIV + 1
#define MATH_NEGATE MATH_MULTIPLY + 1
#define MATH_NORMALIZE MATH_NEGATE + 1
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
#define MATH_SIGN MATH_SCALEANDSHIFT + 1
#define MATH_SUM MATH_SIGN + 1
#define MATH_SUB MATH_SUM + 1
#define MATH_SUMDIM MATH_SUB + 1
#define MATH_NORMALIZE MATH_SUMDIM + 1
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
#define REDUCE MATH_SCALEANDSHIFT + 1
#define REDUCE MATH_SUMDIM + 1
#define REDUCE_REDUCEMAX REDUCE + 1
#define REDUCE_REDUCEMEAN REDUCE_REDUCEMAX + 1
#define REDUCE_REDUCESUM REDUCE_REDUCEMEAN + 1
......
......@@ -42,6 +42,8 @@
#include "core/movement/CopyValues.h"
#include "core/arithmetic/Sum.h"
#include "core/arithmetic/Multiply.h"
#include "core/arithmetic/Sub.h"
#include "core/arithmetic/Div.h"
#include "core/math/ScaleAndShift.h"
#ifdef USE_CUDA
......@@ -354,6 +356,18 @@ XTensor XTensor::operator* (const XTensor& tensor)
return Multiply(*this, tensor);
}
/* overloading of the minus-sign */
XTensor XTensor::operator- (const XTensor& tensor)
{
return Sub(*this, tensor);
}
/* overloading of the division-sign */
XTensor XTensor::operator/ (const XTensor& tensor)
{
return Div(*this, tensor);
}
/*
linear transformation b = a * \scale + \shift
>> scale - the slope
......@@ -458,6 +472,27 @@ void XTensor::Reshape(const int myOrder, const int * myDimSize)
memcpy(dimSizeRDI, dimsRDI, sizeof(int) * order);
}
/*
reshape the tensor to a vector
>> num - number of elements
*/
void XTensor::Reshape(const int num)
{
int dim = num;
Reshape(1, &dim);
}
/*
reshape the tensor to a matrix
>> rowNum - number of rows
>> colNum - number of columns
*/
void XTensor::Reshape(const int rowNum, const int colNum)
{
int dims[2] = {rowNum, colNum};
Reshape(2, dims);
}
/* get the number of items in the data array */
int XTensor::GetSize() const
{
......@@ -564,25 +599,24 @@ set the tensor items by a uniform distribution in range [lower, upper]
void XTensor::SetDataRand(DTYPE lower, DTYPE upper)
{
// TODO: cuda code!!!!!!!
// TODO: replace float with DTYPE
if (data == NULL)
return;
// srand((unsigned)time(0));
DTYPE variance = upper - lower;
void * d = NULL;
if (dataType == X_FLOAT) {
d = new float[unitNum];
for (int i = 0; i < unitNum; i++) {
DTYPE value = lower + (upper - lower) * (float)rand() / RAND_MAX;
DTYPE value = lower + variance * (float)rand() / RAND_MAX;
*((float*)d + i) = value;
}
}
else if (dataType == X_DOUBLE) {
d = new double[unitNum];
for (int i = 0; i < unitNum; i++) {
*((double*)d + i) = lower + (upper - lower) * rand() / RAND_MAX;
*((double*)d + i) = lower + variance * rand() / RAND_MAX;
}
}
else {
......@@ -592,15 +626,15 @@ void XTensor::SetDataRand(DTYPE lower, DTYPE upper)
SetData(d, unitNum);
if (dataType == X_FLOAT) {
delete[](float*)d;
delete[] (float*)d;
}
else {
delete[](double*)d;
delete[] (double*)d;
}
}
/* a gauss distribution */
double GaussRand()
/* a gauss distribution (Box-Muller method) */
double GaussRand(DTYPE mean, DTYPE standardDeviation)
{
// TODO: cuda code!!!!!!!
......@@ -610,8 +644,8 @@ double GaussRand()
double pi = 3.141592654;
if (phase == 0){
u = rand() / (RAND_MAX + 1.0);
v = rand() / (RAND_MAX + 1.0);
u = (rand() + 1.0) / (RAND_MAX + 1.0);
v = (rand() + 1.0) / (RAND_MAX + 1.0);
z = sqrt(-2.0 * log(u))* sin(2.0 * pi * v);
}
else{
......@@ -619,7 +653,7 @@ double GaussRand()
}
phase = 1 - phase;
return z;
return mean + (z * standardDeviation);
}
/*
......@@ -630,7 +664,6 @@ set the tensor items by a normal distribution
void XTensor::SetDataRandn(DTYPE mean, DTYPE standardDeviation)
{
// TODO: cuda code!!!!!!!
// TODO: replace float with DTYPE
if (data == NULL)
return;
......@@ -640,13 +673,13 @@ void XTensor::SetDataRandn(DTYPE mean, DTYPE standardDeviation)
if (dataType == X_FLOAT) {
d = new float[unitNum];
for (int i = 0; i < unitNum; i++) {
*((float*)d + i) = (float)GaussRand();
*((float*)d + i) = (float)GaussRand(mean, standardDeviation);
}
}
else if (dataType == X_DOUBLE) {
d = new double[unitNum];
for (int i = 0; i < unitNum; i++) {
*((double*)d + i) = GaussRand();
*((double*)d + i) = GaussRand(mean, standardDeviation);
}
}
else {
......@@ -656,10 +689,10 @@ void XTensor::SetDataRandn(DTYPE mean, DTYPE standardDeviation)
SetData(d, unitNum);
if (dataType == X_FLOAT) {
delete[](float*)d;
delete[] (float*)d;
}
else {
delete[](double*)d;
delete[] (double*)d;
}
}
......@@ -1007,11 +1040,11 @@ set the value of a cell in a 3d tensor in default type
*/
bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2)
{
CheckNTErrors((order == 3), "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors((d0 >= 0 && d1 < dimSize[0]), "dimension 0 is out of range!");
CheckNTErrors((d2 >= 0 && d2 < dimSize[1]), "dimension 1 is out of range!");
CheckNTErrors((d2 >= 0 && d2 < dimSize[2]), "dimension 1 is out of range!");
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type.");
CheckNTErrors(order == 3, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors(d0 >= 0 && d0 < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors(d1 >= 0 && d1 < dimSize[1], "dimension 1 is out of range!");
CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 1 is out of range!");
CheckNTErrors(dataType == DEFAULT_DTYPE, "The tensor is not in default type.");
int dims[3] = {d0, d1, d1};
......
......@@ -203,6 +203,12 @@ public:
/* overloading of the multiply-sign */
XTensor operator* (const XTensor &tensor);
/* overloading of the minus-sign */
XTensor operator- (const XTensor &tensor);
/* overloading of the division-sign */
XTensor operator/ (const XTensor &tensor);
/* linear transformation */
XTensor Lin(DTYPE scale, DTYPE shift = 0);
......@@ -223,6 +229,12 @@ public:
/* reshape the tensor */
void Reshape(const int order, const int * myDimSize);
/* reshape the tensor to a vector */
void Reshape(const int num);
/* reshape the tensor to a matrix */
void Reshape(const int rowNum, const int colNum);
/* get the number of items in the data array */
int GetSize() const;
......
......@@ -46,6 +46,7 @@
#include "getandset/Select.h"
#include "getandset/SetData.h"
#include "math/Clip.h"
#include "math/Normalize.h"
#include "math/Power.h"
#include "math/ScaleAndShift.h"
......
......@@ -53,11 +53,29 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
CheckNTErrors(a && b && c, "Empty input tensors!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Input tensors should have the same data type!");
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2),
CheckNTErrors(a->order >= 2 && b->order >= 2 && c->order >= 2,
"Input tensors must have a order >= 2!");
CheckNTErrors(c->order == a->order + b->order - 2, "wrong tensor order")
/* we transform a higher order tensor to a matrix to kill the number
of calls of matrix multiplication */
if(transposedA == X_NOTRANS && a->order > 2 && b->order == 2){
int ncolA = a->dimSize[a->order - 1];
int ncolC = c->dimSize[c->order - 1];
XTensor * a2 = NewTensor2D(a->unitNum/ncolA, -ncolA, a->dataType, a->devID, a->mem);
XTensor * c2 = NewTensor2D(c->unitNum/ncolC, -ncolC, c->dataType, c->devID, c->mem);
a2->data = a->data;
c2->data = c->data;
_MatrixMul2D(a2, transposedA, b, transposedB, c2, alpha, beta, parallelRunner);
a2->data = NULL;
c2->data = NULL;
delete a2;
delete c2;
return;
}
int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1];
int am = transposedA == X_TRANS ? a->dimSizeRDI[1] : a->dimSizeRDI[0];
......@@ -144,10 +162,10 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
cublasHandle_t * handle = a->mem != NULL ? a->mem->GetCublasHandle() : GDevs.GetCudaHandle(a->devID);
_CudaBLASMatrixMULList(handle,
aList, transposedA,
bList, transposedB,
cList, aList->count,
alpha, beta);
aList, transposedA,
bList, transposedB,
cList, aList->count,
alpha, beta);
BacktoCudaDev(a->devID, devIDBackup);
#else
......@@ -251,9 +269,7 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
/*
matrix multiplication with no transposition c = a * b * alpha
>> a - tensor a
>> transposedA - indicates whether the matrices in a are transposed
>> b - tensor b
>> transposedB - indicates whether teh matrices in b are transposed
>> alpha - a coefficient
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
......
......@@ -117,14 +117,19 @@ void _MatrixMulBatchedGPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
blockNum *= a->dimSizeRDI[i];
}
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
cublasHandle_t * handle = a->mem != NULL ? a->mem->GetCublasHandle() : GDevs.GetCudaHandle(a->devID);
_CudaBLASMatrixMULBatchedStrided(handle,
a->data, transposedA, a->dataType, aBlockSize,
b->data, transposedB, b->dataType, bBlockSize,
c->data, c->dataType, cBlockSize, blockNum,
a->dimSizeRDI[1], a->dimSizeRDI[0],
b->dimSizeRDI[1], b->dimSizeRDI[0],
c->dimSizeRDI[1], c->dimSizeRDI[0], alpha, beta);
a->data, transposedA, a->dataType, aBlockSize,
b->data, transposedB, b->dataType, bBlockSize,
c->data, c->dataType, cBlockSize, blockNum,
a->dimSizeRDI[1], a->dimSizeRDI[0],
b->dimSizeRDI[1], b->dimSizeRDI[0],
c->dimSizeRDI[1], c->dimSizeRDI[0], alpha, beta);
BacktoCudaDev(a->devID, devIDBackup);
#endif
}
......@@ -150,12 +155,12 @@ void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
XTensor * c, DTYPE alpha, DTYPE beta)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
"Input tensors should have the same data type!");
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2),
"Input tensors must have a order >= 2!");
CheckNTErrors((a->order == b->order && a->order == c->order),
"Input tensor and output tensor must have same order!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Input tensors should have the same data type!");
CheckNTErrors(a->order >= 2 && b->order >= 2 && c->order >= 2,
"Input tensors must have a order >= 2!");
CheckNTErrors(a->order == b->order && a->order == c->order,
"Input tensor and output tensor must have same order!");
int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1];
......@@ -165,7 +170,7 @@ CheckNTErrors((a && b && c), "Empty input tensors!");
int cn = c->dimSizeRDI[1];
int cm = c->dimSizeRDI[0];
CheckNTErrors((am == bn && an == cn && bm == cm), "Unmatched tensors in multiplication!");
CheckNTErrors(am == bn && an == cn && bm == cm, "Unmatched tensors in multiplication!");
int aBlockSize = a->dimSizeRDI[0] * a->dimSizeRDI[1];
int bBlockSize = b->dimSizeRDI[0] * b->dimSizeRDI[1];
......@@ -326,4 +331,60 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const
return c;
}
/*
matrix multiplication of the two tensors (do it on site)
c = a * b * alpha
make a new tensor to keep the result and return it
for each 2-dimensional data array in a (denoted as ai) and
each 2-dimensional data array in b (denoted as bi), we have
ci = ai * bi * alpha + cm * beta
>> a - tensor a
>> b - tensor b
>> alpha - a coefficient
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication of the two tensors
*/
XTensor MatrixMulBatched(const XTensor &a, const XTensor &b,
DTYPE alpha, XPRunner * parallelRunner)
{
CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!");
CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!");
CheckNTErrors(a.order == b.order, "Input tensor and output tensor must have same order!");
int an = a.dimSizeRDI[1];
int am = a.dimSizeRDI[0];
int bn = b.dimSizeRDI[1];
int bm = b.dimSizeRDI[0];
CheckNTErrors(am == bn, "Unmatched tensors in multiplication!");
int order = a.order;
int sub = 0;
int * dimSize = new int[order];
for (int i = 0; i < a.order - 2; i++)
dimSize[sub++] = a.dimSize[i];
dimSize[sub++] = an;
dimSize[sub++] = bm;
float dr = (!a.isSparse || !b.isSparse) ? 1.0F : MAX(a.denseRatio, b.denseRatio);
XTensor c(order, dimSize, a.dataType, dr, a.devID, a.mem);
c.SetTMP();
/*call _MatrixMulBatched function */
_MatrixMulBatched(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha);
/* destroy variables */
delete[] dimSize;
return c;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -73,6 +73,17 @@ where trans() returns the transposed matrix if the flag is fired
XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
/*
matrix multiplication of the two tensors (return a XTensor structure) c = a * b * alpha
make a new tensor to keep the result and return it
for each 2-dimensional data array in a (denoted as ai) and
each 2-dimensional data array in b (denoted as bi), we have
ci = ai * bi * alpha + cm * beta
*/
XTensor MatrixMulBatched(const XTensor &a, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor)
#endif // __MATRIXMULBATCHED_H__
\ No newline at end of file
......@@ -76,7 +76,7 @@ XTensor Sign(const XTensor & a)
XTensor b(&a);
b.SetTMP();
/* call _ScaleAndShift function */
/* call _Sign function */
_Sign(&a, &b);
/* tensor connections */
......
......@@ -22,6 +22,7 @@
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "../movement/CopyValues.h"
#include "Sum.h"
#include "Sum.cuh"
#include "SumDim.h"
......@@ -44,8 +45,12 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched tensors in addition!");
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
if(beta == 0){
_CopyValues(a, c);
return;
}
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
if (a == c) {
int P2PAccesible = 0;
......
......@@ -214,34 +214,32 @@ void _SetDataFixedDouble(XTensor * tensor, double p)
}
/*
generate data items with a uniform distribution in [low,high]
generate data items with a uniform distribution in [lower, upper]
>> tensor - the tensor whose data array would be initialized
>> low - lower value of the range
>> high - higher value of the range
>> lower - lower value of the range
>> upper - upper value of the range
*/
void _SetDataRand(XTensor * tensor, DTYPE low, DTYPE high)
void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
{
CheckNTErrors(high > low, "the high value must be greater than low value!");
CheckNTErrors(upper > lower, "the high value must be greater than low value!");
if(tensor == NULL)
return;
/* GPU code */
if(tensor->devID < 0){
DTYPE variance = high - low;
srand((unsigned)time(NULL));
DTYPE variance = upper - lower;
if(tensor->dataType == X_FLOAT){
float * d = (float*)tensor->data;
for(int i = 0; i < tensor->unitNum; i++){
d[i] = variance * ((float)rand()/RAND_MAX) + low;
d[i] = variance * ((float)rand()/RAND_MAX) + lower;
}
}
else if(tensor->dataType == X_DOUBLE){
double * d = (double*)tensor->data;
for(int i = 0; i < tensor->unitNum; i++){
d[i] = variance * ((double)rand()/RAND_MAX) + low;
d[i] = variance * ((double)rand()/RAND_MAX) + lower;
}
}
else{
......@@ -256,7 +254,7 @@ void _SetDataRand(XTensor * tensor, DTYPE low, DTYPE high)
*/
else{
#ifdef USE_CUDA
_CudaSetDataRand(tensor, low, high);
_CudaSetDataRand(tensor, lower, upper);
#endif
//XTensor * t2 = NewTensor(tensor->order, tensor->dimSize, tensor->dataType, tensor->denseRatio, -1);
//_SetDataRand(t2, low, high);
......@@ -265,5 +263,17 @@ void _SetDataRand(XTensor * tensor, DTYPE low, DTYPE high)
}
}
/*
generate data items with a normal distribution with specified mean and standard deviation
>> mean - mean or expectation of the distribution
>> standardDeviation - standard deviation of the distribution
*/
void _SetDataRandN(XTensor * tensor, DTYPE mean, DTYPE standardDeviation)
{
// TODO: rewrite it and add cuda code!!!!!!!
tensor->SetDataRandn(mean, standardDeviation);
}
} // namespace nts(NiuTrans.Tensor)
......@@ -150,61 +150,20 @@ void _CudaSetDataFixedDouble(XTensor * tensor, double p)
}
/*
call curand_init function on each kernel with the same random seed
and init the rng states
*/
__global__
void KernelInitializeCurand(curandState * state, unsigned long seed)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
curand_init(seed, i, 0, &state[i]);
}
/* */
__device__
float GenerateFloat(curandState* globalState, int i)
{
//copy state to local mem
curandState localState = globalState[i];
//apply uniform distribution with calculated random
float randNum = curand_uniform(&localState);
//update state
globalState[i] = localState;
//return value
return randNum;
}
/**/
__device__
double GenerateDouble(curandState* globalState, int i)
{
//copy state to local mem
curandState localState = globalState[i];
//apply uniform distribution with calculated random
double randNum = curand_uniform_double(&localState);
//update state
globalState[i] = localState;
//return value
return randNum;
}
/*
set data array with a uniform distribution in [low, high]
>> deviceStates - the state of curand
>> d - float datatype pointer to the data array
>> size - size of the array
>> low - low value of the range
>> high - high value of the range
>> lower - low value of the range
>> variance - the variance of the range
*/
__global__
void KernelSetDataRandFloat(curandState* deviceStates, float * d, int size, DTYPE low, DTYPE variance)
void KernelSetDataRandFloat(float * d, int size, DTYPE lower, DTYPE variance)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
float randNum = GenerateFloat(deviceStates, i);
d[i] = randNum * variance + low;
d[i] = d[i] * variance + lower;
}
}
/*
......@@ -212,29 +171,28 @@ set data array with a uniform distribution in [low, high]
>> deviceStates - the state of curand
>> d - double datatype pointer to the data array
>> size - size of the array
>> low - low value of the range
>> high - high value of the range
>> lower - low value of the range
>> variance - the variance of the range
*/
__global__
void KernelSetDataRandDouble(curandState* deviceStates, double * d, int size, DTYPE low, DTYPE variance)
void KernelSetDataRandDouble(double * d, int size, DTYPE lower, DTYPE variance)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size){
double randNum = GenerateDouble(deviceStates, i);
d[i] = randNum * variance + low;
d[i] = d[i] * variance + lower;
}
}
/*
generate data items with a uniform distribution in [low,high]
generate data items with a uniform distribution in [lower, upper]
>> tensor - the tensor whose data array would be initialized
>> low - lower value of the range
>> high - higher value of the range
>> lower - lower value of the range
>> upper - upper value of the range
*/
void _CudaSetDataRand(XTensor * tensor, DTYPE low, DTYPE high)
void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
{
CheckNTErrors(high > low, "the high value must be greater than low value!");
CheckNTErrors(upper > lower, "the high value must be greater than low value!");
int gridSize[3];
int blockSize[3];
......@@ -247,15 +205,17 @@ void _CudaSetDataRand(XTensor * tensor, DTYPE low, DTYPE high)
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
curandState *deviceStates;
cudaMalloc(&deviceStates, sizeof(curandState));
DTYPE variance = high - low;
curandGenerator_t gen;
curandCreateGenerator (&gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(gen, time(NULL));
curandGenerateUniform(gen , (float*)tensor->data , tensor->unitNum);
curandDestroyGenerator(gen);
DTYPE variance = upper - lower;
KernelInitializeCurand<<<blocks, threads>>>(deviceStates, unsigned(time(NULL)));
if (tensor->dataType == X_FLOAT)
KernelSetDataRandFloat <<<blocks, threads >>>(deviceStates, (float*)tensor->data, tensor->unitNum, low, variance);
KernelSetDataRandFloat <<<blocks, threads >>>((float*)tensor->data, tensor->unitNum, lower, variance);
else if (tensor->dataType == X_DOUBLE)
KernelSetDataRandDouble <<<blocks, threads >>>(deviceStates, (double*)tensor->data, tensor->unitNum, low, variance);
KernelSetDataRandDouble <<<blocks, threads >>>((double*)tensor->data, tensor->unitNum, lower, variance);
BacktoCudaDev(tensor->devID, devIDBackup);
}
......
......@@ -37,8 +37,8 @@ void _CudaSetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */
void _CudaSetDataFixedDouble(XTensor * tensor, double p);
/* generate data items with a uniform distribution in [low,high] */
void _CudaSetDataRand(XTensor * tensor, DTYPE low, DTYPE high);
/* generate data items with a uniform distribution in [lower, upper] */
void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -45,8 +45,8 @@ void _SetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */
void _SetDataFixedDouble(XTensor * tensor, double p);
/* generate data items with a uniform distribution in [low,high] */
void _SetDataRand(XTensor * tensor, DTYPE low, DTYPE high);
/* generate data items with a uniform distribution in [lower, upper] */
void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
/* generate data items with a normal distribution with specified mean and standard deviation */
void _SetDataRandN(XTensor * tensor, DTYPE mean, DTYPE standardDeviation);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-03
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "Clip.h"
#include "Clip.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
set every entry to its clip value
>> a - input tensor we are processing
>> b - output tensor we are processing
>> lower - the lower border
>> upper - the upper border
*/
void _Clip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
_CudaClip(a, b, lower, upper);
return;
}
#endif
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
DTYPE * d = (DTYPE*)a->data;
DTYPE * db = (DTYPE*)b->data;
for (int i = 0; i < a->unitNum; i++) {
if (d[i] > upper)
db[i] = upper;
else if (d[i] < lower)
db[i] = lower;
else
db[i] = d[i];
}
}
/*
set every entry to its clip value (do it on site)
keep the result in the input tensor a and return nothing
>> a - the tensor we are processing
>> lower - the lower border
>> upper - the upper border
*/
void _ClipMe(XTensor * a, DTYPE lower, DTYPE upper)
{
_Clip(a, a, lower, upper);
}
/*
set every entry to its clip value (return a XTensor structure)
make a new tensor to keep the result and return it
>> a - input tensor we are processing
>> lower - the lower border
>> upper - the upper border
<< return - the clip value of the input tensor
*/
XTensor Clip(const XTensor & a, DTYPE lower, DTYPE upper)
{
XTensor b(&a);
b.SetTMP();
/* call _Clip function */
_Clip(&a, &b, lower, upper);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_CLIP);
XLink::AddParamToHead(&b, lower);
XLink::AddParamToHead(&b, upper);
return b;
}
/*
backward computation
dE/dx = dE/dy * dy/dx
hard tanh: y = upper if x > upper
x if lower <= x <= upper
lower if x< lower
and dy/dx = 1 if lower <= x <= upper
0 otherwise
>> gold - gold standard to measure error (or loss)
>> y - output of the function
>> x - input of the function
>> dedy - dE/dy
>> dedx - dE/dx
>> lossName - type of loss function, e.g., cross entropy
*/
void _ClipBackward(XTensor * y, XTensor * x, XTensor * dedy, XTensor * dedx, DTYPE lower, DTYPE upper)
{
#ifdef USE_CUDA
if (x->devID >= 0) {
_CudaClipBackward(y, x, dedy, dedx, lower, upper);
return;
}
#endif
if (x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE) {
DTYPE * dedyp = (DTYPE*)dedy->data;
DTYPE * dedxp = (DTYPE*)dedx->data;
DTYPE * ip = (DTYPE*)x->data;
int size = y->unitNum;
/* dE/dx = dE/dy * dy/dx */
for (int i = 0; i < size; i++) {
DTYPE s = ip[i];
if (s > upper || s < lower)
dedxp[i] = 0;
else
dedxp[i] = dedyp[i];
}
}
else
ShowNTErrors("TODO!");
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-03
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "Clip.h"
#include "Clip.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
set each entry to its clip value (CUDA Kernel)
>> a - pointer to input data array
>> b - pointer to output data array
>> lower - the lower border
>> upper - the upper border
>> size - size of the data array
*/
__global__
void KernelClip(DTYPE * a, DTYPE * b, DTYPE lower, DTYPE upper, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
if (a[i] > upper)
b[i] = upper;
else if (a[i] < lower)
b[i] = lower;
else
b[i] = a[i];
}
}
/*
set each entry to its clip value with float16 data type value (CUDA Kernel)
This is for float16 computation
>> a - pointer to input data array
>> b - pointer to output data array
>> lower - the lower border
>> upper - the upper border
>> size - size of the data array
*/
__global__
void KernelClip(__half * a, __half * b, DTYPE lower, DTYPE upper, int size)
{
return;
}
/*
set each entry to its clip value
>> a - input tensor we are processing
>> b - output tensor we are processing
>> lower - the lower border
>> upper - the upper border
*/
void _CudaClip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper)
{
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->isSparse == false), "TODO!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
KernelClip << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, lower, upper, a->unitNum);
}
else if (a->dataType == X_FLOAT16) {
KernelClip << <blocks, threads >> >((__half*)a->data, (__half*)b->data, lower, upper, a->unitNum);
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
/*
clip backward computation of dE/dx (Cuda kernel)
dy/dx = 1 if lower <= x <= upper
0 otherwise
>> dedy - dE/dy
>> dedx - dE/dx
>> y - y of the function
>> x - x of the function
>> lower
>> upper
*/
__global__
void KernelClipBackward(DTYPE * dedy, DTYPE * dedx, DTYPE * y, DTYPE * x, DTYPE lower, DTYPE upper, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
DTYPE s = x[i];
if (s > upper || s < lower)
dedx[i] = 0;
else
dedx[i] = dedy[i];
}
}
/*
backward computation (Cuda version)
dE/dx = dE/dy * dy/dx
hard tanh: y = upper if x > upper
x if lower <= x <= upper
lower if x< lower
and dy/dx = 1 if lower <= x <= upper
0 otherwise
>> gold - gold standard to measure error (or loss)
>> y - output of the function
>> x - input of the function
>> dedy - dE/dy
>> dedx - dE/dx
>> lossName - type of loss function, e.g., cross entropy
*/
void _CudaClipBackward(XTensor * y, XTensor * x, XTensor * dedy, XTensor * dedx, DTYPE lower, DTYPE upper)
{
if (x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE) {
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(x->devID, x->unitNum, gridSize, blockSize);
int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup);
/* dE/dx = dE/dy * dy/dx */
KernelClipBackward <<<dim3(gridSize[0]), dim3(blockSize[0])>>>
((DTYPE*)dedy->data,
(DTYPE*)dedx->data,
(DTYPE*)y->data, (DTYPE*)x->data,
lower, upper,
x->unitNum);
BacktoCudaDev(x->devID, devIDBackup);
}
else
ShowNTErrors("TODO!");
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-03
*/
#ifndef __CLIP_CUH__
#define __CLIP_CUH__
#include "Clip.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* set each entry to its clip value (CUDA Kernel) */
__global__
void KernelClip(DTYPE * a, DTYPE * b, DTYPE lower, DTYPE upper, int size);
/* set each entry to its clip value (CUDA Kernel) with float16 data type*/
__global__
void KernelClip(__half * a, __half * b, DTYPE lower, DTYPE upper, int size);
/* set each entry to its clip value */
void _CudaClip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper);
/* backward of Clip function (CUDA Kernel) */
__global__
void KernelClipBackward(DTYPE * dedy, DTYPE * dedx, DTYPE * y, DTYPE * x, DTYPE lower, DTYPE upper, int size);
/* backward of Clip function */
void _CudaClipBackward(XTensor * y, XTensor * x, XTensor * dedy, XTensor * dedx, DTYPE lower, DTYPE upper);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __CLIP_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-03
*/
#ifndef __CLIP_H__
#define __CLIP_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its clip value */
void _Clip(const XTensor * a, XTensor * b, DTYPE lower, DTYPE upper);
/*
set every entry to its clip value (do it on site)
keep the result in the input tensor a and return nothing
*/
void _ClipMe(XTensor * a, DTYPE lower, DTYPE upper);
/*
set every entry to its clip value (return a XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Clip(const XTensor & a, DTYPE lower, DTYPE upper);
/*
backward of Clip function
*/
void _ClipBackward(XTensor * y, XTensor * x, XTensor * dedy, XTensor * dedx, DTYPE lower, DTYPE upper);
} // namespace nts(NiuTrans.Tensor)
#endif // __CLIP_H__
......@@ -110,7 +110,7 @@ void _CudaNormalize(const XTensor * input, XTensor * output, int dim,
int cudaBlockSize[3];
GDevs.GetCudaThread2D(input->devID, strideNum, stride * blockNum,
MAX_INT, cudaGridSize, cudaBlockSize);
MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[1], cudaGridSize[0]);
dim3 threads(cudaBlockSize[1], cudaBlockSize[0]);
......@@ -119,9 +119,9 @@ void _CudaNormalize(const XTensor * input, XTensor * output, int dim,
ProtectCudaDev(a->devID, devIDBackup);
KernelNormalize << <blocks, threads >> >((DTYPE*)input->data, (DTYPE*)output->data,
(DTYPE*)mean->data, (DTYPE*)var->data,
(DTYPE*)a->data, (DTYPE*)b->data, epsilon,
stride, strideNum, blockNum);
(DTYPE*)mean->data, (DTYPE*)var->data,
(DTYPE*)a->data, (DTYPE*)b->data, epsilon,
stride, strideNum, blockNum);
BacktoCudaDev(a->devID, devIDBackup);
}
......
......@@ -64,6 +64,10 @@ SIMPLE_UNARY_FUNCTION(Cos, _Cos, MATH_COS)
_SIMPLE_UNARY_FUNCTION(_Tan, _CudaTan, tan)
_SIMPLE_UNARY_FUNCTION_ME(_TanMe, _Tan)
SIMPLE_UNARY_FUNCTION(Tan, _Tan, MATH_TAN)
_SIMPLE_UNARY_FUNCTION(_Round, _CudaRound, round)
_SIMPLE_UNARY_FUNCTION_ME(_RoundMe, _Round)
SIMPLE_UNARY_FUNCTION(Round, _Round, MATH_ROUND)
#else
/* define three marco separately, specify the respective function names */
#define _SIMPLE_UNARY_FUNCTION(_funcName, origFunc) \
......@@ -117,6 +121,10 @@ SIMPLE_UNARY_FUNCTION(Cos, _Cos, MATH_COS)
_SIMPLE_UNARY_FUNCTION(_Tan, tan)
_SIMPLE_UNARY_FUNCTION_ME(_TanMe, _Tan)
SIMPLE_UNARY_FUNCTION(Tan, _Tan, MATH_TAN)
_SIMPLE_UNARY_FUNCTION(_Round, round)
_SIMPLE_UNARY_FUNCTION_ME(_RoundMe, _Round)
SIMPLE_UNARY_FUNCTION(Round, _Round, MATH_ROUND)
#endif
}
\ No newline at end of file
......@@ -5,51 +5,51 @@
namespace nts {
#define SIMPLE_UNARY_FUNCTION_GPU(funcName, origFunc) \
__global__ \
void Kernel##funcName(DTYPE * a, DTYPE * b, int size) \
{ \
int i = blockDim.x * blockIdx.x + threadIdx.x; \
\
if (i < size) \
b[i] = (DTYPE)origFunc(a[i]); \
} \
__global__ \
void Kernel##funcName(__half * a, __half * b, int size) \
{ \
return; \
} \
void _Cuda##funcName(const XTensor * a, XTensor * b) \
{ \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \
CheckNTErrors((a->isSparse == false), "TODO!"); \
\
int gridSize[3]; \
int blockSize[3]; \
\
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize); \
\
dim3 blocks(gridSize[0]); \
dim3 threads(blockSize[0]); \
\
int devIDBackup; \
ProtectCudaDev(a->devID, devIDBackup); \
\
if (a->dataType == DEFAULT_DTYPE) { \
Kernel##funcName << <blocks, threads >> > \
((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum); \
} \
else if (a->dataType == X_FLOAT16) { \
Kernel##funcName << <blocks, threads >> > \
((__half*)a->data, (__half*)b->data, a->unitNum); \
} \
else { \
ShowNTErrors("TODO!"); \
} \
\
BacktoCudaDev(a->devID, devIDBackup); \
} \
#define SIMPLE_UNARY_FUNCTION_GPU(funcName, origFunc) \
__global__ \
void Kernel##funcName(DTYPE * a, DTYPE * b, int size) \
{ \
int i = blockDim.x * blockIdx.x + threadIdx.x; \
\
if (i < size) \
b[i] = (DTYPE)origFunc(a[i]); \
} \
__global__ \
void Kernel##funcName(__half * a, __half * b, int size) \
{ \
return; \
} \
void _Cuda##funcName(const XTensor * a, XTensor * b) \
{ \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \
CheckNTErrors((a->isSparse == false), "TODO!"); \
\
int gridSize[3]; \
int blockSize[3]; \
\
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize); \
\
dim3 blocks(gridSize[0]); \
dim3 threads(blockSize[0]); \
\
int devIDBackup; \
ProtectCudaDev(a->devID, devIDBackup); \
\
if (a->dataType == DEFAULT_DTYPE) { \
Kernel##funcName << <blocks, threads >> > \
((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum); \
} \
else if (a->dataType == X_FLOAT16) { \
Kernel##funcName << <blocks, threads >> > \
((__half*)a->data, (__half*)b->data, a->unitNum); \
} \
else { \
ShowNTErrors("TODO!"); \
} \
\
BacktoCudaDev(a->devID, devIDBackup); \
} \
SIMPLE_UNARY_FUNCTION_GPU(Absolute, fabs)
SIMPLE_UNARY_FUNCTION_GPU(Exp, exp)
......@@ -57,5 +57,6 @@ SIMPLE_UNARY_FUNCTION_GPU(Log, log)
SIMPLE_UNARY_FUNCTION_GPU(Sin, sin)
SIMPLE_UNARY_FUNCTION_GPU(Cos, cos)
SIMPLE_UNARY_FUNCTION_GPU(Tan, tan)
SIMPLE_UNARY_FUNCTION_GPU(Round, round)
}
\ No newline at end of file
......@@ -83,6 +83,15 @@ void KernelTan(__half * a, __half * b, int size);
/* set each entry to its tangent value */
void _CudaTan(const XTensor * a, XTensor * b);
/* set each entry to its round value (CUDA Kernel) */
__global__
void KernelRound(DTYPE * a, DTYPE * b, int size);
/* set each entry to its round value (CUDA Kernel) with float16 data type*/
__global__
void KernelRound(__half * a, __half * b, int size);
/* set each entry to its round value */
void _CudaRound(const XTensor * a, XTensor * b);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
......
......@@ -104,5 +104,19 @@ make a new tensor to keep the result and return it
*/
XTensor Tan(const XTensor & a);
/* set every entry to its round value */
void _Round(const XTensor * a, XTensor * b);
/*
set every entry to its round value (do it on site)
keep the result in the input tensor a and return nothing
*/
void _RoundMe(XTensor * a);
/*
set every entry to its round value (return a XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Round(const XTensor & a);
}
#endif //end __UNARY_H__
\ No newline at end of file
......@@ -36,7 +36,7 @@ copy s to t
void _CopyValues(const XTensor * s, XTensor * t, XStream * stream)
{
CheckNTErrors((s != NULL && t != NULL), "The input tensor and output tensor must be nonempty!");
CheckNTErrors((s->data != NULL), "Cannot copy from an empty data array!");
CheckNTErrors((s->data != NULL), "Cannot copy an empty data array!");
CheckNTErrors((t->data != NULL), "Cannot copy to an empty data array!");
CheckNTErrors((s->unitNum == t->unitNum), "Unmatched data item number!");
......@@ -82,7 +82,7 @@ copy s to t
void _CopyValues(const XTensor * s, const int sBeg, const int sLen, XTensor * t, const int tBeg, XStream * stream)
{
CheckNTErrors(s != NULL && t != NULL, "The input tensor and output tensor must be nonempty!");
CheckNTErrors(s->data != NULL && t->data != NULL, "Cannot copy from an empty data array!");
CheckNTErrors(s->data != NULL && t->data != NULL, "Cannot copy an empty data array!");
CheckNTErrors(s->unitSize == t->unitSize, "The input tensors must be of the same unit size!");
CheckNTErrors(s->order > sBeg && sBeg >= 0 && sLen <= s->unitNum, "Wrong segment on the source side");
CheckNTErrors(t->order > tBeg && tBeg >= 0, "Wrong segment on the target side");
......
......@@ -109,6 +109,9 @@ void _CudaMergeBlockLists(const XList * sourceList, int * blockSizes, int blockN
CheckNTErrors((maxBlockSize % sizeof(DTYPE) == 0), "Unsupported block size!");
realMaxBlockSize = maxBlockSize / sizeof(DTYPE);
int devIDBackup;
ProtectCudaDev(myMem->devID, devIDBackup);
int cudaGridSizes[3];
int cudaBlockSizes[3];
......@@ -135,6 +138,8 @@ void _CudaMergeBlockLists(const XList * sourceList, int * blockSizes, int blockN
delete[] targetArrays;
delete[] sizes;
delete[] offsets;
BacktoCudaDev(myMem->devID, devIDBackup);
}
#endif // USE_CUDA
......
......@@ -168,6 +168,8 @@ make a new tensor to keep the result and return it
XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
{
CheckNTErrors(&s, "Invalid tensors!");
CheckNTErrors(s.dimSize[whereToSplit] % splitNum == 0,
"The dimension cannot be splitted due to the inproper split number");
int order = s.order + 1;
int * dimSize = new int[order];
......
......@@ -116,8 +116,7 @@ void _HardTanHBackward(XTensor * gold, XTensor * y, XTensor * x,
}
#endif
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE)
{
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
/* calculate dE/dy */
if(lossName != NOLOSS)
_LossBackward(dedy, gold, y, lossName);
......
......@@ -150,11 +150,10 @@ void _LogSoftmax(const XTensor * x, XTensor * y, int leadDim)
}
}
if (x->devID < 0) {
DelTensorBuf(max);
DelTensorBuf(sum);
}
else {
DelTensorBuf(max);
DelTensorBuf(sum);
if (x->devID >= 0) {
delete blockx;
delete blocky;
delete blockMax;
......@@ -282,6 +281,9 @@ void _LogSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
CheckNTErrors((!dedx->isSparse), "The gradient matrix must be dense!");
CheckNTErrors((gold != NULL), "The gold standard cannot be empty!");
if(leadDim < 0)
leadDim = y->order - 1;
int leadDimRDI = y->order - leadDim - 1;
#ifdef USE_CUDA
if (gold->devID >= 0) {
......
......@@ -185,10 +185,14 @@ void _SoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
int leadDim,
LOSS_FUNCTION_NAME lossName)
{
CheckNTErrors((dedx->isSparse == false), "The gradient tensor must be dense!");
CheckNTErrors((gold != NULL), "Incorrect x gold standard tensor!");
CheckNTErrors(dedx->isSparse == false, "The gradient tensor must be dense!");
CheckNTErrors(gold != NULL || lossName == NOLOSS, "Gold standard is required for computing loss!");
if(leadDim < 0)
leadDim = y->order - 1;
int leadDimRDI = y->order - leadDim - 1;
#ifdef USE_CUDA
if(y->devID >= 0){
_CudaSoftmaxBackward(gold, y, x, dedy, dedx, leadDim, lossName);
......
......@@ -156,6 +156,50 @@ void KernelSoftmaxComputeTensor(__half * x, __half * max, __half * sum, __half *
}
/*
use PTX code to broadcast float data
*/
__device__ __forceinline__
float broadcast(float input)
{
float output;
asm(
"{"
"shfl.idx.b32 %0,%1,0x0,0x1f;"
"}"
:"=f"(output) : "f"(input)
);
return output;
}
/*
use warp broadcast to optimize softmax computing
*/
__global__
void KernelSoftmaxComputeTensorUseBroadcast(DTYPE * input, DTYPE * max, DTYPE * sum, DTYPE * output,
int stride, int strideNum, int blockNum)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
int j = blockDim.y * blockIdx.y + threadIdx.y;
int i2 = j % stride;
int blockSize = stride * strideNum;
if (j < stride * blockNum) {
DTYPE sumData, maxData;
if (i % 32 == 0) {
sumData = sum[j];
maxData = max[j];
}
sumData = broadcast(sumData);
maxData = broadcast(maxData);
if (i < strideNum){
int offset = int(j / stride) * blockSize + i * stride + i2;
output[offset] = exp(input[offset] - maxData) / sumData;
}
}
}
/*
softmax y = e^x / \sum_{i} e^{x_i} (Cuda version)
>> x - x vector
>> y - result
......@@ -183,20 +227,42 @@ void _CudaSoftmaxSumMax(const XTensor * x, XTensor * y, int leadDim, XTensor * s
int cudaGridSize[3];
int cudaBlockSize[3];
GDevs.GetCudaThread2D(x->devID, stride * blockNum, dimensionSize, MAX_INT, cudaGridSize, cudaBlockSize);
if (leadDim != 0 || dimensionSize <= 10){
/* allocate thread num for old function */
GDevs.GetCudaThread2D(x->devID, stride * blockNum, dimensionSize, MAX_INT, cudaGridSize, cudaBlockSize);
}
else {
/* allocate thread num for new function */
GDevs.GetCudaThread2D(x->devID, dimensionSize, stride * blockNum, MAX_INT, cudaGridSize, cudaBlockSize);
if (cudaBlockSize[0] < 32) {
/* use at least a warp */
cudaBlockSize[0] = 32;
if (cudaBlockSize[1] > 32) {
cudaGridSize[1] = int(ceil(float(stride * blockNum) / 32));
cudaBlockSize[1] = 32;
}
}
}
int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup);
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
KernelSoftmaxComputeTensor<<<dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1])>>>
((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data,
stride, dimensionSize, stride * dimensionSize, blockNum, stride * blockNum);
if (leadDim != 0 || dimensionSize <= 10) {
KernelSoftmaxComputeTensor <<< dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >>>
((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data,
stride, dimensionSize, stride * dimensionSize, blockNum, stride * blockNum);
}
else {
KernelSoftmaxComputeTensorUseBroadcast <<< dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >>>
((DTYPE*)x->data, (DTYPE*)max->data, (DTYPE*)sum->data, (DTYPE*)y->data,
stride, dimensionSize, blockNum);
}
}
else if(x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16){
KernelSoftmaxComputeTensor<<<dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1])>>>
((__half*)x->data, (__half*)max->data, (__half*)sum->data, (__half*)y->data,
stride, dimensionSize, blockNum);
KernelSoftmaxComputeTensor <<< dim3(cudaGridSize[0], cudaGridSize[1]), dim3(cudaBlockSize[0], cudaBlockSize[1]) >>>
((__half*)x->data, (__half*)max->data, (__half*)sum->data, (__half*)y->data,
stride, dimensionSize, blockNum);
}
else{
ShowNTErrors("TODO!");
......@@ -239,6 +305,9 @@ void _CudaSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
CheckNTErrors((x->devID == y->devID), "Matrices used in log softmax are not on the same GPU.");
CheckNTErrors((y->order >= 1), "Empty tensor!");
int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup);
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
CheckNTErrors((lossName == CROSSENTROPY ||
......@@ -284,8 +353,14 @@ void _CudaSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
/* make a matrix to keep \beta */
XTensor * beta = new XTensor(y->order - 1, dimSize, y->dataType, y->denseRatio, y->devID, mem);
ytmp->data = mem->AllocBuf(mem->devID, y->unitNum * y->unitSize);
beta->data = mem->AllocBuf(mem->devID, beta->unitNum * beta->unitSize);
if(mem != NULL){
ytmp->data = mem->AllocBuf(mem->devID, y->unitNum * y->unitSize);
beta->data = mem->AllocBuf(mem->devID, beta->unitNum * beta->unitSize);
}
else{
ytmp->data = XMemAlloc(y->devID, y->unitNum * y->unitSize);
beta->data = XMemAlloc(y->devID, beta->unitNum * beta->unitSize);
}
/* \beta = \sum_i (dE/dy_i * y_i) */
_Multiply(dedy, y, ytmp, 0, 0);
......@@ -298,8 +373,18 @@ void _CudaSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
/* dE/ds_j = y_j * ytmp = y_j * (dE/dy_j - \beta) */
_Multiply(y, ytmp, dedx, 0, 0);
mem->ReleaseBuf(mem->devID, y->unitNum * y->unitSize);
mem->ReleaseBuf(mem->devID, beta->unitNum * beta->unitSize);
if(mem != NULL){
mem->ReleaseBuf(mem->devID, y->unitNum * y->unitSize);
mem->ReleaseBuf(mem->devID, beta->unitNum * beta->unitSize);
}
else{
XMemFree(y->devID, ytmp->data);
XMemFree(y->devID, beta->data);
}
ytmp->data = NULL;
beta->data = NULL;
delete[] dimSize;
delete ytmp;
......@@ -311,6 +396,8 @@ void _CudaSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
}
else
ShowNTErrors("TODO!");
BacktoCudaDev(x->devID, devIDBackup);
}
#endif
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-03
*/
#include "../XTensor.h"
#include "TClip.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: test Clip function.
Set every entry to its clip value.
*/
bool TestClip1()
{
/* a tensor of size (3, 2) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 3;
aDimSize[1] = 2;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
DTYPE aData[3][2] = { {1.0F, -2.0F},
{0.0F, 4.0F},
{5.0F, -6.0F} };
DTYPE answer[3][2] = { {1.0F, -1.0F},
{0.0F, 1.0F},
{1.0F, -1.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(aOrder, aDimSize);
XTensor * aMe = NewTensor(aOrder, aDimSize);
XTensor bUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
aMe->SetData(aData, aUnitNum);
/* call Clip function */
_Clip(a, b, -1.0, 1.0);
_ClipMe(aMe, -1.0, 1.0);
bUser = Clip(*a, -1.0, 1.0);
/* check results */
cpuTest = b->CheckData(answer, aUnitNum, 1e-4F) &&
aMe->CheckData(answer, aUnitNum, 1e-4F) &&
bUser.CheckData(answer, aUnitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * aMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor bUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
aMeGPU->SetData(aData, aUnitNum);
/* call Clip function */
_Clip(aGPU, bGPU, -1.0, 1.0);
_ClipMe(aMeGPU, -1.0, 1.0);
bUserGPU = Clip(*aGPU, -1.0, 1.0);
/* check results */
gpuTest = bGPU->CheckData(answer, aUnitNum, 1e-4F) &&
aMeGPU->CheckData(answer, aUnitNum, 1e-4F) &&
bUserGPU.CheckData(answer, aUnitNum, 1e-4F);
/* destroy variables */
delete a;
delete b;
delete aMe;
delete aGPU;
delete bGPU;
delete aMeGPU;
delete[] aDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete aMe;
delete[] aDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for Clip Function */
bool TestClip()
{
XPRINT(0, stdout, "[TEST Clip] set every entry to its clip value \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestClip1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-08-03
*/
#ifndef __TEST_CLIP_H__
#define __TEST_CLIP_H__
#include "../core/math/Clip.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for Clip Function */
extern "C"
bool TestClip();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_CLIP_H__
......@@ -66,7 +66,9 @@ bool TestExp1()
bUser = Exp(*a);
/* check results */
cpuTest = b->CheckData(answer, unitNum, 1e-4F) && aMe->CheckData(answer, unitNum, 1e-4F) && bUser.CheckData(answer, unitNum, 1e-4F);
cpuTest = b->CheckData(answer, unitNum, 1e-4F) &&
aMe->CheckData(answer, unitNum, 1e-4F) &&
bUser.CheckData(answer, unitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
......@@ -88,7 +90,9 @@ bool TestExp1()
bUserGPU = Exp(*aGPU);
/* check results */
gpuTest = bGPU->CheckData(answer, unitNum, 1e-4F) && aMeGPU->CheckData(answer, unitNum, 1e-4F) && bUserGPU.CheckData(answer, unitNum, 1e-4F);
gpuTest = bGPU->CheckData(answer, unitNum, 1e-4F) &&
aMeGPU->CheckData(answer, unitNum, 1e-4F) && \
bUserGPU.CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete a;
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#include "../core/math/Unary.h"
#include "TRound.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: test Round function.
Set every entry to its round value.
*/
bool TestRound1()
{
/* a tensor of size (3, 2) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 3;
dimSize[1] = 2;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE aData[3][2] = { {1.3F, 2.7F},
{-1.3F, -2.7F},
{0.0F, 0.5F} };
DTYPE answer[3][2] = { {1.0F, 3.0F},
{-1.0F, -3.0F},
{0.0F, 1.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(order, dimSize);
XTensor * b = NewTensor(order, dimSize);
XTensor * aMe = NewTensor(order, dimSize);
XTensor bUser;
/* initialize variables */
a->SetData(aData, unitNum);
aMe->SetData(aData, unitNum);
/* call Round function */
_Round(a, b);
_RoundMe(aMe);
bUser = Round(*a);
/* check results */
cpuTest = b->CheckData(answer, unitNum, 1e-4F) &&
aMe->CheckData(answer, unitNum, 1e-4F) &&
bUser.CheckData(answer, unitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * aMeGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor bUserGPU;
/* Initialize variables */
aGPU->SetData(aData, unitNum);
aMeGPU->SetData(aData, unitNum);
/* call Round function */
_Round(aGPU, bGPU);
_RoundMe(aMeGPU);
bUserGPU = Round(*aGPU);
/* check results */
gpuTest = bGPU->CheckData(answer, unitNum, 1e-4F) &&
aMeGPU->CheckData(answer, unitNum, 1e-4F) &&
bUserGPU.CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete a;
delete b;
delete aMe;
delete aGPU;
delete bGPU;
delete aMeGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete aMe;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for Round Function */
bool TestRound()
{
XPRINT(0, stdout, "[TEST Round] set every entry to its round value \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestRound1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-03
*/
#ifndef __TEST_ROUND_H__
#define __TEST_ROUND_H__
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for Round Function */
bool TestRound();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_ROUND_H__
......@@ -30,6 +30,7 @@ bool Test()
XPRINT(0, stdout, "Testing the XTensor utilites ... \n\n");
wrong = !TestAbsolute() || wrong;
wrong = !TestClip() || wrong;
wrong = !TestConcatenate() || wrong;
wrong = !TestConcatenateSolely() || wrong;
wrong = !TestCos() || wrong;
......@@ -53,6 +54,7 @@ bool Test()
wrong = !TestReduceSum() || wrong;
wrong = !TestReduceSumSquared() || wrong;
wrong = !TestReduceVariance() || wrong;
wrong = !TestRound() || wrong;
wrong = !TestScaleAndShift() || wrong;
wrong = !TestSelect() || wrong;
wrong = !TestSetAscendingOrder() || wrong;
......@@ -68,7 +70,7 @@ bool Test()
wrong = !TestSumDim() || wrong;
wrong = !TestTan() || wrong;
wrong = !TestTranspose() || wrong;
wrong = !TestTopK() || wrong;
//wrong = !TestTopK() || wrong;
wrong = !TestUnsqueeze() || wrong;
wrong = !TestXMem() || wrong;
......
......@@ -23,6 +23,7 @@
#define __TEST_H__
#include "TAbsolute.h"
#include "TClip.h"
#include "TConcatenate.h"
#include "TConcatenateSolely.h"
#include "TCos.h"
......@@ -46,6 +47,7 @@
#include "TReduceSum.h"
#include "TReduceSumSquared.h"
#include "TReduceVariance.h"
#include "TRound.h"
#include "TScaleAndShift.h"
#include "TSelect.h"
#include "TSetAscendingOrder.h"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论