Commit 909c7236 by xiaotong

coding of predictor

parent df09abef
......@@ -28,6 +28,36 @@ namespace transformer
{
/* constructor */
T2TStateBundle::T2TStateBundle()
{
states = NULL;
}
/* de-constructor */
T2TStateBundle::~T2TStateBundle()
{
if(states != NULL)
delete[] states;
}
/*
create states
>> num - number of states
*/
void T2TStateBundle::MakeStates(int num)
{
CheckNTErrors(num > 0, "invalid number");
if(states != NULL)
delete[] states;
states = new T2TState[num];
for(int i = 0; i < num; i++)
states[i].last = NULL;
}
/* constructor */
T2TPredictor::T2TPredictor()
{
}
......
......@@ -65,6 +65,19 @@ public:
/* layers on the decoder side */
XList layersDec;
/* list of states */
T2TState * states;
public:
/* constructor */
T2TStateBundle();
/* de-constructor */
~T2TStateBundle();
/* create states */
void MakeStates(int num);
};
/* The predictor reads the current state and then predicts the next.
......
......@@ -118,12 +118,31 @@ void T2TSearch::Generate(T2TStateBundle * beam)
score.Reshape(order, dimsBeam);
/* keep the most promissing candidates in the beam */
TopK(score, scoreTopK, index, 0, beamSize);
score.Reshape(order, dims);
}
/*
expand the search graph
>> beam - the beam that keeps a number of states
*/
void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
{
beam->MakeStates(beam->prediction.unitNum);
T2TState * states = beam->states;
XTensor &predict = beam->prediction;
XTensor index = *NewTensorBuf(predict.order - 1, predict.dimSize, X_FLOAT, 1.0F,
predict.devID, predict.mem);
index.SetAscendingOrder(-1);
DelTensorBuf(&index);
}
/*
save the output sequences in a tensor
>> beam - the beam that keeps a number of states
*/
......
......@@ -60,6 +60,9 @@ public:
/* generate token indices via beam pruning */
void Generate(T2TStateBundle * beam);
/* expand the search graph */
void Expand(T2TStateBundle * prev, T2TStateBundle * beam);
/* save the output sequences in a tensor */
void DumpOutput(T2TStateBundle * beam, XTensor * output);
};
......
......@@ -969,8 +969,20 @@ set the cell to the ascending order along a given dimension
*/
void XTensor::SetAscendingOrder(int dim)
{
CheckNTErrors((dim >= 0 && dim < order), "Wrong dimension specified!");
CheckNTErrors((dataType == X_INT), "TODO!");
CheckNTErrors(dim < order, "Wrong dimension specified!");
CheckNTErrors(dataType == X_INT, "TODO!");
if(dim < 0){
int o = order;
int ds[MAX_TENSOR_DIM_NUM];
memcpy(ds, dimSize, sizeof(int) * order);
Reshape(unitNum);
SetAscendingOrder(0);
Reshape(o, ds);
return;
}
int dimRDI = order - dim - 1;
if(devID >= 0){
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论