Commit b9871b8d by xiaotong

class of predictors

parent 8cb65ef5
...@@ -68,7 +68,7 @@ void AttDecoder::InitModel(int argc, char ** argv, ...@@ -68,7 +68,7 @@ void AttDecoder::InitModel(int argc, char ** argv,
LoadParamFloat(argc, argv, "dropout", &dropoutP, 0); LoadParamFloat(argc, argv, "dropout", &dropoutP, 0);
CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!"); CheckNTErrors(nlayer >= 1, "We have one encoding layer at least!");
CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsize\""); CheckNTErrors(vSize > 1, "set vocabulary size by \"-vsizetgt\"");
/* embedding model */ /* embedding model */
embedder.InitModel(argc, argv, devID, mem, false); embedder.InitModel(argc, argv, devID, mem, false);
......
...@@ -31,6 +31,9 @@ ...@@ -31,6 +31,9 @@
namespace transformer namespace transformer
{ {
/* a transformer model that keeps parameters of the encoder,
the decoder and the output layer (softmax). Also, it creates
the network used in transformer. */
class T2TModel class T2TModel
{ {
public: public:
......
/* NiuTrans.Tensor - an open-source tensor library /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University. * Copyright (C) 2019, Natural Language Processing Lab, Northestern University.
* All rights reserved. * All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -19,5 +19,9 @@ ...@@ -19,5 +19,9 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13
*/ */
#include "T2TSearcher.h" #include "T2TPredictor.h"
namespace transformer
{
}
/* NiuTrans.Tensor - an open-source tensor library /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University. * Copyright (C) 2019, Natural Language Processing Lab, Northestern University.
* All rights reserved. * All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -20,7 +20,52 @@ ...@@ -20,7 +20,52 @@
* This is the first source file I create in 2019 - new start! * This is the first source file I create in 2019 - new start!
*/ */
#ifndef __T2TSEARCHER_H__ #ifndef __T2TPREDICTOR_H__
#define __T2TSEARCHER_H__ #define __T2TPREDICTOR_H__
#include "T2TModel.h"
namespace transformer
{
/* state in decoder - it keeps all previously-generated words and their
hidden states */
class T2TState
{
/* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */
XList * encoderLayers;
/* layers on the decoder side */
XList * decoderLayers;
/* */
};
/* The predictor reads the current state and then predicts the next.
It is exactly the same procedure of MT inference -
we get the state of previous words and then generate the next word.
Here, a state can be regared as the representation of words (word
indices, hidden states, embeddings and etc.). */
class T2TPredictor
{
/* pointer to the transformer model */
T2TModel * model;
public:
/* constructor */
T2TPredictor();
/* de-constructor */
~T2TPredictor();
/* read a state */
void Read(T2TModel * model, T2TState * current);
/* predict the next state */
void Predict(T2TState * next);
};
}
#endif #endif
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include "T2TModel.h" #include "T2TModel.h"
#include "T2TUtility.h" #include "T2TUtility.h"
#include "T2TTrainer.h" #include "T2TTrainer.h"
#include "T2TSearcher.h" #include "T2TPredictor.h"
#include "../../tensor/XDevice.h" #include "../../tensor/XDevice.h"
#include "../../tensor/XUtility.h" #include "../../tensor/XUtility.h"
#include "../../tensor/XGlobal.h" #include "../../tensor/XGlobal.h"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论