Commit 855a2803 by huchi

add new binary format for data and model

parent 771643c6
#include "Model.h"
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
/* register a parameter with a unique name */
void Model::Register(const char* name, Dim dims, TENSOR_DATA_TYPE dataType, int devID)
{
parameters.AddParameter(name, dims, dataType, devID);
}
/* get a parameter by its name */
XTensor* Model::operator[](const char* name)
{
return parameters.GetParameter(name);
}
/* load a model from a binary file */
void Model::Load(const char* fn)
{
CheckNTErrors(parameters.list.Size() > 0, "empty tensor list");
FILE* file = fopen(fn, "rb");
LongList offset(parameters.list.Size());
/* check number of parameter */
unsigned long int number;
fread(&number, sizeof(number), 1, file);
CheckNTErrors(number == parameters.list.Size(), "parameter number not matched");
/* read offset from the file */
fread(parameters.list.items, sizeof(long), offset.Size(), file);
/* read parameters from the file */
for (int i = 0; i < offset.Size(); i++) {
parameters.list[i]->BinaryRead(file, offset[i]);
}
fclose(file);
}
/* dump a model to a binary file */
void Model::Dump(const char* fn)
{
FILE* file = fopen(fn, "wb");
/* dump number of parameter */
unsigned long int number = parameters.list.Size();
fwrite(&number, sizeof(number), 1, file);
/* dump offset of parameters */
unsigned long int offset = sizeof(number);
for (int i = 0; i < parameters.list.Size(); i++) {
if (i > 0) {
offset += parameters.list[i - 1]->unitNum;
}
fwrite(&offset, sizeof(offset), 1, file);
}
/* dump parameters to the file */
for (int i = 0; i < parameters.list.Size(); i++) {
parameters.list[i]->BinaryDump(file);
}
fclose(file);
}
/* get a parameter by its name */
XTensor* Model::Get(const char* name)
{
return parameters.GetParameter(name);
}
/* add a parameter to the list */
void Parameter::AddParameter(const char* name, Dim dims, TENSOR_DATA_TYPE dataType, int devID)
{
CheckNTErrors(GetParameter(name) == NULL, "the name must be unique");
IntList dim;
for (int i : dims) {
dim.Add(i);
}
XTensor* p = NewTensorV2(dims.size(), dim.items, dataType, devID);
strcpy(p->name, (char*)name);
list.Add(p);
}
/* get a parameter by its name */
XTensor* Parameter::GetParameter(const char* name)
{
for (int i = 0; i < list.Size(); i++) {
if (strcmp(list[i]->name, name) == 0)
return list[i];
}
/* if miss, return a null pointer */
return NULL;
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
*
* the model class
*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-09-12
*
*/
#ifndef __MODEL_H__
#define __MODEL_H__
#include <utility>
#include "../tensor/XGlobal.h"
#include "../tensor/XTensor.h"
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
using Dim = std::initializer_list<int>;
/* Parameter is a base class for parameters */
struct Parameter {
public:
/* the parameter list */
TensorList list;
public:
/* add a parameter to the list */
void AddParameter(const char* name, Dim dims, TENSOR_DATA_TYPE dataType, int devID);
/* get a parameter by its name */
XTensor* GetParameter(const char* name);
};
/* Model is a base class for neural networks */
struct Model {
public:
Parameter parameters;
public:
/* load a model from a binary file */
void Load(const char* fn);
/* dump the model to a binary file */
void Dump(const char* fn);
/* get a parameter by its name */
XTensor* Get(const char* name);
/* get a parameter by its name */
XTensor* operator[] (const char* name);
/* register a parameter with a unique name */
void Register(const char* name, Dim dims, TENSOR_DATA_TYPE dataType, int devID);
};
}
#endif // __MODEL_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
......@@ -15,201 +15,117 @@
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-10
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-10
*/
#include <stdio.h>
#include "XNet.h"
#include "../tensor/XUtility.h"
#include "../tensor/function/FHeader.h"
#include "../tensor/core/CHeader.h"
#include "../tensor/test/Test.h"
#include "../sample/fnnlm/FNNLM.h"
#include "../sample/transformer/Transformer.h"
//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>
//#include <crtdbg.h>
void BackwardTest();
void TransposeTest();
void SumDimTest();
#include <fstream>
#include <string>
using namespace nts;
using namespace fnnlm;
using namespace transformer;
int main( int argc, const char ** argv )
{
//_CrtSetDbgFlag(_CrtSetDbgFlag(_CRTDBG_REPORT_FLAG) | _CRTDBG_LEAK_CHECK_DF);
//_CrtSetBreakAlloc(2708);
//if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
// FNNLMMain(argc - 1, argv + 1);
//else if(argc > 1 && !strcmp(argv[1], "-t2t"))
// TransformerMain(argc - 1, argv + 1);
//else{
// fprintf(stderr, "Thanks for using NiuTrans.Network! This is a library for building\n");
// fprintf(stderr, "neural networks in an easy way. \n\n");
// fprintf(stderr, "Run this program with \"-test\" for unit test!\n");
// fprintf(stderr, "Or run this program with \"-fnnlm\" for sample FNNLM!\n");
//}
BackwardTest();
//_CrtDumpMemoryLeaks();
return 0;
void test() {
XTensor posEmbeddingBase;
int length = 5;
int eSize = 4;
int d = 4;
InitTensor2D(&posEmbeddingBase, length, eSize, X_FLOAT);
float* data = new float[posEmbeddingBase.unitNum];
for (int pos = 0; pos < length; pos++) {
float* dp = data + pos * eSize;
//int channelSize = eSize / 2;
//int offset = 0;
//for(int i = 0; i < channelSize; i++){
// dp[offset++] = (float)sin(pos/pow(10000.0F, 2.0F*i/(d - 2)));
//}
//for(int i = 0; i < channelSize; i++){
// dp[offset++] = (float)cos(pos/pow(10000.0F, 2.0F*i/(d - 2)));
//}
for (int k = 0; k < eSize; k++) {
if (k % 2 == 0) {
int i = k / 2;
dp[k] = (float)sin(pos / pow(10000.0F, 2.0F * i / d));
}
else {
int i = (k - 1) / 2;
dp[k] = (float)cos(pos / pow(10000.0F, 2.0F * i / d));
}
}
}
posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum);
posEmbeddingBase.Dump(stderr);
delete[] data;
}
void BackwardTest()
{
XNet net;
struct A {
XTensor a;
XTensor b;
XTensor c;
a.enableGrad = true;
b.enableGrad = false;
c.enableGrad = false;
XTensor mean;
XTensor origin;
InitTensor2D(&a, 2, 3);
InitTensor1D(&b, 2);
a.SetZeroAll();
b.SetZeroAll();
a.Set2D(1.0F, 0, 0);
a.Set2D(2.0F, 0, 1);
a.Set2D(3.0F, 0, 2);
a.Set2D(4.0F, 1, 0);
a.Set2D(5.0F, 1, 1);
a.Set2D(6.0F, 1, 2);
b.Set1D(2.0F, 0);
b.Set1D(1.0F, 1);
DivDim(a, b, c, 0);
c.Dump(stderr, "c:");
auto loss = CrossEntropy(c, a);
//XLink::ShowNetwork(stderr, &c);
net.Backward(loss);
a.grad->Dump(stderr);
}
void update(XTensor b) {
a = b;
}
};
void TransposeTest()
{
#ifdef USE_CUDA
XMem mem0(0, UNI_FREE, MILLION * 64, 1024, MILLION * 64);
//XMem mem1(1, UNI_FREE, MILLION * 64, 1024, MILLION * 64);
void test2(A *a) {
XTensor x;
InitTensor2D(&x, 2, 3);
XTensor y;
InitTensor2D(&y, 3, 2);
float data[]{ 1,1,1,1,1,1 };
x.SetData(data, 6);
y.SetData(data, 6);
XTensor z;
z = MatrixMul(x, y);
a->update(z);
}
int loops = 2000;
int B = 3 * 2 * 4;
int K = 8 * 1;
int N = 50;
int H = 512 * 4;
int nnn = GDevs.nGPU;
InitTensor3D(&x, B, N, H, X_FLOAT, 0);
InitTensor4D(&y, K, B, N, H/K, X_FLOAT, 0);
InitTensor3D(&z, B, N, H, X_FLOAT, 0);
cudaEvent_t ctime0;
cudaEvent_t ctime1;
cudaEvent_t ctime2;
cudaEvent_t ctime3;
cudaEvent_t ctime4;
cudaEvent_t ctime5;
float elapsedSplit = 0.0;
float elapsedMerge = 0.0;
float elapsedSum = 0.0;
cudaEventCreate(&ctime0);
cudaEventCreate(&ctime1);
cudaEventCreate(&ctime2);
cudaEventCreate(&ctime3);
cudaEventCreate(&ctime4);
cudaEventCreate(&ctime5);
cudaEventRecord(ctime0, 0);
double time0 = GetClock();
for(int i = 0; i < loops; i++)
_Split(&x, &y, 2, K);
double time1 = GetClock();
cudaEventRecord(ctime1, 0);
cudaEventSynchronize(ctime1);
cudaEventElapsedTime(&elapsedSplit, ctime0, ctime1);
cudaEventRecord(ctime2, 0);
double time2 = GetClock();
for(int i = 0; i < loops; i++)
_Merge(&y, &x, 3);
double time3 = GetClock();
cudaEventRecord(ctime3, 0);
cudaEventSynchronize(ctime3);
cudaEventElapsedTime(&elapsedMerge, ctime2, ctime3);
cudaEventRecord(ctime4, 0);
double time4 = GetClock();
for(int i = 0; i < loops; i++)
_Sum(&x, &z, &x);
double time5 = GetClock();
cudaEventRecord(ctime5, 0);
cudaEventSynchronize(ctime5);
cudaEventElapsedTime(&elapsedSum, ctime4, ctime5);
fprintf(stderr, "split:%f merge:%f sum:%f\n", time1 - time0, time3 - time2, time5 - time4);
fprintf(stderr, "split:%f merge:%f sum:%f\n", elapsedSplit, elapsedMerge, elapsedSum);
#endif
void TestMemory() {
int devID = 0;
int memSize = 1024;
XMem *mem = new XMem(devID, FREE_ON_THE_FLY, (MTYPE)MILLION * 256, 1024, MILLION * 128);
mem->SetDesiredSize(devID, 0, (MTYPE)memSize * MILLION);
XTensor a;
InitTensor2D(&a, 5, 5, X_FLOAT, 0, mem);
float d[25]{ 0 };
for (int i = 0; i < 25; i++)
d[i] = float(i);
a.SetData(d, 25);
int index[]{ 0,1,2,3,4 };
for (int i = 0; i < 4; i++) {
XTensor srcIdx, tgtIdx;
InitTensor1D(&srcIdx, 4 - i, X_INT, a.devID, a.mem);
InitTensor1D(&tgtIdx, 4 - i, X_INT, a.devID, a.mem);
srcIdx.SetData(index, srcIdx.unitNum);
tgtIdx.SetAscendingOrder(0);
a = CopyIndexed(a, 0, srcIdx, tgtIdx);
printf("\nround %d\n", i);
a.Dump(stderr);
}
delete mem;
}
void SumDimTest()
int main(int argc, const char** argv)
{
XTensor x;
XTensor y;
XTensor z;
int a = 5;
int b = 7;
int c = 3;
InitTensor3D(&x, a, b, c, X_FLOAT, -1);
InitTensor1D(&y, c, X_FLOAT, -1);
InitTensor3D(&z, a, b, c, X_FLOAT, -1);
x.SetZeroAll();
y.SetZeroAll();
z.SetZeroAll();
DTYPE * data = new DTYPE[x.unitNum];
for(int i = 0; i < x.unitNum; i++)
data[i] = (DTYPE)i;
x.SetData(data, x.unitNum);
TransformerMain(argc - 1, argv + 1);
for(int i = 0; i < y.unitNum; i++)
data[i] = -(DTYPE)i;
y.SetData(data, y.unitNum);
_SumDim(&x, &y, &z, 2);
z.Dump(stderr, "z:");
delete[] data;
return 0;
}
......@@ -43,18 +43,18 @@ void XFuncGrad::MakeGrad(XTensor * node, bool isEfficient)
XNoder::MakeGrad(input);
if(operID == FUNC_HARDTANH)
_HardTanHBackward(NULL, output, input, output->grad, input->grad, NOLOSS);
_HardTanHBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_IDENTITY)
_IdentityBackward(NULL, output, input, output->grad, input->grad, NOLOSS);
_IdentityBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_LOGSOFTMAX){
int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in logsoftmax!");
_LogSoftmaxBackward(NULL, output, input, output->grad, input->grad, NULL, leadDim, NOLOSS);
}
else if(operID == FUNC_RECTIFY)
_RectifyBackward(NULL, output, input, output->grad, input->grad, NOLOSS);
_RectifyBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_SIGMOID)
_SigmoidBackward(NULL, output, input, output->grad, input->grad, NOLOSS);
_SigmoidBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_SOFTMAX){
int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in softmax!");
......
......@@ -69,7 +69,7 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient)
if(operID == LOSS_CROSSENTROPY) {
if (income.tailNum == 3)
padding = income.tails[2];
leadingDim = income.GetParamInt(0);
leadingDim = income.GetParamInt(0);
CheckNTErrors(leadingDim >= 0 && leadingDim < output->order, "wrong leading dimension in logsoftmax!");
_CrossEntropyBackward(dedy, output, gold, weight, padding, leadingDim);
}
......@@ -98,39 +98,39 @@ compute dE/dx for a given function y = f(x)
>> params - parameters of the function
>> lossName - name of the loss, e.g., cross entropy
*/
void XLossGrad::Compute(XTensor * gold, XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx, XTensor * padding,
int funcID, void * params,
LOSS_FUNCTION_NAME lossName)
{
CheckNTErrors(gold && y && x, "Empty input tensors!");
CheckNTErrors(dedx, "Empty gradient tensors!");
CheckNTErrors((funcID & FUNCTION_BASE) != 0, "Illegal function id");
if(funcID == FUNC_HARDTANH){
_HardTanHBackward(gold, y, x, dedy, dedx, lossName);
}
else if(funcID == FUNC_IDENTITY){
_IdentityBackward(gold, y, x, dedy, dedx, lossName);
}
else if(funcID == FUNC_LOGSOFTMAX){
int leadDim = *(int*)params;
_LogSoftmaxBackward(gold, y, x, dedy, dedx, padding, leadDim, lossName);
}
else if(funcID == FUNC_RECTIFY){
_RectifyBackward(gold, y, x, dedy, dedx, lossName);
}
else if(funcID == FUNC_SIGMOID){
_SigmoidBackward(gold, y, x, dedy, dedx, lossName);
}else if(funcID == FUNC_SOFTMAX){
int leadDim = *(int*)params;
_SoftmaxBackward(gold, y, x, dedy, dedx, padding, leadDim, lossName);
}
else{
ShowNTErrors("wrong function found when call the backward process!");
}
}
//void XLossGrad::Compute(XTensor * gold, XTensor * y, XTensor * x,
// XTensor * dedy, XTensor * dedx, XTensor * padding,
// int funcID, void * params,
// LOSS_FUNCTION_NAME lossName)
//{
// CheckNTErrors(gold && y && x, "Empty input tensors!");
// CheckNTErrors(dedx, "Empty gradient tensors!");
// CheckNTErrors((funcID & FUNCTION_BASE) != 0, "Illegal function id");
//
// if(funcID == FUNC_HARDTANH){
// _HardTanHBackward(gold, y, x, dedy, dedx, lossName);
// }
// else if(funcID == FUNC_IDENTITY){
// _IdentityBackward(gold, y, x, dedy, dedx, lossName);
// }
// else if(funcID == FUNC_LOGSOFTMAX){
// int leadDim = *(int*)params;
// _LogSoftmaxBackward(gold, y, x, dedy, dedx, padding, leadDim, lossName);
// }
// else if(funcID == FUNC_RECTIFY){
// _RectifyBackward(gold, y, x, dedy, dedx, lossName);
// }
// else if(funcID == FUNC_SIGMOID){
// _SigmoidBackward(gold, y, x, dedy, dedx, lossName);
// }else if(funcID == FUNC_SOFTMAX){
// int leadDim = *(int*)params;
// _SoftmaxBackward(gold, y, x, dedy, dedx, padding, leadDim, lossName);
// }
// else{
// ShowNTErrors("wrong function found when call the backward process!");
// }
//
//}
/*
compute dE/dy for variable y and error(loss) function E
......@@ -139,27 +139,27 @@ compute dE/dy for variable y and error(loss) function E
>> dedy - dE/dy
>> lossName - name of the loss, e.g., cross entropy
*/
void XLossGrad::Compute(XTensor * gold, XTensor * y,
XTensor * dedy, XTensor * padding,
LOSS_FUNCTION_NAME lossName)
{
if(gold == NULL){
if(dedy->dataType == X_FLOAT)
_SetDataFixedFloat(dedy, 1.0F);
else if(dedy->dataType == X_DOUBLE)
_SetDataFixedDouble(dedy, 1.0);
else if(dedy->dataType == X_INT)
_SetDataFixedInt(dedy, 1);
else{
ShowNTErrors("TODO");
}
return;
}
//_LossBackward(dedy, gold, y, lossName);
if(lossName == CROSSENTROPY)
_CrossEntropyBackward(dedy, y, gold, NULL, padding);
}
//void XLossGrad::Compute(XTensor * gold, XTensor * y,
// XTensor * dedy, XTensor * padding,
// LOSS_FUNCTION_NAME lossName)
//{
// if(gold == NULL){
// if(dedy->dataType == X_FLOAT)
// _SetDataFixedFloat(dedy, 1.0F);
// else if(dedy->dataType == X_DOUBLE)
// _SetDataFixedDouble(dedy, 1.0);
// else if(dedy->dataType == X_INT)
// _SetDataFixedInt(dedy, 1);
// else{
// ShowNTErrors("TODO");
// }
// return;
// }
//
// //_LossBackward(dedy, gold, y, lossName);
// if(lossName == CROSSENTROPY)
// _CrossEntropyBackward(dedy, y, gold, NULL, padding);
//
//}
}
\ No newline at end of file
......@@ -43,11 +43,11 @@ public:
static
bool IsLossOP(XTensor * node);
/* compute dE/dx for a given function y = f(x) */
void Compute(XTensor * gold, XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx, XTensor * padding,
int funcID, void * params,
LOSS_FUNCTION_NAME lossName);
///* compute dE/dx for a given function y = f(x) */
//void Compute(XTensor * gold, XTensor * y, XTensor * x,
// XTensor * dedy, XTensor * dedx, XTensor * padding,
// int funcID, void * params,
// LOSS_FUNCTION_NAME lossName);
/* compute dE/dy for variable y and error(loss) function E */
void Compute(XTensor * gold, XTensor * y,
......
......@@ -530,7 +530,7 @@ void XMathGrad::GradMatrixMul(XTensor * node, bool isEfficient)
XTensor * dedc = node->grad;
XTensor * deda = a->grad;
XTensor * dedb = b->grad;
if(a->order == 2 && b->order == 2)
GradMatrixMul(a, deda, transA, b, dedb, transB, dedc, alpha, isEfficient);
else if(transA == X_NOTRANS && a->order > 2 && b->order == 2){
......
......@@ -55,7 +55,7 @@ void XNetClearAll()
XNet::XNet()
{
nodes.Clear();
isGradEfficient = false;
isGradEfficient = true;
}
/* de-constructor */
......@@ -187,7 +187,7 @@ void XNet::Backward(TensorList &roots, TensorList &golds, TensorList &paddings,
node->visitMark = NODE_UNFINISHED;
}
XLossGrad lossGrad;
//XLossGrad lossGrad;
/* we start with the gradient with respect to the loss for output layers */
/*for(int i = 0; i < roots.count; i++){
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
*
* This is a simple impelementation of the feed-forward network-baesd language
* model (FNNLM). See more details about FNNLM in
* "A Neural Probabilistic Language Model" by Bengio et al.
* Journal of Machine Learning Research 3 (2003) 1137¨C1155
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-06-22
* Today I was awarded as the most popular teacher in our college.
* It was the great honour for me!!!
*/
#ifndef __FNNLM_H__
#define __FNNLM_H__
#include "../../tensor/XGlobal.h"
#include "../../tensor/XTensor.h"
#include "../../tensor/core/CHeader.h"
using namespace nts;
namespace fnnlm
{
#define _EXIT_(x)// exit(x)
#define CheckErrors(x, msg) { if(!(x)) { fprintf(stderr, "Error! calling '%s' (%s line %d): %s\n", #x, __FILENAME__, __LINE__, msg); _EXIT_(1); } }
#define ShowErrors(msg) { { fprintf(stderr, "Error! (%s line %d): %s\n", __FILENAME__, __LINE__, msg); _EXIT_(1); } }
#define MAX_N_GRAM 8
#define MAX_HIDDEN_NUM 8
/* an n-gram = a sequence of n words
words[0..n-2] is the history, and
words[n-1] is the word for prediction. */
struct NGram
{
int words[MAX_N_GRAM];
};
/* fnn model */
struct FNNModel
{
/* word embedding */
XTensor embeddingW;
/* parameter matrix of each hidden layer
hidden layer: y = f(x * w + b)
where x is the input, y is the output, w is
the tranformation (parameter) matrix, b is
the bias and f() is the activation function. */
XTensor hiddenW[MAX_HIDDEN_NUM];
/* bias of each hidden layer */
XTensor hiddenB[MAX_HIDDEN_NUM];
/* parameter matrix of the output layer */
XTensor outputW;
/* bias of the output layer */
XTensor outputB;
/* order of the language model */
int n;
/* embedding size */
int eSize;
/* number of hidden layers */
int hDepth;
/* hidden layer size */
int hSize;
/* vocabulary size */
int vSize;
/* id of the device for running the model */
int devID;
/* indicates whether we use memory pool */
bool useMemPool;
/* memory pool */
XMem * mem;
FNNModel(){ n = -1; vSize = -1;hDepth = 0;devID = -1;mem = NULL;};
~FNNModel(){delete mem;};
};
/* the network built on the fly */
struct FNNNet
{
/* embedding result of the previous n - 1 words */
XTensor embeddings[MAX_N_GRAM];
/* concatenation of embeddings */
XTensor embeddingCat;
/* output of the hidden layers */
XTensor hiddens[MAX_HIDDEN_NUM];
/* state of the hidden layers (before activation function) */
XTensor hiddenStates[MAX_HIDDEN_NUM];
/* state before softmax */
XTensor stateLast;
/* output of the net */
XTensor output;
};
/* entrance of the program */
int FNNLMMain(int argc, const char ** argv);
};
#endif
......@@ -29,6 +29,48 @@ using namespace nts;
namespace transformer
{
/* layer cache for key and value */
class Cache {
public:
/* cache for key */
XTensor* k{ NULL };
/* cache for value */
XTensor* v{ NULL };
public:
bool IsEmpty(){
return (k == NULL) && (v == NULL);
}
void Clear() {
if (k && v && k->id > 0 && v->id >0) {
DelTensor(k);
DelTensor(v);
}
k = NULL;
v = NULL;
}
void Update(XTensor* newK, XTensor* newV) {
if (!newK || (k == newK) || !newV || (v == newV))
return;
Clear();
k = newK;
v = newV;
}
XTensor* GetK() {
return k;
}
XTensor* GetV() {
return v;
}
};
/*
multi-head attention
y(Q, K, V) = cat(head_1, head_2, ..., head_n)
......@@ -48,20 +90,33 @@ public:
/* head number */
int nhead;
/* transformation matrix for K */
/* transformation matrix for query */
XTensor wq;
/* bias for query */
XTensor bq;
/* transformation matrix for query */
XTensor wk;
/* transformation matrix for Q */
XTensor wq;
/* bias for query */
XTensor bk;
/* transformation matrix for V */
/* transformation matrix for query */
XTensor wv;
/* bias for query */
XTensor bv;
/* max relative window size */
XTensor rp_embedding_k;
/* transformation after dot-product attention */
XTensor wa;
XTensor wbig;
/* bias after dot-product attention */
XTensor ba;
/* size of transformed Q and K */
int dk;
......@@ -84,6 +139,10 @@ public:
/* dropout probability */
DTYPE dropoutP;
/* max relative window size */
int max_relative_position;
public:
/* constructor */
T2TAttention();
......@@ -97,13 +156,18 @@ public:
int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining);
/* make the network given a big tensor that keeps keys, queries and values */
XTensor MakeBig(XTensor &kqv, XTensor &mask, bool isTraining);
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor *mask,
bool isTraining, Cache* cache, int cacheType);
/* make the attention network given keys, queries and values (after linear transformation) */
XTensor MakeAttention(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining);
XTensor MakeAttention(XTensor *k, XTensor *q, XTensor *v, const XTensor *mask, bool isTraining, bool is_encoder);
/* make the attention network given keys, queries and values (after linear transformation) */
XTensor MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTensor *mask, bool isTraining, bool is_encoder);
void GetRPEmbedding(XTensor* emb_matrix, const int len_q, const int len_kv, const int max_relative_length, const int device_id, const bool is_encoder);
void RPDotProduct(XTensor* x, XTensor* y, XTensor* z, XTensor* attention, const bool is_key);
};
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-04-25
* it is cold today but i'll move to a warm place tomorrow :)
*/
#ifndef __T2TBATCHLOADER_H__
#define __T2TBATCHLOADER_H__
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
#define MAX_SEQUENCE_LENGTH 1024 * 4
/* node to keep batch information */
struct BatchNode
{
/* begining position */
int beg;
/* end position */
int end;
/* maximum word number on the encoder side */
int maxEnc;
/* maximum word number on the decoder side */
int maxDec;
/* a key for sorting */
int key;
};
class T2TBatchLoader
{
public:
/* buffer for loading words */
int * buf;
/* another buffer */
int * buf2;
/* batch buf */
BatchNode * bufBatch;
/* buffer size */
int bufSize;
/* size of batch buffer */
int bufBatchSize;
/* length of each sequence */
int * seqLen;
/* another array */
int * seqLen2;
/* offset of the first word for each sequence */
int * seqOffset;
/* number of sequences in the buffer */
int nseqBuf;
/* offset for next sequence in the buffer */
int nextSeq;
/* offset for next batch */
int nextBatch;
/* indicates whether we double the </s> symbol for the output of lms */
bool isDoubledEnd;
/* indicates whether we use batchsize = max * sc
rather rather than batchsize = word-number, where max is the maximum
length and sc is the sentence number */
bool isSmallBatch;
/* counterpart of "isSmallBatch" */
bool isBigBatch;
/* randomize batches */
bool isRandomBatch;
/* bucket size */
int bucketSize;
public:
/* constructor */
T2TBatchLoader();
/* de-constructor */
~T2TBatchLoader();
/* initialization */
void Init(int argc, char ** argv);
/* load data to buffer */
int LoadBuf(FILE * file, bool isSorted, int step);
/* clear data buffer */
void ClearBuf();
/* set the random batch flag */
void SetRandomBatch(bool flag = true);
/* load a batch of sequences */
int LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining);
/* load a batch of sequences (for language modeling) */
int LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem,
bool isTraining);
/* load a batch of sequences (for machine translation) */
int LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining);
/* shuffle the data file */
void Shuffle(const char * srcFile, const char * tgtFile);
};
}
#endif
\ No newline at end of file
......@@ -34,20 +34,24 @@ AttDecoder::AttDecoder()
attentions = NULL;
fnns = NULL;
attLayerNorms = NULL;
fnnLayerNorms = NULL;
attentionsEnde = NULL;
attEndeLayerNorms = NULL;
decodeLayerNorm = NULL;
selfCache = NULL;
contextCache = NULL;
}
/* de-constructor */
AttDecoder::~AttDecoder()
{
delete[] selfCache;
delete[] contextCache;
delete[] attentions;
delete[] fnns;
delete[] attLayerNorms;
delete[] fnnLayerNorms;
delete[] attentionsEnde;
delete[] attEndeLayerNorms;
delete decodeLayerNorm;
}
/*
......@@ -69,7 +73,7 @@ void AttDecoder::InitModel(int argc, char ** argv,
mem = myMem;
ignored = myIgnored;
LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
LoadParamInt(argc, argv, "nlayer", &nlayer, 3);
LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "esize", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "vsizetgt", &vSize, -1);
......@@ -84,19 +88,21 @@ void AttDecoder::InitModel(int argc, char ** argv,
attentions = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer];
attLayerNorms = new T2TLN[nlayer];
fnnLayerNorms = new T2TLN[nlayer];
attentionsEnde = new T2TAttention[nlayer];
attEndeLayerNorms = new T2TLN[nlayer];
decodeLayerNorm = new T2TLN;
selfCache = new Cache[nlayer];
contextCache = new Cache[nlayer];
/* initialize the stacked layers */
for (int i = 0; i < nlayer; i++) {
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
fnns[i].InitModel(argc, argv, myDevID, myMem);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
attentionsEnde[i].InitModel(argc, argv, true, myIgnored, myDevID, myMem);
attEndeLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
}
decodeLayerNorm->InitModel(argc, argv, myDevID);
}
/*
......@@ -108,11 +114,11 @@ make the decoding network
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining)
XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, const XTensor *mask, XTensor &maskEncDec, bool isTraining)
{
XTensor x;
x = embedder.Make(inputDec);
x = embedder.Make(inputDec, inputDec.GetDim(1));
/* dropout */
if(isTraining && dropoutP > 0)
......@@ -123,50 +129,50 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, X
XTensor ende;
XTensor ln;
XTensor fnn;
XTensor res;
XTensor inputNorm;
XTensor attNorm;
/* layer normalization */
inputNorm = attLayerNorms[i].Make(x);
//inputNorm.Dump(stderr, "inputNorm", 10);
/******************/
/* self attention */
att = attentions[i].MakeBig(x, mask, isTraining);
att = attentions[i].Make(inputNorm, inputNorm, inputNorm, NULL, isTraining, &selfCache[i], 1);
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
/* residual connection */
res = Sum(att, x);
_SumMe(&att, &x);
//att.Dump(stderr, "Sum(att, x)", 10);
/* layer normalization */
x = attLayerNorms[i].Make(res);
attNorm = attEndeLayerNorms[i].Make(att);
//attNorm.Dump(stderr, "attNorm", 10);
/*****************************/
/* encoder-decoder attention */
ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, maskEncDec, isTraining);
ende = attentionsEnde[i].Make(outputEnc, attNorm, outputEnc, &maskEncDec, isTraining, &contextCache[i], 2);
//ende.Dump(stderr, "ende atten", 10);
/* dropout */
if(isTraining && dropoutP > 0)
ende = Dropout(ende, dropoutP);
/* residual connection */
res = Sum(ende, x);
_SumMe(&ende, &att);
//res.Dump(stderr, "Sum(ende, att)", 10);
/* layer normalization */
x = attEndeLayerNorms[i].Make(res);
/*******/
/* fnn */
fnn = fnns[i].Make(x, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP);
/* residual connection */
res = Sum(fnn, x);
x = fnns[i].Make(ende, isTraining);
//x.Dump(stderr, "fnns[i]", 10);
/* layer normalization */
x = fnnLayerNorms[i].Make(res);
}
x = decodeLayerNorm->Make(x);
//x.Dump(stderr, "decodeLayerNorm", 10);
x.SetName(DECODING_NAME);
......
......@@ -22,6 +22,7 @@
#ifndef __T2TDECODER_H__
#define __T2TDECODER_H__
#include <array>
#include "T2TEncoder.h"
namespace transformer
......@@ -56,7 +57,7 @@ public:
DTYPE dropoutP;
/* some positions can be ignored in attention. this is useful in lm where the first position needs
* special design for the attention model. */
* special design for the attention model. */
int ignored;
/* embedding of word at each position */
......@@ -68,12 +69,12 @@ public:
/* attention model of each layer */
T2TAttention * attentions;
/* layer normalization for fnn */
T2TLN * fnnLayerNorms;
/* layer normalization for attention */
T2TLN * attLayerNorms;
/* layer normalization for decoder */
T2TLN * decodeLayerNorm;
/* input tensor of the encoder */
XTensor * input;
......@@ -85,6 +86,13 @@ public:
/* layer normalization for encoder-decoder attention */
T2TLN * attEndeLayerNorms;
/* layer cache list */
Cache* selfCache;
/* layer cache list */
Cache* contextCache;
public:
/* constructor */
AttDecoder();
......@@ -98,7 +106,7 @@ public:
int myDevID = -1, XMem * myMem = NULL);
/* make the decoding network */
XTensor Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining);
XTensor Make(XTensor &inputDec, XTensor &outputEnc, const XTensor *mask, XTensor &maskEncDec, bool isTraining);
};
}
......
......@@ -60,17 +60,19 @@ void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, XMem * myMem, b
LoadParamInt(argc, argv, "vsizetgt", &vSize, -1);
}
//LoadParamInt(argc, argv, "vsize", &vSize, -1);
LoadParamInt(argc, argv, "maxlen", &maxLength, 512);
LoadParamInt(argc, argv, "maxlen", &maxLength, 1024);
LoadParamInt(argc, argv, "d", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "pad", &padIdx, 1);
InitTensor2D(&w, vSize, eSize, X_FLOAT, devID, mem);
InitTensor2DV2(&w, vSize, eSize, X_FLOAT, devID);
maxLength = maxLength + 1 + 1;
DTYPE v = 1.0F/(float)sqrt((float)eSize);
w.SetDataRandn(0, v);
/* create the positional embedding matrix */
MakePosEmbedding(eSize, d, maxLength);
MakePosEmbedding(eSize, d, maxLength, padIdx);
}
/*
......@@ -79,9 +81,9 @@ make positional embeddings (of size eSize * length)
>> d - dimension size of the hidden layers
>> length - length of the sequence
*/
void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length, int padIdx)
{
InitTensor2D(&posEmbeddingBase, length, eSize, X_FLOAT, devID, mem);
InitTensor2DV2(&posEmbeddingBase, length, eSize, X_FLOAT, devID);
float * data = new float[posEmbeddingBase.unitNum];
......@@ -91,76 +93,77 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
int channelSize = eSize / 2;
int offset = 0;
for(int i = 0; i < channelSize; i++){
dp[offset++] = (float)sin(pos/pow(10000.0F, 2.0F*i/(d - 2)));
dp[offset++] = (float)sin(pos * exp(-i * log(10000.0F) / (channelSize - 1)));
}
for(int i = 0; i < channelSize; i++){
dp[offset++] = (float)cos(pos/pow(10000.0F, 2.0F*i/(d - 2)));
dp[offset++] = (float)cos(pos * exp(-i * log(10000.0F) / (channelSize - 1)));
}
/*
for(int k = 0; k < eSize; k++){
if(k % 2 == 0){
int i = k/2;
dp[k] = (float)sin(pos/pow(10000.0F, 2.0F*i/d));
}
else{
int i = (k - 1)/2;
dp[k] = (float)cos(pos/pow(10000.0F, 2.0F*i/d));
}
}
*/
}
/* zero pad */
int padStart = padIdx * eSize;
for (int i = padStart; i < padStart + eSize; i++)
data[i] = 0.F;
posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum);
delete[] data;
}
/*
make the network
*/
XTensor T2TEmbedder::Make(XTensor &input)
XTensor T2TEmbedder::Make(XTensor &input, int prevLen)
{
//CheckNTErrors(input.GetDim(-1) == vSize, "Wrong vocabulary size!");
/* assert padding index is 1 */
CheckNTErrors(input.order > 1, "Wrong input tensor size!");
CheckNTErrors(input.dimSize[input.order - 1] < maxLength, "The sequence is too long!");
CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\"");
CheckNTErrors(eSize > 0, "set embedding size by \"-esize\"");
int dims[MAX_TENSOR_DIM_NUM];
memcpy(dims, input.dimSize, input.order * sizeof(int));
dims[input.order] = eSize;
XTensor wordEmbedding, position, posEmbedding;
InitTensor(&position, &input);
int* posData = new int[input.unitNum];
XTensor wordEmbedding;
XTensor posEmbedding;
XTensor inputCPU;
InitTensorOnCPU(&inputCPU, &input);
_CopyValues(&input, &inputCPU);
bool match = (posEmbedding.order == input.order);
if(match){
for(int i = 0; i < input.order; i++){
if(dims[i] != posEmbedding.GetDim(i))
match = false;
for (int i = 0; i < inputCPU.GetDim(0); i++) {
int startNoPad = 2 + prevLen - 1;
int* p = ((int*)inputCPU.data) + i * inputCPU.GetDim(1);
for (int j = 0; j < inputCPU.GetDim(1); j++) {
if (p[j] == 1) {
posData[i * inputCPU.GetDim(1) + j] = 1;
}
else {
posData[i * inputCPU.GetDim(1) + j] = startNoPad++;
}
}
}
position.SetData(posData, position.unitNum);
delete[] posData;
/* we make positional embeddings first */
//if(!match){
if(true){
InitTensor(&posEmbedding, input.order + 1, dims, X_FLOAT, 1.0F, devID, mem);
XTensor * posTMP = NewTensorBuf(2, dims + 1, X_FLOAT, 1.0F, devID, mem);
_CopyValues(&posEmbeddingBase, 0, posTMP->unitNum, posTMP, 0);
_Unsqueeze(posTMP, &posEmbedding, 0, dims[0]);
DelTensorBuf(posTMP);
posEmbedding = Gather(posEmbeddingBase, position);
}
/* then we make word embeddings */
wordEmbedding = Gather(w, input);
wordEmbedding = Linear(wordEmbedding, (float)sqrt((float)eSize));
/* we sum over the two embeddings */
return wordEmbedding + posEmbedding;
return wordEmbedding;
}
}
......@@ -56,6 +56,9 @@ public:
/* dimension size of the hidden layers in the t2t model */
int d;
/* padding index */
int padIdx;
/* word embedding matrix */
XTensor w;
......@@ -74,10 +77,10 @@ public:
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL, bool isEnc = true);
/* make positional embeddings */
void MakePosEmbedding(int eSize, int d, int length);
void MakePosEmbedding(int eSize, int d, int length, int padIdx);
/* make the network */
XTensor Make(XTensor &input);
XTensor Make(XTensor &input, int prevLen=0);
};
}
......
......@@ -34,7 +34,7 @@ AttEncoder::AttEncoder()
attentions = NULL;
fnns = NULL;
attLayerNorms = NULL;
fnnLayerNorms = NULL;
encodeLayerNorm = NULL;
}
/* de-constructor */
......@@ -43,7 +43,7 @@ AttEncoder::~AttEncoder()
delete[] attentions;
delete[] fnns;
delete[] attLayerNorms;
delete[] fnnLayerNorms;
delete encodeLayerNorm;
}
/*
......@@ -63,7 +63,7 @@ void AttEncoder::InitModel(int argc, char ** argv,
mem = myMem;
ignored = myIgnored;
LoadParamInt(argc, argv, "nlayer", &nlayer, 6);
LoadParamInt(argc, argv, "nlayer", &nlayer, 35);
LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "esize", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "vsize", &vSize, -1);
......@@ -73,20 +73,21 @@ void AttEncoder::InitModel(int argc, char ** argv,
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsize\"");
/* embedding model */
embedder.InitModel(argc, argv, devID, mem);
embedder.InitModel(argc, argv, devID);
attentions = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer];
attLayerNorms = new T2TLN[nlayer];
fnnLayerNorms = new T2TLN[nlayer];
encodeLayerNorm = new T2TLN;
/* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
fnns[i].InitModel(argc, argv, myDevID, myMem);
attLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
}
encodeLayerNorm->InitModel(argc, argv, myDevID, myMem);
}
/*
......@@ -97,49 +98,34 @@ make the encoding network
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, bool isTraining)
XTensor AttEncoder::Make(XTensor &input, XTensor *mask, XTensor &maskEncDec, bool isTraining)
{
XTensor x;
x = embedder.Make(input);
/* dropout */
if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
x = embedder.Make(input, 0);
for(int i = 0; i < nlayer; i++){
XTensor att;
XTensor ln;
XTensor fnn;
XTensor res;
XTensor inputNorm;
/* layer normalization */
inputNorm = attLayerNorms[i].Make(x);
/* self attention */
att = attentions[i].MakeBig(x, mask, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
att = attentions[i].Make(inputNorm, inputNorm, inputNorm, mask, isTraining, NULL, 0);
/* residual connection */
res = Sum(att, x);
/* layer normalization */
x = attLayerNorms[i].Make(res);
/* fnn */
fnn = fnns[i].Make(x, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP);
x = fnns[i].Make(res, isTraining);
}
/* residual connection */
res = Sum(fnn, x);
x = encodeLayerNorm->Make(x);
/* layer normalization */
x = fnnLayerNorms[i].Make(res);
}
x.SetName(ENCODING_NAME);
input.SetName(ENCODING_INPUT_NAME);
......@@ -153,7 +139,7 @@ make the encoding network (wrapper)
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining)
XTensor AttEncoder::Make(XTensor &input, XTensor *mask, bool isTraining)
{
XTensor nothing;
......
......@@ -43,7 +43,7 @@ class T2TEncoder
{
public:
virtual
XTensor Make(XTensor &input, XTensor &mask, XTensor &mask2, bool isTraining) = 0;
XTensor Make(XTensor &input, XTensor *mask, XTensor &mask2, bool isTraining) = 0;
};
/*
......@@ -52,7 +52,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder
{
public:
XTensor Make(XTensor &input, XTensor &mask, XTensor &mask2, bool isTraining);
XTensor Make(XTensor &input, XTensor *mask, XTensor &mask2, bool isTraining);
};
......@@ -96,12 +96,12 @@ public:
/* attention model of each layer */
T2TAttention * attentions;
/* layer normalization for fnn */
T2TLN * fnnLayerNorms;
/* layer normalization for attention */
T2TLN * attLayerNorms;
/* layer normalization for encoder */
T2TLN * encodeLayerNorm;
/* input tensor of the encoder */
XTensor * input;
......@@ -121,10 +121,10 @@ public:
int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */
XTensor Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, bool isTraining);
XTensor Make(XTensor &input, XTensor *mask, XTensor &maskEncDec, bool isTraining);
/* make the encoding network (wrapper) */
XTensor Make(XTensor &input, XTensor &mask, bool isTraining);
XTensor Make(XTensor &input, XTensor *mask, bool isTraining);
};
......
......@@ -62,20 +62,22 @@ void T2TFNN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.1F);
LoadParamFloat(argc, argv, "dropoutfnn", &dropoutP, 0);
InitTensor2D(&w1, inSize, hSize, X_FLOAT, devID, mem);
InitTensor1D(&b1, hSize, X_FLOAT, devID, mem);
InitTensor2DV2(&w1, hSize, inSize, X_FLOAT, devID);
InitTensor1DV2(&b1, hSize, X_FLOAT, devID);
InitTensor2D(&w2, hSize, outSize, X_FLOAT, devID, mem);
InitTensor1D(&b2, outSize, X_FLOAT, devID, mem);
InitTensor2DV2(&w2, outSize, hSize, X_FLOAT, devID);
InitTensor1DV2(&b2, outSize, X_FLOAT, devID);
float scale = 1.0F;
float finfout1 = (float)sqrt(6.0F * scale/(inSize + hSize));
float finfout2 = (float)sqrt(6.0F * scale/(hSize + outSize));
w1.SetDataRand(-finfout1, finfout1);
b1.SetZeroAll();
w2.SetDataRand(-finfout2, finfout2);
b2.SetZeroAll();
fnnLayerNorm.InitModel(argc, argv, myDevID, myMem);
//float scale = 1.0F;
//float finfout1 = (float)sqrt(6.0F * scale/(inSize + hSize));
//float finfout2 = (float)sqrt(6.0F * scale/(hSize + outSize));
//
//w1.SetDataRand(-finfout1, finfout1);
//b1.SetZeroAll();
//w2.SetDataRand(-finfout2, finfout2);
//b2.SetZeroAll();
}
/*
......@@ -89,15 +91,16 @@ XTensor T2TFNN::Make(XTensor &input, bool isTraining)
XTensor t1;
/* t1 = max(0, x * w1 + b1) */
//t1 = Rectify(MMul(input, w1) + b1);
t1 = Rectify(MulAndShift(input, w1, b1));
t1 = Rectify(MulAndShift(fnnLayerNorm.Make(input), X_NOTRANS, w1, X_TRANS, b1));
if(isTraining && dropoutP > 0)
t1 = Dropout(t1, dropoutP);
/* result = t1 * w2 + b2 */
//return MMul(t1, w2) + b2;
return MulAndShift(t1, w2, b2);
XTensor res;
res = MulAndShift(t1, X_NOTRANS, w2, X_TRANS, b2);
_SumMe(&res, &input);
return res;
}
......
......@@ -22,6 +22,7 @@
#ifndef __T2TFNN_H__
#define __T2TFNN_H__
#include "T2TLayerNormal.h"
#include "../../tensor/XTensor.h"
using namespace nts;
......@@ -60,6 +61,9 @@ public:
/* bias of transformation 2 */
XTensor b2;
/* layer normalization for fnn */
T2TLN fnnLayerNorm;
/* dropout probability */
DTYPE dropoutP;
......
......@@ -56,11 +56,11 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
d = 0;
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
InitTensor1D(&w, d, X_FLOAT, devID, mem);
InitTensor1D(&b, d, X_FLOAT, devID, mem);
InitTensor1DV2(&w, d, X_FLOAT, devID);
InitTensor1DV2(&b, d, X_FLOAT, devID);
w.SetDataRand(1.0F, 1.0F);
b.SetZeroAll();
//w.SetDataRand(1.0F, 1.0F);
//b.SetZeroAll();
}
/*
......
......@@ -35,9 +35,7 @@ XTensor T2TLengthPenalizer::GNMT(const XTensor & length, float alpha)
XTensor base;
XTensor lp;
//base = ScaleAndShift(ScaleAndShift(length, 0, 5.0F), 1.0F/(5 + 1));
base = (length + 5)/(1 + 5);
base = (length + 5)/(1.0F + 5.0F);
lp = Power(base, alpha);
return lp;
......
......@@ -71,11 +71,14 @@ public:
/* initialize the model */
void InitModel(int argc, char ** argv);
/* reset cache for decoder */
void ResetCache();
/* make the encoding network */
XTensor MakeEncoder(XTensor &input, XTensor &mask, bool isTraining);
XTensor MakeEncoder(XTensor &input, XTensor *mask, bool isTraining);
/* make the encoding network */
XTensor MakeDecoder(XTensor &inputEnc, XTensor &inputDec, XTensor &mask, XTensor &MaskEncDec, bool isTraining);
XTensor MakeDecoder(XTensor &inputEnc, XTensor &inputDec, XTensor *mask, XTensor &MaskEncDec, bool isTraining);
/* make the network for langauge modeling (with the output softmax layer) */
void MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
......@@ -95,7 +98,7 @@ public:
/* make the mask of the decoder */
void MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
XTensor &paddingEnc, XTensor &paddingDec,
XTensor &maskDec, XTensor &maskEncDec);
XTensor &maskDec, XTensor &maskEncDec, int incDim);
/* get parameter matrics */
void GetParams(TensorList &list);
......@@ -107,6 +110,9 @@ public:
void Read(const char * fn);
};
void FastRead(XTensor* x, FILE* f);
void FastDump(XTensor* x, FILE* f);
void ConvertModelFile(const TensorList* params, const char* src, const char* tgt);
}
#endif
......@@ -25,6 +25,7 @@
#include "T2TEmbedding.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
......@@ -61,14 +62,14 @@ void T2TOutput::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
LoadParamInt(argc, argv, "d", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamFloat(argc, argv, "outputminmax", &minmax, 0.08F);
InitTensor2D(&w, hSize, vSize, X_FLOAT, devID, mem);
InitTensor2DV2(&w, hSize, vSize, X_FLOAT, devID);
float scale = 1.0F;
float finfout = (float)sqrt(6.0F * scale/(hSize + vSize));
w.SetDataRand(-finfout, finfout);
//float scale = 1.0F;
//float finfout = (float)sqrt(6.0F * scale/(hSize + vSize));
//w.SetDataRand(-finfout, finfout);
DTYPE v = 1.0F/(float)sqrt((float)hSize);
w.SetDataRandn(0, v);
//DTYPE v = 1.0F/(float)sqrt((float)hSize);
//w.SetDataRandn(0, v);
}
/*
......@@ -81,7 +82,8 @@ XTensor T2TOutput::Make(XTensor &input)
{
XTensor &x = input;
return LogSoftmax(MMul(x, w), -1);
return Softmax(MMul(x, X_NOTRANS, w, X_TRANS), -1);
//return MulAndShift(x, X_NOTRANS, w, X_TRANS, b);
}
/*
......@@ -93,8 +95,8 @@ void T2TOutput::Make(XTensor &input, XTensor &output)
{
XTensor &x = input;
//output = LogSoftmax(MMul(x, w), -1);
output = Softmax(MMul(x, w), -1);
output = LogSoftmax(MMul(x, X_NOTRANS, w, X_NOTRANS), -1);
output.SetName(OUTPUT_NAME);
}
......
......@@ -146,7 +146,7 @@ public:
~T2TPredictor();
/* create an initial state */
void Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state);
void Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state, XTensor * encoding);
/* set the start symbol */
void SetStartSymbol(int symbol);
......@@ -155,7 +155,9 @@ public:
void Read(T2TModel * model, T2TStateBundle * state);
/* predict the next state */
void Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc);
void Predict(T2TStateBundle * next, XTensor & encoding,
XTensor & inputEnc, XTensor & paddingEnc,
XTensor& nonFinished, bool updateFinished);
/* generate paths up to the states of the current step */
XTensor GeneratePaths(T2TStateBundle * state);
......
......@@ -62,6 +62,12 @@ private:
/* start symbol */
int startSymbol;
/* scalar of the input sequence (for max number of search steps) */
float scalarMaxLength;
/* indicate whether the early stop strategy is used */
bool isEarlyStop;
public:
/* constructor */
T2TSearch();
......@@ -73,7 +79,8 @@ public:
void Init(int argc, char ** argv);
/* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output);
void Search(T2TModel * model, XTensor * input, XTensor * padding,
XTensor * output, XTensor * score);
/* preparation */
void Prepare(int myBatchSize,int myBeamSize);
......@@ -94,7 +101,7 @@ public:
void FillHeap(T2TStateBundle * beam);
/* save the output sequences in a tensor */
void Dump(XTensor * output);
void Dump(XTensor * output, XTensor * score);
/* check if the token is an end symbol */
bool IsEnd(int token);
......@@ -102,6 +109,17 @@ public:
/* set end symbols for search */
void SetEnd(const int * tokens, const int tokenNum);
/* penalize beams that completed */
int UpdateCompleted(T2TStateBundle * beam, XTensor & encoding,
XTensor& inputEnc, XTensor& paddingEnc,
IntList completedStates, XTensor &nonFinished);
/* check whether all hypotheses are completed */
bool IsAllCompleted(T2TStateBundle * beam);
/* check if any hypotheses are completed */
IntList IsAnyCompleted(T2TStateBundle * beam);
/* make a mask to prevent duplicated entries in beam expansion for the first position */
XTensor MakeFirstMask(T2TStateBundle * beam);
};
......
......@@ -15,17 +15,17 @@
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27
*/
#include <math.h>
#include "T2TUtility.h"
#include "T2TTester.h"
#include "T2TSearch.h"
#include "T2TUtility.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/core/CHeader.h"
#include "../../network/XNoder.h"
using namespace nts;
......@@ -35,6 +35,7 @@ namespace transformer
/* constructor */
T2TTester::T2TTester()
{
}
/* de-constructor */
......@@ -43,127 +44,120 @@ T2TTester::~T2TTester()
}
/* initialize the model */
void T2TTester::Init(int argc, char ** argv)
void T2TTester::Init(int argc, char** argv)
{
LoadParamInt(argc, argv, "vsize", &vSize, 1);
LoadParamInt(argc, argv, "vsizetgt", &vSizeTgt, vSize);
LoadParamInt(argc, argv, "sentBatch", &sentBatch, 1);
LoadParamBool(argc, argv, "sort", &batchLoader.sortBuffer, false);
batchLoader.Init(argc, argv);
seacher.Init(argc, argv);
}
/*
Result ExtractRes(XTensor& output, IntList& indices, int i) {
Result res;
XTensor sent, srcIdx, tgtIdx;
InitTensor1D(&srcIdx, 1, X_INT, output.devID);
int idx[]{ i };
srcIdx.SetData(idx, 1);
InitTensor(&tgtIdx, &srcIdx);
tgtIdx.SetAscendingOrder(0);
sent = CopyIndexed(output, 0, srcIdx, tgtIdx);
res.data.Add((int*)sent.data, sent.unitNum);
res.id = indices[i];
return res;
}
/*
test the model
>> fn - test data file
>> ofn - output data file
>> model - model that is trained
*/
void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model)
{
int wc = 0;
int ws = 0;
int wordCount = 0;
int wordCountTotal = 0;
int sentCount = 0;
int batchCount = 0;
float loss = 0;
/* data files */
FILE * file = fopen(fn, "rb");
CheckNTErrors(file, "Cannot read the test file");
FILE * ofile = fopen(ofn, "wb");
FILE* ofile = fopen(ofn, "w");
CheckNTErrors(ofile, "Cannot open the output file");
int devID = model->devID;
XMem * mem = model->mem;
XMem* mem = model->mem;
XNet net;
double startT = GetClockSec();
wordCount = 0;
/* batch of input sequences */
XTensor batchEnc;
XTensor batchDec;
/* label */
XTensor label;
/* padding */
XTensor paddingEnc;
XTensor paddingDec;
/* gold standard */
XTensor gold;
/* an array that keeps the sequences */
int * seqs = new int[MILLION];
batchLoader.SetRandomBatch(false);
batchLoader.ClearBuf();
while(batchLoader.LoadBatch(file, model->isLM,
&batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold, &label,
seqs, vSize, vSizeTgt,
1, 1, false, ws, wc, devID, mem, false))
{
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch!");
CheckNTErrors(!model->isLM, "Only MT model is supported!");
int* seqs = new int[MILLION];
batchLoader.Init(fn, 100, true);
int count = 0;
while (!batchLoader.IsEmpty()) {
count++;
printf("sent: %d\n", count);
wordCount = 0;
/* reset cache for decoder */
model->ResetCache();
XTensor output;
IntList indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID);
seacher.Search(model, &batchEnc, &paddingEnc, &output);
XTensor output, score;
Dump(ofile, &output);
seacher.Search(model, &batchEnc, &paddingEnc, &output, &score);
for (int i = 0; i < indices.Size(); i++)
batchLoader.resBuffer.Add(ExtractRes(output, indices, i));
float prob = 0;
loss += -prob;
wc = batchEnc.GetDim(-1);
wordCount += wc;
wordCountTotal += wc;
sentCount += batchEnc.GetDim(-2);
batchCount += 1;
if (batchCount % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr,
"[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n",
elapsed, sentCount, wordCount);
}
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, "[INFO] elapsed=%.1fs, sent=%d, sword=%d\n", elapsed, sentCount, wordCount);
}
fclose(file);
fclose(ofile);
batchLoader.SortRes();
for (int i = 0; i < batchLoader.resBuffer.Size(); i++)
Dump(ofile, batchLoader.resBuffer[i].data);
fclose(ofile);
delete[] seqs;
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)\n",
elapsed,wordCountTotal, exp(loss/wordCount));
double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, sent=%d)\n", elapsed, wordCountTotal, sentCount);
}
/*
dump the result into the file
>> file - data file
>> output - output tensor
>> output - output list
*/
void T2TTester::Dump(FILE * file, XTensor * output)
void T2TTester::Dump(FILE* file, IntList& output)
{
int seqLength = output->GetDim(-1);
for (int i = 0; i < output->unitNum; i += seqLength) {
for (int j = 0; j < seqLength; j++) {
int w = output->GetInt(i + j);
fprintf(file, "%d ", w);
if (w < 0)
break;
}
fprintf(file, "\n");
for (int i = 0; i < output.Size(); i++) {
int w = output[i];
if (w < 0)
break;
fprintf(file, "%d ", w);
}
fprintf(file, "\n");
}
}
......@@ -24,7 +24,7 @@
#define __T2TTESTER_H__
#include "T2TSearch.h"
#include "T2TBatchLoader.h"
#include "t2tdata/DataSet.h"
namespace transformer
{
......@@ -38,9 +38,12 @@ public:
/* vocabulary size of the target side */
int vSizeTgt;
/* batch size for sentences */
int sentBatch;
/* for batching */
T2TBatchLoader batchLoader;
DataSet batchLoader;
/* decoder for inference */
T2TSearch seacher;
......@@ -59,7 +62,7 @@ public:
void Test(const char * fn, const char * ofn, T2TModel * model);
/* dump the result into the file */
void Dump(FILE * file, XTensor * output);
void Dump(FILE * file, IntList& output);
};
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-08-02
*/
#ifndef __T2TTRAINER_H__
#define __T2TTRAINER_H__
#include "T2TModel.h"
#include "T2TBatchLoader.h"
#include "../../tensor/function/FHeader.h"
using namespace nts;
namespace transformer
{
/* trainer of the T2T model */
class T2TTrainer
{
public:
/* paramter number */
int argNum;
/* parameter array */
char ** argArray;
/* dimension size of each inner layer */
int d;
/* step number of warm-up for training */
int nwarmup;
/* vocabulary size of the source side */
int vSize;
/* vocabulary size of the target side */
int vSizeTgt;
/* learning rate */
float lrate;
/* the parameter that controls the maximum learning rate in training */
float lrbias;
/* sentence batch size */
int sBatchSize;
/* word batch size */
int wBatchSize;
/* training epoch number */
int nepoch;
/* traing step number */
int nstep;
/* indicates whether we use adam */
bool useAdam;
/* hyper parameters of adam*/
float adamBeta1;
float adamBeta2;
float adamDelta;
float adamBeta1T;
float adamBeta2T;
/* list of the moment of the parameter matrics */
TensorList moments;
/* list of the 2nd order moment of the parameter matrics */
TensorList moments2nd;
/* indicates whether the data file is shuffled for training */
bool isShuffled;
/* the factor of label smoothing */
DTYPE labelSmoothingP;
/* number of steps after which we make a checkpoint */
int nStepCheckpoint;
/* indicates whether we make a checkpoint after each traing epoch */
bool useEpochCheckpoint;
/* number of batches on which we do model update */
int updateStep;
/* indicates whether we intend to debug the net */
bool isDebugged;
/* indicates whether the sequence is sorted by length */
bool isLenSorted;
/* for batching */
T2TBatchLoader batchLoader;
public:
/* constructor */
T2TTrainer();
/* de-constructor */
~T2TTrainer();
/* initialize the trainer */
void Init(int argc, char ** argv);
/* train the model */
void Train(const char * fn, const char * validFN, const char * modelFN, T2TModel * model);
/* test the model */
void Test(const char * fn, const char * ofn, T2TModel * model);
/* make a checkpoint */
void MakeCheckpoint(T2TModel * model, const char * validFN, const char * modelFN, const char * label, int id);
/* get word probabilities for a batch of sequences */
float GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs);
/* update the model by delta rule */
void Update(T2TModel * model, const float lr);
/* prepare model for training */
void PrepareModel(T2TModel * model);
/* do padding on the output */
void PadOutput(XTensor * output, XTensor * gold, XTensor * padding);
/* recale the output and gold tensors for normalized loss */
void RescaleOutput(XTensor * output, XTensor * gold, XTensor * padding);
/* perform label smoothing */
void LabelSmooth(XTensor * gold, XTensor * smoothed, DTYPE p);
};
}
#endif
......@@ -22,6 +22,7 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "T2TUtility.h"
namespace transformer
{
......@@ -114,4 +115,11 @@ void ShowParams(int argc, char ** argv)
fprintf(stderr, "\n");
}
/* dump tensors */
void DumpTensors(std::initializer_list<nts::XTensor*> list) {
int i(0);
for (auto& x : list)
x->Dump(stderr, std::to_string(++i).c_str());
}
}
......@@ -23,6 +23,9 @@
#define __T2TUTILITY_H__
#include <stdio.h>
#include <string>
#include "..//..//tensor/XTensor.h"
#include <initializer_list>
namespace transformer
{
......@@ -38,6 +41,10 @@ void LoadParamFloat(int argc, char ** argv, const char * name, float * p, float
/* show arguments */
void ShowParams(int argc, char ** argv);
/* dump tensors */
void DumpTensors(std::initializer_list<nts::XTensor*> list);
extern int llnum;
extern FILE * tf;
......
......@@ -24,92 +24,113 @@
#include "Transformer.h"
#include "T2TModel.h"
#include "T2TUtility.h"
#include "T2TTrainer.h"
#include "T2TPredictor.h"
#include "T2TTester.h"
#include "../../tensor/XDevice.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/XGlobal.h"
#include "..//..//model/Model.h"
namespace transformer
{
int TransformerMain(int argc, const char ** argv)
{
if(argc == 0)
return 1;
char ** args = new char*[argc];
for(int i = 0; i < argc; i++){
args[i] = new char[strlen(argv[i]) + 1];
strcpy(args[i], argv[i]);
struct AttModel : Model {
AttModel(int devID) {
Register("w1", {2,3,4}, X_FLOAT, devID);
Register("b1", {2,3,4}, X_FLOAT, devID);
Register("3", {2,3,4}, X_FLOAT, devID);
}
};
tmpFILE = fopen("tmp.txt", "wb");
ShowParams(argc, args);
bool isBeamSearch = false;
char * trainFN = new char[MAX_LINE_LENGTH];
char * modelFN = new char[MAX_LINE_LENGTH];
char * testFN = new char[MAX_LINE_LENGTH];
char * outputFN = new char[MAX_LINE_LENGTH];
LoadParamString(argc, args, "train", trainFN, "");
LoadParamString(argc, args, "model", modelFN, "");
LoadParamString(argc, args, "test", testFN, "");
LoadParamString(argc, args, "output", outputFN, "");
LoadParamBool(argc, args, "beamsearch", &isBeamSearch, false);
srand((unsigned int)time(NULL));
T2TTrainer trainer;
trainer.Init(argc, args);
T2TModel model;
model.InitModel(argc, args);
/* learn model parameters */
if(strcmp(trainFN, ""))
trainer.Train(trainFN, testFN, strcmp(modelFN, "") ? modelFN : "checkpoint.model", &model);
/* save the final model */
if(strcmp(modelFN, "") && strcmp(trainFN, ""))
model.Dump(modelFN);
/* load the model if neccessary */
if(strcmp(modelFN, ""))
model.Read(modelFN);
/* test the model on the new data */
if(strcmp(testFN, "") && strcmp(outputFN, "")){
/* beam search */
if(isBeamSearch){
T2TTester searcher;
searcher.Init(argc, args);
searcher.Test(testFN, outputFN, &model);
}
/* forced decoding */
else{
T2TTrainer tester;
tester.Init(argc, args);
tester.Test(testFN, outputFN, &model);
}
struct Transformer {
AttModel *att;
Transformer(int devID) {
att = new AttModel(devID);
}
~Transformer() {
delete att;
}
};
delete[] trainFN;
delete[] modelFN;
delete[] testFN;
delete[] outputFN;
for(int i = 0; i < argc; i++)
delete[] args[i];
delete[] args;
fclose(tmpFILE);
void test() {
Transformer model(0);
model.att->Get("w1")->SetZeroAll();
model.att->Get("w1")->Dump(stderr);
}
int TransformerMain(int argc, const char ** argv)
{
test();
return 0;
//if(argc == 0)
// return 1;
//char ** args = new char*[argc];
//for(int i = 0; i < argc; i++){
// args[i] = new char[strlen(argv[i]) + 1];
// strcpy(args[i], argv[i]);
//}
//ShowParams(argc, args);
//bool convertFile = false;
//bool isBeamSearch = false;
//bool convertModel = false;
//
//char * modelFN = new char[MAX_LINE_LENGTH];
//char * rawFN = new char[MAX_LINE_LENGTH];
//char * testFN = new char[MAX_LINE_LENGTH];
//char * outputFN = new char[MAX_LINE_LENGTH];
//char * rawModel = new char[MAX_LINE_LENGTH];
//LoadParamString(argc, args, "model", modelFN, "");
//LoadParamString(argc, args, "rawModel", rawModel, "");
//LoadParamString(argc, args, "test", testFN, "");
//LoadParamString(argc, args, "rawFile", rawFN, "");
//LoadParamString(argc, args, "output", outputFN, "");
//LoadParamBool(argc, args, "beamsearch", &isBeamSearch, false);
//LoadParamBool(argc, args, "convertFile", &convertFile, false);
//LoadParamBool(argc, args, "convertModel", &convertModel, false);
//
//srand((unsigned int)time(NULL));
//T2TModel model;
//model.InitModel(argc, args);
///* convert test file from text to binary */
//if (convertFile) {
// DataSet::ConvertFile(rawFN, testFN);
//}
//
///* convert parameters from text to binary */
//if (convertModel) {
// TensorList params(100);
// model.GetParams(params);
// ConvertModelFile(&params, rawModel, modelFN);
//}
///* load the model if neccessary */
//if(strcmp(modelFN, ""))
// model.Read(modelFN);
///* test the model on the new data */
//if(strcmp(testFN, "") && strcmp(outputFN, "")){
// T2TTester searcher;
// searcher.Init(argc, args);
// searcher.Test(testFN, outputFN, &model);
//}
//delete[] modelFN;
//delete[] testFN;
//delete[] outputFN;
//delete[] rawModel;
//for(int i = 0; i < argc; i++)
// delete[] args[i];
//delete[] args;
//return 0;
}
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-04-05
*/
#include <string>
#include <fstream>
#include <algorithm>
#include "DataSet.h"
#include "StringUtil.h"
#include "../../../tensor/XUtility.h"
using namespace nts;
using namespace std;
/* sort results by their ids */
void DataSet::SortRes()
{
auto cmp = [](Result& a, Result& b) {
return a.id < b.id;
};
std::sort(resBuffer.items, resBuffer.items + resBuffer.count, cmp);
}
/*
load data from the file to the buffer
*/
void DataSet::LoadDataToBuffer()
{
bufferUsed = 0;
srcBuffer.Clear();
bufferSize = min(bufferSize, exampleNumber);
for (int i = 0; i < bufferSize; i++) {
long off = offset[index++];
IntList data(off);
data.count = off;
fread(data.items, sizeof(int), off, fp);
Example example;
example.id = id++;
example.data = data;
srcBuffer.Add(example);
}
if (sortBuffer) {
auto cmp = [](Example& a, Example& b) {
return a.data.Size() > b.data.Size();
};
std::sort(srcBuffer.items, srcBuffer.items + srcBuffer.count, cmp);
}
}
/*
select a field and generate a mini-batch by indices
>>> batchEnc - a tensor to store the batch of input
>>> paddingEnc - a tensor to store the batch of paddings
>>> batchSize - batch size
>>> devID - devices id, -1 for CPU
>>> mem - the memory pool
*/
IntList DataSet::LoadBatch(XTensor * batchEnc, XTensor * paddingEnc, size_t batchSize, int devID)
{
if(srcBuffer.count == 0)
LoadDataToBuffer();
size_t realBatchSize = batchSize;
/* real batch size */
if ((srcBuffer.Size() - bufferUsed) < batchSize) {
realBatchSize = srcBuffer.Size() - bufferUsed;
}
/* get the maximum sentence length in a mini-batch */
size_t maxLen = 0;
if (realBatchSize == 1) {
maxLen = srcBuffer[bufferUsed].data.Size();
}
for (size_t i = 0; i < realBatchSize - 1; i++) {
maxLen = max(maxLen, srcBuffer[bufferUsed + i].data.Size());
}
CheckNTErrors(maxLen != 0, "wrong length dectected");
int* batchValues = new int[maxLen * realBatchSize];
float* paddingValues = new float[maxLen * realBatchSize];
for (int i = 0; i < realBatchSize * maxLen; i++) {
batchValues[i] = 1.0F;
}
memset(batchValues, 0, sizeof(int) * maxLen * realBatchSize);
memset(paddingValues, 0, sizeof(float) * maxLen * realBatchSize);
size_t cur = 0;
/* left padding */
IntList indices;
indices.Reserve(realBatchSize);
for (size_t i = 0; i < realBatchSize; i++) {
indices.Add(srcBuffer[bufferUsed + i].id);
IntList& data = srcBuffer[bufferUsed + i].data;
cur = maxLen * (i + 1) - data.Size();
for (int j = 0; j < data.Size(); j++) {
batchValues[cur] = data[j];
paddingValues[cur++] = 1.0F;
}
cur = maxLen * (i + 1);
}
InitTensor2DV2(batchEnc, realBatchSize, maxLen, X_INT, devID);
InitTensor2DV2(paddingEnc, realBatchSize, maxLen, X_FLOAT, devID);
bufferUsed += realBatchSize;
batchEnc->SetData(batchValues, batchEnc->unitNum);
paddingEnc->SetData(paddingValues, paddingEnc->unitNum);
delete[] batchValues;
delete[] paddingValues;
return indices;
}
/*
convert text file to binary file
format of the text file:
one sentence per line, seperated by a blank
format of the binary file:
part 1: number of all examples
part 2: offsets of all examples
part 3: the raw data
>>> src - the path of source text file
>>> tgt - the path of target binary file
*/
void nts::DataSet::ConvertFile(const char* src, const char* tgt)
{
ifstream ifile(src, ios::in);
FILE* ofile = fopen(tgt, "wb");
string line;
long idx = 0;
const int maxExample = 10240;
IntList dataList[maxExample];
while (getline(ifile, line)){
SplitInt(line, " ", dataList[idx++]);
}
/* part 1: number of examples */
fwrite(&idx, sizeof(idx), 1, ofile);
/* part 2: offset of all examples */
for (int i = 0; i < idx; i++) {
int size = (dataList[i].Size());
fwrite(&size, sizeof(size), 1, ofile);
}
/* part 3: value of examples */
for (int i = 0; i < idx; i++) {
fwrite(dataList[i].items, sizeof(int), dataList[i].Size(), ofile);
}
ifile.close();
fclose(ofile);
}
/*
the constructor of DataSet
the binary data consists of three parts
part 1: number of all examples
part 2: offsets of all examples
part 3: the raw data
>>> fname - path of the data file
>>> myBufferSize - size of the data buffer
>>> mySortBuffer - whether sort the data
*/
void DataSet::Init(const char* fname, size_t myBufferSize, bool mySortBuffer)
{
id = 0;
index = 0;
bufferUsed = 0;
bufferSize = myBufferSize;
sortBuffer = mySortBuffer;
fp = fopen(fname, "rb");
CheckNTErrors(fp, "can not open the file");
/* read offsets */
exampleNumber = 0;
fread(&exampleNumber, sizeof(exampleNumber), 1, fp);
CheckNTErrors(exampleNumber > 0, "invalid example numbers");
offset.Reserve(exampleNumber);
for (int i = 0; i < exampleNumber; i++) {
int off;
fread(&off, sizeof(off), 1, fp);
offset.Add(off);
}
/* reset the buffer size if it is too big */
bufferSize = min(bufferSize, exampleNumber);
srcBuffer.Reserve(bufferSize);
}
/* check if the buffer is empty */
bool nts::DataSet::IsEmpty()
{
return (index >= offset.count) && (bufferUsed >= bufferSize);
}
/* de-constructor */
nts::DataSet::~DataSet()
{
if (fp) {
fclose(fp);
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-04-03
*/
#ifndef __DATASET_H__
#define __DATASET_H__
#include <cstdio>
#include "../../..//tensor/XTensor.h"
#include "../../..//tensor/XGlobal.h"
namespace nts {
/* `DataSet` maintains data buffers for the inference stage .*/
struct DataSet {
public:
/* the data buffer */
ExampleList srcBuffer;
/* the result buffer */
ResultList resBuffer;
/* the offset of all examples in the data */
LongList offset;
/* wether sort the dataset */
bool sortBuffer;
/* id for each example */
size_t id;
/* size of the data buffer */
size_t bufferSize;
/* size of used data in buffer */
size_t bufferUsed;
/* size of data in the src file */
long exampleNumber;
/* current index of the offset */
size_t index;
/* the pointer of the src file stream */
FILE * fp;
public:
/* check if the buffer is empty */
bool IsEmpty();
/* load data from a file to the buffer */
void LoadDataToBuffer();
/* initlization function */
void Init(const char* fname, size_t myBufferSize, bool mySortBuffer);
/* generate a mini-batch */
IntList LoadBatch(XTensor * batchEnc, XTensor * paddingEnc, size_t batchSize, int devID);
/* sort results by their ids */
void SortRes();
/* transform text file to binary file */
static void ConvertFile(const char* src, const char* tgt);
/* de-constructor */
~DataSet();
};
} // namespace nts(NiuTrans.Tensor)
#endif // __DATASET_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-03-18
*/
#include "StringUtil.h"
/*
split string by delimiter, this will return indices of all sub-strings
>>> s - the original string
>>> delimiter - as it is
>>> a - the indices of all sub-strings
*/
void SplitToPos(const string& s, const string& delimiter, LongList& indices)
{
if (delimiter.length() == 0) {
indices.Add(0);
}
int pos = 0;
int start = 0;
while ((pos = s.find(delimiter, start)) != string::npos) {
if (pos != start) {
indices.Add(start);
}
start = pos + delimiter.length();
}
if (start != s.length()) {
indices.Add(start);
}
}
IntList SplitInt(const string& s, const string& delimiter)
{
IntList fields;
LongList indices;
SplitToPos(s, delimiter, indices);
for (int i = 0; i < indices.Size(); i++) {
fields.Add(strtol(s.data() + indices[i], nullptr, 10));
}
return fields;
}
void SplitInt(const string& s, const string& delimiter, IntList& fields)
{
LongList indices;
SplitToPos(s, delimiter, indices);
for (int i = 0; i < indices.Size(); i++) {
fields.Add(strtol(s.data() + indices[i], nullptr, 10));
}
}
FloatList SplitFloat(const string& s, const string& delimiter)
{
FloatList fields;
LongList indices;
SplitToPos(s, delimiter, indices);
for (int i = 0; i < indices.Size(); i++) {
fields.Add(strtof(s.data() + indices[i], nullptr));
}
return fields;
}
void SplitInt(const string& s, const string& delimiter, FloatList& fields)
{
LongList indices;
SplitToPos(s, delimiter, indices);
for (int i = 0; i < indices.Size(); i++) {
fields.Add(strtof(s.data() + indices[i], nullptr));
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -16,17 +16,31 @@
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-12
* $Created by: HU Chi (huchinlp@foxmail.com) 2019-03-18
*/
#ifndef __TEST_ABSOLUTE_H__
#define __TEST_ABSOLUTE_H__
#ifndef __STRING_UTIL_H__
#define __STRING_UTIL_H__
#include <string>
namespace nts { // namespace nts(NiuTrans.Tensor)
#include "..//..//..//tensor/XList.h"
using namespace std;
using namespace nts;
/* test for Absolute Function */
bool TestAbsolute();
/* Splits a string based on the given delimiter string. Each pair in the
* returned vector has the start and past-the-end positions for each of the
* parts of the original string. Empty fields are not represented in the output.
*/
void SplitToPos(const string& s, const string& delimiter, LongList& indices);
IntList SplitInt(const string& s, const string& delimiter);
void SplitInt(const string& s, const string& delimiter, IntList& fields);
FloatList SplitFloat(const string& s, const string& delimiter);
void SplitInt(const string& s, const string& delimiter, FloatList& fields);
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_ABSOLUTE_H__
#endif // __STRING_UTIL_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
*
* This is the entrance of the low-level tensor library : NiuTrans.Tensor
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2015-12-14
*
*/
#include <stdio.h>
#include <math.h>
#include <time.h>
#include "XTensor.h"
#include "XDevice.h"
#include "./test/Test.h"
#include "./core/CHeader.h"
#include "./loss/CrossEntropy.h"
//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>
//#include <crtdbg.h>
using namespace nts;
void SmallTest();
void TransposeTest();
void LittleTest();
void T2TTest();
void T2TTest2();
void PowerTest();
int main( int argc, const char ** argv )
{
//PowerTest();
//LittleTest();
//T2TTest();
//T2TTest2();
//return 0;
//_CrtSetBreakAlloc(123);
/* a tiny test */
//SmallTest();
//_CrtDumpMemoryLeaks();
//return 0;
if(argc > 1 && !strcmp(argv[1], "-test"))
Test();
else{
fprintf(stderr, "Thanks for using NiuTrans.Tensor! This is a library that eases the\n");
fprintf(stderr, "use of tensors. All you need is to ... \n\n");
fprintf(stderr, "Run this program with \"-test\" for unit test!\n");
}
//_CrtDumpMemoryLeaks();
return 0;
}
void myRead(XTensor * tensor, const char * filename, const char * label)
{
FILE * file = fopen(filename, "rb");
if(file == NULL)
printf("%s\n", filename);
tensor->Read(file, label);
}
void myDump(XTensor * tensor, const char * filename, const char * label)
{
FILE * file = fopen(filename, "wb");
if(file == NULL)
printf("%s\n", filename);
tensor->Dump(file, label);
}
void PowerTest()
{
XTensor input;
XTensor output;
InitTensor2D(&input, 256, 10000, X_FLOAT, 0);
InitTensor2D(&output, 256, 10000, X_FLOAT, 0);
myRead(&input, "1.txt", "");
_Power(&input, &output, 2);
output.Dump(stderr, "", 200);
}
void SmallTest()
{
XTensor a;
XTensor b;
XTensor c;
XTensor d;
InitTensor2D(&a, 2, 2);
InitTensor2D(&b, 2, 2);
a.SetZeroAll();
b.SetZeroAll();
a.Set2D(1.0F, 0, 0);
a.Set2D(2.0F, 1, 1);
b = Sum(a, Multiply(a, a));
/* this is prohibited !!!!!!!!!!!!! */
//XTensor c = a * b + a;
//XTensor d = a + b + c.Lin(0.5F);
c = a * b + a;
d = a + b + c.Lin(0.5F);
XLink::CheckNetwork(&d);
//XLink::ShowNetwork(stderr, &d);
a.Dump(stderr, "a:");
b.Dump(stderr, "b:");
c.Dump(stderr, "c:");
d.Dump(stderr, "d:");
}
void TransposeTest()
{
XTensor a;
XTensor b;
int I = 2;
int J = 3;
InitTensor4D(&a, 2, 3, 4, 5);
int * dims = new int[a.order];
memcpy(dims, a.dimSize, sizeof(int) * a.order);
dims[I] = a.dimSize[J];
dims[J] = a.dimSize[I];
InitTensor(&b, 4, dims);
a.SetZeroAll();
b.SetZeroAll();
float * data = new float[a.unitNum];
for(int i = 0; i < a.unitNum; i++)
data[i] = (float)i;
a.SetData(data, a.unitNum, 0);
_Transpose(&a, &b, I, J);
b.Dump(stderr, "b:");
delete[] data;
}
void LittleTest()
{
int a = 5000;
int b = 100000;
int c = a*b;
printf("%d\n", c);
exit(1);
}
void T2TTest()
{
XTensor * input;
XTensor * weight;
XTensor * output;
XTensor * gold;
XTensor * dedy;
XTensor * dedx;
XTensor * dedxTmp;
XTensor * dedw;
XTensor * padding;
DTYPE loss;
int * dimSize = new int[2];
dimSize[0] = 256;
dimSize[1] = 10001;
int * dimSize2 = new int[3];
dimSize2[0] = 2;
dimSize2[1] = 31;
dimSize2[2] = 256;
int * dimSize3 = new int[3];
dimSize3[0] = 2;
dimSize3[1] = 31;
dimSize3[2] = 10001;
int * dimSize4 = new int[2];
dimSize4[0] = 2;
dimSize4[1] = 31;
input = NewTensor(3, dimSize2, X_FLOAT, 1.0F, 0);
weight = NewTensor(2, dimSize, X_FLOAT, 1.0F, 0);
dedw = NewTensor(2, dimSize, X_FLOAT, 1.0F, 0);
gold = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
output = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
dedy = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
dedx = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
dedxTmp = NewTensor(3, dimSize3, X_FLOAT, 1.0F, 0);
padding = NewTensor(2, dimSize4, X_FLOAT, 1.0F, 0);
//weight = NewTensor(2, dimSize);
//dedw = NewTensor(2, dimSize);
//input = NewTensor(3, dimSize2);
//gold = NewTensor(3, dimSize3);
//output = NewTensor(3, dimSize3);
//dedy = NewTensor(3, dimSize3);
//dedx = NewTensor(3, dimSize3);
//dedxTmp = NewTensor(3, dimSize3);
//padding = NewTensor(2, dimSize4);
myRead(input, "x.txt", "x");
myRead(weight, "w.txt", "w");
myRead(gold, "gold.txt", "gold");
myRead(padding, "padding.txt", "padding");
XTensor inter;
inter = MMul(*input, *weight);
_Softmax(&inter, output, 2);
//_LogMe(output);
loss = _CrossEntropyFast(output, gold, REDUCE_MEAN, NULL, padding);
printf("loss: %f\n", loss);
_CrossEntropyBackward(dedy, output, gold, NULL);
//_CrossEntropyBackward(dedy, output, gold, NULL, padding);
myDump(dedy, "dedy.txt", "dedy");
_SoftmaxBackward(NULL, output, input, dedy, dedx, NULL, -1, NOLOSS);
_Sub(output, gold, dedxTmp);
myDump(dedx, "dedx.txt", "dedx");
dedx->Dump(stderr, "dedx", 200);
dedxTmp->Dump(stderr, "dedxTmp", 200);
input->Reshape(input->unitNum/input->GetDim(-1), input->GetDim(-1));
dedx->Reshape(dedx->unitNum/dedx->GetDim(-1), dedx->GetDim(-1));
_MatrixMulBatched(input, X_TRANS, dedx, X_NOTRANS, dedw);
myDump(dedw, "dedw.txt", "dedw");
}
void T2TTest2()
{
int dimSize[3];
dimSize[0] = 161;
dimSize[1] = 47;
dimSize[2] = 10001;
XTensor * probs = NewTensor(3, dimSize, X_FLOAT, 1.0F, 0);
//XTensor * probs = NewTensor(3, dimSize, X_FLOAT, 1.0F, -1);
//myRead(probs, "probs.txt", " ");
_SetDataFixedFloat(probs, 1.0F);
probs->Reshape(1, probs->unitNum);
DTYPE sum = _ReduceSumAll(probs);
printf("%e\n", sum);
//XTensor tmp;
//tmp = IsNonZero(*probs);
//DTYPE nonZeroNum = ReduceSumAll(tmp);
//printf("%f\n", nonZeroNum);
//
//DTYPE gpu = ReduceSum(*probs, 1).Get2D(0, 0);
//printf("%e\n", gpu);
}
......@@ -60,7 +60,7 @@ TENSOR_DATA_TYPE GetDataType(const char * typeName)
}
}
/****************************************************
/*
Below is for calling CPU BLAS for fast matrix operations
I'm not sure how fast it is. But it seems that other
guys are crazy about this. So I decided to have a try.
......@@ -81,35 +81,4 @@ _XINLINE_ float Float16ToFloat(unsigned short h)
return f;
}
/*
data type conversion
>> devID - device id
>> s - source data array
>> typeS - source data type
>> t - target data array
>> typeT - target data type
>> size - number of the items in s (and t)
*/
void ConvertDataType(int devID, void * s, TENSOR_DATA_TYPE typeS, void * t, TENSOR_DATA_TYPE typeT, int size)
{
CheckNTErrors((devID < 0), "This code must be run on CPUs!");
if(typeS == typeT)
return;
if(typeS == X_FLOAT && typeT == X_FLOAT16){
for(int i = 0; i < size; i++){
((unsigned short*)t)[i] = FloatToFloat16(((float*)s)[i]);
}
}
else if(typeS == X_FLOAT16 && typeT == X_FLOAT){
for(int i = 0; i < size; i++){
((float*)t)[i] = Float16ToFloat(((unsigned short*)s)[i]);
}
}
else{
ShowNTErrors("Unsupported data types for conversion!");
}
}
} /* end of the nts (NiuTrans.Tensor) namespace */
......@@ -49,15 +49,6 @@ extern TENSOR_DATA_TYPE GetDataType(const char * typeName);
/* data conversion (for lower precision computation) */
unsigned short FloatToFloat16(float f);
float Float16ToFloat(unsigned short h);
void ConvertDataType(int devID,
void * s, TENSOR_DATA_TYPE typeS,
void * t, TENSOR_DATA_TYPE typeT, int size);
#ifdef USE_CUDA
void CudaConvertDataType(int devID,
void * s, TENSOR_DATA_TYPE typeS,
void * t, TENSOR_DATA_TYPE typeT, int size);
#endif
} /* end of the nts (NiuTrans.Tensor) namespace */
......
......@@ -201,7 +201,8 @@ void XDevice::SetGPUDevice(int devID)
cudaError_t error = cudaSetDevice(devID);
if (error != cudaSuccess){
fprintf(stderr, "Error! Calling cudaSetDevice(%d) fails(%d:%s)\n", devID, error, cudaGetErrorString(error));
fprintf(stderr, "Error! Calling cudaSetDevice(%d) fails(%d:%s)\n",
devID, error, cudaGetErrorString(error));
exit(1);
}
#else
......@@ -216,7 +217,7 @@ void XDevice::SetGPUDeviceFast(int devID)
SetFastFlags();
}
/* switch to a get current dev */
/* get the id of the current GPU device */
int XDevice::GetGPUDevice()
{
#ifdef USE_CUDA
......@@ -224,7 +225,8 @@ int XDevice::GetGPUDevice()
cudaError_t error = cudaGetDevice(&devID);
if (error != cudaSuccess){
fprintf(stderr, "Error! Calling cudaGetDevice(%d) fails(%d:%s)\n", devID, error, cudaGetErrorString(error));
fprintf(stderr, "Error! Calling cudaGetDevice(%d) fails(%d:%s)\n",
devID, error, cudaGetErrorString(error));
exit(1);
}
......@@ -248,7 +250,7 @@ void XDevice::SetFastFlags()
#endif
}
/* reset cuda flag for more efficient cuda execution (all devices) */
/* reset the cuda flag for more efficient cuda execution (all devices) */
void XDevice::SetFastFlagsAllDevices()
{
#ifdef USE_CUDA
......@@ -274,7 +276,7 @@ XDevManager::~XDevManager()
}
/* initialize it and get the CPU and GPU information */
/* initialization */
void XDevManager::Init()
{
srand((unsigned int)time(NULL));
......@@ -318,7 +320,7 @@ void XDevManager::Clear()
#ifdef USE_CUDA
/* get the handle of GPU */
/* get the handle of a given GPU */
cublasHandle_t * XDevManager::GetCudaHandle(const int devID)
{
CheckNTErrors(devID < nGPU, "index of GPU is out of range.");
......@@ -326,7 +328,7 @@ cublasHandle_t * XDevManager::GetCudaHandle(const int devID)
return GPUs[devID].GetCublasHandle();
}
/* get the stream of cuda */
/* get the stream of a given GPU */
cudaStream_t * XDevManager::GetCudaStream(const int devID)
{
CheckNTErrors(devID < nGPU, "index of GPU is out of range.");
......@@ -523,12 +525,12 @@ get device ids for the given device information
devInfo = "0:CPU-1 1:GPU-0 2:CPU-1"
means that the first device is CPU, the second device
is GPU-0, the third device is CPU.
>> devIDs - device sequence specified by devInfo
>> devIDs - device IDs specified by devInfo
<< return - number of devices
*/
int XDevManager::GetDeviceIDs(char * devInfo, int * devIDs)
{
StrList* terms = new StrList(1);
StrList* terms = new StrList(1);
SplitALine(devInfo, " ", terms);
for(int i = 0; i < terms->count; i++){
......@@ -565,7 +567,7 @@ int XDevManager::GetDeviceIDs(char * devInfo, int * devIDs)
return devCount;
}
/* show id sequence */
/* show device IDs */
void XDevManager::ShowDeviceIDs(char * devInfo, char * msg)
{
msg[0] = 0;
......
......@@ -51,7 +51,13 @@ bool CONST_TRUE = true;
int verboseLevel = 0;
bool useBLAS = false;
bool useCUDA = false;
#ifdef USE_CUDA
bool useCUDA = true;
#else
bool useCUDA = false;
#endif
FILE * tmpLog = NULL;
double myTime = 0;
......
......@@ -78,7 +78,7 @@ namespace nts {
if(!(x)) \
{ \
fprintf(stderr, "[ERROR] calling '%s' (%s line %d): %s\n", #x, __FILENAME__, __LINE__, msg); \
exit(1); \
throw; \
} \
} \
......@@ -87,7 +87,7 @@ namespace nts {
if(!(x)) \
{ \
fprintf(stderr, "[ERROR] calling '%s' (%s line %d): %s\n", #x, __FILENAME__, __LINE__); \
exit(1); \
throw; \
} \
} \
......@@ -95,7 +95,7 @@ namespace nts {
{ \
{ \
fprintf(stderr, "[ERROR] (%s line %d): %s\n", __FILENAME__, __LINE__, msg); \
exit(1); \
throw; \
} \
} \
......
......@@ -300,6 +300,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id
if(h == NULL)
return;
if (!t1->enableGrad)
return;
TensorList list(2);
list.Add((XTensor*)t1);
list.Add((XTensor*)t2);
......@@ -320,6 +323,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, const XTensor * t3,
if (h == NULL)
return;
if (!t1->enableGrad || !t2->enableGrad)
return;
TensorList list(3);
list.Add((XTensor*)t1);
list.Add((XTensor*)t2);
......@@ -370,6 +376,9 @@ create a hyper edge with a input tensors and a list of output tensors
*/
void XLink::MakeLink(XTensor * t, TensorList * list, int id)
{
if (!t->enableGrad)
return;
/* forward */
for(int i = 0; i < list->count; i++){
XTensor * h = (XTensor*)list->GetItem(i);
......
......@@ -23,15 +23,11 @@
*
*/
#include "XList.h"
#include "time.h"
#include "XMem.h"
#include "XList.h"
#include "XGlobal.h"
#include <ctime>
#include <utility>
#include <algorithm>
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
......@@ -78,7 +74,8 @@ TensorListBase<T>::TensorListBase(int myMaxNum, XMem* myMem)
template <typename T>
TensorListBase<T>::~TensorListBase()
{
delete[] items;
if(items && mem)
delete[] items;
}
......@@ -101,7 +98,13 @@ void TensorListBase<T>::Add(T&& item)
maxNum = maxNum * 2 + 1;
}
items[count++] = item;
}
/* return number of elements */
template<typename T>
size_t TensorListBase<T>::Size()
{
return count;
}
/*
......@@ -131,7 +134,7 @@ add a number of items into the list
>> inputItemCount - number of input items
*/
template <typename T>
void TensorListBase<T>::Add(T* inputItems, int inputItemCount)
void TensorListBase<T>::Add(const T* inputItems, int inputItemCount)
{
if (count + inputItemCount >= maxNum) {
int newMaxNum = (count + inputItemCount) * 2 + 1;
......@@ -207,10 +210,10 @@ void TensorListBase<T>::Insert(int pos, T&& item)
template <typename T>
T& TensorListBase<T>::GetItem(int i) const
{
CheckNTErrors(i >= -1 && i < count, "Index of a list item is out of scope!");
CheckNTErrors(i >= -count && i < count, "Index of a list item is out of scope!");
CheckNTErrors(count > 0, "Cannt index the item in an empty list!");
if (i == -1)
return items[count - 1];
if (i < 0)
return items[count + i];
else
return items[i];
}
......@@ -227,7 +230,7 @@ template<typename T>
inline void TensorListBase<T>::SetItem(int i, T&& item)
{
if (i >= 0 && i < count)
items[i] = std::move(item);
items[i] = item;
}
/*
......@@ -246,6 +249,26 @@ inline int TensorListBase<T>::FindFirst(const T& item)
return -1;
}
template <>
inline int TensorListBase<Example>::FindFirst(const Example& item)
{
for (int i = 0; i < count; i++) {
if (item.id == items[i].id)
return i;
}
return -1;
}
template <>
inline int TensorListBase<Result>::FindFirst(const Result& item)
{
for (int i = 0; i < count; i++) {
if (item.id == items[i].id)
return i;
}
return -1;
}
/* clear the data array */
template <typename T>
void TensorListBase<T>::Clear()
......@@ -295,6 +318,17 @@ void TensorListBase<T>::Remove(int i)
count--;
}
template<typename T>
void TensorListBase<T>::Reserve(int n)
{
if (items) {
/* reserve failed */
return;
}
items = new T[n];
}
/*
copy the list
>> myMem - memory pool used for allocating the data in the new list
......@@ -349,6 +383,8 @@ template struct TensorListBase<long>;
template struct TensorListBase<float>;
template struct TensorListBase<short>;
template struct TensorListBase<XTensor*>;
template struct TensorListBase<Result>;
template struct TensorListBase<Example>;
template struct TensorListBase<void*>;
} /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
......@@ -66,11 +66,14 @@ public:
/* add an item into the list */
void Add(T&& item);
/* return number of elements */
size_t Size();
/* add an item into the list */
void Add(const T& item);
/* add a number of items into the list */
void Add(T* inputItems, int inputItemCount);
void Add(const T* inputItems, int inputItemCount);
/* append a list to the current list */
void AddList(TensorListBase* l);
......@@ -105,6 +108,9 @@ public:
/* remove the item at position i */
void Remove(int i);
/* reserve space for data entry */
void Reserve(int n);
/* copy the list */
TensorListBase* Copy(XMem* myMem);
......@@ -112,22 +118,33 @@ public:
void Shuffle(int nround = 10, int beg = -1, int len = 0);
/* short */
T& operator[] (int i) {
return GetItem(i);
};
T& operator[] (int i) { return GetItem(i); };
T& Get(int i) { return GetItem(i); };
void Set(int i, T item) { SetItem(i, item); };
};
struct XTensor;
typedef TensorListBase<void*> XList;
typedef TensorListBase<int> IntList;
typedef TensorListBase<char> CharList;
typedef TensorListBase<char*> StrList;
typedef TensorListBase<long> LongList;
typedef TensorListBase<float> FloatList;
typedef TensorListBase<short> ShortList;
typedef TensorListBase<void*> XList;
struct Example {
int id;
IntList data;
};
struct Result {
int id;
IntList data;
};
typedef TensorListBase<Result> ResultList;
typedef TensorListBase<Example> ExampleList;
typedef TensorListBase<XTensor*> TensorList;
} /* end of the nts (NiuTrans.Tensor) namespace */
......
......@@ -53,6 +53,7 @@ XMem::XMem()
strcpy(name, "xmem");
signature = 0;
mergeFreeOTF = true;
isInitialized = false;
}
/*
......@@ -63,7 +64,7 @@ constructor
>> myMode - mode of running the memory pool
UNI_FREE: free all the space at the end of using the memory pool
FREE_ON_THE_FLY: normal "malloc" and "free" mode
>> myBlockSize - size of memory block
>> myBlockSize - size of a memory block
>> myBlockNum - number of memory blocks
>> myBufSize - size of buffer
*/
......@@ -108,7 +109,7 @@ initialize it
>> myMode - mode of running the memory pool
UNI_FREE: free all the space at the end of using the memory pool
FREE_ON_THE_FLY: normal "malloc" and "free" mode
>> myBlockSize - size of memory block
>> myBlockSize - size of a memory block
>> myBlockNum - number of memory blocks
>> myBufSize - size of buffer
*/
......@@ -169,6 +170,7 @@ void XMem::Initialize(int myDevID, MEMPOOL_MODE myMode, MTYPE myBlockSize, int m
#endif
signature++;
isInitialized = true;
}
/* free memory */
......@@ -221,9 +223,9 @@ void XMem::Free(int myDevID, void * mem)
}
}
/*
get signature
<< return - return the signature
/*
get the signature
<< return - the signature
*/
MTYPE XMem::GetSignature()
{
......@@ -231,7 +233,7 @@ MTYPE XMem::GetSignature()
}
/*
use string as the name of the memory pool
set the name of the memory pool
>> myName - name of the memory pool
*/
void XMem::SetName(const char * myName)
......@@ -264,7 +266,7 @@ void XMem::SetDevice(int myDevID)
}
/*
switch to the device (with fast cuda execution mode) we want to work
switch to the device (with fast cuda execution mode) we intend to work on
>> myDevID - device id(-1: CPU memory, >=0: GPU device ID)
*/
void XMem::SetDeviceFast(int myDevID)
......@@ -280,7 +282,7 @@ void XMem::SetDeviceFast(int myDevID)
}
/*
run in static mode
run in the static mode
>> myIsStatic - specify if the memory allocation is static
*/
void XMem::SetStaticMode(bool myIsStatic)
......@@ -1508,16 +1510,27 @@ XMemManager::~XMemManager()
MTYPE XMemManager::GetAvailableMemory()
{
unsigned long freeMem = 0;
#ifndef WIN32
long pages = sysconf(_SC_AVPHYS_PAGES);
long page_size = sysconf(_SC_PAGE_SIZE);
freeMem = pages * page_size;
#else
#if __APPLE__
int mib[2] = {CTL_HW, HW_MEMSIZE};
unsigned int namelen = sizeof(mib) / sizeof(mib[0]);
unsigned long long size;
size_t len = sizeof(size);
if (sysctl(mib, namelen, &size, &len, NULL, 0) < 0){
ShowNTErrors("Cannot get memory size on Mac!");
}
else{
return size;
}
#elif _WIN32
MEMORYSTATUSEX memoryStatus;
memoryStatus.dwLength = sizeof(memoryStatus);
if (GlobalMemoryStatusEx(&memoryStatus)){
freeMem = memoryStatus.ullAvailPhys;
}
#else
long pages = sysconf(_SC_AVPHYS_PAGES);
long page_size = sysconf(_SC_PAGE_SIZE);
freeMem = pages * page_size;
#endif
return (MTYPE)freeMem;
}
......@@ -1526,8 +1539,9 @@ MTYPE XMemManager::GetAvailableMemory()
MTYPE XMemManager::GetAvailableGPUMemory(int devID)
{
size_t freeMem = 0;
size_t totalMem = 0;
#ifdef USE_CUDA
size_t totalMem = 0;
cudaSetDevice(devID);
if (cudaMemGetInfo(&freeMem, &totalMem) != cudaSuccess){
XPRINT(0, stderr, "cannot get GPU memory information.");
......@@ -1567,11 +1581,6 @@ void XMemManager::Initialize()
/* CPUs (we actually do not care about how many CPUs are using) */
nCPUMem = 1;
MTYPE freeMem = GetAvailableMemory();
MTYPE myBufSize = 0;
GetBufferSize(freeMem, &myBufSize);
CPUMems[0].Initialize(-1, UNI_FREE, MIN_BLOCK_SIZE_FOR_MEMPOOL, MIN_BLOCK_NUM_FOR_MEMPOOL, myBufSize);
/* GPUs */
nGPUMem = 0;
......@@ -1580,23 +1589,16 @@ void XMemManager::Initialize()
XPRINT(0, stderr, "cannot get GPU information.");
exit(1);
}
for (int i = 0; i < nGPUMem; i++) {
MTYPE freeMem = GetAvailableGPUMemory(i);
MTYPE myBufSize = 0;
GetBufferSize(freeMem, &myBufSize);
GPUMems[i].Initialize(i, UNI_FREE, MIN_BLOCK_SIZE_FOR_MEMPOOL, MIN_BLOCK_NUM_FOR_MEMPOOL, myBufSize);
}
#endif
}
/* free it */
void XMemManager::Free()
{
for (int i = 0; i < MAX_CPU_NUM; i++)
for (int i = 0; i < MAX_CPU_MEM_NUM; i++)
CPUMems[i].Free();
for (int i = 0; i < MAX_GPU_NUM; i++)
for (int i = 0; i < MAX_GPU_MEM_NUM; i++)
GPUMems[i].Free();
}
......@@ -1604,13 +1606,34 @@ void XMemManager::Free()
XMem * XMemManager::GetMem(const int devID)
{
XMem * mem = NULL;
if (devID < 0)
if (devID < 0){
if(!CPUMems[0].isInitialized){
MTYPE freeMem = GetAvailableMemory();
MTYPE myBufSize = 0;
GetBufferSize(freeMem, &myBufSize);
CPUMems[0].Initialize(-1, FREE_ON_THE_FLY,
MIN_BLOCK_SIZE_FOR_MEMPOOL,
MIN_BLOCK_NUM_FOR_MEMPOOL,
myBufSize);
}
mem = CPUMems;
}
else{
if (devID < nGPUMem)
if (devID < nGPUMem){
if(!GPUMems[devID].isInitialized){
MTYPE freeMem = GetAvailableGPUMemory(devID);
MTYPE myBufSize = 0;
GetBufferSize(freeMem, &myBufSize);
GPUMems[devID].Initialize(devID, FREE_ON_THE_FLY,
MIN_BLOCK_SIZE_FOR_MEMPOOL,
MIN_BLOCK_NUM_FOR_MEMPOOL,
myBufSize);
}
mem = GPUMems + devID;
else
}
else{
XPRINT1(0, stderr, "Cannot get the memory (%d). Please check your device id!", devID);
}
}
return mem;
......@@ -1638,12 +1661,12 @@ void XMemManager::ShowMemInfo()
int myBlockNum;
for(int i = 0; i < nCPUMem; i++){
GetMemSize(-1, &myBlockSize, &myBlockNum, &myBufSize);
XPRINT3(1, stderr, " - id:-1 CPU, blockSize:%d, blockNum:%d, bufSize:%d\n", myBlockSize, myBlockNum, myBufSize);
XPRINT3(1, stderr, " - id:-1 CPU, blockSize:%lld, blockNum:%d, bufSize:%lld\n", myBlockSize, myBlockNum, myBufSize);
}
for(int i = 0; i < nGPUMem; i++){
GetMemSize(i, &myBlockSize, &myBlockNum, &myBufSize);
XPRINT4(1, stderr, " - id:%2d GPU, blockSize:%d, blockNum:%d, bufSize:%d\n", i, myBlockSize, myBlockNum, myBufSize);
XPRINT4(1, stderr, " - id:%2d GPU, blockSize:%lld, blockNum:%d, bufSize:%lld\n", i, myBlockSize, myBlockNum, myBufSize);
}
}
......
......@@ -39,10 +39,13 @@
#include <curand.h>
#endif
#ifndef WIN32
#include <unistd.h>
#else
#ifdef __APPLE__
#include <sys/types.h>
#include <sys/sysctl.h>
#elif WIN32
#include <windows.h>
#else
#include <unistd.h>
#endif
/* the nts (NiuTrans.Tensor) namespace */
......@@ -57,10 +60,10 @@ typedef long long INT_64;
#define CUDA_HOST_MALLOC 1
#define MY_PITCH CUDA_PITCH
#define BUF_PITCH 256
#define MIN_BLOCK_SIZE_FOR_MEMPOOL 128 * 1024 * 1024
#define MIN_BLOCK_SIZE_FOR_MEMPOOL 256 * 1024 * 1024
#define MIN_BLOCK_NUM_FOR_MEMPOOL 1024
#define MAX_CPU_NUM 16
#define MAX_GPU_NUM 16
#define MAX_CPU_MEM_NUM 16
#define MAX_GPU_MEM_NUM 16
/*
mode of runnig a memory pool
......@@ -210,6 +213,9 @@ public:
MTYPE curUsedPin;
MTYPE bufUsedPin;
/* indicates whether the memory pool is initialized */
bool isInitialized;
#ifdef USE_CUDA
/* handle used for cublas */
cublasHandle_t cublasHandle;
......@@ -426,15 +432,15 @@ a class for the management of memory
*/
class XMemManager
{
public:
private:
/* cpu memory pool information */
XMem CPUMems[MAX_CPU_NUM];
XMem CPUMems[MAX_CPU_MEM_NUM];
/* number of cpu memory pools */
int nCPUMem;
/* gpu memory pool information */
XMem GPUMems[MAX_GPU_NUM];
XMem GPUMems[MAX_GPU_MEM_NUM];
/* number of gpu memory pools */
int nGPUMem;
......
......@@ -59,6 +59,8 @@ const char * GetOPName(int type)
return "M_DIV";
else if (type == MATH_DIVDIM)
return "M_DIVDIM";
else if (type == MATH_MASK)
return "M_MASK";
else if (type == MATH_MATRIXMUL)
return "M_MATRIXMUL";
else if (type == MATH_MATRIXMULBATCHED)
......
......@@ -48,7 +48,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_CLIP MATH_ROUND + 1
#define MATH_DIV MATH_CLIP + 1
#define MATH_DIVDIM MATH_DIV + 1
#define MATH_MATRIXMUL MATH_DIVDIM + 1
#define MATH_MASK MATH_DIVDIM + 1
#define MATH_MATRIXMUL MATH_MASK + 1
#define MATH_MATRIXMULBATCHED MATH_MATRIXMUL + 1
#define MATH_MULTIPLY MATH_MATRIXMULBATCHED + 1
#define MATH_MULTIPLYDIM MATH_MULTIPLY + 1
......@@ -79,7 +80,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* data and shape related operations */
#define DATA_BASE MATH_BASE * 2
#define GETANDSET DATA_BASE + 1
#define GETANDSET_SELECT GETANDSET + 1
#define GETANDSET_CONVERTDATATYPE GETANDSET + 1
#define GETANDSET_SELECT GETANDSET_CONVERTDATATYPE + 1
#define MOVEMENT GETANDSET_SELECT + 1
#define MOVEMENT_COPYINDEXED MOVEMENT + 1
......
......@@ -238,6 +238,9 @@ public:
/* overloading of the minus-sign */
XTensor operator- (const DTYPE shift) const;
/* overloading of the minus-sign */
XTensor operator- () const;
/* overloading of the division-sign */
XTensor operator/ (const XTensor &tensor) const;
......@@ -301,6 +304,9 @@ public:
/* set the tensor with an data array */
void SetData(const void * d, int num, int beg = 0);
/* generate data items with a uniform distribution in [0, 1] */
void Rand(int rNum, int cNum);
/* set tensor items by a uniform distribution */
void SetDataRand(DTYPE lower = 0.0F, DTYPE upper = 1.0F);
......@@ -424,9 +430,15 @@ public:
static
void Dump(const XTensor * tensor, FILE * file, const char * label = NULL, const int n = -1, const int beg = 0, const int verbose = 0);
/* dump data to a binary file */
void BinaryDump(FILE * file);
/* read data from a file */
void Read(FILE * file, const char * label = NULL);
/* read data from a binary file */
void BinaryRead(FILE * file, size_t offset);
/* flush the data to the target device */
void FlushToMem(XMem * targetMem);
......@@ -497,7 +509,7 @@ void InitTensor5D(XTensor * tensor, const int d0, const int d1, const int d2, co
/* initialize a dense 5d tensor V2 */
void InitTensor5DV2(XTensor * tensor, const int d0, const int d1, const int d2, const int d3, const int d4,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1);
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1);
/* initialize a tensor with a reference tensor */
void InitTensor(XTensor * tensor, const XTensor * reference);
......
......@@ -36,13 +36,9 @@
#include "arithmetic/MatrixMulBatched.h"
#include "arithmetic/Multiply.h"
#include "arithmetic/MultiplyDim.h"
#include "arithmetic/Negate.h"
#include "arithmetic/Sign.h"
#include "arithmetic/Sub.h"
#include "arithmetic/SubDim.h"
#include "arithmetic/Sum.h"
#include "arithmetic/SumByColumnTV.h"
#include "arithmetic/SumByColumnVT.h"
#include "arithmetic/SumDim.h"
#include "arithmetic/XTensorBLAS.h"
#include "arithmetic/MulAndShift.h"
......@@ -56,7 +52,6 @@
#include "math/Clip.h"
#include "math/Compare.h"
#include "math/Normalize.h"
#include "math/Power.h"
#include "math/ScaleAndShift.h"
#include "math/Unary.h"
......@@ -97,5 +92,4 @@
#include "utilities/XMatrixSegment.h"
#include "utilities/FlushToMem.h"
#include "../function/DropoutWithIndex.h"
#endif // __CHEADER_H__
......@@ -195,7 +195,6 @@ void DivDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE alpha)
if (c.enableGrad == true) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
......
......@@ -151,16 +151,35 @@ XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha)
XTensor c(&a);
c.SetTMPFlag();
/* call _Sum function */
/* call _Mask function */
_Mask(&a, &mask, &c, alpha);
/* tensor connections */
//XLink::MakeLink(&a, &mask, &c, MATH_SUM);
//XLink::AddParamToHead(&c, alpha);
// TODO!!
ShowNTErrors("TODO!");
XLink::MakeLink(&a, &mask, &c, MATH_MASK);
XLink::AddParamToHead(&c, alpha);
return c;
}
/*
mask entries of a given tensor (return an XTensor structure):
a(i) = a(i) if mask(i) is non-zero
a(i) = alpha if mask(i) = 0
where i is the index of the element
*/
void Mask(const XTensor &a, const XTensor &mask, XTensor &c, DTYPE alpha)
{
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) {
InitTensor(&c, &a);
}
/* call _Mask function */
_Mask(&a, &mask, &c, alpha);
if (c.enableGrad) {
XLink::MakeLink(&a, &mask, &c, MATH_MASK);
XLink::AddParamToHead(&c, alpha);
}
}
}
\ No newline at end of file
......@@ -34,7 +34,7 @@ c(i) = a(i) if mask(i) is non-zero
c(i) = alpha if mask(i) = 0
where i is the index of the element
*/
void _Mask(const XTensor * a, const XTensor * mask, XTensor * c, DTYPE alpha);
void _Mask(const XTensor * a, const XTensor * mask, XTensor * c, DTYPE alpha = 0.0);
/*
mask entries of a given tensor (on site):
......@@ -42,10 +42,10 @@ a(i) = a(i) if mask(i) is non-zero
a(i) = alpha if mask(i) = 0
where i is the index of the element
*/
void _MaskMe(XTensor * a, const XTensor * mask, DTYPE alpha);
void MaskMe(XTensor & a, const XTensor & mask, DTYPE alpha);
void _MaskMe(XTensor * a, const XTensor * mask, DTYPE alpha = 0.0);
void MaskMe(XTensor & a, const XTensor & mask, DTYPE alpha = 0.0);
/*
/*
mask entries of a given tensor (return an XTensor structure):
a(i) = a(i) if mask(i) is non-zero
a(i) = alpha if mask(i) = 0
......@@ -53,6 +53,14 @@ where i is the index of the element
*/
XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha = 0.0);
/*
mask entries of a given tensor (return an XTensor structure):
a(i) = a(i) if mask(i) is non-zero
a(i) = alpha if mask(i) = 0
where i is the index of the element
*/
void Mask(const XTensor &a, const XTensor &mask, XTensor &c, DTYPE alpha = 0.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __MASK_H__
......@@ -202,7 +202,9 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
delete cList;
}
bool CheckMMulShape(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c)
bool CheckMMulShape(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c)
{
if (!(a && b && c))
return false;
......@@ -231,10 +233,13 @@ bool CheckMMulShape(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTen
dimSize[sub++] = bm;
for (int i = 0; i < order; i++) {
if (dimSize[i] != c->dimSize[i])
if (dimSize[i] != c->dimSize[i]) {
delete[] dimSize;
return false;
}
}
delete[] dimSize;
return true;
}
......@@ -303,8 +308,8 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
}
void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
const XTensor &b, MATRIX_TRANS_TYPE transposedB, XTensor &c,
DTYPE alpha, XPRunner * parallelRunner)
const XTensor &b, MATRIX_TRANS_TYPE transposedB, XTensor &c,
DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{
CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!");
CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!");
......@@ -337,7 +342,7 @@ void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
}
/* call _MatrixMul function */
_MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner);
_MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, beta, parallelRunner);
if (c.enableGrad) {
/* tensor connections */
......@@ -400,7 +405,7 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b,
}
void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c,
DTYPE alpha, XPRunner * parallelRunner)
DTYPE alpha, XPRunner * parallelRunner)
{
CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!");
CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!");
......
......@@ -40,8 +40,11 @@ bj is the j-th element tensor of B, and c_{i,j} is the (i,j) elementtensor of th
C should be a tensor of z * x * n * m.
Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x * y.
*/
void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0, XPRunner * parallelRunner = NULL);
void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0,
XPRunner * parallelRunner = NULL);
/*
matrix multiplication (return an XTensor structure) c = trans(a) * trans(b) * alpha
......@@ -56,19 +59,23 @@ bj is the j-th element tensor of B, and c_{i,j} is the (i,j) elementtensor of th
C should be a tensor of z * x * n * m.
Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x * y.
*/
XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha = (DTYPE)1.0,
XPRunner * parallelRunner = NULL);
void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
XTensor &c, DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
const XTensor &b, MATRIX_TRANS_TYPE transposedB,
XTensor &c,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0,
XPRunner * parallelRunner = NULL);
/* matrix multiplication with no transposition c = a * b * alpha*/
XTensor MatrixMul(const XTensor &a, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -154,7 +154,7 @@ void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors(a && b && c, "Empty input tensors!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Input tensors should have the same data type!");
CheckNTErrors(a->order >= 2 && b->order >= 2 && c->order >= 2,
......
......@@ -132,6 +132,78 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
}
/*
operation c = x * w + b MulAndShift
>> x - tensor x
>> w - tensor w
>> b - tensor b
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
*/
XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedA,
const XTensor& w, MATRIX_TRANS_TYPE transposedB,
const XTensor& b, DTYPE alpha, XPRunner* parallelRunner)
{
CheckNTErrors(x.dataType == w.dataType, "Input tensors should have the same data type!");
CheckNTErrors(x.order >= 2 && w.order >= 2, "Input tensors must have a order >= 2!");
int xn = transposedA == X_TRANS ? x.dimSizeRDI[0] : x.dimSizeRDI[1];
int xm = transposedA == X_TRANS ? x.dimSizeRDI[1] : x.dimSizeRDI[0];
int wn = transposedB == X_TRANS ? w.dimSizeRDI[0] : w.dimSizeRDI[1];
int wm = transposedB == X_TRANS ? w.dimSizeRDI[1] : w.dimSizeRDI[0];
int order = x.order + w.order - 2;
int sub = 0;
int * dimSize = new int[order];
for (int i = 2; i < x.order; i++)
dimSize[sub++] = x.dimSizeRDI[x.order + 1 - i];
for (int i = 2; i < w.order; i++)
dimSize[sub++] = w.dimSizeRDI[w.order + 1 - i];
dimSize[sub++] = xn;
dimSize[sub++] = wm;
float dr = (!x.isSparse || !w.isSparse) ? 1.0F : MAX(x.denseRatio, w.denseRatio);
XTensor * tmp = NewTensorBuf(order, dimSize, x.dataType, dr, x.devID, x.mem);
/* call _MatrixMul function */
_MatrixMul(&x, transposedA, &w, transposedB, tmp, alpha, 0, parallelRunner);
XTensor c(tmp);
c.SetTMPFlag();
int n = GetSumIndex(tmp, b);
if (n == -1) {
/* call _Sum function */
_Sum(tmp, &b, &c);
// TODO!!
ShowNTErrors("TODO!");
}
else if (n >= 0 && n < tmp->order) {
/* call _SumDim function */
_SumDim(tmp, &b, &c, n);
}
else {
ShowNTErrors("Something is wrong!");
}
/* tensor connections */
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHeadTrans(&c, transposedB);
//XLink::AddParamToHead(&c, beta);
/* destroy variables */
delete[] dimSize;
DelTensorBuf(tmp);
return c;
}
}
\ No newline at end of file
......@@ -29,8 +29,11 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
XTensor MulAndShift(const XTensor &x, MATRIX_TRANS_TYPE transposedA,
const XTensor &w, MATRIX_TRANS_TYPE transposedB,
const XTensor &b, DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -123,9 +123,9 @@ where i is the item index
void _CudaMultiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim)
{
int leadingDimRDI = a->order - leadingDim - 1;
CheckNTErrors((a->unitNum <= c->unitNum && b->unitNum <= c->unitNum),
CheckNTErrors(a->unitNum <= c->unitNum && b->unitNum <= c->unitNum,
"Unmatched tensors in multiplication!");
CheckNTErrors((a->order == b->order && a->order == c->order), "Unmatched tensors!");
CheckNTErrors(a->order == b->order && a->order == c->order, "Unmatched tensors!");
int stride = 1;
int blockSizeA = 1;
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "Negate.h"
#include "Negate.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
set every entry to its minus value
>> a - input tensor we are processing
>> b - output tensor we are processing
*/
void _Negate(const XTensor * a, XTensor * b)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
_CudaNegate(a, b);
return;
}
#endif
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
DTYPE * d = (DTYPE*)a->data;
DTYPE * db = (DTYPE*)b->data;
for (int i = 0; i < a->unitNum; i++)
db[i] = -d[i];
}
/*
set every entry to its minus value (do it on site)
keep the result in the input tensor a and return nothing
>> a - the tensor we are processing
*/
void _NegateMe(XTensor * a)
{
_Negate(a, a);
}
/*
set every entry to its minus value (do it on site)
keep the result in the input tensor a and return nothing
>> a - the tensor we are processing
*/
void NegateMe(XTensor& a)
{
_Negate(&a, &a);
}
/*
set every entry to its minus value (return an XTensor structure)
make a new tensor to keep the result and return it
>> a - input tensor we are processing
<< return - the minus value of input tensor
*/
XTensor Negate(const XTensor & a)
{
XTensor b(&a);
b.SetTMPFlag();
/* call _Negate function */
_Negate(&a, &b);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_NEGATE);
return b;
}
/*
set every entry to its minus value
>> a - input tensor we are processing
>> b - output tensor we are processing
*/
void Negate(const XTensor & a, XTensor & b)
{
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) {
InitTensor(&b, &a);
}
/* call _Negate function */
_Negate(&a, &b);
if (b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_NEGATE);
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "Negate.h"
#include "Negate.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
set each entry to its negtive value (CUDA Kernel)
>> a - pointer to the input data array
>> b - pointer to the output data array
>> size - size of the data array
*/
__global__
void KernelNegate(DTYPE * a, DTYPE * b, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
b[i] = -a[i];
}
/*
set each entry to its negtive value (CUDA Kernel)
This is for float16 computation
>> a - pointer to the input data array
>> b - pointer to the output data array
>> size - size of the data array
*/
__global__
void KernelNegate(__half * a, __half * b, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
if (i < size)
b[i] = __hsub(__float2half(0), a[i]);
#else
if (i < size)
b[i] = __float2half(-__half2float(a[i]));
#endif
}
/*
set each entry to its negtive value
>> a - input tensor
>> b - output tensor
*/
void _CudaNegate(const XTensor * a, XTensor * b)
{
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->isSparse == false), "TODO!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
KernelNegate << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum);
}
else if (a->dataType == X_FLOAT16) {
KernelNegate << <blocks, threads >> >((__half*)a->data, (__half*)b->data, a->unitNum);
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __NEGATE_CUH__
#define __NEGATE_CUH__
#include "Negate.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* set each entry to its negtive value (CUDA Kernel) */
__global__
void KernelNegate(DTYPE * a, DTYPE * b, int size);
/* set each entry to its negtive value (CUDA Kernel) with float16 data type*/
__global__
void KernelNegate(__half * a, __half * b, int size);
/* set each entry to its negtive value */
void _CudaNegate(const XTensor * a, XTensor * b);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __NEGATE_CUH__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __NEGATE_H__
#define __NEGATE_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its minus value */
void _Negate(const XTensor * a, XTensor * b);
/*
set every entry to its minus value (do it on site)
keep the result in the input tensor a and return nothing
*/
void _NegateMe(XTensor * a);
void NegateMe(XTensor & a);
/*
set every entry to its minus value (return an XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Negate(const XTensor & a);
/* set every entry to its minus value */
void Negate(const XTensor & a, XTensor & b);
} // namespace nts(NiuTrans.Tensor)
#endif // __NEGATE_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "Sign.h"
#include "Sign.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
set every entry to its sign value
>> a - input tensor we are processing
>> b - output tensor we are processing
*/
void _Sign(const XTensor * a, XTensor * b)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
_CudaSign(a, b);
return;
}
#endif
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
DTYPE * d = (DTYPE*)a->data;
DTYPE * db = (DTYPE*)b->data;
for (int i = 0; i < a->unitNum; i++) {
if (d[i] > 0)
db[i] = 1.0F;
else if (d[i] == 0)
db[i] = 0.0F;
else
db[i] = -1.0F;
}
}
/*
set every entry to its sign value (do it on site)
keep the result in the input tensor a and return nothing
>> a - the tensor we are processing
*/
void _SignMe(XTensor * a)
{
_Sign(a, a);
}
/*
set every entry to its sign value (do it on site)
keep the result in the input tensor a and return nothing
>> a - the tensor we are processing
*/
void SignMe(XTensor& a)
{
_Sign(&a, &a);
}
/*
set every entry to its sign value (return an XTensor structure)
make a new tensor to keep the result and return it
>> a - input tensor we are processing
<< return - the sign value of the input tensor
*/
XTensor Sign(const XTensor & a)
{
XTensor b(&a);
b.SetTMPFlag();
/* call _Sign function */
_Sign(&a, &b);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SIGN);
return b;
}
/*
set every entry to its sign value
>> a - input tensor we are processing
>> b - output tensor we are processing
*/
void Sign(const XTensor & a, XTensor & b)
{
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) {
InitTensor(&b, &a);
}
/* call _Sign function */
_Sign(&a, &b);
if (b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SIGN);
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "Sign.h"
#include "Sign.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
set each entry to its sign value (CUDA Kernel)
>> a - pointer to input data array
>> b - pointer to output data array
>> size - size of the data array
*/
__global__
void KernelSign(DTYPE * a, DTYPE * b, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
if (a[i] > 0)
b[i] = 1.0F;
else if (a[i] == 0)
b[i] = 0.0F;
else
b[i] = -1.0F;
}
}
/*
set each entry to its sign value with float16 data type value (CUDA Kernel)
This is for float16 computation
>> a - pointer to input data array
>> b - pointer to output data array
>> size - size of the data array
*/
__global__
void KernelSign(__half * a, __half * b, int size)
{
return;
}
/*
set each entry to its sign value
>> a - input tensor we are processing
>> b - output tensor we are processing
*/
void _CudaSign(const XTensor * a, XTensor * b)
{
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->isSparse == false), "TODO!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
KernelSign << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum);
}
else if (a->dataType == X_FLOAT16) {
KernelSign << <blocks, threads >> >((__half*)a->data, (__half*)b->data, a->unitNum);
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论