Commit ce92843d by xiaotong

speed-up data fetching

parent 229db8c6
......@@ -862,18 +862,27 @@ int T2TTrainer::LoadBatchMT(FILE * file,
wCount = 0;
MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2];
MTYPE * paddingDecOffsets = new MTYPE[sc * maxDec / 2];
MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2];
/* batch of the source-side sequences */
for(int s = seq; s < seq + sc; s += 2){
int len = seqLen[s];
int sent = (s - seq)/2;
for(int w = 0; w < len; w++){
batchEnc->Set2DInt(buf[seqOffset[s] + w], sent, w);
//batchEnc->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
paddingEnc->Set2D(1.0F, sent, w);
//paddingEnc->Set2D(1.0F, sent, w);
paddingEncOffsets[wCount] = paddingEnc->GetOffset2D(sent, w);
wCount++;
}
}
paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCount);
int wCountDec = 0;
int wGold = 0;
/* batch of the target-side sequences */
for(int s = seq + 1; s < seq + sc; s += 2){
int len = isDoubledEnd ? seqLen[s] : seqLen[s] - 1;
......@@ -881,15 +890,21 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int sent = (s - seq - 1)/2;
for(int w = 0; w < len; w++){
batchDec->Set2DInt(buf[seqOffset[s] + w], sent, w);
//batchDec->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
paddingDec->Set2D(1.0F, sent, w);
if(w > 0)
gold->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]);
//paddingDec->Set2D(1.0F, sent, w);
paddingDecOffsets[wCountDec++] = paddingDec->GetOffset2D(sent, w);
if (w > 0) {
//gold->Set3D(1.0F, sent, w - 1, buf[seqOffset[s] + w]);
goldOffsets[wGold++] = gold->GetOffset3D(sent, w - 1, buf[seqOffset[s] + w]);
}
if (w == len - 1) {
if(isDoubledEnd)
gold->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
else
gold->Set3D(1.0F, sent, w, buf[seqOffset[s] + w + 1]);
if (isDoubledEnd) {
//gold->Set3D(1.0F, sent, w, buf[seqOffset[s] + w]);
goldOffsets[wGold++] = gold->GetOffset3D(sent, w, buf[seqOffset[s] + w]);
}
else {
//gold->Set3D(1.0F, sent, w, buf[seqOffset[s] + w + 1]);
goldOffsets[wGold++] = gold->GetOffset3D(sent, w, buf[seqOffset[s] + w + 1]);
}
}
wCount++;
......@@ -903,6 +918,12 @@ int T2TTrainer::LoadBatchMT(FILE * file,
}
}
paddingDec->SetDataBatched(paddingDecOffsets, 1.0F, wCountDec);
gold->SetDataBatched(goldOffsets, 1.0F, wGold);
delete[] paddingEncOffsets;
delete[] paddingDecOffsets;
return sc;
}
......
......@@ -46,6 +46,7 @@
#include "core/arithmetic/Sub.h"
#include "core/arithmetic/Div.h"
#include "core/math/ScaleAndShift.h"
#include "core/getandset/SetData.h"
#include "function/Identity.h"
#ifdef USE_CUDA
......@@ -585,6 +586,36 @@ int XTensor::GetUnitSize(TENSOR_DATA_TYPE myDataType)
}
/*
get offset (2D)
>> row - index of demension 0
>> col - index of demension 1
*/
MTYPE XTensor::GetOffset2D(int row, int col)
{
CheckNTErrors(order == 2, "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors(row >= 0 && row < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors(col >= 0 && col < dimSize[1], "dimension 1 is out of range!");
return row * dimSize[1] + col;
}
/*
get offset (3D)
>> d0 - index of demension 0
>> d1 - index of demension 1
>> d2 - index of demension 2
*/
MTYPE XTensor::GetOffset3D(int d0, int d1, int d2)
{
CheckNTErrors(order == 3, "Cannot get a 3d cell for a tensor whose order is not 2!");
CheckNTErrors(d0 >= 0 && d0 < dimSize[0], "dimension 0 is out of range!");
CheckNTErrors(d1 >= 0 && d1 < dimSize[1], "dimension 1 is out of range!");
CheckNTErrors(d2 >= 0 && d2 < dimSize[2], "dimension 2 is out of range!");
return (d0 * dimSize[1] + d1) * dimSize[2] + d2;
}
/*
a vector with all entries of 0
>> stream - stream for the job pipeline
*/
......@@ -756,6 +787,17 @@ void XTensor::SetDataRandn(DTYPE mean, DTYPE standardDeviation)
}
}
/*
set tensor items with an array of offsets
>> offsets - offset for each data item
>> values - value for each data item
>> num - number of the data items
*/
void XTensor::SetDataBatched(MTYPE * offsets, DTYPE value, int num)
{
_SetDataWithOffset(this, offsets, value, num);
}
/* check whether the data array is the same as the answer
>> d - input data. it must be on CPU
>> num - number of data items
......
......@@ -261,18 +261,27 @@ public:
/* get unit size in terms of "dataType" */
int GetUnitSize(TENSOR_DATA_TYPE myDataType);
/* get offset (2D) */
MTYPE GetOffset2D(int row, int col);
/* get offset (3D) */
MTYPE GetOffset3D(int d0, int d1, int d2);
/* a tensor with all entries of 0 */
void SetZeroAll(XStream * stream = NULL);
/* set the tensor with an data array */
void SetData(const void * d, int num, int beg = 0);
/* set the tensor items by a uniform distribution */
/* set tensor items by a uniform distribution */
void SetDataRand(DTYPE lower, DTYPE upper);
/* set the tensor items by a normal distribution */
/* set tensor items by a normal distribution */
void SetDataRandn(DTYPE mean, DTYPE standardDeviation);
/* set tensor items with an array of offsets */
void SetDataBatched(MTYPE * offsets, DTYPE value, int num);
/* check whether the data array is the same as the answer */
bool CheckData(const void * answer, int num, int beg = 0);
......
......@@ -432,6 +432,7 @@ void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
/*
generate data items with a normal distribution with specified mean and standard deviation
>> tensor - the tensor that keeps the data
>> mean - mean or expectation of the distribution
>> standardDeviation - standard deviation of the distribution
*/
......@@ -441,5 +442,53 @@ void _SetDataRandN(XTensor * tensor, DTYPE mean, DTYPE standardDeviation)
tensor->SetDataRandn(mean, standardDeviation);
}
/*
set the data with an array of offsets
>> tensor - the tensor that keeps the data
>> offsets - offset for each data item
>> num - number of the data items
>> value - value of the data items
*/
void _SetDataWithOffset(XTensor * tensor, MTYPE * offsets, DTYPE value, MTYPE num)
{
CheckNTErrors(tensor->dataType == X_FLOAT, "Data type is incorrect!");
if (tensor->devID < 0) {
DTYPE * d = (DTYPE*)tensor->data;
for (int i = 0; i < num; i++) {
d[offsets[i]] = value;
}
}
else {
#ifdef USE_CUDA
XMem * mem = tensor->mem;
MTYPE size = num * sizeof(MTYPE);
MTYPE * offsetsCuda = mem != NULL ? (MTYPE*)mem->AllocBuf(mem->devID, size) : (MTYPE*)XMemAlloc(mem->devID, size);
XMemCopy(offsetsCuda, tensor->devID, offsets, -1, num * sizeof(MTYPE));
_CudaSetDataWithOffset(tensor, offsetsCuda, value, num);
if (mem != NULL)
mem->ReleaseBuf(mem->devID, size);
else
XMemFree(mem->devID, offsetsCuda);
#else
ShowNTErrors("Please recompile the code with USE_CUDA");
#endif
}
}
/*
set the data with an array of values
>> tensor - the tensor that keeps the data
>> offsets - offset for each data item
>> values - value for each data item
>> num - number of the data items
*/
void _SetDataWithOffsetAndValue(XTensor * tensor, MTYPE * offsets, DTYPE * values, MTYPE num)
{
ShowNTErrors("TODO!");
}
} // namespace nts(NiuTrans.Tensor)
......@@ -466,4 +466,48 @@ void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
set the data with an array of offsets (kernel version)
>> data - pointer to the data array
>> offsets - offset for each data item
>> num - number of the data items
>> value - value of the data items
*/
__global__
void _KernelSetDataWithOffset(DTYPE * data, MTYPE * offsets, DTYPE value, MTYPE num)
{
/* index */
int i = blockDim.x * blockIdx.x + threadIdx.x;
if(i < num)
data[offsets[i]] = value;
}
/*
set the data with an array of offsets (cuda version)
>> tensor - the tensor that keeps the data
>> offsets - offset for each data item
>> num - number of the data items
>> value - value of the data items
*/
void _CudaSetDataWithOffset(XTensor * tensor, MTYPE * offsets, DTYPE value, MTYPE num)
{
CheckNTErrors(tensor->dataType == X_FLOAT, "Data type is incorrect!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(tensor->devID, (int)num, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
_KernelSetDataWithOffset << <blocks, threads >> > ((DTYPE*)tensor->data, offsets, value, num);
BacktoCudaDev(tensor->devID, devIDBackup);
}
} // namespace nts(NiuTrans.Tensor)
......@@ -49,6 +49,9 @@ void _CudaSetDataLowTri(XTensor * tensor, DTYPE p, int shift);
/* generate data items with a uniform distribution in [lower, upper] */
void _CudaSetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
/* set the data with an array of offsets */
void _CudaSetDataWithOffset(XTensor * tensor, MTYPE * offsets, DTYPE value, MTYPE num);
} // namespace nts(NiuTrans.Tensor)
#endif // __SETDATA_CUH__
\ No newline at end of file
......@@ -60,6 +60,12 @@ void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
/* generate data items with a normal distribution with specified mean and standard deviation */
void _SetDataRandN(XTensor * tensor, DTYPE mean = 0.0F, DTYPE standardDeviation = 1.0F);
/* set the data with an array of offsets */
void _SetDataWithOffset(XTensor * tensor, MTYPE * offsets, DTYPE value, MTYPE num);
/* set the data with an array of values */
void _SetDataWithOffsetAndValue(XTensor * tensor, MTYPE * offsets, DTYPE * values, MTYPE num);
} // namespace nts(NiuTrans.Tensor)
#endif // __SETDATA_H__
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论