Commit a7223650 by xiaotong

coding

parent db8e0968
......@@ -25,4 +25,63 @@ using namespace nts;
namespace transformer
{
/*
search for the most promising states
>> model - the transformer model
>> input - input of the model
>> output - output that represents the sequences as rows
*/
void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * output)
{
XTensor maskNULL;
XTensor encoding;
/* make the encoding network */
encoding = model->MakeEncoder(*input, maskNULL, false);
encoding.SetName(ENCODING_NAME);
T2TPredictor predictor;
T2TStateBundle state1, state2;
T2TStateBundle * cur = &state1;
T2TStateBundle * next = &state2;
/* initialize the predictor */
predictor.Init(model, &encoding, cur);
/* generate the sequence from left-to-right */
for(int i = 0 ; i < maxLength; i++){
/* read the current state */
predictor.Read(model, cur);
/* predict the next state */
predictor.Predict(next);
/* pruning */
Prune(next);
T2TStateBundle * backup = cur;
cur = next;
next = backup;
}
}
/*
beam pruning
>> beam - the beam that keeps a number of states
*/
void T2TSearch::Prune(T2TStateBundle * beam)
{
}
/*
save the output sequences in a tensor
>> beam -
*/
void T2TSearch::DumpOutput(T2TStateBundle * beam, XTensor * output)
{
}
}
\ No newline at end of file
......@@ -23,6 +23,7 @@
#define __T2TSEARCH_H__
#include "T2TModel.h"
#include "T2TPredictor.h"
namespace transformer
{
......@@ -34,12 +35,24 @@ namespace transformer
class T2TSearch
{
public:
/* max length of the generated sequence */
int maxLength;
public:
/* constructor */
T2TSearch() {};
/* de-constructor */
~T2TSearch() {};
/* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * output);
/* beam pruning */
void Prune(T2TStateBundle * beam);
/* save the output sequences in a tensor */
void DumpOutput(T2TStateBundle * beam, XTensor * output);
};
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论