Commit 63e2cfa7 by xiaotong

improve the implementation of softmax

parent 7f483801
...@@ -98,6 +98,8 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining) ...@@ -98,6 +98,8 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining)
x = embedder.Make(input); x = embedder.Make(input);
x.Dump(tmpFILE, "embedding: ");
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP); x = Dropout(x, dropoutP);
......
...@@ -406,7 +406,7 @@ void T2TTrainer::MakeCheckpoint(T2TModel * model, const char * validFN, const ch ...@@ -406,7 +406,7 @@ void T2TTrainer::MakeCheckpoint(T2TModel * model, const char * validFN, const ch
sprintf(fn, "%s.%s.%03d", modelFN, label, id); sprintf(fn, "%s.%s.%03d", modelFN, label, id);
sprintf(fn2, "%s.%s.%03d.output", modelFN, label, id); sprintf(fn2, "%s.%s.%03d.output", modelFN, label, id);
//model->Dump(fn); model->Dump(fn);
if(validFN != NULL){ if(validFN != NULL){
T2TTrainer trainer; T2TTrainer trainer;
trainer.Init(argNum, argArray); trainer.Init(argNum, argArray);
......
...@@ -34,7 +34,7 @@ int TransformerMain(int argc, const char ** argv) ...@@ -34,7 +34,7 @@ int TransformerMain(int argc, const char ** argv)
if(argc == 0) if(argc == 0)
return 1; return 1;
fprintf(stderr, "%e\n", exp(-1e9F)); fprintf(stderr, "%e\n", log(1e-8F));
char ** args = new char*[argc]; char ** args = new char*[argc];
for(int i = 0; i < argc; i++){ for(int i = 0; i < argc; i++){
......
...@@ -55,7 +55,7 @@ namespace nts { ...@@ -55,7 +55,7 @@ namespace nts {
#define DTYPE_MIN (DTYPE)-3.40E+38 #define DTYPE_MIN (DTYPE)-3.40E+38
#endif #endif
#define LOGPROB_MIN (DTYPE)-1E+15 #define LOGPROB_MIN (DTYPE)-2E+1
#define GRAD_MAX (DTYPE)1E+5 #define GRAD_MAX (DTYPE)1E+5
#if WIN32 #if WIN32
......
...@@ -78,6 +78,7 @@ void KernelLogSoftmaxComputeByRow(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y ...@@ -78,6 +78,7 @@ void KernelLogSoftmaxComputeByRow(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y
if (i < rowNum && j < colNum) { if (i < rowNum && j < colNum) {
int key = i * colNum + j; int key = i * colNum + j;
DTYPE r = log(exp(x[key] - inputMax[threadIdx.x]) / inputSum[threadIdx.x]); DTYPE r = log(exp(x[key] - inputMax[threadIdx.x]) / inputSum[threadIdx.x]);
if (isnan(r)) if (isnan(r))
r = LOGPROB_MIN; r = LOGPROB_MIN;
if (isinf(r)) if (isinf(r))
...@@ -124,6 +125,12 @@ void KernelLogSoftmaxComputeByCol(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y ...@@ -124,6 +125,12 @@ void KernelLogSoftmaxComputeByCol(DTYPE * x, DTYPE * max, DTYPE * sum, DTYPE * y
if (i < rowNum && j < colNum) { if (i < rowNum && j < colNum) {
int key = i * colNum + j; int key = i * colNum + j;
DTYPE r = log(exp(x[key] - inputMax[threadIdx.y]) / inputSum[threadIdx.y]); DTYPE r = log(exp(x[key] - inputMax[threadIdx.y]) / inputSum[threadIdx.y]);
/*if (r < LOGPROB_MIN)
{
printf("min %e %e, %e %e, %e %e\n", r, x[key] - inputMax[threadIdx.y], x[key], inputMax[threadIdx.y], exp(x[key] - inputMax[threadIdx.y]), inputSum[threadIdx.y]);
}*/
if (isnan(r)) if (isnan(r))
r = LOGPROB_MIN; r = LOGPROB_MIN;
if (isinf(r)) if (isinf(r))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论