Commit 3bc771f4 by xiaotong

updates

parent b199a6ee
...@@ -24,17 +24,18 @@ ...@@ -24,17 +24,18 @@
#include "Encoder.h" #include "Encoder.h"
#include "Decoder.h" #include "Decoder.h"
#include "Utility.h"
#include "submodel/FNN.h" #include "submodel/FNN.h"
#include "submodel/Output.h" #include "submodel/Output.h"
#include "Utility.h"
#include "submodel/Attention.h" #include "submodel/Attention.h"
#include "../../train/XModel.h"
namespace nmt namespace nmt
{ {
/* a nmt model that keeps parameters of the encoder, /* an nmt model that keeps parameters of the encoder,
the decoder and the output layer (softmax). */ the decoder and the output layer (softmax). */
class Model class Model : public XModel
{ {
public: public:
/* device id */ /* device id */
...@@ -114,6 +115,13 @@ public: ...@@ -114,6 +115,13 @@ public:
/* read the parameters */ /* read the parameters */
void Read(FILE* file); void Read(FILE* file);
public:
/* clone the model (overloaded method of XModel) */
XModel * Clone(int devID);
/* run the neural network (overloaded method of XModel) */
bool RunSimple(XList * inputs, XList * outputs, XList * golds, XList * losses);
}; };
} }
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "../../../tensor/XList.h" #include "../../../tensor/XList.h"
#include "../../../tensor/XTensor.h" #include "../../../tensor/XTensor.h"
#include "../../../tensor/XGlobal.h" #include "../../../tensor/XGlobal.h"
#include "../../../train/XBaseTemplate.h"
#define MAX_WORD_NUM 120 #define MAX_WORD_NUM 120
...@@ -56,7 +57,8 @@ struct TrainExample { ...@@ -56,7 +57,8 @@ struct TrainExample {
}; };
/* A `TrainDataSet` is associated with a file which contains training data. */ /* A `TrainDataSet` is associated with a file which contains training data. */
struct TrainDataSet { struct TrainDataSet : public DataDistributeBase
{
public: public:
/* the data buffer */ /* the data buffer */
TrainBufferType buffer; TrainBufferType buffer;
...@@ -98,21 +100,6 @@ public: ...@@ -98,21 +100,6 @@ public:
XTensor* batchDec, XTensor* paddingDec, XTensor* label, XTensor* batchDec, XTensor* paddingDec, XTensor* label,
size_t minSentBatch, size_t batchSize, int devID); size_t minSentBatch, size_t batchSize, int devID);
/* load the samples into the buffer (a list) */
bool LoadBatchToBuf(XList * buf);
/* load the samples into tensors from the buffer */
static
bool LoadBatch(XList * buf,
XTensor* batchEnc, XTensor* paddingEnc,
XTensor* batchDec, XTensor* paddingDec, XTensor* label,
size_t minSentBatch, size_t batchSize, int devID,
int &wc, int &sc);
/* release the samples in a buffer */
static
void ClearSamples(XList * buf);
/* initialization function */ /* initialization function */
void Init(const char* dataFile, int bucketSize, bool training); void Init(const char* dataFile, int bucketSize, bool training);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论