Commit a26caf40 by xiaotong

better code of SumDim

parent daa2f801
......@@ -41,8 +41,8 @@ int main( int argc, const char ** argv )
//TransposeTest();
//return 0;
SumDimTest();
return 0;
//SumDimTest();
//return 0;
if(argc > 1 && !strcmp(argv[1], "-test"))
1;//Test();
......
......@@ -37,6 +37,8 @@ void XMathGrad::MakeGrad(XTensor * node)
if(operID == MATH_SUM)
GradSum(node);
else if(operID == MATH_SUMDIM)
GradSumDim(node);
else if(operID == MATH_MULTIPLY)
GradMultiply(node);
else if(operID == MATH_MATRIXMUL)
......@@ -80,6 +82,90 @@ void XMathGrad::GradSum(XTensor * node)
}
/*
gradient for sum with one dimension
c = a + b * \beta
where the size of b is equal to dimension n of a, i.e., |b| = a.dimSize[n]
dE/da = dE/dc
dE/db = dE/dc * b.reduce(0,...,n-1,n+1,...) * \beta
*/
void XMathGrad::GradSumDim(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 2, "Wrong input tensor number for SUM!");
XTensor * a = income.tails[0];
XTensor * b = income.tails[1];
int n = income.GetParamInt(0);
DTYPE beta = income.GetParam(1);
XNoder::MakeGrad(a);
XNoder::MakeGrad(b);
_Sum(a->grad, node->grad, a->grad);
int order = a->order;
int dimSize[MAX_TENSOR_DIM_NUM];
memcpy(dimSize, a->dimSize, sizeof(int) * a->order);
if(n == order - 1){
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = a->unitNum/dimSize[order - 1];
reshapedSize[1] = dimSize[order - 1];
/* we reshape dE/dc to a matrix whose column number is equal to the
size of b. Then we can reduce the matrix into a row vector. */
node->grad->Reshape(2, reshapedSize);
if(b->outgo.tailNum > 1){
XTensor * bGradTMP = NewTensorBuf(b->grad, b->devID, b->mem);
_ReduceSum(node->grad, bGradTMP, 0);
_Sum(bGradTMP, b->grad, b->grad);
DelTensorBuf(bGradTMP);
}
else
_ReduceSum(node->grad, b->grad, 0);
node->grad->Reshape(order, dimSize);
}
else{
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = 1;
reshapedSize[1] = dimSize[n];
reshapedSize[2] = 1;
for(int i = 0; i < order; i++){
if(i < n)
reshapedSize[0] *= dimSize[i];
}
reshapedSize[2] = a->unitNum / (reshapedSize[0] * reshapedSize[1]);
/* we reshape dE/dc to a 3D tensor of size (x, y, z) where y = |b|.
Then reduce along with z and x to obtain dE/db. */
node->grad->Reshape(3, reshapedSize);
XTensor * interGrad = NewTensorBuf(2, reshapedSize, b->devID, b->mem, b->dataType, b->denseRatio);
_ReduceSum(node->grad, interGrad, 2);
if(b->outgo.tailNum > 1){
XTensor * bGradTMP = NewTensorBuf(b->grad, b->devID, b->mem);
_ReduceSum(interGrad, bGradTMP, 0);
_Sum(bGradTMP, b->grad, b->grad);
DelTensorBuf(bGradTMP);
}
else
_ReduceSum(interGrad, b->grad, 0);
node->grad->Reshape(order, dimSize);
DelTensorBuf(interGrad);
}
node->visitMark = NODE_FINISHED;
}
/*
gradient for multiply (dot production)
for
c = a * b
......
......@@ -44,6 +44,11 @@ private:
static
void GradSum(XTensor * node);
/* gradient for sum with one dimension: c = a + b * \beta
where the size of b is equal to that of one dimension of a */
static
void GradSumDim(XTensor * node);
/* gradient for multiply (dot production): c = a * b */
static
void GradMultiply(XTensor * node);
......
......@@ -999,15 +999,11 @@ void ForwardAutoDiff(XTensor inputs[], XTensor &output, FNNModel &model)
hidden = Merge(hidden, 2, 0);
/* hidden layers */
for(int i = 0; i < depth; i++){
b = Unsqueeze(model.hiddenB[i], 0, batchSize);
hidden = MMul(hidden, model.hiddenW[i]) + b;
}
b = Unsqueeze(model.outputB, 0, batchSize);
for(int i = 0; i < depth; i++)
hidden = MMul(hidden, model.hiddenW[i]) + model.hiddenB[i];
/* output layer */
output = LogSoftmax(MMul(hidden, model.outputW) + b, 1);
output = LogSoftmax(MMul(hidden, model.outputW) + model.outputB, 1);
//XLink::ShowNetwork(stderr, &output);
}
......
......@@ -41,6 +41,8 @@ const char * GetOPName(int type)
return "M_SIGN";
else if (type == MATH_SUM)
return "M_SUM";
else if (type == MATH_SUMDIM)
return "M_SUMDIM";
else if (type == MATH_LOG)
return "M_LOG";
else if (type == MATH_NORMALIZE)
......
......@@ -1885,12 +1885,13 @@ generate a XTensor which allocates data on the buffer
>> myDimSize - the size of each dimension
>> myMem - memory pool used to allocating the data array.
we actually allocate the data on the buffer associated with
the memory pool.
the memory pool
>> devID - device id
>> myDataType - unit size (e.g., int, float, and double)
>> myDenseRatio - how often an element has non-zero value
*/
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, XMem * myMem,
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, int devID, XMem * myMem,
const TENSOR_DATA_TYPE myDataType, const float myDenseRatio)
{
CheckNTErrors(myMem != NULL, "No memory pool specified!");
......@@ -1901,12 +1902,31 @@ XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, XMem * myMem,
dims[0] = -abs(dims[0]);
XTensor * tensor = NewTensor(myOrder, dims, myDataType, myDenseRatio, -1, myMem);
if(myMem != NULL)
tensor->data = myMem->AllocBuf(myMem->devID, tensor->unitNum * tensor->unitSize);
else
tensor->data = XMemAlloc(devID, tensor->unitNum * tensor->unitSize);
return tensor;
}
/*
generate a XTensor which allocates data on the buffer
>> reference - reference tensor
>> devID - device id
>> myMem - memory pool used to allocating the data array.
we actually allocate the data on the buffer associated with
the memory pool
*/
XTensor * NewTensorBuf(const XTensor * reference, int devID, XMem * myMem)
{
return NewTensorBuf(reference->order, reference->dimSize,
devID, myMem,
reference->dataType, reference->denseRatio);
}
/*
generate a dense vector
>> num - number of entries
>> myDataType - unit size (e.g., int, float, and double)
......@@ -2056,7 +2076,7 @@ XTensor * NewTensor(XTensor * a, bool isFilledData)
free the data space of a given tensor
>> tensor - pointer to the tensor
*/
void DelTensor(const XTensor * tensor)
void DelTensor(XTensor * tensor)
{
delete tensor;
}
......@@ -2065,10 +2085,13 @@ void DelTensor(const XTensor * tensor)
free the data space of a given tensor (on the buffer)
>> tensor - pointer to the tensor
*/
void DelTensorBuf(const XTensor * tensor)
void DelTensorBuf(XTensor * tensor)
{
CheckNTErrors(tensor->mem != NULL, "No memory pool found!");
if(tensor->mem != NULL)
tensor->mem->ReleaseBuf(tensor->devID, tensor->unitNum * tensor->unitSize);
else
XMemFree(tensor->devID, tensor->data);
tensor->data = NULL;
delete tensor;
}
......
......@@ -391,9 +391,12 @@ XTensor * NewTensor(const int myOrder, const int * myDimSize, const TENSOR_DATA_
const float myDenseRatio = 1.0F, const int myDevID = -1, XMem * myMem = NULL);
/* generate a XTensor which allocates data on the buffer */
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, XMem * myMem,
XTensor * NewTensorBuf(const int myOrder, const int * myDimSize, int devID, XMem * myMem,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const float myDenseRatio = 1.0F);
/* generate a XTensor which allocates data on the buffer */
XTensor * NewTensorBuf(const XTensor * reference, int devID, XMem * myMem);
/* generate a dense vector */
XTensor * NewTensor1D(const int num, const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1,
XMem * myMem = NULL);
......@@ -422,10 +425,10 @@ XTensor * NewTensor5D(const int d0, const int d1, const int d2, const int d3, co
XTensor * NewTensor(XTensor * a, bool isFilledData = true);
/* free the data space of a given tensor */
void DelTensor(const XTensor * tensor);
void DelTensor(XTensor * tensor);
/* free the data space of a given tensor (on the buffer) */
void DelTensorBuf(const XTensor * tensor);
void DelTensorBuf(XTensor * tensor);
} /* end of the nts (NiuTrans.Tensor) namespace */
......
......@@ -24,6 +24,7 @@
#include "../../XUtility.h"
#include "Sum.h"
#include "Sum.cuh"
#include "SumDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -125,6 +126,33 @@ void _SumMe(XTensor * a, const XTensor * b, DTYPE beta)
}
/*
return a dimension if the sum is performed as SumDim (in more details in SumDim.h
>> a - a tensor
>> b - another tensor for sum
*/
int GetSumDimIndex(const XTensor &a, const XTensor &b)
{
if(a.order < b.order)
return -1;
int hitCount = 0;
int hitDim = -1;
for(int i = 0; i < b.order; i++){
if(b.dimSize[b.order - 1 - i] == 1)
continue;
else if(b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]){
hitCount++;
hitDim = a.order - b.order + i;
}
}
if(hitCount == 1)
return hitDim;
else
return -1;
}
/*
tensor summation c = a + b * \beta (return a XTensor structure)
make a new tensor c to keep the result and return it
......@@ -138,12 +166,28 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
XTensor c(&a);
c.SetTMP();
int n = GetSumDimIndex(a, b);
if(n == -1){
/* call _Sum function */
_Sum(&a, &b, &c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta);
}
else if(n >= 0 && n < a.order){
/* call _Sum function */
_SumDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
else{
ShowNTErrors("Something is wrong!");
}
return c;
}
......
......@@ -151,7 +151,7 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
c.SetTMP();
/* call _Sum function */
_Sum(&a, &b, &c, beta);
_SumDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
......
......@@ -138,6 +138,9 @@ void _CudaSumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE
int cudaGrids[3];
int cudaBlocks[3];
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE){
if(stride > 1){
GDevs.GetCudaThread2D(a->devID, stride * blockNum, blockSize, MAX_INT, cudaGrids, cudaBlocks);
......@@ -168,6 +171,8 @@ void _CudaSumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif
......
......@@ -73,17 +73,23 @@ void _CopyIndexed(const XTensor * s, XTensor * t, int dim, int * srcIndex, int i
int * realSrcIndex = new int[realIndexSize];
int * realTgtIndex = new int[realIndexSize];
for (int i = 0; i < indexOffsetNum; i++) {
int base = i * indexSize * copyNum;
int baseSrc = i * leadDimSizeSrc;
int baseTgt = i * leadDimSizeTgt;
for (int j = 0; j < indexSize; j++) {
int offset = base + j * copyNum;
int * rsi = realSrcIndex + offset;
int * rti = realTgtIndex + offset;
for (int k = 0; k < copyNum; k++) {
realSrcIndex[i * indexSize * copyNum + j * copyNum + k] = i * leadDimSizeSrc + srcIndex[j] + k;
realTgtIndex[i * indexSize * copyNum + j * copyNum + k] = i * leadDimSizeTgt + tgtIndex[j] + k;
rsi[k] = baseSrc + srcIndex[j] + k;
rti[k] = baseTgt + tgtIndex[j] + k;
}
}
}
for (int i = 0; i < indexSize; i++) {
CheckNTErrors((srcIndex[i] < blockNumSrc), "Index is out of range!");
CheckNTErrors((tgtIndex[i] < blockNumTgt), "Index is out of range!");
CheckNTErrors((srcIndex[i] < blockNumSrc), "Index is out of scope!");
CheckNTErrors((tgtIndex[i] < blockNumTgt), "Index is out of scope!");
}
_CopyBlocks(s->data, blockSizeSrc * s->unitSize, realSrcIndex, realIndexSize, t->data, realTgtIndex, s->mem, s->devID);
......
......@@ -32,12 +32,108 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
insert a dimension by copying the blocks for n times (where n is the size of the inerted dimension)
>> s - pointer to the source data array
>> blockSize - size of a block
>> totalSize - total size of the blocks (i.e., blockSIze * n)
>> t - pointer to the target data array
>> n - number of blocks to copy data
*/
template<class T>
__global__
void KernelUnsqueezeFlat(void * s, int blockSize, int totalSize, void * t, int n)
{
/* index of data items */
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= blockSize)
return;
T value = ((T*)s)[i];
T * tData = (T*)t;
__syncthreads();
for (int k = i; k < totalSize; k += blockSize)
tData[k] = value;
}
/*
insert a dimension by copying the blocks for n times (where n is the size of the inerted dimension)
>> s - pointer to the source data array
>> blockSize - size of a block
>> totalSize - total size of the blocks (i.e., blockSIze * n)
>> t - pointer to the target data array
>> n - number of blocks to copy data
*/
template<class T>
__global__
void KernelUnsqueezeFlatBigram(void * s, int blockSize, int totalSize, void * t, int n)
{
/* index of data items */
int i = (blockDim.x * blockIdx.x + threadIdx.x) * 2;
if (i >= blockSize)
return;
T value = ((T*)s)[i];
T value2 = ((T*)s)[i + 1];
T * tData = (T*)t;
__syncthreads();
for (int k = i; k < totalSize; k += blockSize){
tData[k] = value;
tData[k + 1] = value2;
}
}
/*
insert a dimension by copying the blocks for n times (where n is the size of the inerted dimension)
>> s - pointer to the source data array
>> blockSize - size of a block
>> totalSize - total size of the blocks (i.e., blockSIze * n)
>> t - pointer to the target data array
>> n - number of blocks to copy data
*/
template<class T>
__global__
void KernelUnsqueezeFlat2D(void * s, int blockSize, int totalSize, void * t, int n)
{
__shared__ T data[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ int offsets[MAX_CUDA_THREAD_NUM_PER_BLOCK];
/* index of data items */
int i = blockDim.x * blockIdx.x + threadIdx.x;
/* index of data items */
int j = blockDim.y * blockIdx.y + threadIdx.y;
if (i >= blockSize || j >= n)
return;
if(threadIdx.y == 0)
data[threadIdx.x] = ((T*)s)[i];
if(threadIdx.x == 0)
offsets[threadIdx.y] = blockSize * j;
__syncthreads();
((T*)t)[offsets[threadIdx.y] + i] = data[threadIdx.x];
}
/*
insert a dimension by copying the blocks for n times (where n is the size of the inerted dimension)
>> s - pointer to the source data array
>> blockSize - size of a block
>> blockNum - number of the blocks
>> totalSize - total size of the blocks (i.e., blockSIze * n)
>> t - pointer to the target data array
>> n - number of blocks to copy data
*/
template<class T>
__global__
void KernelUnsqueeze(void * s, int blockSize, int blockNum, void * t, int n)
void KernelUnsqueeze(void * s, int blockSize, int blockNum, int totalSize, void * t, int n)
{
/* index of data items */
int i = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -51,11 +147,10 @@ void KernelUnsqueeze(void * s, int blockSize, int blockNum, void * t, int n)
MTYPE offset = blockSize * j;
T value = ((T*)s)[offset + i];
T * tData = (T*)t + offset * n;
int length = blockSize * n;
__syncthreads();
for (int k = i; k < length; k += blockSize)
for (int k = i; k < totalSize; k += blockSize)
tData[k] = value;
}
......@@ -83,22 +178,72 @@ void _CudaUnsqueeze(const XTensor * a, XTensor * b, int dim, int dSize)
int cudaGrids[3];
int cudaBlocks[3];
GDevs.GetCudaThread2D(a->devID, blockSize, blockNumA, MAX_INT, cudaGrids, cudaBlocks);
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
if(blockNumA > 1){
GDevs.GetCudaThread2D(a->devID, blockSize, blockNumA, MAX_INT, cudaGrids, cudaBlocks);
if (a->dataType == X_FLOAT && a->dataType == X_FLOAT) {
KernelUnsqueeze<float> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockNumA, b->data, dSize);
(a->data, blockSize, blockNumA, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_INT && a->dataType == X_INT) {
KernelUnsqueeze<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockNumA, b->data, dSize);
(a->data, blockSize, blockNumA, blockSize * dSize, b->data, dSize);
}
else {
ShowNTErrors("TODO!");
}
}
else if(blockNumA == 1 && blockSize < MAX_CUDA_THREAD_NUM_PER_BLOCK){
GDevs.GetCudaThread2D(a->devID, blockSize, dSize, MAX_CUDA_THREAD_NUM_PER_BLOCK/4, cudaGrids, cudaBlocks);
if (a->dataType == X_FLOAT && a->dataType == X_FLOAT) {
KernelUnsqueezeFlat2D<float> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_INT && a->dataType == X_INT) {
KernelUnsqueezeFlat2D<int> << <dim3(cudaGrids[0], cudaGrids[1]), dim3(cudaBlocks[0], cudaBlocks[1]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else {
ShowNTErrors("TODO!");
}
}
else if(blockNumA == 1 && blockSize % 2 == 0){
GDevs.GetCudaThread(a->devID, blockSize/2, cudaGrids, cudaBlocks);
if (a->dataType == X_FLOAT && a->dataType == X_FLOAT) {
KernelUnsqueezeFlatBigram<float> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_INT && a->dataType == X_INT) {
KernelUnsqueezeFlatBigram<int> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else {
ShowNTErrors("TODO!");
}
}
else if(blockNumA == 1){
GDevs.GetCudaThread(a->devID, blockSize, cudaGrids, cudaBlocks);
if (a->dataType == X_FLOAT && a->dataType == X_FLOAT) {
KernelUnsqueezeFlat<float> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else if (a->dataType == X_INT && a->dataType == X_INT) {
KernelUnsqueezeFlat<int> << <dim3(cudaGrids[0]), dim3(cudaBlocks[0]) >> >
(a->data, blockSize, blockSize * dSize, b->data, dSize);
}
else {
ShowNTErrors("TODO!");
}
}
else{
ShowNTErrors("Something is wrong!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论