Commit 4c5776f0 by huchi

add positional embedding

parent ac6ed3a1
...@@ -116,7 +116,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X ...@@ -116,7 +116,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X
{ {
XTensor x; XTensor x;
x = embedder.Make(inputDec, inputDec.GetDim(1)); x = embedder.Make(inputDec, inputDec.GetDim(1), true);
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
......
...@@ -100,7 +100,7 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length, int padIdx) ...@@ -100,7 +100,7 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length, int padIdx)
/* padding zeros */ /* padding zeros */
int padStart = padIdx * eSize; int padStart = padIdx * eSize;
for (int i = padStart; i < padStart + eSize; ++i) for (int i = padStart; i < padStart + eSize; i++)
data[i] = 0.F; data[i] = 0.F;
posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum); posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum);
...@@ -111,55 +111,66 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length, int padIdx) ...@@ -111,55 +111,66 @@ void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length, int padIdx)
/* /*
make the network make the network
*/ */
XTensor T2TEmbedder::Make(XTensor &input, int prevLen) XTensor T2TEmbedder::Make(XTensor &input, int prevLen, int nstep, bool isDec)
{ {
///* assert padding index is 1 */
/* assert padding index is 1 */
//CheckNTErrors(input.order > 1, "Wrong input tensor size!");
//CheckNTErrors(input.dimSize[input.order - 1] < maxLength, "The sequence is too long!"); CheckNTErrors(input.order > 1, "Wrong input tensor size!");
//CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\""); CheckNTErrors(input.dimSize[input.order - 1] < maxLength, "The sequence is too long!");
//CheckNTErrors(eSize > 0, "set embedding size by \"-esize\""); CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\"");
CheckNTErrors(eSize > 0, "set embedding size by \"-esize\"");
//
//XTensor wordEmbedding, position, posEmbedding; XTensor wordEmbedding, position, posEmbedding;
//InitTensor(&position, &input); InitTensorV2(&position, &input);
//int* posData = new int[input.unitNum]; int* posData = new int[input.unitNum];
//XTensor inputCPU; XTensor inputCPU;
//InitTensorOnCPU(&inputCPU, &input); InitTensorOnCPU(&inputCPU, &input);
//_CopyValues(&input, &inputCPU); _CopyValues(&input, &inputCPU);
//for (int i = 0; i < inputCPU.GetDim(0); i++) { if (!isDec)
// int startNoPad = 2 + prevLen - 1; {
// int* p = ((int*)inputCPU.data) + i * inputCPU.GetDim(1); for (int i = 0; i < inputCPU.GetDim(0); i++) {
// for (int j = 0; j < inputCPU.GetDim(1); j++) { int startNoPad = 2 + prevLen;
// if (p[j] == 1) { int* p = ((int*)inputCPU.data) + i * inputCPU.GetDim(1);
// posData[i * inputCPU.GetDim(1) + j] = 1; for (int j = 0; j < inputCPU.GetDim(1); j++) {
// } if (p[j] == 1) {
// else { posData[i * inputCPU.GetDim(1) + j] = 1;
// posData[i * inputCPU.GetDim(1) + j] = startNoPad++; }
// } else {
// } posData[i * inputCPU.GetDim(1) + j] = startNoPad++;
//} }
}
//position.SetData(posData, position.unitNum); }
//delete[] posData; position.SetData(posData, position.unitNum);
}
///* we make positional embeddings first */ else
//if(true){ {
// posEmbedding = Gather(posEmbeddingBase, position); for (int i = 0; i < position.GetDim(0); i++) {
//} for (int j = 0; j < position.GetDim(1); j++) {
position.Set2DInt(nstep + 2, i, j);
}
}
}
delete[] posData;
/* we make positional embeddings first */
if (true) {
posEmbedding = Gather(posEmbeddingBase, position);
}
/* then we make word embeddings */ /* then we make word embeddings */
XTensor wordEmbedding;
wordEmbedding = Gather(w, input); wordEmbedding = Gather(w, input);
wordEmbedding = Linear(wordEmbedding, (float)sqrt((float)eSize)); wordEmbedding = Linear(wordEmbedding, (float)sqrt((float)eSize));
/* we sum over the two embeddings */ /* we sum over the two embeddings */
return wordEmbedding; return Sum(wordEmbedding, posEmbedding);
} }
} }
...@@ -77,7 +77,7 @@ public: ...@@ -77,7 +77,7 @@ public:
void MakePosEmbedding(int eSize, int d, int length, int padIdx); void MakePosEmbedding(int eSize, int d, int length, int padIdx);
/* make the network */ /* make the network */
XTensor Make(XTensor &input, int prevLen=0); XTensor Make(XTensor &input, int prevLen=0, int nstep = -1, bool isDec = false);
}; };
} }
......
...@@ -166,7 +166,7 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp ...@@ -166,7 +166,7 @@ void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inp
inputDec = GetLastPrediction(s); inputDec = GetLastPrediction(s);
inputDec.SetDevice(inputEnc->devID); inputDec.SetDevice(inputEnc->devID);
} }
inputDec.Dump(stderr, "inputDec");
/* prediction probabilities */ /* prediction probabilities */
XTensor& output = next->prob; XTensor& output = next->prob;
......
...@@ -464,8 +464,7 @@ void T2TSearch::Collect(T2TStateBundle* beam) ...@@ -464,8 +464,7 @@ void T2TSearch::Collect(T2TStateBundle* beam)
for (int i = 0; i < beam->stateNum; i++) { for (int i = 0; i < beam->stateNum; i++) {
T2TState& state = states[i]; T2TState& state = states[i];
CheckNTErrors(state.pid >= 0 && state.pid < batchSize, CheckNTErrors(state.pid >= 0 && state.pid < batchSize, "Invalid sample id!");
"Invalid sample id!");
/* check if this is the first end symbol. It is false /* check if this is the first end symbol. It is false
if there have been end symbols in previously generated words. */ if there have been end symbols in previously generated words. */
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论