Commit be552310 by xiaotong

add dropout to fnn and attention sub-layers in t2t

parent 8b0e06ab
......@@ -69,6 +69,7 @@ void T2TAttention::InitModel(int argc, char ** argv,
LoadParamInt(argc, argv, "d", &dv, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
LoadParamFloat(argc, argv, "attminmax", &minmax, 0.1F);
LoadParamFloat(argc, argv, "dropoutatt", &dropoutP, 0);
InitTensor2D(&wk, d, dk, X_FLOAT, devID, mem);
InitTensor2D(&wq, d, dk, X_FLOAT, devID, mem);
......@@ -90,10 +91,11 @@ make the network
and H = vector size of each position
>> q - queries
>> v - values
>> maske - as it is
>> mask - as it is
>> isTraining - indicates whether the model is used for training
<< return - multi-attention result
*/
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining)
{
XTensor k2;
XTensor q2;
......@@ -126,6 +128,9 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask)
dot = Linear(dot, 1.0F/(float)sqrt((float)dk));
scalar = Softmax(dot, -1);
if(isTraining && dropoutP > 0)
scalar = Dropout(scalar, dropoutP);
att = BMMul(scalar, vheads);
......
......@@ -75,6 +75,9 @@ public:
/* indicates whether the model is used for training */
bool isTraining;
/* dropout probability */
DTYPE dropoutP;
public:
/* constructor */
......@@ -89,7 +92,7 @@ public:
int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask);
XTensor Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining);
};
}
......
......@@ -90,7 +90,7 @@ make the encoding network
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> skipInputRes - indicates whether we skip the residual connection of the first layer
>> isTraining - indicates whether the model is for training
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool isTraining)
......@@ -113,7 +113,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool
the encoder is used in language modeling. */
if(skipInputRes && i == 0){
/* self attention */
att = attentions[i].Make(x, x, x, mask);
att = attentions[i].Make(x, x, x, mask, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
......@@ -125,7 +125,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool
else{
/* self attention */
att = attentions[i].Make(x, x, x, mask);
att = attentions[i].Make(x, x, x, mask, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
......@@ -139,7 +139,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool
}
/* fnn */
fnn = fnns[i].Make(x);
fnn = fnns[i].Make(x, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
......
......@@ -60,6 +60,7 @@ void T2TFNN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
LoadParamInt(argc, argv, "d", &outSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, DEFAULT_EMBEDDING_SIZE * 4);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.1F);
LoadParamFloat(argc, argv, "dropoutfnn", &dropoutP, 0);
InitTensor2D(&w1, inSize, hSize, X_FLOAT, devID, mem);
InitTensor1D(&b1, hSize, X_FLOAT, devID, mem);
......@@ -83,12 +84,15 @@ y = max(0, x * w1 + b1) * w2 + b2
>> input - the input tensor
>> return - the output tensor
*/
XTensor T2TFNN::Make(XTensor &input)
XTensor T2TFNN::Make(XTensor &input, bool isTraining)
{
XTensor t1;
/* t1 = max(0, x * w1 + b1) */
t1 = Rectify(MMul(input, w1) + b1);
if(isTraining && dropoutP > 0)
t1 = Dropout(t1, dropoutP);
/* result = t1 * w2 + b2 */
return MMul(t1, w2) + b2;
......
......@@ -59,6 +59,9 @@ public:
/* bias of transformation 2 */
XTensor b2;
/* dropout probability */
DTYPE dropoutP;
public:
......@@ -72,7 +75,7 @@ public:
void InitModel(int argc, char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor Make(XTensor &input);
XTensor Make(XTensor &input, bool isTraining);
};
......
......@@ -215,8 +215,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if (step % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f, sppl=%.3f\n",
lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount), exp(-prob/wc));
XPRINT8(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f\n",
lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount), exp(-prob/wc));
}
if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){
......@@ -239,8 +239,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
epoch = MIN(epoch, nepoch);
XPRINT6(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, ppl=%.3f\n",
lr, elapsed, step, epoch, wordCountTotal, exp(loss / wordCount));
XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f\n",
lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount));
XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
elapsed, step, epoch);
......
......@@ -148,6 +148,7 @@ extern bool useCUDA;
#define XPRINT5(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5);FFLUSH(FILEH);}}
#define XPRINT6(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6);FFLUSH(FILEH);}}
#define XPRINT7(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7);FFLUSH(FILEH);}}
#define XPRINT8(VERBOSE,FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7,ARG8) {if(VERBOSE<=verboseLevel) {fprintf(FILEH,STR,ARG,ARG2,ARG3,ARG4,ARG5,ARG6,ARG7,ARG8);FFLUSH(FILEH);}}
#define B2I(V) V==0?false:true
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论