Commit 1faabe78 by xiaotong

probability of one step

parent c5fb044c
......@@ -135,6 +135,8 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
/* prediction probabilities */
XTensor &output = next->prob;
XTensor decoding;
XTensor decodingStep;
XTensor paddingDec;
InitTensor3D(&paddingDec, inputDec.GetDim(0), inputDec.GetDim(1), m->outputLayer->vSize, X_INT);
......@@ -146,8 +148,33 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
/* decoder mask */
m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec);
/* make the decoding network and generate the output probabilities */
output = decoder.Make(inputDec, *encoding, maskDec, maskEncDec, false);
/* make the decoding network */
decoding = decoder.Make(inputDec, *encoding, maskDec, maskEncDec, false);
XTensor selectSrc;
XTensor selectTgt;
CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!");
int dims[MAX_TENSOR_DIM_NUM];
for(int i = 0; i < decoding.order - 1; i++)
dims[i] = decoding.GetDim(i);
dims[decoding.order - 2] = 1;
InitTensor(&selectSrc, decoding.order - 1, dims, X_INT);
InitTensor(&selectTgt, decoding.order - 1, dims, X_INT);
int stride = decoding.GetDim(decoding.order - 2);
for(int i = 0; i < selectSrc.unitNum; i++){
selectSrc.SetInt(i * stride + stride - 1, i);
selectTgt.SetInt(i, i);
}
/* the decoder output of the last position */
decodingStep = CopyIndexed(decoding, decoding.order - 2, selectSrc, selectTgt);
/* generate the output probabilities */
m->outputLayer->Make(decodingStep, output);
next->layersEnc.AddList(&s->layersEnc);
next->layersDec.Add(&inputDec);
......
......@@ -319,10 +319,17 @@ void T2TSearch::Collect(T2TStateBundle * beam)
/*
save the output sequences in a tensor
>> beam - the beam that keeps a number of states
>> input - input sequences
>> output - output sequences (for return)
*/
void T2TSearch::Dump(T2TStateBundle * beam, XTensor * output)
void T2TSearch::Dump(XTensor * input, XTensor * output)
{
int dims[MAX_TENSOR_DIM_NUM];
for(int i = 0; i < input->order - 1; i++)
dims[i] = input->GetDim(i);
dims[input->order - 1] = maxLength;
InitTensor(output, input->order, dims, X_INT);
}
/*
......
......@@ -88,7 +88,7 @@ public:
void Collect(T2TStateBundle * beam);
/* save the output sequences in a tensor */
void Dump(T2TStateBundle * beam, XTensor * output);
void Dump(XTensor * input, XTensor * output);
/* check if the token is an end symbol */
bool IsEnd(int token);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论