Commit e0c7e275 by xiaotong

update the create function

parent 73f32216
......@@ -81,9 +81,11 @@ T2TPredictor::~T2TPredictor()
create an initial state
>> model - the t2t model
>> top - the top-most layer of the network
>> input - input of the network
>> beamSize - beam size
>> state - the state to be initialized
*/
void T2TPredictor::Create(T2TModel * model, XTensor * top, T2TStateBundle * state)
void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state)
{
state->layersEnc.Clear();
state->layersDec.Clear();
......@@ -93,6 +95,17 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, T2TStateBundle * stat
state->layersEnc.Add(encoding);
state->layersDec.Add(NULL);
int dims[MAX_TENSOR_DIM_NUM];
for (int i = 0; i < input->order - 1; i++)
dims[i] = input->GetDim(i);
dims[input->order - 1] = beamSize;
InitTensor(&state->probPath, input->order, dims, X_FLOAT);
InitTensor(&state->nstep, input->order, dims, X_FLOAT);
state->probPath.SetZeroAll();
state->nstep.SetZeroAll();
}
/*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论