Commit e0c7e275 by xiaotong

update the create function

parent 73f32216
...@@ -81,9 +81,11 @@ T2TPredictor::~T2TPredictor() ...@@ -81,9 +81,11 @@ T2TPredictor::~T2TPredictor()
create an initial state create an initial state
>> model - the t2t model >> model - the t2t model
>> top - the top-most layer of the network >> top - the top-most layer of the network
>> input - input of the network
>> beamSize - beam size
>> state - the state to be initialized >> state - the state to be initialized
*/ */
void T2TPredictor::Create(T2TModel * model, XTensor * top, T2TStateBundle * state) void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state)
{ {
state->layersEnc.Clear(); state->layersEnc.Clear();
state->layersDec.Clear(); state->layersDec.Clear();
...@@ -93,6 +95,17 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, T2TStateBundle * stat ...@@ -93,6 +95,17 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, T2TStateBundle * stat
state->layersEnc.Add(encoding); state->layersEnc.Add(encoding);
state->layersDec.Add(NULL); state->layersDec.Add(NULL);
int dims[MAX_TENSOR_DIM_NUM];
for (int i = 0; i < input->order - 1; i++)
dims[i] = input->GetDim(i);
dims[input->order - 1] = beamSize;
InitTensor(&state->probPath, input->order, dims, X_FLOAT);
InitTensor(&state->nstep, input->order, dims, X_FLOAT);
state->probPath.SetZeroAll();
state->nstep.SetZeroAll();
} }
/* /*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论