Commit c25d47f0 by xiaotong

define the mask function. By using Maks, we do not need to call Multiply to mask…

define the mask function. By using Maks, we do not need to call Multiply to mask tensor and it is cheaper
parent 70f3926a
......@@ -134,6 +134,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
XTensor &lenPrev = prev->nstep;
XTensor &len = beam->nstep;
XTensor lp;
XTensor mask;
InitTensor(&score, &prob);
......@@ -150,6 +151,14 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
/* score = log-prob/lp */
_Div(&score, &lp, &score);
InitTensor(&mask, &prev->endMark);
CopyValues(prev->endMark, mask);
_ScaleAndShiftMe(&mask, -1e9F);
/* mask the completed hypotheses so that they cannot
be involved in further sorting and beam search. */
_Sum(&score, &mask, &score);
}
/*
......@@ -232,16 +241,21 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
XTensor & modelScoreRef = beam->modelScore;
XTensor & probRef = beam->prob;
XTensor & probPathRef = beam->probPath;
XTensor & prediction = beam->prediction;
XTensor & predictionRef = beam->prediction;
XTensor & endMark = beam->endMark;
XTensor id;
XTensor modelScore;
XTensor prob;
XTensor probPath;
XTensor prediction;
XTensor endMarkCPU;
InitTensorOnCPU(&id, &idRef);
InitTensorOnCPU(&modelScore, &modelScoreRef);
InitTensorOnCPU(&prob, &probRef);
InitTensorOnCPU(&probPath, &probPathRef);
InitTensorOnCPU(&prediction, &predictionRef);
InitTensorOnCPU(&endMarkCPU, &predictionRef);
/* we copy the data to CPU because the frequent access to GPU is slow
and we can speed-up the process by doing the job on CPU. */
......@@ -269,7 +283,16 @@ void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
/* prediction */
state.prediction = prediction.GetInt(i);
/* check if it is the end of the sequence */
state.isEnd = IsEnd(state.prediction) ? 1 : 0;
/* set the ending mark */
endMarkCPU.SetInt(state.isEnd, i);
}
/* copy the ending mark from CPU to the target device */
CopyValues(endMarkCPU, endMark);
}
/*
......@@ -288,12 +311,9 @@ void T2TSearch::Collect(T2TStateBundle * beam)
CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!");
if(IsEnd(state.prediction)){
/* we push the hypothesis into the heap when it is completed */
if(state.isEnd != 0)
fullHypos[state.pid].Push(HeapNode<float>(&state, state.modelScore));
state.isEnd = 1;
}
else
state.isEnd = 0;
}
}
......@@ -301,7 +321,7 @@ void T2TSearch::Collect(T2TStateBundle * beam)
save the output sequences in a tensor
>> beam - the beam that keeps a number of states
*/
void T2TSearch::DumpOutput(T2TStateBundle * beam, XTensor * output)
void T2TSearch::Dump(T2TStateBundle * beam, XTensor * output)
{
}
......
......@@ -28,6 +28,7 @@
#include "arithmetic/Div.h"
#include "arithmetic/DivDim.h"
#include "arithmetic/Mask.h"
#include "arithmetic/MatrixMul.h"
#include "arithmetic/MatrixMul2D.h"
#include "arithmetic/MatrixMul2DMultiTheading.h"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论