Commit 1faabe78 by xiaotong

probability of one step

parent c5fb044c
...@@ -135,6 +135,8 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor * ...@@ -135,6 +135,8 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
/* prediction probabilities */ /* prediction probabilities */
XTensor &output = next->prob; XTensor &output = next->prob;
XTensor decoding;
XTensor decodingStep;
XTensor paddingDec; XTensor paddingDec;
InitTensor3D(&paddingDec, inputDec.GetDim(0), inputDec.GetDim(1), m->outputLayer->vSize, X_INT); InitTensor3D(&paddingDec, inputDec.GetDim(0), inputDec.GetDim(1), m->outputLayer->vSize, X_INT);
...@@ -146,8 +148,33 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor * ...@@ -146,8 +148,33 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
/* decoder mask */ /* decoder mask */
m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec); m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec);
/* make the decoding network and generate the output probabilities */ /* make the decoding network */
output = decoder.Make(inputDec, *encoding, maskDec, maskEncDec, false); decoding = decoder.Make(inputDec, *encoding, maskDec, maskEncDec, false);
XTensor selectSrc;
XTensor selectTgt;
CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!");
int dims[MAX_TENSOR_DIM_NUM];
for(int i = 0; i < decoding.order - 1; i++)
dims[i] = decoding.GetDim(i);
dims[decoding.order - 2] = 1;
InitTensor(&selectSrc, decoding.order - 1, dims, X_INT);
InitTensor(&selectTgt, decoding.order - 1, dims, X_INT);
int stride = decoding.GetDim(decoding.order - 2);
for(int i = 0; i < selectSrc.unitNum; i++){
selectSrc.SetInt(i * stride + stride - 1, i);
selectTgt.SetInt(i, i);
}
/* the decoder output of the last position */
decodingStep = CopyIndexed(decoding, decoding.order - 2, selectSrc, selectTgt);
/* generate the output probabilities */
m->outputLayer->Make(decodingStep, output);
next->layersEnc.AddList(&s->layersEnc); next->layersEnc.AddList(&s->layersEnc);
next->layersDec.Add(&inputDec); next->layersDec.Add(&inputDec);
......
...@@ -319,10 +319,17 @@ void T2TSearch::Collect(T2TStateBundle * beam) ...@@ -319,10 +319,17 @@ void T2TSearch::Collect(T2TStateBundle * beam)
/* /*
save the output sequences in a tensor save the output sequences in a tensor
>> beam - the beam that keeps a number of states >> input - input sequences
>> output - output sequences (for return)
*/ */
void T2TSearch::Dump(T2TStateBundle * beam, XTensor * output) void T2TSearch::Dump(XTensor * input, XTensor * output)
{ {
int dims[MAX_TENSOR_DIM_NUM];
for(int i = 0; i < input->order - 1; i++)
dims[i] = input->GetDim(i);
dims[input->order - 1] = maxLength;
InitTensor(output, input->order, dims, X_INT);
} }
/* /*
......
...@@ -88,7 +88,7 @@ public: ...@@ -88,7 +88,7 @@ public:
void Collect(T2TStateBundle * beam); void Collect(T2TStateBundle * beam);
/* save the output sequences in a tensor */ /* save the output sequences in a tensor */
void Dump(T2TStateBundle * beam, XTensor * output); void Dump(XTensor * input, XTensor * output);
/* check if the token is an end symbol */ /* check if the token is an end symbol */
bool IsEnd(int token); bool IsEnd(int token);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论