Commit 7294cb66 by xiaotong

code cleaning and better management of device

parent fcc18004
......@@ -302,9 +302,10 @@ load a batch of sequences (for LM)
>> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> gold - gold standard
>> gold - gold standard (distribution of every position)
>> label - (gold standard) label index of every position
>> seqs - keep the sequences in an array
>> vs - vocabulary size
>> vSize - vocabulary size
>> sBatch - batch size of sequences
>> wBatch - batch size of words
>> isSorted - indicates whether the sequences are sorted by length
......@@ -318,7 +319,7 @@ int T2TBatchLoader::LoadBatchLM(FILE * file,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs,
int vs, int sBatch, int wBatch,
int vSize, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem,
bool isTraining)
......@@ -355,7 +356,7 @@ int T2TBatchLoader::LoadBatchLM(FILE * file,
int dims[MAX_TENSOR_DIM_NUM];
dims[0] = sc;
dims[1] = max;
dims[2] = vs;
dims[2] = vSize;
InitTensor2D(batchEnc, sc, max, X_INT, devID, mem);
InitTensor2D(label, sc, max, X_INT, devID, mem);
......@@ -459,10 +460,11 @@ load a batch of sequences (for MT)
>> paddingEnc - padding of the input sequences
>> batchDec - the batch of the output sequences
>> paddingDec - padding of the output sequences
>> gold - gold standard
>> gold - gold standard (distribution of every position)
>> label - (gold standard) label index of every position
>> seqs - keep the sequences in an array
>> vsEnc - size of the encoder vocabulary
>> vsDec - size of the decoder vocabulary
>> vSizeEnc - size of the encoder vocabulary
>> vSizeDec - size of the decoder vocabulary
>> sBatch - batch size of sequences
>> wBatch - batch size of words
>> isSorted - indicates whether the sequences are sorted by length
......@@ -476,7 +478,7 @@ int T2TBatchLoader::LoadBatchMT(FILE * file,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * label,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
int vSizeEnc, int vSizeDec, int sBatch, int wBatch,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining)
......
......@@ -160,7 +160,7 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
dims[inputDec.order - 1] = inputDec.GetDim(-1);
XTensor paddingDec;
InitTensor(&paddingDec, inputDec.order, dims, X_INT);
InitTensor(&paddingDec, inputDec.order, dims, X_INT, 1.0F, paddingEnc->devID, paddingEnc->mem);
SetDataFixedInt(paddingDec, 1);
XTensor maskDec;
......@@ -190,6 +190,9 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
selectTgt.SetInt(i, i);
}
selectSrc.SetDevice(decoding.devID, decoding.mem);
selectTgt.SetDevice(decoding.devID, decoding.mem);
/* the decoder output of the last position */
decodingStep = CopyIndexed(decoding, decoding.order - 2, selectSrc, selectTgt);
......
......@@ -344,7 +344,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
batchLoader.ClearBuf();
while(batchLoader.LoadBatch(file, model->isLM,
&batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold, &label,
&batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, &label,
seqs, vSize, vSizeTgt,
1, 1, false, ws, wc, devID, mem, false))
{
......@@ -369,8 +369,12 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
XTensor probs;
InitTensor1D(&probs, bSize * length);
XTensor labelOnehot;
labelOnehot = IndexToOnehot(label, vSizeTgt, 0);
/* get probabilities */
float prob = GetProb(&output, &gold, &probs);
float prob = GetProb(&output, &labelOnehot, &probs);
/* dump the test result */
for(int s = 0; s < bSize; s++){
......
......@@ -81,11 +81,8 @@ int TransformerMain(int argc, const char ** argv)
if(strcmp(modelFN, ""))
model.Read(modelFN);
/* test the model on the new data */
if(strcmp(testFN, "") && strcmp(outputFN, ""))
{
if(strcmp(testFN, "") && strcmp(outputFN, "")){
/* beam search */
if(isBeamSearch){
T2TTester searcher;
......
......@@ -512,6 +512,21 @@ XTensor XTensor::Lin(DTYPE scale, DTYPE shift) const
}
/*
relocate the data on the target device
>> myDevId - target device id
>> myMem - memory pool on the target device
*/
void XTensor::SetDevice(int myDevId, XMem * myMem)
{
if (myMem != NULL) {
FlushToMem(myMem);
}
else {
ShowNTErrors("TODO!");
}
}
/*
judge whether the two matrices are in the same type and size
>> a - input tensor
>> b - anther tensor to compare with
......
......@@ -248,6 +248,9 @@ public:
/* linear transformation */
XTensor Lin(DTYPE scale, DTYPE shift = 0) const;
/* relocate the data on the target device */
void SetDevice(int myDevId, XMem * myMem = NULL);
/* judge whether the two matrices are in the same type and size */
static
bool IsSameShaped(const XTensor * a, const XTensor * b);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论