Commit 591d6121 by 姜雨帆

implement MulAndShift and bug fix

parent 0b43acf6
...@@ -99,6 +99,8 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient) ...@@ -99,6 +99,8 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient)
GradReduceSumSquared(node, isEfficient); GradReduceSumSquared(node, isEfficient);
else if(operID == REDUCE_REDUCEVARIANCE) else if(operID == REDUCE_REDUCEVARIANCE)
GradReduceVariance(node, isEfficient); GradReduceVariance(node, isEfficient);
else if (operID == MATH_MULANDSHIFT)
GradMulAndShift(node, isEfficient);
else{ else{
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
} }
...@@ -1487,4 +1489,126 @@ void XMathGrad::GradReduceVariance(XTensor * node, bool isEfficient) ...@@ -1487,4 +1489,126 @@ void XMathGrad::GradReduceVariance(XTensor * node, bool isEfficient)
node->visitMark = NODE_FINISHED; node->visitMark = NODE_FINISHED;
} }
/*
gradient for operation
for c = matmul(x, w) + b
we have
dE/dx = dE/dc * w^T
dE/dw = x^T * dE/dc
dE/db = dE/dc * x.reduce(0,...,n-1,n+1,...)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 3, "wrong input tensor number")
XTensor * x = income.tails[0];
XTensor * w = income.tails[1];
XTensor * b = income.tails[2];
int n = income.GetParamInt(0);
MATRIX_TRANS_TYPE transW = income.GetParamTrans(1);
MATRIX_TRANS_TYPE transX = income.GetParamTrans(2);
if (!isEfficient || w->isGrad)
XNoder::MakeGrad(w);
if (!isEfficient || x->isGrad)
XNoder::MakeGrad(x);
if (!isEfficient || b->isGrad)
XNoder::MakeGrad(b);
int order = node->order;
int dimSize[MAX_TENSOR_DIM_NUM];
memcpy(dimSize, node->dimSize, sizeof(int) * node->order);
/* compute dE/db */
if (n == order - 1) {
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = node->unitNum / dimSize[order - 1];
reshapedSize[1] = dimSize[order - 1];
/* we reshape dE/dc to a matrix whose column number is equal to the
size of b. Then we can reduce the matrix into a row vector. */
node->grad->Reshape(2, reshapedSize);
XTensor * bGradTMP = NewTensorBuf(b->grad, b->devID, b->mem);
_ReduceSum(node->grad, bGradTMP, 0);
_Sum(bGradTMP, b->grad, b->grad);
DelTensorBuf(bGradTMP);
node->grad->Reshape(order, dimSize);
}
else {
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = 1;
reshapedSize[1] = dimSize[n];
reshapedSize[2] = 1;
for (int i = 0; i < order; i++) {
if (i < n)
reshapedSize[0] *= dimSize[i];
}
reshapedSize[2] = node->unitNum / (reshapedSize[0] * reshapedSize[1]);
/* we reshape dE/dc to a 3D tensor of size (x, y, z) where y = |b|.
Then reduce along with z and x to obtain dE/db. */
node->grad->Reshape(3, reshapedSize);
XTensor * interGrad = NewTensorBuf(2, reshapedSize, b->dataType, b->denseRatio, b->devID, b->mem);
_ReduceSum(node->grad, interGrad, 2);
XTensor * bGradTMP = NewTensorBuf(b->grad, b->devID, b->mem);
_ReduceSum(interGrad, bGradTMP, 0);
_Sum(bGradTMP, b->grad, b->grad);
DelTensorBuf(bGradTMP);
node->grad->Reshape(order, dimSize);
DelTensorBuf(interGrad);
}
/* compute dE/dx, dE/dw */
XTensor * c = node;
XTensor * dedc = node->grad;
XTensor * dedw = w->grad;
XTensor * dedx = x->grad;
if (x->order == 2 && w->order == 2)
GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, 1.0F, isEfficient);
else if (transX == X_NOTRANS && x->order > 2 && w->order == 2){
int orderBackupX = x->order;
int orderBackupC = c->order;
int dimsBackupX[MAX_TENSOR_DIM_NUM];
int dimsBackupC[MAX_TENSOR_DIM_NUM];
memcpy(dimsBackupX, x->dimSize, sizeof(int) * x->order);
memcpy(dimsBackupC, c->dimSize, sizeof(int) * c->order);
x->Reshape(x->unitNum / x->GetDim(-1), x->GetDim(-1));
c->Reshape(c->unitNum / c->GetDim(-1), c->GetDim(-1));
if (!isEfficient || x->isGrad)
dedx->Reshape(dedx->unitNum / dedx->GetDim(-1), dedx->GetDim(-1));
dedc->Reshape(dedc->unitNum / dedc->GetDim(-1), dedc->GetDim(-1));
GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, 1.0F, isEfficient);
x->Reshape(orderBackupX, dimsBackupX);
c->Reshape(orderBackupC, dimsBackupC);
if (!isEfficient || x->isGrad)
dedx->Reshape(orderBackupX, dimsBackupX);
dedc->Reshape(orderBackupC, dimsBackupC);
}
node->visitMark = NODE_FINISHED;
}
} }
...@@ -168,6 +168,10 @@ private: ...@@ -168,6 +168,10 @@ private:
/* gradient for reduceVariance */ /* gradient for reduceVariance */
static static
void GradReduceVariance(XTensor * node, bool isEfficient); void GradReduceVariance(XTensor * node, bool isEfficient);
/* gradient for operation */
static
void GradMulAndShift(XTensor * node, bool isEfficient);
}; };
} }
......
...@@ -61,6 +61,7 @@ public: ...@@ -61,6 +61,7 @@ public:
XTensor wa; XTensor wa;
XTensor wbig; XTensor wbig;
/* size of transformed Q and K */ /* size of transformed Q and K */
int dk; int dk;
......
...@@ -80,7 +80,6 @@ void AttDecoder::InitModel(int argc, char ** argv, ...@@ -80,7 +80,6 @@ void AttDecoder::InitModel(int argc, char ** argv,
attentionsEnde = new T2TAttention[nlayer]; attentionsEnde = new T2TAttention[nlayer];
attEndeLayerNorms = new T2TLN[nlayer]; attEndeLayerNorms = new T2TLN[nlayer];
/* initialize the stacked layers */ /* initialize the stacked layers */
for (int i = 0; i < nlayer; i++) { for (int i = 0; i < nlayer; i++) {
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem); attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
...@@ -89,9 +88,7 @@ void AttDecoder::InitModel(int argc, char ** argv, ...@@ -89,9 +88,7 @@ void AttDecoder::InitModel(int argc, char ** argv,
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem); fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
attentionsEnde[i].InitModel(argc, argv, true, myIgnored, myDevID, myMem); attentionsEnde[i].InitModel(argc, argv, true, myIgnored, myDevID, myMem);
attEndeLayerNorms[i].InitModel(argc, argv, myDevID, myMem); attEndeLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
} }
} }
/* /*
......
...@@ -100,4 +100,4 @@ public: ...@@ -100,4 +100,4 @@ public:
} }
#endif #endif
\ No newline at end of file
...@@ -103,8 +103,6 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo ...@@ -103,8 +103,6 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
x = embedder.Make(input); x = embedder.Make(input);
//x.Dump(tmpFILE, "embedding: ");
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP); x = Dropout(x, dropoutP);
...@@ -160,4 +158,3 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining) ...@@ -160,4 +158,3 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining)
} }
} }
...@@ -89,13 +89,15 @@ XTensor T2TFNN::Make(XTensor &input, bool isTraining) ...@@ -89,13 +89,15 @@ XTensor T2TFNN::Make(XTensor &input, bool isTraining)
XTensor t1; XTensor t1;
/* t1 = max(0, x * w1 + b1) */ /* t1 = max(0, x * w1 + b1) */
t1 = Rectify(MMul(input, w1) + b1); //t1 = Rectify(MMul(input, w1) + b1);
t1 = Rectify(MulAndShift(input, w1, b1));
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
t1 = Dropout(t1, dropoutP); t1 = Dropout(t1, dropoutP);
/* result = t1 * w2 + b2 */ /* result = t1 * w2 + b2 */
return MMul(t1, w2) + b2; //return MMul(t1, w2) + b2;
return MulAndShift(t1, w2, b2);
} }
......
...@@ -219,7 +219,7 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -219,7 +219,7 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
dims[i + 1] = inputDec.GetDim(i); dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead; dims[0] = nhead;
dims[inputDec.order + 1] = len; dims[inputDec.order + 1] = len;
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem); InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9. /* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in this matrix can be used to prevent the attention to current or following words in
...@@ -236,10 +236,10 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -236,10 +236,10 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID, paddingEnc.mem); XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID, paddingEnc.mem);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1)); _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_Unsqueeze(&paddingDec, maskEncDecTMPDec, paddingEnc.order, paddingEnc.GetDim(-1)); //_Unsqueeze(&paddingDec, maskEncDecTMPDec, paddingEnc.order, paddingEnc.GetDim(-1));
_Multiply(maskEncDecTMPDec, maskEncDecTMPEnc, maskEncDecTMPDec); //_Multiply(maskEncDecTMPDec, maskEncDecTMPEnc, maskEncDecTMPDec);
_ScaleAndShiftMe(maskEncDecTMPDec, 1e9F, -1e9F); _ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
_Unsqueeze(maskEncDecTMPDec, &maskEncDec, 0, dims[0]); _Unsqueeze(maskEncDecTMPEnc, &maskEncDec, 0, dims[0]);
DelTensorBuf(maskEncDecTMPDec); DelTensorBuf(maskEncDecTMPDec);
DelTensorBuf(maskEncDecTMPEnc); DelTensorBuf(maskEncDecTMPEnc);
...@@ -274,10 +274,9 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe ...@@ -274,10 +274,9 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
_Sum(&maskEnc, padding3, &maskEnc); _Sum(&maskEnc, padding3, &maskEnc);
encoding = MakeEncoder(inputEnc, maskEnc, isTraining); encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
//encoding.Dump(stderr, "encoding",10);
decoding = MakeDecoder(inputDec, encoding, maskDec, maskEncDec, isTraining); decoding = MakeDecoder(inputDec, encoding, maskDec, maskEncDec, isTraining);
//decoding.Dump(stderr, "decoding", 10);
outputLayer->Make(decoding, output); outputLayer->Make(decoding, output);
delete[] dims; delete[] dims;
......
...@@ -122,6 +122,7 @@ void T2TTrainer::Init(int argc, char ** argv) ...@@ -122,6 +122,7 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamBool(argc, argv, "bigbatch", &isBigBatch, false); LoadParamBool(argc, argv, "bigbatch", &isBigBatch, false);
LoadParamBool(argc, argv, "debug", &isDebugged, false); LoadParamBool(argc, argv, "debug", &isDebugged, false);
LoadParamBool(argc, argv, "randbatch", &isRandomBatch, false); LoadParamBool(argc, argv, "randbatch", &isRandomBatch, false);
LoadParamInt(argc, argv, "bucketsize", &bucketSize, 0);
buf = new int[bufSize]; buf = new int[bufSize];
buf2 = new int[bufSize]; buf2 = new int[bufSize];
...@@ -147,8 +148,11 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -147,8 +148,11 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
{ {
int step = 0; int step = 0;
int wc = 0; int wc = 0;
int ws =0;
int wordCount = 0; int wordCount = 0;
int totalW;
int wordCountTotal = 0; int wordCountTotal = 0;
int wordCountBatch = 0;
bool isEnd = false; bool isEnd = false;
float loss = 0; float loss = 0;
float lr = 0; float lr = 0;
...@@ -177,9 +181,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -177,9 +181,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
PrepareModel(model); PrepareModel(model);
double startT = GetClockSec(); double startT = GetClockSec();
FILE * fileen = fopen("enc.txt", "w");
FILE * filede = fopen("dec.txt", "w");
for(epoch = 1; epoch <= nepoch; epoch++){ for(epoch = 1; epoch <= nepoch; epoch++){
#ifndef WIN32 #ifndef WIN32
...@@ -197,6 +198,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -197,6 +198,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
XTensor batchEnc; XTensor batchEnc;
XTensor batchDec; XTensor batchDec;
/* labels */
XTensor label;
/* padding */ /* padding */
XTensor paddingEnc; XTensor paddingEnc;
XTensor paddingDec; XTensor paddingDec;
...@@ -207,17 +211,13 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -207,17 +211,13 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* label smoothed gold standard (if needed) */ /* label smoothed gold standard (if needed) */
XTensor goldSmoothed; XTensor goldSmoothed;
while (LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, while (LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, &label,
NULL, vSize, vSizeTgt, NULL, vSize, vSizeTgt,
sBatchSize, wBatchSize, isLenSorted, wc, devID, mem, true)) sBatchSize, wBatchSize, isLenSorted, ws, wc, devID, mem, true))
{ {
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
//batchEnc.Dump(stderr, "enc",1);
//batchDec.Dump(stderr, "dec",1);
//paddingDec.Dump(stderr, "paddec");
/* output probabilities */ /* output probabilities */
XTensor output; XTensor output;
...@@ -231,35 +231,41 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -231,35 +231,41 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
} }
/* back-propagation for obtaining gradients */ /* back-propagation for obtaining gradients */
if (labelSmoothingP > 0) //if (labelSmoothingP > 0)
LabelSmooth(&gold, &goldSmoothed, labelSmoothingP); // LabelSmooth(&gold, &goldSmoothed, labelSmoothingP);
XTensor labelOnehot;
labelOnehot = IndexToOnehot(label, vSizeTgt, labelSmoothingP);
/* make paddings for the output */ /* make paddings for the output */
if (output.GetDim(0) > 0) if (output.GetDim(0) > 0)
PadOutput(&output, &gold, &paddingDec); PadOutput(&output, &labelOnehot, &paddingDec);
/* get probabilities */ /* get probabilities */
float prob = GetProb(&output, &gold, NULL); float prob = GetProb(&output, &labelOnehot, NULL);
//printf("%f\n", prob);
//float prob = 0;
DTYPE lossLocal = -prob / wc; DTYPE lossLocal = -prob / wc;
bool doUpdate = (!IsNAN(lossLocal) && !IsINF(lossLocal) && lossLocal < 1e3F); bool doUpdate = (!IsNAN(lossLocal) && !IsINF(lossLocal) && lossLocal < 1e3F);
XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold; //XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold;
//doUpdate = false;
if (doUpdate) { if (doUpdate) {
/* recale the output for normalized loss */ /* recale the output for normalized loss */
RescaleOutput(&output, &g, &paddingDec); RescaleOutput(&output, &labelOnehot, &paddingDec);
/* back-propagation */ /* back-propagation */
net.Backward(output, g, paddingDec, CROSSENTROPY); net.Backward(output, labelOnehot, paddingDec, CROSSENTROPY);
//net.Backward(output, label, labelSmoothingP, CROSSENTROPY);
gradStep += 1; gradStep += 1;
loss += -prob; loss += -prob;
wordCount += wc; wordCount += wc;
wordCountTotal += wc; wordCountTotal += wc;
//totalW = wc + ws;
wordCountBatch += ws;
/* update the parameters */ /* update the parameters */
if(gradStep == updateStep){ if(gradStep == updateStep){
...@@ -283,8 +289,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -283,8 +289,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if (step % 100 == 0) { if (step % 100 == 0) {
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
XPRINT8(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f", XPRINT8(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, tword=%d, sword=%d, loss=%.3f, ppl=%.3f, sppl=%.3f",
lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount), exp(-prob/wc)); elapsed, step, epoch, wordCountTotal, wordCountBatch, loss/wordCount, exp(loss/wordCount), exp(-prob/wc));
if (!doUpdate) if (!doUpdate)
XPRINT(0, stderr, " (no update)"); XPRINT(0, stderr, " (no update)");
XPRINT(0, stderr, "\n"); XPRINT(0, stderr, "\n");
...@@ -306,9 +312,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model ...@@ -306,9 +312,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch); MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
} }
fclose(fileen);
fclose(filede);
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
epoch = MIN(epoch, nepoch); epoch = MIN(epoch, nepoch);
...@@ -330,6 +333,7 @@ test the model ...@@ -330,6 +333,7 @@ test the model
void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
{ {
int wc = 0; int wc = 0;
int ws = 0;
int wordCount = 0; int wordCount = 0;
int wordCountTotal = 0; int wordCountTotal = 0;
int sentCount = 0; int sentCount = 0;
...@@ -354,6 +358,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -354,6 +358,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
XTensor batchEnc; XTensor batchEnc;
XTensor batchDec; XTensor batchDec;
/* label */
XTensor label;
/* padding */ /* padding */
XTensor paddingEnc; XTensor paddingEnc;
XTensor paddingDec; XTensor paddingDec;
...@@ -366,9 +373,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -366,9 +373,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
ClearBuf(); ClearBuf();
while(LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold, while(LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold, &label,
seqs, vSize, vSizeTgt, seqs, vSize, vSizeTgt,
1, 1, false, wc, devID, mem, false)) 1, 1, false, ws, wc, devID, mem, false))
{ {
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch"); CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
...@@ -470,6 +477,7 @@ struct SampleNode ...@@ -470,6 +477,7 @@ struct SampleNode
int * p; int * p;
int size; int size;
int value; int value;
int key;
}; };
int CompareSampleNode(const void * a, const void * b) int CompareSampleNode(const void * a, const void * b)
...@@ -477,6 +485,11 @@ int CompareSampleNode(const void * a, const void * b) ...@@ -477,6 +485,11 @@ int CompareSampleNode(const void * a, const void * b)
return ((SampleNode*)b)->value - ((SampleNode*)a)->value; return ((SampleNode*)b)->value - ((SampleNode*)a)->value;
} }
int CompareSampleNodeV2(const void * a, const void * b)
{
return ((SampleNode*)b)->key - ((SampleNode*)a)->key;
}
/* /*
load data to buffer load data to buffer
>> file - where to load data >> file - where to load data
...@@ -490,8 +503,7 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step) ...@@ -490,8 +503,7 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
int wordCount = 0; int wordCount = 0;
while(fgets(line, MAX_SEQUENCE_LENGTH - 1, file)){ while(fgets(line, MAX_SEQUENCE_LENGTH - 1, file)){
int len = (int)strlen(line); int len = (int)strlen(line);
if(line[0]=='b')
break;
while(line[len - 1] == '\r' || line[len - 1] == '\n'){ while(line[len - 1] == '\r' || line[len - 1] == '\n'){
line[len - 1] = 0; line[len - 1] = 0;
len--; len--;
...@@ -563,19 +575,41 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step) ...@@ -563,19 +575,41 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
node.p = buf + offset; node.p = buf + offset;
node.size = 0; node.size = 0;
int max = 0; int max = 0;
for(int j = 0; j < step; j++){ for (int j = 0; j < step; j++) {
node.size += seqLen[i + j]; node.size += seqLen[i + j];
max = MAX(max, seqLen[i + j]); max = MAX(max, seqLen[i + j]);
} }
//node.value = seqLen[i+1]+seqLen[i];
//node.value = MAX(seqLen[i+1],seqLen[i]);
node.value = max; node.value = max;
node.key = rand();
count++; count++;
offset += node.size; offset += node.size;
} }
qsort(nodes, count, sizeof(SampleNode), CompareSampleNode); qsort(nodes, count, sizeof(SampleNode), CompareSampleNode);
/* distribute samples into buckets. In each bucket, sequences have
similar a length */
if (bucketSize > 0) {
int bucketCount = 0;
int low = 0;
int high = low + bucketSize;
int n = count - 1;
int m = n;
int num = 0;
while (num < count) {
for (m = n; m >= 0; m--) {
if (nodes[m].value > high)
break;
}
qsort(nodes + m + 1, n - m, sizeof(SampleNode), CompareSampleNodeV2);
num += (n - m);
n = m;
low += bucketSize;
high = low + bucketSize;
}
}
count = 0; count = 0;
offset = 0; offset = 0;
for(int i = 0; i < seqCount; i += step){ for(int i = 0; i < seqCount; i += step){
...@@ -633,22 +667,22 @@ load a batch of sequences ...@@ -633,22 +667,22 @@ load a batch of sequences
int T2TTrainer::LoadBatch(FILE * file, bool isLM, int T2TTrainer::LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec, XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * gold, XTensor * label,
int * seqs, int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &ws, int &wCount,
int devID, XMem * mem, int devID, XMem * mem,
bool isTraining) bool isTraining)
{ {
if(isLM){ if(isLM){
return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, gold, return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, gold, label,
seqs, vsEnc, sBatch, wBatch, seqs, vsEnc, sBatch, wBatch,
isSorted, wCount, devID, mem, isTraining); isSorted, wCount, devID, mem, isTraining);
} }
else{ else{
return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, gold, return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, gold, label,
seqs, vsEnc, vsDec, sBatch, wBatch, seqs, vsEnc, vsDec, sBatch, wBatch,
isSorted, wCount, devID, mem, isTraining); isSorted, ws, wCount, devID, mem, isTraining);
} }
} }
...@@ -674,7 +708,7 @@ load a batch of sequences (for LM) ...@@ -674,7 +708,7 @@ load a batch of sequences (for LM)
int T2TTrainer::LoadBatchLM(FILE * file, int T2TTrainer::LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec, XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * gold, XTensor * label,
int * seqs, int * seqs,
int vs, int sBatch, int wBatch, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
...@@ -716,11 +750,13 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -716,11 +750,13 @@ int T2TTrainer::LoadBatchLM(FILE * file,
dims[2] = vs; dims[2] = vs;
InitTensor2D(batchEnc, sc, max, X_INT, devID, mem); InitTensor2D(batchEnc, sc, max, X_INT, devID, mem);
InitTensor2D(label, sc, max, X_INT, devID, mem);
InitTensor(gold, 3, dims, X_FLOAT, 1.0F, devID, mem); InitTensor(gold, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingEnc, sc, max, X_FLOAT, devID, mem); InitTensor2D(paddingEnc, sc, max, X_FLOAT, devID, mem);
InitTensor2D(paddingDec, sc, max, X_FLOAT, devID, mem); InitTensor2D(paddingDec, sc, max, X_FLOAT, devID, mem);
batchEnc->SetZeroAll(); batchEnc->SetZeroAll();
label->SetZeroAll();
gold->SetZeroAll(); gold->SetZeroAll();
paddingEnc->SetZeroAll(); paddingEnc->SetZeroAll();
paddingDec->SetZeroAll(); paddingDec->SetZeroAll();
...@@ -728,13 +764,15 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -728,13 +764,15 @@ int T2TTrainer::LoadBatchLM(FILE * file,
int seqSize = 0; int seqSize = 0;
int * batchEncValues = new int[batchEnc->unitNum]; int * batchEncValues = new int[batchEnc->unitNum];
int * labelValues = new int[label->unitNum];
MTYPE * goldOffsets = new MTYPE[gold->unitNum]; MTYPE * goldOffsets = new MTYPE[gold->unitNum];
//MTYPE * paddingEncOffsets = new MTYPE[paddingEnc->unitNum]; MTYPE * paddingEncOffsets = new MTYPE[paddingEnc->unitNum];
//MTYPE * paddingDecOffsets = new MTYPE[paddingDec->unitNum]; MTYPE * paddingDecOffsets = new MTYPE[paddingDec->unitNum];
int wGold = 0; int wGold = 0;
memset(batchEncValues, 0, sizeof(int) * batchEnc->unitNum); memset(batchEncValues, 0, sizeof(int) * batchEnc->unitNum);
memset(labelValues, 0, sizeof(int) * label->unitNum);
for(int s = seq; s < seq + sc; s++){ for(int s = seq; s < seq + sc; s++){
int len = isDoubledEnd ? seqLen[s] : seqLen[s] - 1; int len = isDoubledEnd ? seqLen[s] : seqLen[s] - 1;
...@@ -742,16 +780,23 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -742,16 +780,23 @@ int T2TTrainer::LoadBatchLM(FILE * file,
for(int w = 0; w < len; w++){ for(int w = 0; w < len; w++){
int num = buf[seqOffset[s] + w]; int num = buf[seqOffset[s] + w];
batchEncValues[(int)batchEnc->GetOffset2D(s - seq, w)] = num; batchEncValues[(int)batchEnc->GetOffset2D(s - seq, w)] = num;
//paddingEncOffsets[wCount] = paddingEnc->GetOffset2D(s - seq, w); paddingEncOffsets[wCount] = paddingEnc->GetOffset2D(s - seq, w);
//paddingDecOffsets[wCount] = paddingDec->GetOffset2D(s - seq, w); paddingDecOffsets[wCount] = paddingDec->GetOffset2D(s - seq, w);
if (w > 0) if (w > 0) {
goldOffsets[wGold++] = gold->GetOffset3D(s - seq, w - 1, num); goldOffsets[wGold++] = gold->GetOffset3D(s - seq, w - 1, num);
labelValues[(int)label->GetOffset2D(s - seq, w - 1)] = buf[seqOffset[s] + w];
}
if (w == len - 1) { if (w == len - 1) {
if (isDoubledEnd) if (isDoubledEnd) {
goldOffsets[wGold++] = gold->GetOffset3D(s - seq, w, num); goldOffsets[wGold++] = gold->GetOffset3D(s - seq, w, num);
else labelValues[(int)label->GetOffset2D(s - seq, w)] = buf[seqOffset[s] + w];
}
else {
goldOffsets[wGold++] = gold->GetOffset3D(s - seq, w, buf[seqOffset[s] + w + 1]); goldOffsets[wGold++] = gold->GetOffset3D(s - seq, w, buf[seqOffset[s] + w + 1]);
labelValues[(int)label->GetOffset2D(s - seq, w)] = buf[seqOffset[s] + w + 1];
}
} }
wCount++; wCount++;
...@@ -767,11 +812,12 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -767,11 +812,12 @@ int T2TTrainer::LoadBatchLM(FILE * file,
} }
batchEnc->SetData(batchEncValues, batchEnc->unitNum); batchEnc->SetData(batchEncValues, batchEnc->unitNum);
label->SetData(labelValues, label->unitNum);
gold->SetDataBatched(goldOffsets, 1.0F, wGold); gold->SetDataBatched(goldOffsets, 1.0F, wGold);
//paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCount); paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCount);
//paddingDec->SetDataBatched(paddingDecOffsets, 1.0F, wCount); paddingDec->SetDataBatched(paddingDecOffsets, 1.0F, wCount);
XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem); /*XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
_ConvertDataType(batchEnc, tmp); _ConvertDataType(batchEnc, tmp);
_NotEqual(tmp, paddingEnc, 0); _NotEqual(tmp, paddingEnc, 0);
DelTensorBuf(tmp); DelTensorBuf(tmp);
...@@ -779,12 +825,13 @@ int T2TTrainer::LoadBatchLM(FILE * file, ...@@ -779,12 +825,13 @@ int T2TTrainer::LoadBatchLM(FILE * file,
XTensor * tmp2 = NewTensorBuf(paddingDec, devID, mem); XTensor * tmp2 = NewTensorBuf(paddingDec, devID, mem);
_ConvertDataType(batchEnc, tmp2); _ConvertDataType(batchEnc, tmp2);
_NotEqual(tmp2, paddingDec, 0); _NotEqual(tmp2, paddingDec, 0);
DelTensorBuf(tmp2); DelTensorBuf(tmp2);*/
delete[] batchEncValues; delete[] batchEncValues;
delete[] labelValues;
delete[] goldOffsets; delete[] goldOffsets;
//delete[] paddingEncOffsets; delete[] paddingEncOffsets;
//delete[] paddingDecOffsets; delete[] paddingDecOffsets;
fflush(tf); fflush(tf);
...@@ -819,15 +866,13 @@ load a batch of sequences (for MT) ...@@ -819,15 +866,13 @@ load a batch of sequences (for MT)
int T2TTrainer::LoadBatchMT(FILE * file, int T2TTrainer::LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec, XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * gold, XTensor * label,
int * seqs, int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &ws, int &wCount,
int devID, XMem * mem, int devID, XMem * mem,
bool isTraining) bool isTraining)
{ {
//if (nextSeq < 0 || nextSeq >= nseqBuf)
// LoadBuf(file, isSorted, 2);
if (nextBatch < 0 || nextBatch >= bufBatchSize) { if (nextBatch < 0 || nextBatch >= bufBatchSize) {
LoadBuf(file, isSorted, 2); LoadBuf(file, isSorted, 2);
...@@ -849,30 +894,30 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -849,30 +894,30 @@ int T2TTrainer::LoadBatchMT(FILE * file,
while (seq + sc < nseqBuf) { while (seq + sc < nseqBuf) {
/* source-side sequence */ /* source-side sequence */
wnEnc = seqLen[seq + sc]; wnEnc = seqLen[seq + sc];
/* target-side sequence */
wnDec = isDoubledEnd ? seqLen[seq + sc + 1] : seqLen[seq + sc + 1] - 1;
int tcEnc = isBigBatch ? (wcEnc + wnEnc): MAX(maxEnc, wnEnc) * (sc + 2) / 2;
int tcDec = isBigBatch ? (wcDec + wnDec): MAX(maxDec, wnDec) * (sc + 2) / 2;
if(sc != 0 && sc > sBatch * 2 && (tcEnc > wBatch || tcDec > wBatch)) /* target-side sequence */
break; wnDec = isDoubledEnd ? seqLen[seq + sc + 1] : seqLen[seq + sc + 1] - 1;
wcEnc += wnEnc; int tcEnc = isBigBatch ? (wcEnc + wnEnc) : MAX(maxEnc, wnEnc) * (sc + 2) / 2;
sc += 1; int tcDec = isBigBatch ? (wcDec + wnDec) : MAX(maxDec, wnDec) * (sc + 2) / 2;
if(maxEnc < wnEnc) if (sc != 0 && sc > sBatch * 2 && (tcEnc > wBatch || tcDec > wBatch))
maxEnc = wnEnc; break;
wcDec += wnDec; wcEnc += wnEnc;
sc += 1; sc += 1;
if(maxDec < wnDec) if (maxEnc < wnEnc)
maxDec = wnDec; maxEnc = wnEnc;
}
wcDec += wnDec;
sc += 1;
if (maxDec < wnDec)
maxDec = wnDec;
}
BatchNode & batch = bufBatch[bufBatchSize]; BatchNode & batch = bufBatch[bufBatchSize];
batch.beg = seq; batch.beg = seq;
...@@ -889,46 +934,6 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -889,46 +934,6 @@ int T2TTrainer::LoadBatchMT(FILE * file,
qsort(bufBatch, bufBatchSize, sizeof(BatchNode), CompareBatchNode); qsort(bufBatch, bufBatchSize, sizeof(BatchNode), CompareBatchNode);
} }
/*int seq = MAX(nextSeq, 0);
int wcEnc = 0;
int wcDec = 0;
int wnEnc = 0;
int wnDec = 0;
int maxEnc = 0;
int maxDec = 0;
int sc = 0;
CheckNTErrors((nseqBuf - seq) % 2 == 0, "Input sequence must be paired!");
while(seq + sc < nseqBuf){
wnEnc = seqLen[seq + sc];
wnDec = isDoubledEnd ? seqLen[seq + sc + 1] : seqLen[seq + sc + 1] - 1;
int tcEnc = isBigBatch ? (wcEnc + wnEnc): MAX(maxEnc, wnEnc) * (sc + 2) / 2;
int tcDec = isBigBatch ? (wcDec + wnDec): MAX(maxDec, wnDec) * (sc + 2) / 2;
if(sc != 0 && sc > sBatch * 2 && (tcEnc > wBatch || tcDec > wBatch))
break;
wcEnc += wnEnc;
sc += 1;
if(maxEnc < wnEnc)
maxEnc = wnEnc;
wcDec += wnDec;
sc += 1;
if(maxDec < wnDec)
maxDec = wnDec;
}
nextSeq = seq + sc;
if(sc <= 0)
return 0;*/
if(bufBatchSize <= 0) if(bufBatchSize <= 0)
return 0; return 0;
...@@ -948,13 +953,15 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -948,13 +953,15 @@ int T2TTrainer::LoadBatchMT(FILE * file,
InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem); InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem);
InitTensor2D(batchDec, sCount, maxDec, X_INT, devID, mem); InitTensor2D(batchDec, sCount, maxDec, X_INT, devID, mem);
InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem); InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem);
InitTensor(gold, 3, dimsDec, X_FLOAT, 1.0F, devID, mem); InitTensor2D(label, sCount, maxDec, X_INT, devID, mem);
//InitTensor(gold, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
batchEnc->SetZeroAll(); batchEnc->SetZeroAll();
paddingEnc->SetZeroAll(); paddingEnc->SetZeroAll();
batchDec->SetZeroAll(); batchDec->SetZeroAll();
paddingDec->SetZeroAll(); paddingDec->SetZeroAll();
gold->SetZeroAll(); label->SetZeroAll();
//gold->SetZeroAll();
int wCountEnc = 0; int wCountEnc = 0;
int wCountDec = 0; int wCountDec = 0;
...@@ -964,12 +971,14 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -964,12 +971,14 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int * batchEncValues = new int[batchEnc->unitNum]; int * batchEncValues = new int[batchEnc->unitNum];
int * batchDecValues = new int[batchDec->unitNum]; int * batchDecValues = new int[batchDec->unitNum];
int * labelValues = new int[label->unitNum];
//MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2]; //MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2];
MTYPE * paddingDecOffsets = new MTYPE[sc * maxDec / 2]; MTYPE * paddingDecOffsets = new MTYPE[sc * maxDec / 2];
MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2]; //MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2];
memset(batchEncValues, 0, sizeof(int) * batchEnc->unitNum); memset(batchEncValues, 0, sizeof(int) * batchEnc->unitNum);
memset(batchDecValues, 0, sizeof(int) * batchDec->unitNum); memset(batchDecValues, 0, sizeof(int) * batchDec->unitNum);
memset(labelValues, 0, sizeof(int) * batchDec->unitNum);
/* batch of the source-side sequences */ /* batch of the source-side sequences */
for(int s = seq; s < seq + sc; s += 2){ for(int s = seq; s < seq + sc; s += 2){
...@@ -982,7 +991,7 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -982,7 +991,7 @@ int T2TTrainer::LoadBatchMT(FILE * file,
wCountEnc++; wCountEnc++;
} }
} }
ws = wCountEnc;
batchEnc->SetData(batchEncValues, batchEnc->unitNum); batchEnc->SetData(batchEncValues, batchEnc->unitNum);
//paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc); //paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc);
XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem); XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
...@@ -1003,14 +1012,19 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -1003,14 +1012,19 @@ int T2TTrainer::LoadBatchMT(FILE * file,
paddingDecOffsets[wCountPad++] = paddingDec->GetOffset2D(sent, w); paddingDecOffsets[wCountPad++] = paddingDec->GetOffset2D(sent, w);
wCount++; wCount++;
} }
if (w > 0) if (w > 0) {
goldOffsets[wGold++] = gold->GetOffset3D(sent, w - 1, buf[seqOffset[s] + w]); //goldOffsets[wGold++] = gold->GetOffset3D(sent, w - 1, buf[seqOffset[s] + w]);
labelValues[label->GetOffset2D(sent, w - 1)] = buf[seqOffset[s] + w];
}
if (w == len - 1) { if (w == len - 1) {
if (isDoubledEnd) if (isDoubledEnd) {
goldOffsets[wGold++] = gold->GetOffset3D(sent, w, buf[seqOffset[s] + w]); //goldOffsets[wGold++] = gold->GetOffset3D(sent, w, buf[seqOffset[s] + w]);
else labelValues[label->GetOffset2D(sent, w)] = buf[seqOffset[s] + w];
goldOffsets[wGold++] = gold->GetOffset3D(sent, w, buf[seqOffset[s] + w + 1]); }
else {
//goldOffsets[wGold++] = gold->GetOffset3D(sent, w, buf[seqOffset[s] + w + 1]);
labelValues[label->GetOffset2D(sent, w)] = buf[seqOffset[s] + w + 1];
}
} }
//wCount++; //wCount++;
wCountDec++; wCountDec++;
...@@ -1025,6 +1039,7 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -1025,6 +1039,7 @@ int T2TTrainer::LoadBatchMT(FILE * file,
} }
batchDec->SetData(batchDecValues, batchDec->unitNum); batchDec->SetData(batchDecValues, batchDec->unitNum);
label->SetData(labelValues, label->unitNum);
paddingDec->SetDataBatched(paddingDecOffsets, 1.0F, wCountPad); paddingDec->SetDataBatched(paddingDecOffsets, 1.0F, wCountPad);
//XTensor * tmp2 = NewTensorBuf(paddingDec, devID, mem); //XTensor * tmp2 = NewTensorBuf(paddingDec, devID, mem);
...@@ -1032,13 +1047,14 @@ int T2TTrainer::LoadBatchMT(FILE * file, ...@@ -1032,13 +1047,14 @@ int T2TTrainer::LoadBatchMT(FILE * file,
//_NotEqual(tmp2, paddingDec, 0); //_NotEqual(tmp2, paddingDec, 0);
//DelTensorBuf(tmp2); //DelTensorBuf(tmp2);
gold->SetDataBatched(goldOffsets, 1.0F, wGold); //gold->SetDataBatched(goldOffsets, 1.0F, wGold);
delete[] batchEncValues; delete[] batchEncValues;
delete[] batchDecValues; delete[] batchDecValues;
delete[] labelValues;
//delete[] paddingEncOffsets; //delete[] paddingEncOffsets;
delete[] paddingDecOffsets; delete[] paddingDecOffsets;
delete[] goldOffsets; //delete[] goldOffsets;
return sc; return sc;
} }
...@@ -1071,12 +1087,6 @@ float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs) ...@@ -1071,12 +1087,6 @@ float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs)
XTensor probs; XTensor probs;
InitTensor(&probs, output); InitTensor(&probs, output);
//XTensor logOutput;
//InitTensor(&logOutput, output);
//_Log(output, &logOutput);
/* probs[i,j] = output[i,j] * gold[i,j] */
//_Multiply(&logOutput, gold, &probs);
_Multiply(output, gold, &probs); _Multiply(output, gold, &probs);
/* probability of each word */ /* probability of each word */
......
...@@ -176,6 +176,9 @@ public: ...@@ -176,6 +176,9 @@ public:
/* indicates whether we intend to debug the net */ /* indicates whether we intend to debug the net */
bool isDebugged; bool isDebugged;
/* bucket size */
int bucketSize;
public: public:
/* constructor */ /* constructor */
T2TTrainer(); T2TTrainer();
...@@ -205,10 +208,10 @@ public: ...@@ -205,10 +208,10 @@ public:
int LoadBatch(FILE * file, bool isLM, int LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec, XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * gold, XTensor * label,
int * seqs, int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &ws, int &wCount,
int devID, XMem * mem, int devID, XMem * mem,
bool isTraining); bool isTraining);
...@@ -216,7 +219,7 @@ public: ...@@ -216,7 +219,7 @@ public:
int LoadBatchLM(FILE * file, int LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec, XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * gold, XTensor * label,
int * seqs, int vs, int sBatch, int wBatch, int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &wCount,
int devID, XMem * mem, int devID, XMem * mem,
...@@ -226,9 +229,9 @@ public: ...@@ -226,9 +229,9 @@ public:
int LoadBatchMT(FILE * file, int LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc, XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec, XTensor * batchDec, XTensor * paddingDec,
XTensor * gold, XTensor * gold, XTensor * label,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch, int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount, bool isSorted, int &ws, int &wCount,
int devID, XMem * mem, int devID, XMem * mem,
bool isTraining); bool isTraining);
......
...@@ -36,8 +36,6 @@ int TransformerMain(int argc, const char ** argv) ...@@ -36,8 +36,6 @@ int TransformerMain(int argc, const char ** argv)
{ {
if(argc == 0) if(argc == 0)
return 1; return 1;
fprintf(stderr, "%e\n", log(1e-8F));
char ** args = new char*[argc]; char ** args = new char*[argc];
for(int i = 0; i < argc; i++){ for(int i = 0; i < argc; i++){
...@@ -66,10 +64,7 @@ int TransformerMain(int argc, const char ** argv) ...@@ -66,10 +64,7 @@ int TransformerMain(int argc, const char ** argv)
T2TModel model; T2TModel model;
model.InitModel(argc, args); model.InitModel(argc, args);
//if(strcmp(modelFN, ""))
//model.Read(modelFN);
/* learn model parameters */ /* learn model parameters */
if(strcmp(trainFN, "")) if(strcmp(trainFN, ""))
trainer.Train(trainFN, testFN, strcmp(modelFN, "") ? modelFN : "checkpoint.model", &model); trainer.Train(trainFN, testFN, strcmp(modelFN, "") ? modelFN : "checkpoint.model", &model);
......
...@@ -307,6 +307,27 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id ...@@ -307,6 +307,27 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id
MakeLink(&list, h, id); MakeLink(&list, h, id);
} }
/*
create a hyperedge with two input tensors and a output tensor
>> t1 - a tail tensor
>> t2 - the second tail tensor
>> t3 - the third tail tensor
>> h - head tensor
>> id - id of the edge type
*/
void XLink::MakeLink(const XTensor * t1, const XTensor * t2, const XTensor * t3,XTensor * h, int id)
{
if (h == NULL)
return;
XList list(3);
list.Add(t1);
list.Add(t2);
list.Add(t3);
MakeLink(&list, h, id);
}
/* /*
create a hyper edge with a list of tensors and a output tensor create a hyper edge with a list of tensors and a output tensor
>> list - a list of input tensors >> list - a list of input tensors
......
...@@ -138,6 +138,10 @@ struct XLink ...@@ -138,6 +138,10 @@ struct XLink
static static
void MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id); void MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id);
/* create a hyper edge with three input tensors and a output tensor */
static
void MakeLink(const XTensor * t1, const XTensor * t2, const XTensor * t3, XTensor * h, int id);
/* create a hyper edge with a list of input tensors and a output tensor */ /* create a hyper edge with a list of input tensors and a output tensor */
static static
void MakeLink(const XList * list, XTensor * h, int id); void MakeLink(const XList * list, XTensor * h, int id);
......
...@@ -77,6 +77,8 @@ const char * GetOPName(int type) ...@@ -77,6 +77,8 @@ const char * GetOPName(int type)
return "M_POWER"; return "M_POWER";
else if (type == MATH_SCALEANDSHIFT) else if (type == MATH_SCALEANDSHIFT)
return "M_SCALEANDSHIFT"; return "M_SCALEANDSHIFT";
else if (type == MATH_MULANDSHIFT)
return "M_OPERATION";
else if (type == MATH_SIGN) else if (type == MATH_SIGN)
return "M_SIGN"; return "M_SIGN";
else if (type == MATH_SUB) else if (type == MATH_SUB)
......
...@@ -57,7 +57,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -57,7 +57,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_NORMALIZE MATH_NEGATE + 1 #define MATH_NORMALIZE MATH_NEGATE + 1
#define MATH_POWER MATH_NORMALIZE + 1 #define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1 #define MATH_SCALEANDSHIFT MATH_POWER + 1
#define MATH_SIGN MATH_SCALEANDSHIFT + 1 #define MATH_MULANDSHIFT MATH_SCALEANDSHIFT + 1
#define MATH_SIGN MATH_MULANDSHIFT + 1
#define MATH_SUB MATH_SIGN + 1 #define MATH_SUB MATH_SIGN + 1
#define MATH_SUBDIM MATH_SUB + 1 #define MATH_SUBDIM MATH_SUB + 1
#define MATH_SUM MATH_SUBDIM + 1 #define MATH_SUM MATH_SUBDIM + 1
......
...@@ -1614,17 +1614,11 @@ void XTensor::Dump(FILE * file, const char * label, const int n, const int beg, ...@@ -1614,17 +1614,11 @@ void XTensor::Dump(FILE * file, const char * label, const int n, const int beg,
else if (dataType == X_INT) { else if (dataType == X_INT) {
int end = MIN(n > 0 ? beg + n : beg + unitNum, unitNum); int end = MIN(n > 0 ? beg + n : beg + unitNum, unitNum);
for(int i = beg; i < end; i++){ for(int i = beg; i < end; i++){
if((i%(dimSize[1]) == 0)&&(i!=0)) {
fprintf(file, " \n");
}
int f = ((int*)d)[i]; int f = ((int*)d)[i];
if(i == beg) if(i == beg)
fprintf(file, "%d", f); fprintf(file, "%d", f);
else else
fprintf(file, " %d", f); fprintf(file, " %d", f);
//if((i%(dimSize[1]-1) == 0)&&(i!=0)) {
//fprintf(file, " \n");
//}
} }
} }
else else
......
...@@ -44,6 +44,7 @@ ...@@ -44,6 +44,7 @@
#include "arithmetic/SumByColumnVT.h" #include "arithmetic/SumByColumnVT.h"
#include "arithmetic/SumDim.h" #include "arithmetic/SumDim.h"
#include "arithmetic/XTensorBLAS.h" #include "arithmetic/XTensorBLAS.h"
#include "arithmetic/MulAndShift.h"
#include "getandset/ConvertDataType.h" #include "getandset/ConvertDataType.h"
#include "getandset/OnehotAndIndex.h" #include "getandset/OnehotAndIndex.h"
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2019-02-27
*/
#include "../../XTensor.h"
#include "../../XDevice.h"
#include "../../XName.h"
#include "MulAndShift.h"
#include "MatrixMul.h"
#include "Sum.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
return a dimension if the sum is performed as SumDim (in more details in SumDim.h)
>> a - a tensor
>> b - another tensor for sum
*/
int GetSumIndex(const XTensor &a, const XTensor &b)
{
if (a.order < b.order)
return -1;
if (XTensor::IsSameShaped(&a, &b))
return -1;
int hitCount = 0;
int hitDim = -1;
for (int i = 0; i < b.order; i++) {
if (b.dimSize[b.order - 1 - i] == 1)
continue;
else if (b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]) {
hitCount++;
hitDim = a.order - b.order + i;
}
}
if (hitCount == 1)
return hitDim;
else
return -1;
}
/*
operation c = x * w + b MulAndShift
>> x - tensor x
>> w - tensor w
>> b - tensor b
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
*/
XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DTYPE alpha, XPRunner * parallelRunner)
{
CheckNTErrors(x.dataType == w.dataType, "Input tensors should have the same data type!");
CheckNTErrors(x.order >= 2 && w.order >= 2, "Input tensors must have a order >= 2!");
int xn = x.dimSizeRDI[1];
int xm = x.dimSizeRDI[0];
int wn = w.dimSizeRDI[1];
int wm = w.dimSizeRDI[0];
CheckNTErrors(xm == wn, "Unmatched tensors in multiplication!");
int order = x.order + w.order - 2;
int sub = 0;
int * dimSize = new int[order];
for (int i = 2; i < x.order; i++)
dimSize[sub++] = x.dimSizeRDI[x.order + 1 - i];
for (int i = 2; i < w.order; i++)
dimSize[sub++] = w.dimSizeRDI[w.order + 1 - i];
dimSize[sub++] = xn;
dimSize[sub++] = wm;
float dr = (!x.isSparse || !w.isSparse) ? 1.0F : MAX(x.denseRatio, w.denseRatio);
XTensor * tmp = NewTensorBuf(order, dimSize, x.dataType, dr, x.devID, x.mem);
/* call _MatrixMul function */
_MatrixMul(&x, X_NOTRANS, &w, X_NOTRANS, tmp, alpha, 0, parallelRunner);
XTensor c(tmp);
c.SetTMPFlag();
int n = GetSumIndex(tmp, b);
if (n == -1) {
/* call _Sum function */
_Sum(tmp, &b, &c);
// TODO!!
ShowNTErrors("TODO!");
}
else if (n >= 0 && n < tmp->order) {
/* call _SumDim function */
_SumDim(tmp, &b, &c, n);
}
else {
ShowNTErrors("Something is wrong!");
}
/* tensor connections */
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
//XLink::AddParamToHead(&c, beta);
/* destroy variables */
delete[] dimSize;
DelTensorBuf(tmp);
return c;
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2019-02-27
*/
#ifndef __MULANDSHIFT_H__
#define __MULANDSHIFT_H__
#include "../../XTensor.h"
#include "../CHeader.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor)
#endif // __OPERATION_H__
...@@ -99,11 +99,11 @@ convert index tensor to onehot tensor ...@@ -99,11 +99,11 @@ convert index tensor to onehot tensor
>> onehot - onehot tensor, which value is 0 or 1 >> onehot - onehot tensor, which value is 0 or 1
>> size - the last dimension size of the onehot tensor >> size - the last dimension size of the onehot tensor
*/ */
void _IndexToOnehot(XTensor * index, XTensor * onehot, int size) void _IndexToOnehot(XTensor * index, XTensor * onehot, int size, float labelSmoothingP)
{ {
CheckNTErrors(onehot->GetDim(-1) == size, "Illegal tensor dimension!"); CheckNTErrors(onehot->GetDim(-1) == size, "Illegal tensor dimension!");
CheckNTErrors(onehot->order == index->order + 1, "Illegal tensor order!"); CheckNTErrors(onehot->order == index->order + 1, "Illegal tensor order!");
CheckNTErrors(onehot->dataType == X_INT, "The onehot tensor must be in X_INT!") //CheckNTErrors(onehot->dataType == X_INT, "The onehot tensor must be in X_INT!")
CheckNTErrors(index->dataType == X_INT, "The index tensor must be in X_INT!") CheckNTErrors(index->dataType == X_INT, "The index tensor must be in X_INT!")
for (int i = 0; i < index->order; i++) for (int i = 0; i < index->order; i++)
...@@ -111,9 +111,12 @@ void _IndexToOnehot(XTensor * index, XTensor * onehot, int size) ...@@ -111,9 +111,12 @@ void _IndexToOnehot(XTensor * index, XTensor * onehot, int size)
onehot->SetZeroAll(); onehot->SetZeroAll();
float confidence = 1 - labelSmoothingP;
float lowconfidence = labelSmoothingP / size;
#ifdef USE_CUDA #ifdef USE_CUDA
if(onehot->devID >= 0 && index->devID >= 0) { if(onehot->devID >= 0 && index->devID >= 0) {
_CudaIndexToOnehot(index, onehot, size); _CudaIndexToOnehot(index, onehot, size, confidence, lowconfidence);
return; return;
} }
#endif #endif
...@@ -122,12 +125,13 @@ void _IndexToOnehot(XTensor * index, XTensor * onehot, int size) ...@@ -122,12 +125,13 @@ void _IndexToOnehot(XTensor * index, XTensor * onehot, int size)
int stride = size; int stride = size;
int * indexData = (int *)index->data; int * indexData = (int *)index->data;
int * onehotData = (int *)onehot->data; DTYPE * onehotData = (DTYPE *)onehot->data;
for (int i = 0; i < blockNum; i++) { for (int i = 0; i < blockNum; i++) {
int id = indexData[i]; int id = indexData[i];
int * od = onehotData + i * stride; DTYPE * od = onehotData + i * stride;
od[id] = 1; od[id] = 2;
//onehotData[i * stride + id] = 1;
} }
} }
...@@ -138,9 +142,10 @@ make a new tensor to keep the result and return it ...@@ -138,9 +142,10 @@ make a new tensor to keep the result and return it
>> index - index tensor, which value is an integer num >> index - index tensor, which value is an integer num
>> size - the last dimension size of the onehot tensor >> size - the last dimension size of the onehot tensor
>> confidence - labelsmoothing
<< return - the onehot tensor << return - the onehot tensor
*/ */
XTensor IndexToOnehot(XTensor & index, int size) XTensor IndexToOnehot(XTensor & index, int size, float labelSmoothingP)
{ {
CheckNTErrors(index.dataType == X_INT, "The onehot tensor must be in X_INT!") CheckNTErrors(index.dataType == X_INT, "The onehot tensor must be in X_INT!")
...@@ -151,9 +156,9 @@ XTensor IndexToOnehot(XTensor & index, int size) ...@@ -151,9 +156,9 @@ XTensor IndexToOnehot(XTensor & index, int size)
int * dim = new int[order + 1]; int * dim = new int[order + 1];
memcpy(dim, index.dimSize, order * sizeof(int)); memcpy(dim, index.dimSize, order * sizeof(int));
dim[order] = size; dim[order] = size;
InitTensor(&onehot, index.order + 1, dim, X_INT, 1.0F, index.devID, index.mem); InitTensor(&onehot, index.order + 1, dim, X_FLOAT, 1.0F, index.devID, index.mem);
_IndexToOnehot(&index, &onehot, size); _IndexToOnehot(&index, &onehot, size, labelSmoothingP);
delete[] dim; delete[] dim;
......
...@@ -96,7 +96,7 @@ convert index tensor to onehot tensor (kernel version) ...@@ -96,7 +96,7 @@ convert index tensor to onehot tensor (kernel version)
>> stride - stride of a data block >> stride - stride of a data block
*/ */
__global__ __global__
void KernelIndexToOnehot(int * onehotData, int * indexData, int blockNum, int stride) void KernelIndexToOnehot(DTYPE * onehotData, int * indexData, int blockNum, int stride, float confidence, float lowconfidence)
{ {
/* block id */ /* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.x * blockIdx.x + threadIdx.x;
...@@ -107,10 +107,17 @@ void KernelIndexToOnehot(int * onehotData, int * indexData, int blockNum, int st ...@@ -107,10 +107,17 @@ void KernelIndexToOnehot(int * onehotData, int * indexData, int blockNum, int st
if (i >= blockNum || offset >= stride) if (i >= blockNum || offset >= stride)
return; return;
int * od = onehotData + i * stride; DTYPE * od = onehotData + i * stride;
int id = indexData[i]; int id = indexData[i];
od[id] = 1;
//od[id] = 2.0;
//onehotData[i * stride + id] = 0.1;
if (offset == id)
od[offset] = confidence;
else{
od[offset] = lowconfidence;
}
} }
/* /*
...@@ -120,7 +127,7 @@ convert index tensor to onehot tensor (cuda version) ...@@ -120,7 +127,7 @@ convert index tensor to onehot tensor (cuda version)
>> onehot - onehot tensor, which value is 0 or 1 >> onehot - onehot tensor, which value is 0 or 1
>> size - the last dimension size of the onehot tensor >> size - the last dimension size of the onehot tensor
*/ */
void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size) void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size, float confidence, float lowconfidence)
{ {
int devID = onehot->devID; int devID = onehot->devID;
...@@ -138,10 +145,10 @@ void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size) ...@@ -138,10 +145,10 @@ void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size)
dim3 blocks(cudaGrids[0], cudaGrids[1]); dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]); dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int * onehotData = (int *)onehot->data; DTYPE * onehotData = (DTYPE *)onehot->data;
int * indexData = (int *)index->data; int * indexData = (int *)index->data;
KernelIndexToOnehot<<<blocks, threads >>>(onehotData, indexData, blockNum, stride); KernelIndexToOnehot<<<blocks, threads >>>(onehotData, indexData, blockNum, stride, confidence, lowconfidence);
BacktoCudaDev(devID, devIDBackup); BacktoCudaDev(devID, devIDBackup);
} }
......
...@@ -30,7 +30,7 @@ namespace nts{ // namespace nts(NiuTrans.Tensor) ...@@ -30,7 +30,7 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
void _CudaOnehotToIndex(XTensor * onehot, XTensor * index, int size); void _CudaOnehotToIndex(XTensor * onehot, XTensor * index, int size);
/* convert index tensor to onehot tensor (cuda version) */ /* convert index tensor to onehot tensor (cuda version) */
void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size); void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size, float confidence, float lowconfidence);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -34,11 +34,11 @@ make a new tensor to keep the result and return it */ ...@@ -34,11 +34,11 @@ make a new tensor to keep the result and return it */
XTensor OnehotToIndex(XTensor & onehot, int num); XTensor OnehotToIndex(XTensor & onehot, int num);
/* convert index tensor to onehot tensor */ /* convert index tensor to onehot tensor */
void _IndexToOnehot(XTensor * index, XTensor * onehot, int size); void _IndexToOnehot(XTensor * index, XTensor * onehot, int size, float labelSmoothingP);
/* convert index tensor to onehot tensor (return an XTensor structure) /* convert index tensor to onehot tensor (return an XTensor structure)
make a new tensor to keep the result and return it */ make a new tensor to keep the result and return it */
XTensor IndexToOnehot(XTensor & index, int num); XTensor IndexToOnehot(XTensor & index, int num, float labelSmoothingP);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论