Commit 2fea6615 by xuchen

improve the implementation of gather and spread

parent 99225c29
......@@ -21,6 +21,7 @@
#include "XBackwardLoss.h"
#include "../tensor/XName.h"
#include "../tensor/core/getandset/SetData.h"
#include "../tensor/function/HardTanH.h"
#include "../tensor/function/Identity.h"
#include "../tensor/function/LogSoftmax.h"
......@@ -86,9 +87,23 @@ void XLossGrad::Compute(XTensor * gold, XTensor * y,
XTensor * dedy, XTensor * padding,
LOSS_FUNCTION_NAME lossName)
{
if(gold == NULL){
if(dedy->dataType == X_FLOAT)
_SetDataFixedFloat(dedy, 1.0F);
else if(dedy->dataType == X_DOUBLE)
_SetDataFixedDouble(dedy, 1.0);
else if(dedy->dataType == X_INT)
_SetDataFixedInt(dedy, 1);
else{
ShowNTErrors("TODO");
}
return;
}
//_LossBackward(dedy, gold, y, lossName);
if(lossName == CROSSENTROPY)
_CrossEntropyBackward(dedy, y, gold, NULL, padding);
}
}
\ No newline at end of file
......@@ -40,6 +40,8 @@ void XShapeGrad::MakeGrad(XTensor * node, bool isEfficent)
if(operID == MOVEMENT_COPYINDEXED)
GradCopyIndexed(node, isEfficent);
if(operID == MOVEMENT_GATHER)
GradGather(node, isEfficent);
else if(operID == SHAPE_MERGE)
GradMerge(node, isEfficent);
else if(operID == SHAPE_MERGE_LIST)
......@@ -118,6 +120,31 @@ void XShapeGrad::GradCopyIndexed(XTensor * node, bool isEfficent)
}
/*
gradient computation for gather function
for
b = gather(a)
we have
dE/da = spreadforgather(b)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XShapeGrad::GradGather(XTensor * node, bool isEfficent)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum > 0, "Wrong input tensor number for CopyIndexed!");
XTensor * input = income.tails[0];
XTensor * index = income.tails[1];
XNoder::MakeGrad(input);
_SpreadForGather(input->grad, node->grad, index);
node->visitMark = NODE_FINISHED;
}
/*
gradient for merge
for
c = merge(a_0, a_1, ...)
......@@ -154,7 +181,6 @@ void XShapeGrad::GradMerge(XTensor * node, bool isEfficent)
XNoder::MakeGrad(input);
int * dims = new int[input->order];
memset(dims, 0, sizeof(int) * input->order);
for(int i = 0, j = 0; i < input->order; i++){
if(i >= leadDim){
dims[j++] = input->dimSize[i];
......@@ -304,14 +330,9 @@ void XShapeGrad::GradReshape(XTensor * node, bool isEfficent)
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for MERGE!");
int order = income.GetParamInt(0);
int * dimSize = (int *)income.GetParamPointer(1);
node->grad->Reshape(order, dimSize);
node->grad->Reshape(input->order, input->dimSize);
_CopyValues(node->grad, input->grad);
delete[] dimSize;
node->grad->Reshape(node->order, node->dimSize);
node->visitMark = NODE_FINISHED;
}
......@@ -407,6 +428,7 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient)
if(income.typeID == SHAPE_SPLIT_LIST){
int w = income.GetParamInt(0);
int splitID = income.GetParamInt(1);
if(whereToSplit < 0)
whereToSplit = w;
splitNum++;
......@@ -415,14 +437,16 @@ void XShapeGrad::GradSplitListPost(XTensor * node, bool isEfficient)
CheckNTErrors(income.tailNum == 1, "Something wrong with outgoing edge!");
CheckNTErrors(splitNum - 1 == splitID, "Wrong split id!");
splits.Add(parent);
splits.Add(parent->grad);
}
}
XNoder::MakeGrad(node);
/* we can simply merge the gradient tensor
if the node is used in spliting only */
if(outgo.tailNum == splitNum){
_Merge(&splits, node->grad, whereToSplit + 1);
_Merge(&splits, node->grad, whereToSplit);
}
/* if the tensor is used as input to other nodes
......
......@@ -50,6 +50,10 @@ private:
static
void GradCopyIndexed(XTensor * node, bool isEfficent);
/* gradient computation for copying indexed sub-tensors: b = gather(a, index) */
static
void GradGather(XTensor * node, bool isEfficent);
/* gradient computation for merge: c = merge(a, b, ...) */
static
void GradMerge(XTensor * node, bool isEfficent);
......
......@@ -73,6 +73,25 @@ void XNet::Clear()
}
/*
backward propagation to obtain gradient
>> root - root node (output) of the network
>> loss - name of loss function
*/
void XNet::Backward(XTensor &root, LOSS_FUNCTION_NAME loss)
{
XList roots(1);
roots.Add(&root);
XList golds(1);
golds.Add(NULL);
XList paddings(1);
paddings.Add(NULL);
Backward(roots, golds, paddings, loss);
}
/*
backward propagation to obtain gradient wrt. the loss/error function
>> root - root node (output) of the network
>> gold - gold standard for the output
......@@ -115,18 +134,33 @@ void XNet::Backward(XTensor &root, XTensor &gold, XTensor &padding, LOSS_FUNCTIO
/*
backward propagation to obtain gradient
>> root - root node (output) of the network
with a number of root nodes
>> roots - a list of root nodes (output) of the network
>> loss - name of loss function
*/
void XNet::Backward(XTensor &root, LOSS_FUNCTION_NAME loss)
void XNet::Backward(XList &roots, LOSS_FUNCTION_NAME loss)
{
XList roots(1);
roots.Add(&root);
XList golds(1);
XList golds(roots.count);
XList paddings(roots.count);
for (int i = 0; i < roots.count; i++) {
golds.Add(NULL);
paddings.Add(NULL);
}
XList paddings(1);
Backward(roots, golds, paddings, loss);
}
/*
backward propagation to obtain gradient
with a number of root nodes
>> roots - a list of root nodes (output) of the network
>> golds - a list of gold standard for the output
>> loss - name of loss function
*/
void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
{
XList paddings(roots.count);
for (int i = 0; i < roots.count; i++)
paddings.Add(NULL);
Backward(roots, golds, paddings, loss);
......@@ -211,40 +245,6 @@ void XNet::Backward(XList &roots, XList &golds, XList &paddings, LOSS_FUNCTION_N
}
/*
backward propagation to obtain gradient
with a number of root nodes
>> roots - a list of root nodes (output) of the network
>> loss - name of loss function
*/
void XNet::Backward(XList &roots, LOSS_FUNCTION_NAME loss)
{
XList golds(roots.count);
XList paddings(roots.count);
for(int i = 0; i < roots.count; i++) {
golds.Add(NULL);
paddings.Add(NULL);
}
Backward(roots, golds, paddings, loss);
}
/*
backward propagation to obtain gradient
with a number of root nodes
>> roots - a list of root nodes (output) of the network
>> golds - a list of gold standard for the output
>> loss - name of loss function
*/
void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
{
XList paddings(roots.count);
for(int i = 0; i < roots.count; i++)
paddings.Add(NULL);
Backward(roots, golds, paddings, loss);
}
/*
backward computation for a given node
>> node - the node keeps the result of an operation (e.g., activation function)
>> isEfficient - indicates whether the back-propagation is compuated in an
......
......@@ -59,19 +59,15 @@ struct XNet
/* clear the network */
void Clear();
/* backward propagation to obtain gradient */
void Backward(XTensor &root, LOSS_FUNCTION_NAME loss = NOLOSS);
/* backward propagation to obtain gradient wrt. the loss/error function */
void Backward(XTensor &root, XTensor &gold, LOSS_FUNCTION_NAME loss = NOLOSS);
/* backward propagation to obtain gradient wrt. the loss/error function */
void Backward(XTensor &root, XTensor &gold, XTensor &padding, LOSS_FUNCTION_NAME loss = NOLOSS);
/* backward propagation to obtain gradient */
void Backward(XTensor &root, LOSS_FUNCTION_NAME loss = NOLOSS);
/* backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes */
void Backward(XList &roots, XList &golds, XList &paddings, LOSS_FUNCTION_NAME loss = NOLOSS);
/* backward propagation to obtain gradient
with a number of root nodes */
void Backward(XList &roots, LOSS_FUNCTION_NAME loss = NOLOSS);
......@@ -80,6 +76,10 @@ struct XNet
with a number of root nodes */
void Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss = NOLOSS);
/* backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes */
void Backward(XList &roots, XList &golds, XList &paddings, LOSS_FUNCTION_NAME loss = NOLOSS);
/* backward computation for a given node */
void BackwardNode(XTensor * node, bool isEfficent = false);
......
......@@ -998,6 +998,7 @@ void ForwardAutoDiff(NGram * ngrams, int batch, XTensor &output, FNNModel &model
XTensor embeddingBig;
XTensor hidden;
XTensor b;
XTensor srcIndex;
int size = batch * (n-1);
int * index = new int[size];
......@@ -1009,8 +1010,11 @@ void ForwardAutoDiff(NGram * ngrams, int batch, XTensor &output, FNNModel &model
}
}
InitTensor1D(&srcIndex, size, X_INT, model.devID, model.mem);
srcIndex.SetData(index, size);
XTensor embedding;
embedding = Gather(model.embeddingW, 0, index, size);
embedding = Gather(model.embeddingW, srcIndex);
delete[] index;
......
......@@ -704,7 +704,7 @@ int T2TTrainer::LoadBatchLM(FILE * file,
dims[1] = max;
dims[2] = vs;
InitTensor(batchEnc, 2, dims, X_INT, 1.0F, -1);
InitTensor2D(batchEnc, sc, max, X_INT, devID, mem);
//InitTensor(batchEnc, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingEnc, sc, max, X_FLOAT, devID, mem);
InitTensor(gold, 3, dims, X_FLOAT, 1.0F, devID, mem);
......@@ -728,25 +728,39 @@ int T2TTrainer::LoadBatchLM(FILE * file,
int seqSize = 0;
int * batchEncValues = new int[batchEnc->unitNum];
MTYPE * paddingEncOffsets = new MTYPE[paddingEnc->unitNum];
MTYPE * goldOffsets = new MTYPE[gold->unitNum];
MTYPE * paddingDecOffsets = new MTYPE[paddingDec->unitNum];
/* need to improve the implementation */
memset(batchEncValues, 0, sizeof(int) * batchEnc->unitNum);
int wGold = 0;
//fprintf(tf, "batch %d(%d)\n", tc++, sc);
/* this might be slow on GPUs :( */
for(int s = seq; s < seq + sc; s++){
int len = isDoubledEnd ? seqLen[s] : seqLen[s] - 1;
CheckNTErrors(len <= max, "Something is wrong!");
for(int w = 0; w < len; w++){
batchEnc->Set2DInt(buf[seqOffset[s] + w], s - seq, w);
//batchEnc->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
paddingEnc->Set2D(1.0F, s - seq, w);
paddingDec->Set2D(1.0F, s - seq, w);
int num = buf[seqOffset[s] + w];
//batchEnc->Set2DInt(buf[seqOffset[s] + w], s - seq, w);
//paddingEnc->Set2D(1.0F, s - seq, w);
//paddingDec->Set2D(1.0F, s - seq, w);
batchEncValues[(s - seq) * dims[1] + w] = num;
paddingEncOffsets[wCount] = paddingEnc->GetOffset2D(s - seq, w);
paddingDecOffsets[wCount] = paddingDec->GetOffset2D(s - seq, w);
if (w > 0)
gold->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]);
//gold->Set3D(1.0F, s - seq, w - 1, buf[seqOffset[s] + w]);
goldOffsets[wGold++] = gold->GetOffset3D(s - seq, w - 1, num);
if (w == len - 1) {
if (isDoubledEnd)
gold->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
//gold->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w]);
goldOffsets[wGold++] = gold->GetOffset3D(s - seq, w, num);
else
gold->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w + 1]);
//gold->Set3D(1.0F, s - seq, w, buf[seqOffset[s] + w + 1]);
goldOffsets[wGold++] = gold->GetOffset3D(s - seq, w, buf[seqOffset[s] + w + 1]);
}
wCount++;
......@@ -765,6 +779,16 @@ int T2TTrainer::LoadBatchLM(FILE * file,
}
}
batchEnc->SetData(batchEncValues, batchEnc->unitNum);
paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCount);
paddingDec->SetDataBatched(paddingDecOffsets, 1.0F, wCount);
gold->SetDataBatched(goldOffsets, 1.0F, wGold);
delete[] batchEncValues;
delete[] paddingEncOffsets;
delete[] paddingDecOffsets;
delete[] goldOffsets;
fflush(tf);
return sc;
......
......@@ -66,7 +66,7 @@ int TransformerMain(int argc, const char ** argv)
/* learn model parameters */
if(strcmp(trainFN, ""))
trainer.Train(trainFN, testFN, modelFN, &model);
trainer.Train(trainFN, testFN, strcmp(modelFN, "") ? modelFN : "checkpoint.model", &model);
/* save the final model */
if(strcmp(modelFN, "") && strcmp(trainFN, ""))
......
......@@ -103,6 +103,8 @@ const char * GetOPName(int type)
return "M_COPYINDEXED";
else if (type == MOVEMENT_COPYVALUES)
return "M_COPYVALUES";
else if (type == MOVEMENT_GATHER)
return "M_GATHER";
else if (type == SHAPE_CONCATENATE)
return "S_CONCATENATE";
else if (type == SHAPE_MERGE)
......
......@@ -77,6 +77,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MOVEMENT GETANDSET_SELECT + 1
#define MOVEMENT_COPYINDEXED MOVEMENT + 1
#define MOVEMENT_COPYVALUES MOVEMENT_COPYINDEXED + 1
#define MOVEMENT_GATHER MOVEMENT_COPYVALUES + 1
#define SHAPE MOVEMENT_COPYVALUES + 1
#define SHAPE_CONCATENATE SHAPE + 1
......
......@@ -841,9 +841,9 @@ bool IsFloatEqual(DTYPE a, DTYPE b, float absError, float relError)
if(fabs(a - b) < absError)
return true;
if(fabs(a) < fabs(b))
return (fabs(a - b) / b < relError) ? true : false;
return (fabs((a - b) / b) < relError) ? true : false;
else
return (fabs(a - b) / a < relError) ? true : false;
return (fabs((a - b) / a) < relError) ? true : false;
}
/* check whether the data array is the same as the answer */
......@@ -1278,7 +1278,7 @@ int XTensor::GetNonzeroSize()
if(dataType == DEFAULT_DTYPE){
int count = 0;
for(int i = 0; i < unitNum; i++){
DTYPE value = *(DTYPE*)((char*)data + i * sizeof(DTYPE));
DTYPE value = *((DTYPE*)(char*)data + i * sizeof(DTYPE));
if(value == 0)
count++;
}
......@@ -1585,7 +1585,6 @@ void XTensor::Dump(FILE * file, const char * label, const int n, const int beg,
fprintf(file, " dtype=%s dense=%f\n", GetDataTypeName(dataType), denseRatio);
if(!isInit){
fprintf(file, "NULL");
}
......@@ -1601,7 +1600,7 @@ void XTensor::Dump(FILE * file, const char * label, const int n, const int beg,
}
}
else if(dataType == X_INT) {
else if (dataType == X_INT) {
int end = MIN(n > 0 ? beg + n : beg + unitNum, unitNum);
for(int i = beg; i < end; i++){
int f = ((int*)d)[i];
......@@ -2261,8 +2260,6 @@ XTensor * NewTensor(const XTensor * a, bool isFilledData)
CheckNTErrors((a != NULL), "Empty input!");
memset(dims, 0, sizeof(int) * MAX_TENSOR_DIM_NUM);
if(a->order > 0)
memcpy(dims, a->dimSize, sizeof(int) * a->order);
......
......@@ -274,7 +274,7 @@ public:
void SetData(const void * d, int num, int beg = 0);
/* set tensor items by a uniform distribution */
void SetDataRand(DTYPE lower, DTYPE upper);
void SetDataRand(DTYPE lower = 0.0F, DTYPE upper = 1.0F);
/* set tensor items by a normal distribution */
void SetDataRandn(DTYPE mean, DTYPE standardDeviation);
......
......@@ -32,7 +32,8 @@ convert data type
*/
void _ConvertDataType(const XTensor * input, XTensor * output)
{
CheckNTErrors((input->unitSize == output->unitSize), "Input and Output must be same in size!");
//CheckNTErrors((input->unitSize == output->unitSize), "Input and Output must be same in size!");
if (input->dataType == output->dataType)
return;
......
......@@ -114,7 +114,8 @@ convert data type (cuda code)
*/
void _CudaConvertDataType(const XTensor * input, XTensor * output)
{
CheckNTErrors((input->unitSize == output->unitSize), "Input and Output must be same in size!");
//CheckNTErrors((input->unitSize == output->unitSize), "Input and Output must be same in size!");
if (input->dataType == output->dataType)
return;
......@@ -133,6 +134,10 @@ void _CudaConvertDataType(const XTensor * input, XTensor * output)
KernelFloatToInt<<<blocks, threads>>>((float*)input->data, (int*)output->data, input->unitNum);
else if(input->dataType == X_INT && output->dataType == X_FLOAT)
KernelIntToFloat<<<blocks, threads>>>((int*)input->data, (float*)output->data, input->unitNum);
else if(input->dataType == X_FLOAT && output->dataType == X_FLOAT16)
KernelFloatToFloat16<<<blocks, threads>>>((float*)input->data, (__half*)output->data, input->unitNum);
else if(input->dataType == X_FLOAT16 && output->dataType == X_FLOAT)
KernelFloat16ToFloat<<<blocks, threads>>>((__half*)input->data, (float*)output->data, input->unitNum);
else{
ShowNTErrors("Unsupported data types for conversion!");
}
......
......@@ -20,8 +20,10 @@
*/
#include "Gather.h"
#include "Gather.cuh"
#include "CopyIndexed.h"
#include "../../XUtility.h"
#include "../../XName.h"
#include "../shape/Reshape.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -37,7 +39,7 @@ gather indexed sub-tensors
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
*/
void _Gather(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize)
void _Gather(XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize)
{
int * tgtIndex = new int[indexSize];
for(int i = 0; i < indexSize; i++)
......@@ -49,32 +51,25 @@ void _Gather(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexS
}
/*
gather indexed sub-tensors (return a XTensor structure)
make a new tensor to keep the result and return it
gather indexed sub-tensors
>> s - the source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
<< return - the result of copying indexed sub-tensors
Notice: the index must be on the CPU!!!
>> t - the target tensor
>> srcIndex - the tensor to save the index of the source tensor
*/
XTensor Gather(const XTensor &s, int dim, int * srcIndex, int indexSize)
void _Gather(XTensor * s, XTensor * t, XTensor * srcIndex)
{
int * tgtIndex = new int[indexSize];
for(int i = 0; i < indexSize; i++)
tgtIndex[i] = i;
/* call CopyIndexed function */
XTensor result;
result = CopyIndexed(s, dim, srcIndex, indexSize, tgtIndex, 1);
delete[] tgtIndex;
return result;
CheckNTErrors((s && t), "Invalid tensors!");
CheckNTErrors((s->devID == t->devID && t->devID == srcIndex->devID),
"the data must be kept on the same device!");
CheckNTErrors((s->unitSize == t->unitSize), "Unmatched tensors!");
#ifdef USE_CUDA
if (s->devID >= 0 && t->devID >= 0 && srcIndex->devID >= 0) {
_CudaGather(s, t, srcIndex);
return;
}
#endif
}
/*
......@@ -83,46 +78,46 @@ make a new tensor to keep the result and return it
>> s - the source tensor(2D)
>> index - the index tensor
<< return - the result of copying indexed sub-tensors
<< return - the result of gather indexed sub-tensors
*/
XTensor Gather(const XTensor &s, const XTensor &index)
XTensor Gather(XTensor &s, XTensor &index)
{
int indexSize = index.unitNum;
int dim = 0;
CheckNTErrors(s.order == 2, "The order of the input tensor must be 2!");
int * srcIndex = new int[index.unitNum];
int order = s.order;
int * dimSize = new int[order];
if(index.dataType == X_INT) {
XMemCopy(srcIndex, -1, index.data, index.devID, indexSize * index.unitSize);
}
else if(index.dataType == X_FLOAT || index.dataType == X_DOUBLE) {
DTYPE * tmp = new DTYPE[indexSize];
XMemCopy(tmp, -1, index.data, index.devID, indexSize * index.unitSize);
for(int i = 0; i < indexSize; i++)
srcIndex[i] = (int)tmp[i];
delete[] tmp;
}
else{
ShowNTErrors("Unsupported data type!");
for (int i = 0; i < s.order; i++) {
if (i == dim)
dimSize[i] = index.unitNum;
else
dimSize[i] = s.dimSize[i];
}
XTensor tensor;
tensor = Gather(s, 0, srcIndex, indexSize);
delete[] srcIndex;
float dr = (!s.isSparse) ? 1.0F : s.denseRatio;
XTensor t(order, dimSize, s.dataType, dr, s.devID, s.mem);
t.SetTMPFlag();
_Gather(&s, &t, &index);
/* tensor connection */
XLink::MakeLink(&s, &index, &t, MOVEMENT_GATHER);
if(index.order > 1) {
int * dims = new int[index.order + 1];
memcpy(dims, index.dimSize, index.order * sizeof(int));
dims[index.order] = tensor.GetDim(-1);
dims[index.order] = t.GetDim(-1);
XTensor t;
t = Reshape(tensor, index.order + 1, dims);
XTensor tt;
tt = Reshape(t, index.order + 1, dims);
delete[] dims;
return t;
return tt;
}
else {
return tensor;
return t;
}
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "Gather.cuh"
#include "CopyBlocksSelected.cuh"
#include "../../XDevice.h"
#include "../../XUtility.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
gather indexed sub-tensors(cuda version)
>> source - the data pointer of the source tensor
>> target - the data pointer of the target tensor
>> srcIndex - the index of the source tensor
>> indexSize - the size of the srcIndex
>> stride - stride of a data block
*/
__global__
void KernelGather(DTYPE * source, DTYPE * target, int * srcIndex, int indexSize, int stride)
{
__shared__ DTYPE * sp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE * cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* offset in each block */
int offset = blockDim.y * blockIdx.y + threadIdx.y;
if(i >= indexSize || offset >= stride)
return;
if(threadIdx.y == 0){
sp[threadIdx.x] = source + srcIndex[i] * stride;
cp[threadIdx.x] = target + i * stride;
}
__syncthreads();
DTYPE * s = sp[threadIdx.x];
DTYPE * c = cp[threadIdx.x];
c[offset] = s[offset];
}
/*
gather indexed sub-tensors(cuda version)
>> s - the source tensor
>> t - the target tensor
>> srcIndex - the tensor to save the index of the source tensor
*/
void _CudaGather(XTensor * s, XTensor * t, XTensor * srcIndex)
{
int devID = s->devID;
int stride = s->GetDim(1);
int indexSize = srcIndex->unitNum;
int cudaGrids[3];
int cudaBlocks[3];
int devIDBackup;
ProtectCudaDev(devID, devIDBackup);
GDevs.GetCudaThread2D(devID, indexSize, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
DTYPE * source = (DTYPE*)s->data;
DTYPE * target = (DTYPE*)t->data;
int * si = (int *)srcIndex->data;
KernelGather<<<blocks, threads >>>(source, target, si, indexSize, stride);
BacktoCudaDev(devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __GATHER_CUH__
#define __GATHER_CUH__
#include "../../XTensor.h"
#include "Gather.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* gather indexed sub-tensors(cuda version) */
void _CudaGather(XTensor * s, XTensor * t, XTensor * srcIndex);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __GATHER_CUH__
\ No newline at end of file
......@@ -27,15 +27,14 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* gather selected sub-tensors */
void _Gather(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize);
void _Gather(XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize);
/* gather selected sub-tensors (return a XTensor structure)
make a new tensor to keep the result and return it */
XTensor Gather(const XTensor &s, int dim, int * srcIndex, int indexSize);
/* gather selected sub-tensors */
void _Gather(XTensor * s, XTensor * t, XTensor * srcIndex);
/* gather selected sub-tensors (return a XTensor structure)
make a new tensor to keep the result and return it */
XTensor Gather(const XTensor &s, const XTensor &index);
XTensor Gather(XTensor &s, XTensor &index);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -197,4 +197,42 @@ void _SpreadForGather(XTensor * source, XTensor * collection, int dim,
}
}
/*
spread a collection tensor to source tensor.
And this is a special spread function for backward computation of gather function.
>> source - the source tensor whose data would be modified
>> collection - the collection whose data would be spread to source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and collIndex)
>> collIndex - index of the gathered sub-tensors
*/
void _SpreadForGather(XTensor * source, XTensor * collection, XTensor * index)
{
int dim = 0;
int order = source->order;
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
for(int i = 0; i < order; i++){
if(i < dim){
CheckNTErrors(collection->GetDim(i) == source->GetDim(i), "Illegal dimension!");
}
else if(i > dim){
CheckNTErrors(collection->GetDim(i) == source->GetDim(i), "Illegal dimension!");
}
}
#ifdef USE_CUDA
if(source->devID >= 0 && collection->devID >= 0 && index->devID >= 0) {
_CudaSpreadForGather(source, collection, index);
return;
}
#endif
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -58,6 +58,54 @@ void KernelSpread(DTYPE * sData, DTYPE * cData, int blockNum,
s[j] = c[j];
}
/*
This is core assignment for spread function.
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - number of data blocks
>> blockSizeSrc - size of source data block
>> blockSizeColl - size of source data block
>> stride - stride of a data block
>> subtensorNum - number of sub-tensors
>> srcIndex - index of the source sub-tensor
>> colIndex - index of the sub-tensor in the collection tensor
*/
__global__
void KernelSpreadFuzed(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride,
int subtensorNum,
int * srcIndex, int * colIndex)
{
__shared__ DTYPE * sp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE * cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* offset in each block */
int offset = blockDim.y * blockIdx.y + threadIdx.y;
int blockId = i % blockNum;
int subtensorId = i / blockNum;
if(subtensorId >= subtensorNum || offset >= stride)
return;
if(threadIdx.y == 0){
sp[threadIdx.x] = sData + srcIndex[subtensorId] * stride;
cp[threadIdx.x] = cData + colIndex[subtensorId] * stride;
}
__syncthreads();
DTYPE * s = sp[threadIdx.x] + blockSizeSrc * blockId;
DTYPE * c = cp[threadIdx.x] + blockSizeColl * blockId;
s[offset] = c[offset];
}
/*
spread a collection tensor to source tensor (cuda version).
This is a inverse operation compared to gather.
......@@ -103,6 +151,12 @@ void _CudaSpread(XTensor * source, XTensor * collection, int dim,
int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup);
if(indexSize < 4){
GDevs.GetCudaThread2D(source->devID, blockNum, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
DTYPE * sData = (DTYPE*)source->data;
DTYPE * cData = (DTYPE*)collection->data;
for(int i = 0; i < indexSize; i++) {
......@@ -113,6 +167,33 @@ void _CudaSpread(XTensor * source, XTensor * collection, int dim,
KernelSpread<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl, stride);
}
}
else{
GDevs.GetCudaThread2D(source->devID, blockNum * indexSize, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
DTYPE * s = (DTYPE*)source->data;
DTYPE * c = (DTYPE*)collection->data;
XMem * mem = source->mem;
int * si = mem != NULL ?
(int*)mem->AllocBuf(mem->devID, sizeof(int) * indexSize * 2) :
(int*)XMemAlloc(mem->devID, sizeof(int) * indexSize * 2);
int * ci = si + indexSize;
XMemCopy(si, mem->devID, srcIndex, -1, sizeof(int) * indexSize);
XMemCopy(ci, mem->devID, collIndex, -1, sizeof(int) * indexSize);
KernelSpreadFuzed<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl,
stride, indexSize, si, ci);
if(mem != NULL)
mem->ReleaseBuf(mem->devID, sizeof(int) * indexSize * 2);
else
XMemFree(mem->devID, si);
}
BacktoCudaDev(source->devID, devIDBackup);
}
......@@ -196,6 +277,53 @@ void KernelSpreadForGatherFuzed(DTYPE * sData, DTYPE * cData, int blockNum,
}
/*
This is core assignment for backward computation of gather function.
Care of the operator "+=" instead of "=".
>> sData - the data pointer of the source tensor
>> cData - the data pointer of collection tensor
>> blockNum - number of data blocks
>> blockSizeSrc - size of source data block
>> blockSizeColl - size of source data block
>> stride - stride of a data block
>> subtensorNum - number of sub-tensors
>> srcIndex - index of the source sub-tensor
*/
__global__
void KernelSpreadForGatherFuzed(DTYPE * sData, DTYPE * cData, int blockNum,
int blockSizeSrc, int blockSizeColl, int stride,
int subtensorNum,
int * srcIndex)
{
__shared__ DTYPE * sp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE * cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* offset in each block */
int offset = blockDim.y * blockIdx.y + threadIdx.y;
int blockId = i % blockNum;
int subtensorId = i / blockNum;
if(subtensorId >= subtensorNum || offset >= stride)
return;
if(threadIdx.y == 0){
sp[threadIdx.x] = sData + srcIndex[subtensorId] * stride;
cp[threadIdx.x] = cData + subtensorId * stride;
}
__syncthreads();
DTYPE * s = sp[threadIdx.x] + blockSizeSrc * blockId;
DTYPE * c = cp[threadIdx.x] + blockSizeColl * blockId;
s[offset] += c[offset];
}
/*
spread a collection tensor to source tensor (cuda version).
And this is a special spread function for backward computation of gather function.
......@@ -282,6 +410,46 @@ void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
XMemFree(collection->devID, ci);
}
}
}
/*
spread a collection tensor to source tensor (cuda version).
And this is a special spread function for backward computation of gather function.
>> source - the source tensor whose data would be modified
>> collection - the collection whose data would be spread to source tensor
>> srcIndex - index of the source sub-tensors
*/
void _CudaSpreadForGather(XTensor * source, XTensor * collection, XTensor * srcIndex)
{
int dim = 0;
int devID = source->devID;
int blockNum = 1;
int stride = source->GetDim(1);
int indexSize = srcIndex->unitNum;
int blockSizeSrc = stride * source->GetDim(dim);
int blockSizeColl = stride * collection->GetDim(dim);
int cudaGrids[3];
int cudaBlocks[3];
int devIDBackup;
ProtectCudaDev(source->devID, devIDBackup);
GDevs.GetCudaThread2D(devID, indexSize, stride, MAX_INT, cudaGrids, cudaBlocks);
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
DTYPE * s = (DTYPE*)source->data;
DTYPE * c = (DTYPE*)collection->data;
int * si = (int *)srcIndex->data;
KernelSpreadForGatherFuzed<<<blocks, threads >>>(s, c, blockNum, blockSizeSrc, blockSizeColl,
stride, indexSize, si);
BacktoCudaDev(source->devID, devIDBackup);
}
......
......@@ -34,6 +34,9 @@ void _CudaSpread(XTensor * source, XTensor * collection, int dim,
void _CudaSpreadForGather(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex);
/* special spread function for backward computation of gather function (cuda version) */
void _CudaSpreadForGather(XTensor * source, XTensor * collection, XTensor * srcIndex);
} // namespace nts(NiuTrans.Tensor)
#endif // __SPREAD_CUH__
\ No newline at end of file
......@@ -39,6 +39,9 @@ void Spread(XTensor * source, XTensor * collection, int dim,
void _SpreadForGather(XTensor * source, XTensor * collection, int dim,
int * srcIndex, int indexSize, int * collIndex);
/* special spread function for backward computation of gather function */
void _SpreadForGather(XTensor * source, XTensor * collection, XTensor * index);
} // namespace nts(NiuTrans.Tensor)
#endif // __SPREAD_H__
\ No newline at end of file
......@@ -208,8 +208,11 @@ merge small tensors into a big tensor
*/
void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
{
whereToMerge = (whereToMerge < 0 ? big->order - 1 : whereToMerge);
CheckNTErrors((smalls != NULL), "Invalid list!");
CheckNTErrors((smalls->count > 0), "Empty list!");
CheckNTErrors((whereToMerge >= 0 && whereToMerge < big->order), "Wrong range of whereToMerge");
bool uniform = true;
......
......@@ -39,17 +39,11 @@ XTensor Reshape(XTensor &s, int order, int * dimSize)
t.SetTMPFlag();
_CopyValues(&s, &t);
int oriOrder = s.order;
int * oriDimSize = new int[order];
memcpy(oriDimSize, s.dimSize, sizeof(int) * order);
/* call Reshape function */
t.Reshape(order, dimSize);
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE);
XLink::AddParamToHeadInt(&t, oriOrder);
XLink::AddParamToHeadPointer(&t, oriDimSize);
return t;
}
......
......@@ -126,7 +126,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
void * dataTMP = t->data;
if (!isOnSameDevice)
dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(s->devID, size);
dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(mem->devID, size);
int realBlockSize = blockSize * t->unitSize;
int blockSplitSize = blockNum / splitNum;
......@@ -344,22 +344,6 @@ void Split(const XTensor &big, XList &smalls, int whereToSplit, int splitNum)
{
CheckNTErrors(big.GetDim(whereToSplit) % splitNum == 0, "Wrong splitNum!");
int order = big.order;
int * dimSize = new int[order];
for (int i = 0; i < big.order; i++) {
if (i != whereToSplit)
dimSize[i] = big.dimSize[i];
else
dimSize[i] = big.dimSize[whereToSplit] / splitNum;
}
float dr = (!big.isSparse) ? 1.0F : big.denseRatio;
for (int i = 0; i < splitNum; i++) {
XTensor * item = NewTensor(order, dimSize, big.dataType, dr, big.devID, big.mem);
smalls.Add(item);
}
delete[] dimSize;
/* call _Split function */
_Split(&big, &smalls, whereToSplit, splitNum);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-12
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-12
*/
#include "../core/math/Unary.h"
#include "TAbsolute.h"
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-12
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-12
*/
#include "TConvertDataType.h"
#include "../core/arithmetic/MatrixMul.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -102,7 +103,6 @@ bool TestConvertDataType1()
/*
case 2: test ConvertDataType function.
In this case, the int32 data type is converted to float32 data type.
*/
bool TestConvertDataType2()
{
......@@ -175,6 +175,122 @@ bool TestConvertDataType2()
#endif // USE_CUDA
}
/*
case 3: test ConvertDataType function.
In this case, the float data type is converted to float16 data type.
*/
bool TestConvertDataType3()
{
int order = 2;
/* a tensor of size (3, 2) */
int * dimSize1 = new int[order];
dimSize1[0] = 3;
dimSize1[1] = 2;
int unitNum1 = 1;
for (int i = 0; i < order; i++)
unitNum1 *= dimSize1[i];
/* a tensor of size (3, 2) */
int * dimSize2 = new int[order];
dimSize2[0] = 2;
dimSize2[1] = 3;
int unitNum2 = 1;
for (int i = 0; i < order; i++)
unitNum2 *= dimSize2[i];
/* a tensor of size (3, 3) */
int * dimSize3 = new int[order];
dimSize3[0] = 3;
dimSize3[1] = 3;
int unitNum3 = 1;
for (int i = 0; i < order; i++)
unitNum3 *= dimSize3[i];
DTYPE data1[3][2] = { {1.0F, -2.0F},
{0.5F, -4.0F},
{0.0F, 6.0F} };
DTYPE data2[2][3] = { {1.0F, 2.0F, 3.0F},
{0.0F, 4.0F, 5.0F} };
DTYPE answer[3][3] = { {1.0F, -6.0F, -7.0F},
{0.5F, -15.0F, -18.5F},
{0.0F, 24.0F, 30.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(order, dimSize1, X_FLOAT, 1.0F, -1);
XTensor * b = NewTensor(order, dimSize1, X_FLOAT16, 1.0F, -1);
XTensor * c = NewTensor(order, dimSize1, X_FLOAT, 1.0F, -1);
/* initialize variables */
a->SetData(data1, unitNum1);
/* call ConvertDataType function */
//_ConvertDataType(a, b);
//_ConvertDataType(b, c);
/* check results */
cpuTest = a->CheckData(data1, unitNum1, 1e-4F);
c->Dump(stderr, "");
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(order, dimSize1, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(order, dimSize2, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(order, dimSize1, X_FLOAT16, 1.0F, 0);
XTensor * dGPU = NewTensor(order, dimSize2, X_FLOAT16, 1.0F, 0);
XTensor * eGPU = NewTensor(order, dimSize3, X_FLOAT16, 1.0F, 0);
XTensor * fGPU = NewTensor(order, dimSize3, X_FLOAT, 1.0F, 0);
/* Initialize variables */
aGPU->SetData(data1, unitNum1);
bGPU->SetData(data2, unitNum2);
/* call ConvertDataType function */
_ConvertDataType(aGPU, cGPU);
_ConvertDataType(bGPU, dGPU);
_MatrixMul(cGPU, X_NOTRANS, dGPU, X_NOTRANS, eGPU);
_ConvertDataType(eGPU, fGPU);
/* check results */
gpuTest = fGPU->CheckData(answer, unitNum3, 1e-4F);
/* destroy variables */
delete a;
delete b;
delete c;
delete aGPU;
delete bGPU;
delete cGPU;
delete[] dimSize1;
delete[] dimSize2;
delete[] dimSize3;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete[] dimSize1;
delete[] dimSize2;
delete[] dimSize3;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
......@@ -206,6 +322,16 @@ bool TestConvertDataType()
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* case 3 test */
caseFlag = TestConvertDataType3();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 3 failed!\n");
}
else
XPRINT(0, stdout, ">> case 3 passed!\n");
/* other cases test */
/*
TODO!!
......
......@@ -75,7 +75,6 @@ bool TestGather1()
/* create tensors */
XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * t = NewTensor(tOrder, tDimSize);
XTensor tUser;
/* initialize variables */
s->SetData(sData, sUnitNum);
......@@ -83,10 +82,9 @@ bool TestGather1()
/* call Gather function */
_Gather(s, t, dim, srcIndex, indexSize);
tUser = Gather(*s, dim, srcIndex, indexSize);
/* check results */
cpuTest = t->CheckData(answer, tUnitNum) && tUser.CheckData(answer, tUnitNum);
cpuTest = t->CheckData(answer, tUnitNum);
#ifdef USE_CUDA
/* GPU test */
......@@ -103,10 +101,9 @@ bool TestGather1()
/* call Gather function */
_Gather(sGPU, tGPU, dim, srcIndex, indexSize);
tUserGPU = Gather(*sGPU, dim, srcIndex, indexSize);
/* check results */
gpuTest = tGPU->CheckData(answer, tUnitNum) && tUserGPU.CheckData(answer, tUnitNum);
gpuTest = tGPU->CheckData(answer, tUnitNum);
/* destroy variables */
delete s;
......@@ -177,7 +174,6 @@ bool TestGather2()
/* create tensors */
XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * t = NewTensor(tOrder, tDimSize);
XTensor tUser;
/* initialize variables */
s->SetData(sData, sUnitNum);
......@@ -185,10 +181,9 @@ bool TestGather2()
/* call Gather function */
_Gather(s, t, dim, srcIndex, indexSize);
tUser = Gather(*s, dim, srcIndex, indexSize);
/* check results */
cpuTest = t->CheckData(answer, tUnitNum) && tUser.CheckData(answer, tUnitNum);
cpuTest = t->CheckData(answer, tUnitNum);
#ifdef USE_CUDA
/* GPU test */
......@@ -205,7 +200,6 @@ bool TestGather2()
/* call Gather function */
_Gather(sGPU, tGPU, dim, srcIndex, indexSize);
tUserGPU = Gather(*sGPU, dim, srcIndex, indexSize);
/* check results */
gpuTest = tGPU->CheckData(answer, tUnitNum) && tUserGPU.CheckData(answer, tUnitNum);
......@@ -230,6 +224,120 @@ bool TestGather2()
#endif // USE_CUDA
}
/*
case 3: gather indexed sub-tensors
In this case, (3, 3) -> (2, 3), dim = 0,
srcIndex = [0, 2]
*/
bool TestGather3()
{
/* a input tensor of size (3, 3) */
int sOrder = 2;
int * sDimSize = new int[sOrder];
sDimSize[0] = 3;
sDimSize[1] = 3;
int sUnitNum = 1;
for (int i = 0; i < sOrder; i++)
sUnitNum *= sDimSize[i];
/* a output tensor of size (2, 3) */
int tOrder = 2;
int * tDimSize = new int[tOrder];
tDimSize[0] = 2;
tDimSize[1] = 3;
int tUnitNum = 1;
for (int i = 0; i < tOrder; i++)
tUnitNum *= tDimSize[i];
/* a index tensor of size (2) */
int indexOrder = 1;
int * indexDimSize = new int[indexOrder];
indexDimSize[0] = 2;
int indexUnitNum = 1;
for (int i = 0; i < indexOrder; i++)
indexUnitNum *= indexDimSize[i];
DTYPE sData[3][3] = { {0.0F, -1.0F, 2.0F},
{2.0F, 1.0F, 3.0F},
{1.0F, 2.0F, 4.0F} };
DTYPE answer[2][3] = { {0.0F, -1.0F, 2.0F},
{1.0F, 2.0F, 4.0F} };
int dim = 0;
int indexSize = 2;
int srcIndex[2] = {0, 2};
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * t = NewTensor(tOrder, tDimSize);
XTensor * index = NewTensor(indexOrder, indexDimSize, X_INT);
XTensor tUser;
/* initialize variables */
s->SetData(sData, sUnitNum);
t->SetZeroAll();
index->SetData(srcIndex, indexSize);
/* call Gather function */
_Gather(s, t, dim, srcIndex, indexSize);
//tUser = Gather(*s, *index);
/* check results */
cpuTest = t->CheckData(answer, tUnitNum);
//tUser2.CheckData(answer, tUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensors */
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
XTensor * tGPU = NewTensor(sOrder, tDimSize, X_FLOAT, 1.0F, 0);
XTensor * indexGPU = NewTensor(indexOrder, indexDimSize, X_INT, 1.0F, 0);
XTensor tUserGPU;
/* initialize variables */
sGPU->SetData(sData, sUnitNum);
tGPU->SetZeroAll();
indexGPU->SetData(srcIndex, indexSize);
/* call Gather function */
_Gather(sGPU, tGPU, dim, srcIndex, indexSize);
tUserGPU = Gather(*sGPU, *indexGPU);
/* check results */
gpuTest = tGPU->CheckData(answer, tUnitNum) &&
tUserGPU.CheckData(answer, tUnitNum);
/* destroy variables */
delete s;
delete t;
delete index;
delete sGPU;
delete tGPU;
delete indexGPU;
delete[] sDimSize;
delete[] tDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s;
delete t;
delete[] sDimSize;
delete[] tDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
......@@ -259,6 +367,15 @@ bool TestGather()
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* case 2 test */
caseFlag = TestGather3();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 3 failed!\n");
}
else
XPRINT(0, stdout, ">> case 3 passed!\n");
/* other cases test */
/*
TODO!!
......
......@@ -272,6 +272,8 @@ bool TestSplit3()
XTensor * s = NewTensor(sOrder, sDimSize);
XTensor * t1 = NewTensor(tOrder1, tDimSize1);
XTensor * t2 = NewTensor(tOrder2, tDimSize2);
XTensor * t3 = NewTensor(tOrder2, tDimSize2);
XTensor * t4 = NewTensor(tOrder2, tDimSize2);
/* initialize variables */
s->SetData(sData, sUnitNum);
......@@ -282,6 +284,9 @@ bool TestSplit3()
tList->Add(t1);
tList->Add(t2);
tUserList.Add(t3);
tUserList.Add(t4);
/* call split function */
_Split(s, tList, 1, 2);
Split(*s, tUserList, 1, 2);
......@@ -302,6 +307,8 @@ bool TestSplit3()
XTensor * sGPU = NewTensor(sOrder, sDimSize, X_FLOAT, 1.0F, 0);
XTensor * tGPU1 = NewTensor(tOrder1, tDimSize1, X_FLOAT, 1.0F, 0);
XTensor * tGPU2 = NewTensor(tOrder2, tDimSize2, X_FLOAT, 1.0F, 0);
XTensor * tGPU3 = NewTensor(tOrder2, tDimSize2, X_FLOAT, 1.0F, 0);
XTensor * tGPU4 = NewTensor(tOrder2, tDimSize2, X_FLOAT, 1.0F, 0);
/* Initialize variables */
sGPU->SetData(sData, sUnitNum);
......@@ -312,6 +319,9 @@ bool TestSplit3()
tList->Add(tGPU1);
tList->Add(tGPU2);
tUserList.Add(tGPU3);
tUserList.Add(tGPU4);
/* call Split function */
_Split(sGPU, tList, 1, 2);
Split(*sGPU, tUserList, 1, 2);
......@@ -324,9 +334,13 @@ bool TestSplit3()
delete s;
delete t1;
delete t2;
delete t3;
delete t4;
delete sGPU;
delete tGPU1;
delete tGPU2;
delete tGPU3;
delete tGPU4;
delete[] sDimSize;
delete[] tDimSize1;
delete[] tDimSize2;
......@@ -338,6 +352,8 @@ bool TestSplit3()
delete s;
delete t1;
delete t2;
delete t3;
delete t4;
delete[] sDimSize;
delete[] tDimSize1;
delete[] tDimSize2;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论