Commit efe32603 by xiaotong

bug fixes

parent a0a38702
......@@ -125,7 +125,7 @@ XTensor T2TAttention::Make(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bo
if(isMasked)
dot = dot + mask;
dot = Linear(dot, 1.0F/(float)sqrt((float)dk));
dot = Linear(dot, 1.0F/(float)sqrt((float)dk/nhead));
scalar = Softmax(dot, -1);
......
......@@ -135,7 +135,7 @@ XTensor T2TEmbedder::Make(XTensor &input)
}
/* then we make word embeddings */
wordEmbedding = Linear(MMul(input, w), (float)sqrt((float)d));
wordEmbedding = Linear(MMul(input, w), (float)sqrt((float)eSize));
/* we sum over the two embeddings */
return wordEmbedding + posEmbedding;
......
......@@ -58,7 +58,7 @@ void T2TFNN::InitModel(int argc, char ** argv, int myDevID, XMem * myMem)
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &outSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, DEFAULT_EMBEDDING_SIZE * 4);
LoadParamInt(argc, argv, "fnnh", &hSize, outSize * 4);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.1F);
LoadParamFloat(argc, argv, "dropoutfnn", &dropoutP, 0);
......
......@@ -105,8 +105,8 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamInt(argc, argv, "bufsize", &bufSize, 50000);
LoadParamBool(argc, argv, "adam", &useAdam, false);
LoadParamFloat(argc, argv, "adambeta1", &adamBeta1, 0.9F);
LoadParamFloat(argc, argv, "adambeta2", &adamBeta2, 0.999F);
LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-8F);
LoadParamFloat(argc, argv, "adambeta2", &adamBeta2, 0.98F);
LoadParamFloat(argc, argv, "adamdelta", &adamDelta, 1e-9F);
LoadParamBool(argc, argv, "shuffled", &isShuffled, false);
LoadParamFloat(argc, argv, "labelsmoothing", &labelSmoothingP, 0);
LoadParamInt(argc, argv, "nstepcheckpoint", &nStepCheckpoint, -1);
......@@ -143,6 +143,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
float lr = 0;
int nStepCheck = 0;
int nCheckpoint = 0;
int nSkipped = 0;
char * trainFN = new char[(int)strlen(fn) + 10];
strcpy(trainFN, fn);
......@@ -184,40 +185,46 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* label smoothed gold standard (if needed) */
XTensor goldSmoothed;
while(LoadBatch(file, true, &batch, &padding, &gold, NULL, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)){
while (LoadBatch(file, true, &batch, &padding, &gold, NULL, 1, vSize, sBatchSize, wBatchSize, isLenSorted, wc, devID, mem)) {
CheckNTErrors(batch.order == 3, "wrong tensor order of the sequence batch");
/* output probabilities */
XTensor output;
/* make the network */
model->Make(batch, output, padding, true);
/* back-propagation for obtaining gradients */
if(labelSmoothingP > 0)
if (labelSmoothingP > 0)
LabelSmooth(&gold, &goldSmoothed, labelSmoothingP);
/* make paddings for the output */
if(output.GetDim(0) > 1)
if (output.GetDim(0) > 1)
PadOutput(&output, &gold, &padding);
XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold;
net.Backward(output, g, CROSSENTROPY);
/* learning rate */
lr = lrate * (1.0F / (float)sqrt((float)d)) * (float)MIN(pow((float)step + 1, -0.5F - lrbias), ((float)step + 1) * pow((float)nwarmup, -1.5F - lrbias));
/* update the parameters */
Update(model, lr);
/* get probabilities */
float prob = GetProb(&output, &gold, NULL);
loss += -prob;
wordCount += wc;
wordCountTotal += wc;
DTYPE lossLocal = -prob / wc;
bool doUpdate = (!IsNAN(lossLocal) && !IsINF(lossLocal) && lossLocal < 1e3F);
XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold;
if (doUpdate) {
net.Backward(output, g, CROSSENTROPY);
/* learning rate */
lr = lrate * (1.0F / (float)sqrt((float)d)) * (float)MIN(pow((float)step + 1, -0.5F - lrbias), ((float)step + 1) * pow((float)nwarmup, -1.5F - lrbias));
/* update the parameters */
Update(model, lr);
loss += -prob;
wordCount += wc;
wordCountTotal += wc;
}
else
nSkipped++;
if(++step >= nstep){
isEnd = true;
......@@ -226,8 +233,11 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if (step % 1 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT8(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f\n",
XPRINT8(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f",
lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount), exp(-prob/wc));
if (!doUpdate)
XPRINT(0, stderr, " (no update)");
XPRINT(0, stderr, "\n");
}
if(nStepCheckpoint > 0 && ++nStepCheck >= nStepCheckpoint){
......@@ -252,8 +262,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f\n",
lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount));
XPRINT3(0, stderr, "[INFO] training finished (took %.1fs, step=%d and epoch=%d)\n",
elapsed, step, epoch);
XPRINT4(0, stderr, "[INFO] training finished (took %.1fs, step=%d, skipped=%d and epoch=%d)\n",
elapsed, step, nSkipped, epoch);
delete[] trainFN;
}
......@@ -732,12 +742,12 @@ void T2TTrainer::Update(T2TModel * model, const float lr)
DTYPE e = lr * (DTYPE)sqrt(1 - adamBeta2T) / (1 - adamBeta1T);
DTYPE d = adamDelta * (DTYPE)sqrt(1 - adamBeta2T);
/* m = beat_1 * m + (1-beta_1) * grad */
/* m = beta_1 * m + (1-beta_1) * grad */
XTensor * m = (XTensor*)moments.Get(i);
_ScaleAndShiftMe(m, adamBeta1, 0);
_Sum(m, paraGrad, m, (1.0F - adamBeta1));
/* v = beat_2 * v + (1-beta_2) * grad * grad*/
/* v = beta_2 * v + (1-beta_2) * grad * grad*/
XTensor * v = (XTensor*)moments2nd.Get(i);
_Multiply(paraGrad, paraGrad, v, adamBeta2/(1.0F - adamBeta2));
_ScaleAndShiftMe(v, (1.0F - adamBeta2), 0);
......@@ -846,17 +856,15 @@ void T2TTrainer::LabelSmooth(XTensor * gold, XTensor * smoothed, DTYPE lsP)
int n = gold->GetDim(-1);
DTYPE q = 1.0F - p;
DTYPE gift = p / (n - 1);
DTYPE gift = p / n;
InitTensor(smoothed, gold);
_CopyValues(gold, smoothed);
if(p == 0)
return;
_ScaleAndShiftMe(smoothed, gift/q, -gift/q);
_Sum(smoothed, gold, smoothed);
_ScaleAndShiftMe(smoothed, q);
_ScaleAndShiftMe(smoothed, q, gift);
}
}
......@@ -34,7 +34,7 @@ int TransformerMain(int argc, const char ** argv)
if(argc == 0)
return 1;
fprintf(stderr, "%e\n", log(1e-45F));
fprintf(stderr, "%e\n", exp(DTYPE_MIN));
char ** args = new char*[argc];
for(int i = 0; i < argc; i++){
......
......@@ -55,6 +55,9 @@ namespace nts {
#define DTYPE_MIN (DTYPE)-3.40E+38
#endif
#define LOGPROB_MIN (DTYPE)-1E+15
#define GRAD_MAX (DTYPE)1E+5
#if WIN32
#define DELIMITER '\\'
#else
......
......@@ -122,10 +122,11 @@ void _LogSoftmax(const XTensor * x, XTensor * y, int leadDim)
for (int i = 0; i < n; i++) {
DTYPE r = (DTYPE)log(exp(ip[i * m + j] - mp[j]) / sp[j]);
if (IsNAN(r))
r = DTYPE_MIN;
r = LOGPROB_MIN;
if (IsINF(r))
r = DTYPE_MIN;
op[i * m + j] = r;
r = LOGPROB_MIN;
op[i * m + j] = MAX(r, LOGPROB_MIN);
}
}
}
......
......@@ -79,10 +79,11 @@ void KernelLogSoftmaxComputeByRow(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y
int key = i * colNum + j;
DTYPE r = log(exp(x[key] - inputMax[threadIdx.x]) / inputSum[threadIdx.x]);
if (isnan(r))
r = DTYPE_MIN;
r = LOGPROB_MIN;
if (isinf(r))
r = DTYPE_MIN;
y[key] = r;
r = LOGPROB_MIN;
y[key] = MAX(r, LOGPROB_MIN);
}
}
......@@ -124,10 +125,11 @@ void KernelLogSoftmaxComputeByCol(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y
int key = i * colNum + j;
DTYPE r = log(exp(x[key] - inputMax[threadIdx.y]) / inputSum[threadIdx.y]);
if (isnan(r))
r = DTYPE_MIN;
r = LOGPROB_MIN;
if (isinf(r))
r = DTYPE_MIN;
y[key] = r;
r = LOGPROB_MIN;
y[key] = MAX(r, LOGPROB_MIN);
}
}
......@@ -228,21 +230,29 @@ void KernelLogSoftmaxBackwardDEDS(DTYPE * dedy, DTYPE * dedx, DTYPE * gold, DTYP
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
DTYPE r = 0;
/* dE/ds_j = exp(y_j) */
if (lossName == CROSSENTROPY)
dedx[i] = -gold[i] + exp(y[i]);
r = -gold[i] + exp(y[i]);
/* dE/ds_j = exp(y_j) */
else if (lossName == SQUAREDERROR)
dedx[i] = -gold[i] + exp(y[i]);
r = -gold[i] + exp(y[i]);
else if (lossName == ONEHOTERROR) {
if (gold[i] == 1.0F)
dedx[i] = -gold[i] + exp(y[i]);
r = -gold[i] + exp(y[i]);
else
dedx[i] = 0;
r = 0;
}
else {
dedx[i] = dedy[i];
r = dedy[i];
}
if (isnan(r))
r = 0;
if (isinf(r))
r = 0;
dedx[i] = r;
}
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论