Commit 3bc771f4 by xiaotong

updates

parent b199a6ee
......@@ -24,17 +24,18 @@
#include "Encoder.h"
#include "Decoder.h"
#include "Utility.h"
#include "submodel/FNN.h"
#include "submodel/Output.h"
#include "Utility.h"
#include "submodel/Attention.h"
#include "../../train/XModel.h"
namespace nmt
{
/* a nmt model that keeps parameters of the encoder,
/* an nmt model that keeps parameters of the encoder,
the decoder and the output layer (softmax). */
class Model
class Model : public XModel
{
public:
/* device id */
......@@ -85,26 +86,26 @@ public:
/* make the encoding network */
XTensor MakeDecoder(XTensor& inputEnc, XTensor& inputDec, XTensor* mask,
XTensor& MaskEncDec, bool isTraining);
XTensor& MaskEncDec, bool isTraining);
/* make the network for language modeling (with the output softmax layer) */
void MakeLM(XTensor& input, XTensor& output, XTensor& padding, bool isTraining);
/* make the network for machine translation (with the output softmax layer) */
void MakeMT(XTensor& inputEnc, XTensor& inputDec, XTensor& output,
XTensor& paddingEnc, XTensor& paddingDec, bool isTraining);
XTensor& paddingEnc, XTensor& paddingDec, bool isTraining);
/* make the mask for training MT models */
void MakeMTMask(XTensor& inputEnc, XTensor& inputDec,
XTensor& paddingEnc, XTensor& paddingDec,
XTensor& maskEnc, XTensor& maskDec, XTensor& maskEncDec);
XTensor& paddingEnc, XTensor& paddingDec,
XTensor& maskEnc, XTensor& maskDec, XTensor& maskEncDec);
/* make the mask of the encoder */
void MakeMTMaskEnc(XTensor& paddingEnc, XTensor& maskEnc);
/* make the mask of the decoder */
void MakeMTMaskDec(XTensor& paddingEnc, XTensor& paddingDec,
XTensor& maskDec, XTensor& maskEncDec);
XTensor& maskDec, XTensor& maskEncDec);
/* get parameter matrices */
void GetParams(TensorList& list);
......@@ -114,6 +115,13 @@ public:
/* read the parameters */
void Read(FILE* file);
public:
/* clone the model (overloaded method of XModel) */
XModel * Clone(int devID);
/* run the neural network (overloaded method of XModel) */
bool RunSimple(XList * inputs, XList * outputs, XList * golds, XList * losses);
};
}
......
......@@ -29,6 +29,7 @@
#include "../../../tensor/XList.h"
#include "../../../tensor/XTensor.h"
#include "../../../tensor/XGlobal.h"
#include "../../../train/XBaseTemplate.h"
#define MAX_WORD_NUM 120
......@@ -56,7 +57,8 @@ struct TrainExample {
};
/* A `TrainDataSet` is associated with a file which contains training data. */
struct TrainDataSet {
struct TrainDataSet : public DataDistributeBase
{
public:
/* the data buffer */
TrainBufferType buffer;
......@@ -98,21 +100,6 @@ public:
XTensor* batchDec, XTensor* paddingDec, XTensor* label,
size_t minSentBatch, size_t batchSize, int devID);
/* load the samples into the buffer (a list) */
bool LoadBatchToBuf(XList * buf);
/* load the samples into tensors from the buffer */
static
bool LoadBatch(XList * buf,
XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label,
size_t minSentBatch, size_t batchSize, int devID,
int &wc, int &sc);
/* release the samples in a buffer */
static
void ClearSamples(XList * buf);
/* initialization function */
void Init(const char* dataFile, int bucketSize, bool training);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论