Commit f0937eab by xiaotong

length penalty

parent 2bb37f15
...@@ -27,15 +27,20 @@ namespace transformer ...@@ -27,15 +27,20 @@ namespace transformer
GNMT-like length penalty: pl = ((5 + n)/(5 + 1))^\alpha GNMT-like length penalty: pl = ((5 + n)/(5 + 1))^\alpha
where n = length of the sequence where n = length of the sequence
>> length - length of the sequence (for each entry) >> length - length of the sequence (for each entry)
>> lp - length penaltyof the sequence (for each entry) >> alpha - the parameter controls the length preference
<< return - length penaltyof the sequence (for each entry)
*/ */
void T2TLengthPenalizer::GNMT(const XTensor & length, XTensor & lp, float alpha) XTensor T2TLengthPenalizer::GNMT(const XTensor & length, float alpha)
{ {
XTensor base; XTensor base;
XTensor lp;
base = ScaleAndShift(ScaleAndShift(length, 0, 5.0F), 1.0F/(5 + 1)); //base = ScaleAndShift(ScaleAndShift(length, 0, 5.0F), 1.0F/(5 + 1));
base = (length + 5)/(1 + 5);
lp = Power(base, alpha); lp = Power(base, alpha);
return lp;
} }
} }
...@@ -40,9 +40,9 @@ public: ...@@ -40,9 +40,9 @@ public:
/* GNMT-like length penalty: pl = ((5 + n)/(5 + 1))^\alpha /* GNMT-like length penalty: pl = ((5 + n)/(5 + 1))^\alpha
where n = length of the sequence */ where n = length of the sequence */
static static
void GNMT(const XTensor & length, XTensor & lp, float alpha); XTensor GNMT(const XTensor & length, float alpha);
}; };
} }
#endif #endif
\ No newline at end of file
...@@ -111,7 +111,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -111,7 +111,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
_ScaleAndShift(&lenPrev, &len, 1.0F, 1.0F); _ScaleAndShift(&lenPrev, &len, 1.0F, 1.0F);
/* the GNMT-like length penalty */ /* the GNMT-like length penalty */
T2TLengthPenalizer::GNMT(len, lp, alpha); lp = T2TLengthPenalizer::GNMT(len, alpha);
/* score = log-prob/lp */ /* score = log-prob/lp */
_Div(&score, &lp, &score); _Div(&score, &lp, &score);
...@@ -188,17 +188,43 @@ expand the search graph ...@@ -188,17 +188,43 @@ expand the search graph
*/ */
void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam) void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
{ {
CheckNTErrors(beam->prediction.unitNum == beam->preID.unitNum, "A problem occurs in the beam!");
beam->MakeStates(beam->prediction.unitNum); beam->MakeStates(beam->prediction.unitNum);
T2TState * states = beam->states; T2TState * states = beam->states;
XTensor &predict = beam->prediction; XTensor & idRef = beam->preID;
XTensor & modelScoreRef = beam->modelScore;
XTensor & probRef = beam->prob;
XTensor & probPathRef = beam->probPath;
XTensor id;
XTensor modelScore;
XTensor prob;
XTensor probPath;
XTensor index = *NewTensorBuf(predict.order - 1, predict.dimSize, X_FLOAT, 1.0F, InitTensorOnCPU(&id, &idRef);
predict.devID, predict.mem); InitTensorOnCPU(&modelScore, &modelScoreRef);
InitTensorOnCPU(&prob, &probRef);
index.SetAscendingOrder(-1); InitTensorOnCPU(&probPath, &probPathRef);
DelTensorBuf(&index); /* we copy the data to CPU because the frequent access to GPU is slow
and we can speed-up the process by doing the job on CPU. */
CopyValues(idRef, id);
CopyValues(modelScoreRef, modelScore);
CopyValues(prob, probRef);
CopyValues(probPathRef, probPath);
for(int i = 0; i < id.unitNum; i++){
T2TState & state = states[i];
/* pointer to the previous state */
state.last = prev->states + id.GetInt(i);
/* scores */
state.modelScore = modelScore.Get(i);
state.prob = prob.Get(i);
state.probPath = probPath.Get(i);
}
} }
/* /*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论