Commit b9871b8d by xiaotong

class of predictors

parent 8cb65ef5
......@@ -68,7 +68,7 @@ void AttDecoder::InitModel(int argc, char ** argv,
LoadParamFloat(argc, argv, "dropout", &dropoutP, 0);
CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!");
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsize\"");
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsizetgt\"");
/* embedding model */
embedder.InitModel(argc, argv, devID, mem, false);
......
......@@ -31,6 +31,9 @@
namespace transformer
{
/* a transformer model that keeps parameters of the encoder,
the decoder and the output layer (softmax). Also, it creates
the network used in transformer. */
class T2TModel
{
public:
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* Copyright (C) 2019, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -19,5 +19,9 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13
*/
#include "T2TSearcher.h"
#include "T2TPredictor.h"
namespace transformer
{
}
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* Copyright (C) 2019, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -20,7 +20,52 @@
* This is the first source file I create in 2019 - new start!
*/
#ifndef __T2TSEARCHER_H__
#define __T2TSEARCHER_H__
#ifndef __T2TPREDICTOR_H__
#define __T2TPREDICTOR_H__
#include "T2TModel.h"
namespace transformer
{
/* state in decoder - it keeps all previously-generated words and their
hidden states */
class T2TState
{
/* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */
XList * encoderLayers;
/* layers on the decoder side */
XList * decoderLayers;
/* */
};
/* The predictor reads the current state and then predicts the next.
It is exactly the same procedure of MT inference -
we get the state of previous words and then generate the next word.
Here, a state can be regared as the representation of words (word
indices, hidden states, embeddings and etc.). */
class T2TPredictor
{
/* pointer to the transformer model */
T2TModel * model;
public:
/* constructor */
T2TPredictor();
/* de-constructor */
~T2TPredictor();
/* read a state */
void Read(T2TModel * model, T2TState * current);
/* predict the next state */
void Predict(T2TState * next);
};
}
#endif
......@@ -25,7 +25,7 @@
#include "T2TModel.h"
#include "T2TUtility.h"
#include "T2TTrainer.h"
#include "T2TSearcher.h"
#include "T2TPredictor.h"
#include "../../tensor/XDevice.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/XGlobal.h"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论