Commit 4c5776f0 by huchi

add positional embedding

parent ac6ed3a1
......@@ -116,7 +116,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X
{
XTensor x;
x = embedder.Make(inputDec, inputDec.GetDim(1));
x = embedder.Make(inputDec, inputDec.GetDim(1), true);
/* dropout */
if(isTraining && dropoutP > 0)
......
......@@ -100,7 +100,7 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length, int padIdx)
/* padding zeros */
int padStart = padIdx * eSize;
for (int i = padStart; i < padStart + eSize; ++i)
for (int i = padStart; i < padStart + eSize; i++)
data[i] = 0.F;
posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum);
......@@ -111,55 +111,66 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length, int padIdx)
/*
make the network
*/
XTensor T2TEmbedder::Make(XTensor &input, int prevLen)
XTensor T2TEmbedder::Make(XTensor &input, int prevLen, int nstep, bool isDec)
{
///* assert padding index is 1 */
//CheckNTErrors(input.order > 1, "Wrong input tensor size!");
//CheckNTErrors(input.dimSize[input.order - 1] < maxLength, "The sequence is too long!");
//CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\"");
//CheckNTErrors(eSize > 0, "set embedding size by \"-esize\"");
/* assert padding index is 1 */
//
//XTensor wordEmbedding, position, posEmbedding;
//InitTensor(&position, &input);
CheckNTErrors(input.order > 1, "Wrong input tensor size!");
CheckNTErrors(input.dimSize[input.order - 1] < maxLength, "The sequence is too long!");
CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\"");
CheckNTErrors(eSize > 0, "set embedding size by \"-esize\"");
//int* posData = new int[input.unitNum];
XTensor wordEmbedding, position, posEmbedding;
InitTensorV2(&position, &input);
//XTensor inputCPU;
//InitTensorOnCPU(&inputCPU, &input);
//_CopyValues(&input, &inputCPU);
int* posData = new int[input.unitNum];
XTensor inputCPU;
InitTensorOnCPU(&inputCPU, &input);
_CopyValues(&input, &inputCPU);
//for (int i = 0; i < inputCPU.GetDim(0); i++) {
// int startNoPad = 2 + prevLen - 1;
// int* p = ((int*)inputCPU.data) + i * inputCPU.GetDim(1);
// for (int j = 0; j < inputCPU.GetDim(1); j++) {
// if (p[j] == 1) {
// posData[i * inputCPU.GetDim(1) + j] = 1;
// }
// else {
// posData[i * inputCPU.GetDim(1) + j] = startNoPad++;
// }
// }
//}
//position.SetData(posData, position.unitNum);
//delete[] posData;
if (!isDec)
{
for (int i = 0; i < inputCPU.GetDim(0); i++) {
int startNoPad = 2 + prevLen;
int* p = ((int*)inputCPU.data) + i * inputCPU.GetDim(1);
for (int j = 0; j < inputCPU.GetDim(1); j++) {
if (p[j] == 1) {
posData[i * inputCPU.GetDim(1) + j] = 1;
}
else {
posData[i * inputCPU.GetDim(1) + j] = startNoPad++;
}
}
}
position.SetData(posData, position.unitNum);
}
else
{
for (int i = 0; i < position.GetDim(0); i++) {
for (int j = 0; j < position.GetDim(1); j++) {
position.Set2DInt(nstep + 2, i, j);
}
}
}
///* we make positional embeddings first */
//if(true){
// posEmbedding = Gather(posEmbeddingBase, position);
//}
delete[] posData;
/* we make positional embeddings first */
if (true) {
posEmbedding = Gather(posEmbeddingBase, position);
}
/* then we make word embeddings */
XTensor wordEmbedding;
wordEmbedding = Gather(w, input);
wordEmbedding = Linear(wordEmbedding, (float)sqrt((float)eSize));
/* we sum over the two embeddings */
return wordEmbedding;
return Sum(wordEmbedding, posEmbedding);
}
}
......@@ -77,7 +77,7 @@ public:
void MakePosEmbedding(int eSize, int d, int length, int padIdx);
/* make the network */
XTensor Make(XTensor &input, int prevLen=0);
XTensor Make(XTensor &input, int prevLen=0, int nstep = -1, bool isDec = false);
};
}
......
......@@ -166,7 +166,7 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp
inputDec = GetLastPrediction(s);
inputDec.SetDevice(inputEnc->devID);
}
inputDec.Dump(stderr, "inputDec");
/* prediction probabilities */
XTensor& output = next->prob;
......
......@@ -464,8 +464,7 @@ void T2TSearch::Collect(T2TStateBundle* beam)
for (int i = 0; i < beam->stateNum; i++) {
T2TState& state = states[i];
CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!");
CheckNTErrors(state.pid >= 0 && state.pid < batchSize, "Invalid sample id!");
/* check if this is the first end symbol. It is false
if there have been end symbols in previously generated words. */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论