Commit 4c0fafc0 by linye

1. update Spread, ReduceSumAll, Multiply, Unsqueeze, Softmax, CrossEntropy with…

1. update Spread, ReduceSumAll, Multiply, Unsqueeze, Softmax, CrossEntropy with float16 datatype 2. modify the implementation of fnnlm to support float16 computation, but it remains some bugs and the loss is nan
parent c5fc6c59
......@@ -52,9 +52,11 @@ int sentBatch = 0; // batch size at the sentence level
int wordBatch = 1; // batch size at the word level
bool shuffled = false; // shuffled the training data file or not
bool autoDiff = false; // indicator of automatic differentiation
bool fp16 = false; // indicator of use of float16 computation
void LoadArgs(int argc, const char ** argv, FNNModel &model);
void Init(FNNModel &model);
void InitFp16(FNNModel &model);
void Check(FNNModel &model);
void Copy(FNNModel &tgt, FNNModel &src);
void Clear(FNNModel &model, bool isNodeGrad);
......@@ -125,7 +127,12 @@ int FNNLMMain(int argc, const char ** argv)
Check(model);
/* initialize model parameters */
Init(model);
if (!fp16) {
Init(model);
}
else {
InitFp16(model);
}
/* learn model parameters */
if(strcmp(trainFN, ""))
......@@ -224,6 +231,10 @@ void LoadArgs(int argc, const char ** argv, FNNModel &model)
autoDiff = true;
fprintf(stderr, " -autodiff=true\n");
}
if (!strcmp(argv[i], "-fp16")) {
fp16 = true;
fprintf(stderr, " -fp16=true\n");
}
if(!strcmp(argv[i], "-dev") && i + 1 < argc){
model.devID = atoi(argv[i + 1]);
fprintf(stderr, " -dev=%d\n", model.devID);
......@@ -303,6 +314,11 @@ void InitModelTensor1D(XTensor &tensor, int num, FNNModel &model)
InitTensor1DV2(&tensor, num, X_FLOAT, model.devID);
}
void InitModelTensor1DFp16(XTensor &tensor, int num, FNNModel &model)
{
InitTensor1DV2(&tensor, num, X_FLOAT16, model.devID);
}
/*
initialize a 2d tensor using the fnn model setting
>> tensor - the tensor to initialize
......@@ -315,6 +331,10 @@ void InitModelTensor2D(XTensor &tensor, int rowNum, int colNum, FNNModel &model)
InitTensor2DV2(&tensor, rowNum, colNum, X_FLOAT, model.devID);
}
void InitModelTensor2DFp16(XTensor &tensor, int rowNum, int colNum, FNNModel &model)
{
InitTensor2DV2(&tensor, rowNum, colNum, X_FLOAT16, model.devID);
}
/* initialize the model */
void Init(FNNModel &model)
......@@ -357,6 +377,48 @@ void Init(FNNModel &model)
for(int i = 0; i < model.hDepth; i++)
model.hiddenB[i].SetZeroAll();
}
/* initialize the model */
void InitFp16(FNNModel &model)
{
/* create embedding parameter matrix: vSize * eSize */
InitModelTensor2DFp16(model.embeddingW, model.vSize, model.eSize, model);
model.embeddingW.SetVarFlag();
/* create hidden layer parameter matrics */
for (int i = 0; i < model.hDepth; i++) {
/* hidden layer parameter matrix: (n-1)eSize * hsize if it is the first layer
hsize * hsize otherwise */
if (i == 0)
InitModelTensor2DFp16(model.hiddenW[i], (model.n - 1) * model.eSize, model.hSize, model);
else
InitModelTensor2DFp16(model.hiddenW[i], model.hSize, model.hSize, model);
model.hiddenW[i].SetVarFlag();
/* bias term: a row vector of hSize entries */
InitModelTensor1DFp16(model.hiddenB[i], model.hSize, model);
model.hiddenB[i].SetVarFlag();
}
/* create the output layer parameter matrix and bias term */
int iSize = model.hDepth == 0 ? (model.n - 1) * model.eSize : model.hSize;
InitModelTensor2DFp16(model.outputW, iSize, model.vSize, model);
InitModelTensor1DFp16(model.outputB, model.vSize, model);
model.outputW.SetVarFlag();
model.outputB.SetVarFlag();
/* then, we initialize model parameters using a uniform distribution in range
of [-minmax, minmax] */
_SetDataRand(&model.embeddingW, -minmax, minmax);
_SetDataRand(&model.outputW, -minmax, minmax);
for (int i = 0; i < model.hDepth; i++)
_SetDataRand(&model.hiddenW[i], -minmax, minmax);
/* all bias terms are set to zero */
_SetDataFixed(&model.outputB, 0);
for (int i = 0; i < model.hDepth; i++)
_SetDataFixed(&model.hiddenB[i], 0);
}
/*
shuffle lines of the file
......@@ -725,6 +787,15 @@ void InitZeroOneTensor2D(XTensor &tensor, int rowNum, int colNum, int * rows, in
tensor.Set2D(1.0F, rows[i], cols[i]);
}
void InitZeroOneTensor2DFp16(XTensor &tensor, int rowNum, int colNum, int * rows, int * cols,
int itemNum, int devID)
{
InitTensor2DV2(&tensor, rowNum, colNum, X_FLOAT16, devID);
/* set none-zero cells */
_SetDataFixed(&tensor, 1.0);
}
/*
make a tensor that encodes a batch of words
>> batch - the tensor encoding a batch of words
......@@ -744,7 +815,12 @@ void MakeWordBatch(XTensor &batch, NGram * ngrams, int ngramNum, int n, int vSiz
cols[i] = ngrams[i].words[n];
}
InitZeroOneTensor2D(batch, ngramNum, vSize, rows, cols, ngramNum, devID);
if (!fp16) {
InitZeroOneTensor2D(batch, ngramNum, vSize, rows, cols, ngramNum, devID);
}
else {
InitZeroOneTensor2DFp16(batch, ngramNum, vSize, rows, cols, ngramNum, devID);
}
delete[] rows;
delete[] cols;
......@@ -1067,17 +1143,17 @@ void Dump(const char * fn, FNNModel &model)
FILE * file = fopen(fn, "wb");
CheckErrors(file, "Cannot open the model file");
model.embeddingW.Dump(file, "embedding w:");
model.embeddingW.Dump(&model.embeddingW, file, "embedding w:");
for (int i = 0; i < model.hDepth; i++) {
char name[MAX_NAME_LENGTH];
sprintf(name, "hidden %d w:", i);
model.hiddenW[i].Dump(file, name);
model.hiddenW[i].Dump(&model.hiddenW[i], file, name);
sprintf(name, "hidden %d b:", i);
model.hiddenB[i].Dump(file, name);
model.hiddenB[i].Dump(&model.hiddenB[i], file, name);
}
model.outputW.Dump(file, "output w:");
model.outputB.Dump(file, "output b:");
model.outputW.Dump(&model.outputW, file, "output w:");
model.outputB.Dump(&model.outputB, file, "output b:");
fclose(file);
......@@ -1094,17 +1170,17 @@ void Read(const char * fn, FNNModel &model)
FILE * file = fopen(fn, "rb");
CheckErrors(file, "Cannot open the model file");
model.embeddingW.Read(file, "embedding w:");
model.embeddingW.Read(&model.embeddingW, file, "embedding w:");
for (int i = 0; i < model.hDepth; i++) {
char name[MAX_NAME_LENGTH];
sprintf(name, "hidden %d w:", i);
model.hiddenW[i].Read(file, name);
model.hiddenW[i].Read(&model.hiddenW[i], file, name);
sprintf(name, "hidden %d b:", i);
model.hiddenB[i].Read(file, name);
model.hiddenB[i].Read(&model.hiddenB[i], file, name);
}
model.outputW.Read(file, "output w:");
model.outputB.Read(file, "output b:");
model.outputW.Read(&model.outputW, file, "output w:");
model.outputB.Read(&model.outputB, file, "output b:");
fclose(file);
......
......@@ -2043,12 +2043,14 @@ void XTensor::Read(XTensor * tensor, FILE * file, const char * label)
XTensor * a = NewTensor(tensor->order, tensor->dimSize, X_FLOAT, tensor->denseRatio, tensor->devID, tensor->mem);
a->Read(file, label);
_CopyValues(a, tensor);
delete a;
}
else if (tensor->dataType == X_FLOAT16)
{
XTensor * a = NewTensor(tensor->order, tensor->dimSize, X_FLOAT, tensor->denseRatio, tensor->devID, tensor->mem);
a->Read(file, label);
_ConvertDataType(a, tensor);
delete a;
}
else
{
......
......@@ -34,8 +34,9 @@ multiplication of data arrays in a element-wise manner c(i) = a(i)*b(i)
>> c - result data array
>> size - size of c
*/
template <class T>
__global__
void KernelMulElementWise(DTYPE * a, DTYPE * b, DTYPE * c, int size)
void KernelMulElementWise(T * a, T * b, T * c, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -51,8 +52,9 @@ multiplication of data arrays in a element-wise manner c(i) = a(i)*b(i) + \alpha
>> size - size of c
>> alpha - the coefficient
*/
template <class T>
__global__
void KernelMulElementWiseV2(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE alpha)
void KernelMulElementWiseV2(T * a, T * b, T * c, int size, T alpha)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -75,13 +77,13 @@ where |a_lead| means the size of the leading dimension of a
>> ldSizeC - size of the leading dimension of c
>> blockNum - number of blocks
*/
template<int nonZeroAlpha> __global__
void KernelMulElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE alpha,
template<class T, int nonZeroAlpha> __global__
void KernelMulElementWiseTensorDynamic(T * a, T * b, T * c, T alpha,
int stride, int ldSizeA, int ldSizeB, int ldSizeC, int blockNum)
{
__shared__ DTYPE* ap[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE* bp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE* cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ T* ap[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ T* bp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ T* cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int i = blockDim.x * blockIdx.x + threadIdx.x;
int j = blockDim.y * blockIdx.y + threadIdx.y;
......@@ -160,26 +162,56 @@ void _CudaMultiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alph
dim3 blocks(cudaGridSize[0]), threads(cudaBlockSize[0]);
if (alpha == 0)
KernelMulElementWise << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, c->unitNum);
KernelMulElementWise <<<blocks, threads >>>((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, c->unitNum);
else
KernelMulElementWiseV2 << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, c->unitNum, alpha);
KernelMulElementWiseV2 <<<blocks, threads >>>((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, c->unitNum, alpha);
}
else {
GDevs.GetCudaThread2D(c->devID, stride * blockNum, dimensionSizeC, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[0], cudaGridSize[1]), threads(cudaBlockSize[0], cudaBlockSize[1]);
if (alpha == 0) {
KernelMulElementWiseTensorDynamic<0> << <blocks, threads >> >
KernelMulElementWiseTensorDynamic<DTYPE, 0> <<<blocks, threads >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, 0,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
}
else {
KernelMulElementWiseTensorDynamic<1> << <blocks, threads >> >
KernelMulElementWiseTensorDynamic<DTYPE, 1> <<<blocks, threads >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, alpha,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
}
}
}
else if (a->dataType == X_FLOAT16 && b->dataType == X_FLOAT16) {
half alpha1 = __float2half(alpha);
int cudaGridSize[3];
int cudaBlockSize[3];
if (a->unitNum == c->unitNum && b->unitNum == c->unitNum) {
GDevs.GetCudaThread(a->devID, c->unitNum, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[0]), threads(cudaBlockSize[0]);
if (alpha == 0)
KernelMulElementWise <<<blocks, threads >>>((__half*)a->data, (__half*)b->data, (__half*)c->data, c->unitNum);
else
KernelMulElementWiseV2 <<<blocks, threads >>>((__half*)a->data, (__half*)b->data, (__half*)c->data, c->unitNum, alpha1);
}
else {
GDevs.GetCudaThread2D(c->devID, stride * blockNum, dimensionSizeC, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[0], cudaGridSize[1]), threads(cudaBlockSize[0], cudaBlockSize[1]);
if (alpha == 0) {
KernelMulElementWiseTensorDynamic<__half, 0> <<<blocks, threads>>>
((__half*)a->data, (__half*)b->data, (__half*)c->data, 0,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
}
else {
KernelMulElementWiseTensorDynamic<__half, 1> <<<blocks, threads>>>
((__half*)a->data, (__half*)b->data, (__half*)c->data, alpha1,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
......
......@@ -29,16 +29,18 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* multiplication of two tensors in a element-wise manner c(i) = a(i)*b(i) */
template <class T>
__global__
void KernelMulElementWise(DTYPE * a, DTYPE * b, DTYPE * c, int size);
void KernelMulElementWise(T * a, T * b, T * c, int size);
/* multiplication of two tensors in a element-wise manner c(i) = a(i)*b(i) + \alpha*c(i) */
template <class T>
__global__
void KernelMulElementWiseV2(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE alpha);
void KernelMulElementWiseV2(T * a, T * b, T * c, int size, T alpha);
/* multiplication of two tensors in a element-wise manner c(i) = a(i)*b(i)+ \alpha*c(i) */
template<int nonZeroAlpha>__global__
void KernelMulElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE alpha, int stride, int ldSizeA, int ldSizeB, int ldSizeC, int blockNum);
template<class T, int nonZeroAlpha>__global__
void KernelMulElementWiseTensorDynamic(T * a, T * b, T * c, T alpha, int stride, int ldSizeA, int ldSizeB, int ldSizeC, int blockNum);
/* element-wise product of two tensors */
void _CudaMultiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
......
......@@ -234,7 +234,6 @@ void _SpreadForGather(XTensor * source, XTensor * collection, XTensor * index)
int dim = 0;
int order = source->order;
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(collection->GetDim(-1) == source->GetDim(-1), "Illegal dimension!");
CheckNTErrors(collection->unitNum/collection->GetDim(-1) == index->unitNum,
"Illegal dimension!");
......
......@@ -22,6 +22,7 @@
#include "ReduceSumAll.h"
#include "ReduceSum.h"
#include "../movement/CopyValues.h"
#include "../getandset/ConvertDataType.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -54,8 +55,19 @@ DTYPE _ReduceSumAll(const XTensor * source)
_CopyValues(source, all);
_ReduceSum(all, result, 1);
XTensor result1(result->order, result->dimSize, X_FLOAT, result->denseRatio, result->devID, result->mem);
if (result->dataType == X_FLOAT)
{
_CopyValues(result, &result1);
}
else if (result->dataType == X_FLOAT16)
{
_ConvertDataType(result, &result1);
}
DTYPE r = result->Get1D(0);
DTYPE r = result1.Get1D(0);
DelTensorBuf(result);
DelTensorBuf(all);
......
......@@ -269,6 +269,14 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
KernelUnsqueezeByCol<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockNumA, dSize, b->data);
}
else if (a->dataType == X_FLOAT16 && b->dataType == X_FLOAT16) {
if (cudaBlocks[1] == 1)
KernelUnsqueezeByColBigRow<half> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockNumA, dSize, b->data);
else
KernelUnsqueezeByCol<half> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockNumA, dSize, b->data);
}
else {
ShowNTErrors("TODO!");
}
......
......@@ -322,88 +322,88 @@ void _CudaSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup);
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
//if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
CheckNTErrors((lossName == CROSSENTROPY ||
lossName == SQUAREDERROR ||
lossName == ONEHOTERROR ||
lossName == NOLOSS),
"Unknown loss function.");
if(lossName == CROSSENTROPY || lossName == SQUAREDERROR){
_Sum(y, gold, dedx, -1.0F);
if(padding != NULL) {
int paddingOrder = padding->order;
int * paddingDims = new int[paddingOrder];
memcpy(paddingDims, padding->dimSize, padding->order * sizeof(int));
padding->Reshape(padding->unitNum);
int order = dedx->order;
int * dims = new int[order];
memcpy(dims, dedx->dimSize, dedx->order * sizeof(int));
dedx->Reshape(dedx->unitNum/dedx->GetDim(n), dedx->GetDim(n));
_MultiplyDimMe(dedx, padding, 0);
padding->Reshape(paddingOrder, paddingDims);
dedx->Reshape(order, dims);
delete[] paddingDims;
delete[] dims;
}
CheckNTErrors((lossName == CROSSENTROPY ||
lossName == SQUAREDERROR ||
lossName == ONEHOTERROR ||
lossName == NOLOSS),
"Unknown loss function.");
if(lossName == CROSSENTROPY || lossName == SQUAREDERROR){
_Sum(y, gold, dedx, -1.0F);
if(padding != NULL) {
int paddingOrder = padding->order;
int * paddingDims = new int[paddingOrder];
memcpy(paddingDims, padding->dimSize, padding->order * sizeof(int));
padding->Reshape(padding->unitNum);
int order = dedx->order;
int * dims = new int[order];
memcpy(dims, dedx->dimSize, dedx->order * sizeof(int));
dedx->Reshape(dedx->unitNum/dedx->GetDim(n), dedx->GetDim(n));
_MultiplyDimMe(dedx, padding, 0);
padding->Reshape(paddingOrder, paddingDims);
dedx->Reshape(order, dims);
delete[] paddingDims;
delete[] dims;
}
else if(lossName == ONEHOTERROR){
ShowNTErrors("TODO!");
}
else if(lossName == ONEHOTERROR){
ShowNTErrors("TODO!");
}
else if(lossName == NOLOSS){
/*
for softmax:
y_i = e^{x_i} / \sum_{k} e^{x_k}
we have
dy_i/ds_j = y_i * (\delta(i,j) - y_j)
Then
dE/dx_j = \sum_i dE/dy_i * dy_i/dx_j
= \sum_i dE/dy_i * y_i * (\delta(i,j) - y_j)
= dE/dy_j * y_j - y_j * \beta
= y_j * (dE/dy_j - \beta)
where
\beta = \sum_i (dE/dy_i * y_i)
*/
int * dimSize = new int[y->order];
for(int i = 0; i < y->order; i++){
if(i < leadDim)
dimSize[i] = y->dimSize[i];
else if(i > leadDim)
dimSize[i - 1] = y->dimSize[i];
}
else if(lossName == NOLOSS){
/*
for softmax:
y_i = e^{x_i} / \sum_{k} e^{x_k}
we have
dy_i/ds_j = y_i * (\delta(i,j) - y_j)
Then
dE/dx_j = \sum_i dE/dy_i * dy_i/dx_j
= \sum_i dE/dy_i * y_i * (\delta(i,j) - y_j)
= dE/dy_j * y_j - y_j * \beta
= y_j * (dE/dy_j - \beta)
where
\beta = \sum_i (dE/dy_i * y_i)
*/
int * dimSize = new int[y->order];
for(int i = 0; i < y->order; i++){
if(i < leadDim)
dimSize[i] = y->dimSize[i];
else if(i > leadDim)
dimSize[i - 1] = y->dimSize[i];
}
/* make a matrix of the same size as the y (i.e., y) */
XTensor * ytmp = NewTensor(y);
/* make a matrix of the same size as the y (i.e., y) */
XTensor * ytmp = NewTensor(y);
/* make a matrix to keep \beta */
XTensor * beta = NewTensor(y->order - 1, dimSize, y->dataType, y->denseRatio, y->devID, y->mem);
/* make a matrix to keep \beta */
XTensor * beta = NewTensor(y->order - 1, dimSize, y->dataType, y->denseRatio, y->devID, y->mem);
/* \beta = \sum_i (dE/dy_i * y_i) */
_Multiply(dedy, y, ytmp, 0, 0);
_ReduceSum(ytmp, beta, leadDim);
/* \beta = \sum_i (dE/dy_i * y_i) */
_Multiply(dedy, y, ytmp, 0, 0);
_ReduceSum(ytmp, beta, leadDim);
/* ytmp = dE/dy_j - \beta */
_Unsqueeze(beta, ytmp, leadDim, y->dimSize[leadDim]);
_Sum(dedy, ytmp, ytmp, -1.0F);
/* ytmp = dE/dy_j - \beta */
_Unsqueeze(beta, ytmp, leadDim, y->dimSize[leadDim]);
_Sum(dedy, ytmp, ytmp, -1.0F);
/* dE/ds_j = y_j * ytmp = y_j * (dE/dy_j - \beta) */
_Multiply(y, ytmp, dedx, 0, 0);
/* dE/ds_j = y_j * ytmp = y_j * (dE/dy_j - \beta) */
_Multiply(y, ytmp, dedx, 0, 0);
delete[] dimSize;
delete ytmp;
delete beta;
}
else{
ShowNTErrors("TODO!");
}
delete[] dimSize;
delete ytmp;
delete beta;
}
else
else{
ShowNTErrors("TODO!");
}
//}
//else
// ShowNTErrors("TODO!");
BacktoCudaDev(x->devID, devIDBackup);
}
......
......@@ -60,7 +60,7 @@ void _CrossEntropy(const XTensor * output, const XTensor * gold,
CheckNTErrors(padding == NULL || XTensor::IsSameShaped(padding, loss),
"The loss tensor and padding tensor must be same shape!");
CheckNTErrors(loss->order == output->order - 1, "Wrong loss dimension!");
CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE, "TODO!");
//CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE, "TODO!");
XTensor * inter = NewTensor(output);
......@@ -564,8 +564,6 @@ void _CrossEntropyBackward(XTensor * dedy, const XTensor * output,
"Wrong weight tensor!");
CheckNTErrors(padding == NULL || padding->order == output->order - 1,
"Wrong padding tensor!");
CheckNTErrors(gold->dataType == DEFAULT_DTYPE && output->dataType == DEFAULT_DTYPE,
"TODO!");
if(padding != NULL) {
for(int i = 0; i < order; i++){
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-06
*/
#include "TSumByColumnTV.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: test SumByColumnTV function
sum of a tensor and a vector (column vector) in a column by column manner
*/
bool TestSumByColumnTV1()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2, 1) */
int bOrder = 2;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
bDimSize[1] = 1;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
/* a tensor of size (2, 4) */
int cOrder = 2;
int * cDimSize = new int[cOrder];
cDimSize[0] = 2;
cDimSize[1] = 4;
int cUnitNum = 1;
for (int i = 0; i < cOrder; i++)
cUnitNum *= cDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2][1] = { {1.0F},
{0.0F} };
DTYPE answer[2][4] = { {1.0F, 2.0F, 3.0F, 4.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(cOrder, cDimSize);
/* initialize variables */
a->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
/* call SumByColumnTV function */
_SumByColumnTV(a, b, c);
/* check results */
cpuTest = c->CheckData(answer, cUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(cOrder, cDimSize, X_FLOAT, 1.0F, 0);
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call SumByColumnTV function */
_SumByColumnTV(aGPU, bGPU, cGPU);
/* check results */
gpuTest = cGPU->CheckData(answer, cUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete aGPU;
delete bGPU;
delete cGPU;
delete[] aDimSize;
delete[] bDimSize;
delete[] cDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete[] aDimSize;
delete[] bDimSize;
delete[] cDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for SumByColumnTV Function */
bool TestSumByColumnTV()
{
XPRINT(0, stdout, "[TEST SumByColumnTV] sum of a tensor and a vector (column vector) in a column by column manner \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestSumByColumnTV1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-06
*/
#ifndef __TEST_SUMBYCOLUMNTV_H__
#define __TEST_SUMBYCOLUMNTV_H__
#include "../core/arithmetic/SumByColumnTV.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for SumByColumnTV Function */
extern "C"
bool TestSumByColumnTV();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_SUMBYCOLUMNTV_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-06
*/
#include "TSumByColumnVT.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: test SumByColumnVT function
sum of a vector (column vector) and a tensor in a column by column manner
*/
bool TestSumByColumnVT1()
{
/* a tensor of size (2, 1) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 1;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2, 4) */
int bOrder = 2;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
bDimSize[1] = 4;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
/* a tensor of size (2, 1) */
int cOrder = 2;
int * cDimSize = new int[cOrder];
cDimSize[0] = 2;
cDimSize[1] = 1;
int cUnitNum = 1;
for (int i = 0; i < cOrder; i++)
cUnitNum *= cDimSize[i];
DTYPE aData[2][1] = { {1.0F},
{0.0F} };
DTYPE bData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE answer[2][1] = { {7.0F},
{22.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(cOrder, cDimSize);
/* initialize variables */
a->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call SumByColumnVT function */
_SumByColumnVT(a, b, c);
/* check results */
cpuTest = c->CheckData(answer, cUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(cOrder, cDimSize, X_FLOAT, 1.0F, 0);
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call SumByColumnVT function */
_SumByColumnVT(aGPU, bGPU, cGPU);
/* check results */
gpuTest = cGPU->CheckData(answer, cUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete aGPU;
delete bGPU;
delete cGPU;
delete[] aDimSize;
delete[] bDimSize;
delete[] cDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete[] aDimSize;
delete[] bDimSize;
delete[] cDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for SumByColumnVT Function */
bool TestSumByColumnVT()
{
XPRINT(0, stdout, "[TEST SumByColumnVT] sum of a vector (column vector) and a tensor in a column by column manner \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestSumByColumnVT1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-06
*/
#ifndef __TEST_SUMBYCOLUMNVT_H__
#define __TEST_SUMBYCOLUMNVT_H__
#include "../core/arithmetic/SumByColumnVT.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for SumByColumnVT Function */
extern "C"
bool TestSumByColumnVT();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_SUMBYCOLUMNVT_H__
......@@ -48,9 +48,9 @@ bool Test()
//wrong = !TestMatrixMul2DParallel() || wrong;
//wrong = !TestMatrixMulBatched() || wrong;
//wrong = !TestMerge() || wrong;
//wrong = !TestMultiply() || wrong;
wrong = !TestMultiply() || wrong;
//wrong = !TestMultiplyDim() || wrong;
wrong = !TestNegate() || wrong;
//wrong = !TestNegate() || wrong;
//wrong = !TestNormalize() || wrong;
//wrong = !TestPower() || wrong;
//wrong = !TestReduceMax() || wrong;
......@@ -64,7 +64,7 @@ bool Test()
//wrong = !TestSelect() || wrong;
//wrong = !TestSetAscendingOrder() || wrong;
//wrong = !TestSetData() || wrong;
wrong = !TestSign() || wrong;
//wrong = !TestSign() || wrong;
//wrong = !TestSin() || wrong;
//wrong = !TestSort() || wrong;
//wrong = !TestSplit() || wrong;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论