Commit c8cb9219 by xiaotong

improve the code

parent 20e9678d
......@@ -89,11 +89,10 @@ void AttEncoder::InitModel(int argc, char ** argv,
make the encoding network
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> skipInputRes - indicates whether we skip the residual connection of the first layer
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining)
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining)
{
XTensor x;
......@@ -109,34 +108,18 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool
XTensor fnn;
XTensor res;
/* we skip the residual connection for the first layer if
the encoder is used in language modeling. */
if(skipInputRes && i == 0){
/* self attention */
att = attentions[i].Make(x, x, x, mask, isTraining);
/* self attention */
att = attentions[i].Make(x, x, x, mask, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
/* layer normalization */
x = attLayerNorms[i].Make(att);
}
else{
/* self attention */
att = attentions[i].Make(x, x, x, mask, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
/* residual connection */
res = Sum(att, x);
/* residual connection */
res = Sum(att, x);
/* layer normalization */
x = attLayerNorms[i].Make(res);
}
/* layer normalization */
x = attLayerNorms[i].Make(res);
/* fnn */
fnn = fnns[i].Make(x, isTraining);
......@@ -150,9 +133,6 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool
/* layer normalization */
x = fnnLayerNorms[i].Make(res);
if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
}
return x;
......
......@@ -40,7 +40,7 @@ class T2TEncoder
{
public:
virtual
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining) = 0;
XTensor Make(XTensor &input, XTensor &mask, bool isTraining) = 0;
};
/*
......@@ -49,7 +49,7 @@ the encoder based on RNN
class RNNEncoder : T2TEncoder
{
public:
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining);
XTensor Make(XTensor &input, XTensor &mask, bool isTraining);
};
......@@ -118,7 +118,7 @@ public:
int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */
XTensor Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining);
XTensor Make(XTensor &input, XTensor &mask, bool isTraining);
};
......
......@@ -84,13 +84,12 @@ void T2TModel::InitModel(int argc, char ** argv)
make the encoding network
>> input - input tensor
>> mask - the mask for positions that are/not involved in computation
>> skipInputRes - indicates whether we skip the residual connection of the first layer
>> isTraining - indicates whether we are training the model
<< return - encoding result
*/
XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining)
XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask, bool isTraining)
{
return encoder.Make(input, mask, skipInputRes, isTraining);
return encoder.Make(input, mask, isTraining);
}
/*
......@@ -142,9 +141,9 @@ void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding, bool isTr
_ScaleAndShiftMe(padding3, 1e9F, -1e9F);
//_Sum(&mask, padding3, &mask);
_Sum(&mask, padding3, &mask);
encoding = MakeEncoding(input, mask, false, isTraining);
encoding = MakeEncoding(input, mask, isTraining);
outputLayer.Make(encoding, output);
delete[] dims;
......
......@@ -69,7 +69,7 @@ public:
void InitModel(int argc, char ** argv);
/* make the encoding network */
XTensor MakeEncoding(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining);
XTensor MakeEncoding(XTensor &input, XTensor &mask, bool isTraining);
/* make the entire network (with the output softmax layer) */
void Make(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
......
......@@ -181,6 +181,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
XTensor gold;
while(LoadBatch(file, true, &batch, &padding, &gold, NULL, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)){
CheckNTErrors(batch.order == 3, "wrong tensor order of the sequence batch");
/* output probabilities */
XTensor output;
......@@ -258,6 +260,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
int wc = 0;
int wordCount = 0;
int wordCountTotal = 0;
int sentCount = 0;
float loss = 0;
/* data files */
......@@ -289,7 +292,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
ClearBuf();
while(LoadBatch(file, true, &batch, &padding, &gold, seqs, 1, vSize, 1, 1, isLenSorted, wc, devID, mem)){
while(LoadBatch(file, true, &batch, &padding, &gold, seqs, 1, vSize, 1, 1, false, wc, devID, mem)){
CheckNTErrors(batch.order == 3, "wrong tensor order of the sequence batch");
......@@ -336,6 +339,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
loss += -prob;
wordCount += wc;
wordCountTotal += wc;
sentCount += 1;
}
fclose(file);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论