Commit f0937eab by xiaotong

length penalty

parent 2bb37f15
......@@ -27,15 +27,20 @@ namespace transformer
GNMT-like length penalty: pl = ((5 + n)/(5 + 1))^\alpha
where n = length of the sequence
>> length - length of the sequence (for each entry)
>> lp - length penaltyof the sequence (for each entry)
>> alpha - the parameter controls the length preference
<< return - length penaltyof the sequence (for each entry)
*/
void T2TLengthPenalizer::GNMT(const XTensor & length, XTensor & lp, float alpha)
XTensor T2TLengthPenalizer::GNMT(const XTensor & length, float alpha)
{
XTensor base;
XTensor lp;
base = ScaleAndShift(ScaleAndShift(length, 0, 5.0F), 1.0F/(5 + 1));
//base = ScaleAndShift(ScaleAndShift(length, 0, 5.0F), 1.0F/(5 + 1));
base = (length + 5)/(1 + 5);
lp = Power(base, alpha);
return lp;
}
}
......@@ -40,9 +40,9 @@ public:
/* GNMT-like length penalty: pl = ((5 + n)/(5 + 1))^\alpha
where n = length of the sequence */
static
void GNMT(const XTensor & length, XTensor & lp, float alpha);
XTensor GNMT(const XTensor & length, float alpha);
};
}
#endif
\ No newline at end of file
#endif
......@@ -111,7 +111,7 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
_ScaleAndShift(&lenPrev, &len, 1.0F, 1.0F);
/* the GNMT-like length penalty */
T2TLengthPenalizer::GNMT(len, lp, alpha);
lp = T2TLengthPenalizer::GNMT(len, alpha);
/* score = log-prob/lp */
_Div(&score, &lp, &score);
......@@ -188,17 +188,43 @@ expand the search graph
*/
void T2TSearch::Expand(T2TStateBundle * prev, T2TStateBundle * beam)
{
CheckNTErrors(beam->prediction.unitNum == beam->preID.unitNum, "A problem occurs in the beam!");
beam->MakeStates(beam->prediction.unitNum);
T2TState * states = beam->states;
XTensor &predict = beam->prediction;
XTensor & idRef = beam->preID;
XTensor & modelScoreRef = beam->modelScore;
XTensor & probRef = beam->prob;
XTensor & probPathRef = beam->probPath;
XTensor id;
XTensor modelScore;
XTensor prob;
XTensor probPath;
XTensor index = *NewTensorBuf(predict.order - 1, predict.dimSize, X_FLOAT, 1.0F,
predict.devID, predict.mem);
index.SetAscendingOrder(-1);
DelTensorBuf(&index);
InitTensorOnCPU(&id, &idRef);
InitTensorOnCPU(&modelScore, &modelScoreRef);
InitTensorOnCPU(&prob, &probRef);
InitTensorOnCPU(&probPath, &probPathRef);
/* we copy the data to CPU because the frequent access to GPU is slow
and we can speed-up the process by doing the job on CPU. */
CopyValues(idRef, id);
CopyValues(modelScoreRef, modelScore);
CopyValues(prob, probRef);
CopyValues(probPathRef, probPath);
for(int i = 0; i < id.unitNum; i++){
T2TState & state = states[i];
/* pointer to the previous state */
state.last = prev->states + id.GetInt(i);
/* scores */
state.modelScore = modelScore.Get(i);
state.prob = prob.Get(i);
state.probPath = probPath.Get(i);
}
}
/*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论