Commit 909c7236 by xiaotong

coding of predictor

parent df09abef
...@@ -28,6 +28,36 @@ namespace transformer ...@@ -28,6 +28,36 @@ namespace transformer
{ {
/* constructor */ /* constructor */
T2TStateBundle::T2TStateBundle()
{
states = NULL;
}
/* de-constructor */
T2TStateBundle::~T2TStateBundle()
{
if(states != NULL)
delete[] states;
}
/*
create states
>> num - number of states
*/
void T2TStateBundle::MakeStates(int num)
{
CheckNTErrors(num > 0, "invalid number");
if(states != NULL)
delete[] states;
states = new T2TState[num];
for(int i = 0; i < num; i++)
states[i].last = NULL;
}
/* constructor */
T2TPredictor::T2TPredictor() T2TPredictor::T2TPredictor()
{ {
} }
......
...@@ -65,6 +65,19 @@ public: ...@@ -65,6 +65,19 @@ public:
/* layers on the decoder side */ /* layers on the decoder side */
XList layersDec; XList layersDec;
/* list of states */
T2TState * states;
public:
/* constructor */
T2TStateBundle();
/* de-constructor */
~T2TStateBundle();
/* create states */
void MakeStates(int num);
}; };
/* The predictor reads the current state and then predicts the next. /* The predictor reads the current state and then predicts the next.
......
...@@ -118,12 +118,31 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -118,12 +118,31 @@ void T2TSearch::Generate(T2TStateBundle * beam)
score.Reshape(order, dimsBeam); score.Reshape(order, dimsBeam);
/* keep the most promissing candidates in the beam */
TopK(score, scoreTopK, index, 0, beamSize); TopK(score, scoreTopK, index, 0, beamSize);
score.Reshape(order, dims); score.Reshape(order, dims);
} }
/* /*
expand the search graph
>> beam - the beam that keeps a number of states
*/
void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
{
beam->MakeStates(beam->prediction.unitNum);
T2TState * states = beam->states;
XTensor &predict = beam->prediction;
XTensor index = *NewTensorBuf(predict.order - 1, predict.dimSize, X_FLOAT, 1.0F,
predict.devID, predict.mem);
index.SetAscendingOrder(-1);
DelTensorBuf(&index);
}
/*
save the output sequences in a tensor save the output sequences in a tensor
>> beam - the beam that keeps a number of states >> beam - the beam that keeps a number of states
*/ */
......
...@@ -60,6 +60,9 @@ public: ...@@ -60,6 +60,9 @@ public:
/* generate token indices via beam pruning */ /* generate token indices via beam pruning */
void Generate(T2TStateBundle * beam); void Generate(T2TStateBundle * beam);
/* expand the search graph */
void Expand(T2TStateBundle * prev, T2TStateBundle * beam);
/* save the output sequences in a tensor */ /* save the output sequences in a tensor */
void DumpOutput(T2TStateBundle * beam, XTensor * output); void DumpOutput(T2TStateBundle * beam, XTensor * output);
}; };
......
...@@ -969,8 +969,20 @@ set the cell to the ascending order along a given dimension ...@@ -969,8 +969,20 @@ set the cell to the ascending order along a given dimension
*/ */
void XTensor::SetAscendingOrder(int dim) void XTensor::SetAscendingOrder(int dim)
{ {
CheckNTErrors((dim >= 0 && dim < order), "Wrong dimension specified!"); CheckNTErrors(dim < order, "Wrong dimension specified!");
CheckNTErrors((dataType == X_INT), "TODO!"); CheckNTErrors(dataType == X_INT, "TODO!");
if(dim < 0){
int o = order;
int ds[MAX_TENSOR_DIM_NUM];
memcpy(ds, dimSize, sizeof(int) * order);
Reshape(unitNum);
SetAscendingOrder(0);
Reshape(o, ds);
return;
}
int dimRDI = order - dim - 1; int dimRDI = order - dim - 1;
if(devID >= 0){ if(devID >= 0){
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论