Commit 4dbd1f23 by xiaotong

debugging

parent 979698b4
...@@ -145,7 +145,8 @@ predict the next state ...@@ -145,7 +145,8 @@ predict the next state
>> inputEnc - input of the encoder >> inputEnc - input of the encoder
>> paddingEnc - padding of the encoder >> paddingEnc - padding of the encoder
*/ */
void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc) void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
XTensor * inputEnc, XTensor * paddingEnc)
{ {
int dims[MAX_TENSOR_DIM_NUM]; int dims[MAX_TENSOR_DIM_NUM];
...@@ -162,7 +163,8 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor * ...@@ -162,7 +163,8 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, XTensor *
/* the first token */ /* the first token */
XTensor first; XTensor first;
CheckNTErrors(inputEnc->order >= 2, "Wrong order of the tensor!");
for(int i = 0; i < inputEnc->order - 1; i++) for(int i = 0; i < inputEnc->order - 1; i++)
dims[i] = inputEnc->GetDim(i); dims[i] = inputEnc->GetDim(i);
dims[inputEnc->order - 1] = 1; dims[inputEnc->order - 1] = 1;
......
...@@ -94,6 +94,14 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -94,6 +94,14 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* make the encoding network */ /* make the encoding network */
encoding = model->MakeEncoder(*input, maskEnc, false); encoding = model->MakeEncoder(*input, maskEnc, false);
encoding.SetName(ENCODING_NAME); encoding.SetName(ENCODING_NAME);
XTensor encodingBeam = Unsqueeze(encoding, encoding.order - 2, beamSize);
XTensor inputBeam = Unsqueeze(*input, input->order - 2, beamSize);
XTensor paddingBeam = Unsqueeze(*padding, padding->order - 2, beamSize);
encodingBeam.ReshapeMerged(encodingBeam.order - 4);
inputBeam.ReshapeMerged(inputBeam.order - 4);
paddingBeam.ReshapeMerged(paddingBeam.order - 4);
/* max output-length = 2 * source-length */ /* max output-length = 2 * source-length */
maxLength = input->GetDim(-1) * 2; maxLength = input->GetDim(-1) * 2;
...@@ -103,7 +111,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -103,7 +111,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
T2TStateBundle * first = states; T2TStateBundle * first = states;
/* create the first state */ /* create the first state */
predictor.Create(model, &encoding, input, beamSize, first); predictor.Create(model, &encodingBeam, input, beamSize, first);
predictor.SetStartSymbol(startSymbol); predictor.SetStartSymbol(startSymbol);
first->isStart = true; first->isStart = true;
...@@ -117,7 +125,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -117,7 +125,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
predictor.Read(model, cur); predictor.Read(model, cur);
/* predict the next state */ /* predict the next state */
predictor.Predict(next, &encoding, input, padding); predictor.Predict(next, &encodingBeam, &inputBeam, &paddingBeam);
/* compute the model score (given the prediction probability) */ /* compute the model score (given the prediction probability) */
Score(cur, next); Score(cur, next);
......
...@@ -642,6 +642,33 @@ void XTensor::Reshape(const int rowNum, const int colNum) ...@@ -642,6 +642,33 @@ void XTensor::Reshape(const int rowNum, const int colNum)
Reshape(2, dims); Reshape(2, dims);
} }
/*
reshape the tensor by merging two consecutive dimensions
>> i - dimension i
>> j - i + 1
*/
void XTensor::ReshapeMerged(const int i, const int j)
{
if(i < 0)
return;
int di = i;
int dj = j < 0 ? i + 1: j;
CheckNTErrors(di < order, "Wrong dimension index!");
int dims[MAX_TENSOR_DIM_NUM];
for(int k = 0; k < di; k++)
dims[k] = dimSize[k];
dims[di] = dimSize[di] * dimSize[dj];
for(int k = dj + 1; k < order; k++)
dims[k - 1] = dimSize[k];
Reshape(order - 1, dims);
}
/* get the number of items in the data array */ /* get the number of items in the data array */
int XTensor::GetSize() const int XTensor::GetSize() const
{ {
......
...@@ -274,6 +274,9 @@ public: ...@@ -274,6 +274,9 @@ public:
/* reshape the tensor to a matrix */ /* reshape the tensor to a matrix */
void Reshape(const int rowNum, const int colNum); void Reshape(const int rowNum, const int colNum);
/* reshape the tensor by merging two consecutive dimensions */
void ReshapeMerged(const int i, const int j = -1);
/* get the number of items in the data array */ /* get the number of items in the data array */
int GetSize() const; int GetSize() const;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论