Commit 591d6121 by 姜雨帆

implement MulAndShift and bug fix

parent 0b43acf6
......@@ -99,6 +99,8 @@ void XMathGrad::MakeGrad(XTensor * node, bool isEfficient)
GradReduceSumSquared(node, isEfficient);
else if(operID == REDUCE_REDUCEVARIANCE)
GradReduceVariance(node, isEfficient);
else if (operID == MATH_MULANDSHIFT)
GradMulAndShift(node, isEfficient);
else{
ShowNTErrors("TODO!");
}
......@@ -1487,4 +1489,126 @@ void XMathGrad::GradReduceVariance(XTensor * node, bool isEfficient)
node->visitMark = NODE_FINISHED;
}
/*
gradient for operation
for c = matmul(x, w) + b
we have
dE/dx = dE/dc * w^T
dE/dw = x^T * dE/dc
dE/db = dE/dc * x.reduce(0,...,n-1,n+1,...)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XMathGrad::GradMulAndShift(XTensor * node, bool isEfficient)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 3, "wrong input tensor number")
XTensor * x = income.tails[0];
XTensor * w = income.tails[1];
XTensor * b = income.tails[2];
int n = income.GetParamInt(0);
MATRIX_TRANS_TYPE transW = income.GetParamTrans(1);
MATRIX_TRANS_TYPE transX = income.GetParamTrans(2);
if (!isEfficient || w->isGrad)
XNoder::MakeGrad(w);
if (!isEfficient || x->isGrad)
XNoder::MakeGrad(x);
if (!isEfficient || b->isGrad)
XNoder::MakeGrad(b);
int order = node->order;
int dimSize[MAX_TENSOR_DIM_NUM];
memcpy(dimSize, node->dimSize, sizeof(int) * node->order);
/* compute dE/db */
if (n == order - 1) {
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = node->unitNum / dimSize[order - 1];
reshapedSize[1] = dimSize[order - 1];
/* we reshape dE/dc to a matrix whose column number is equal to the
size of b. Then we can reduce the matrix into a row vector. */
node->grad->Reshape(2, reshapedSize);
XTensor * bGradTMP = NewTensorBuf(b->grad, b->devID, b->mem);
_ReduceSum(node->grad, bGradTMP, 0);
_Sum(bGradTMP, b->grad, b->grad);
DelTensorBuf(bGradTMP);
node->grad->Reshape(order, dimSize);
}
else {
int reshapedSize[MAX_TENSOR_DIM_NUM];
reshapedSize[0] = 1;
reshapedSize[1] = dimSize[n];
reshapedSize[2] = 1;
for (int i = 0; i < order; i++) {
if (i < n)
reshapedSize[0] *= dimSize[i];
}
reshapedSize[2] = node->unitNum / (reshapedSize[0] * reshapedSize[1]);
/* we reshape dE/dc to a 3D tensor of size (x, y, z) where y = |b|.
Then reduce along with z and x to obtain dE/db. */
node->grad->Reshape(3, reshapedSize);
XTensor * interGrad = NewTensorBuf(2, reshapedSize, b->dataType, b->denseRatio, b->devID, b->mem);
_ReduceSum(node->grad, interGrad, 2);
XTensor * bGradTMP = NewTensorBuf(b->grad, b->devID, b->mem);
_ReduceSum(interGrad, bGradTMP, 0);
_Sum(bGradTMP, b->grad, b->grad);
DelTensorBuf(bGradTMP);
node->grad->Reshape(order, dimSize);
DelTensorBuf(interGrad);
}
/* compute dE/dx, dE/dw */
XTensor * c = node;
XTensor * dedc = node->grad;
XTensor * dedw = w->grad;
XTensor * dedx = x->grad;
if (x->order == 2 && w->order == 2)
GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, 1.0F, isEfficient);
else if (transX == X_NOTRANS && x->order > 2 && w->order == 2){
int orderBackupX = x->order;
int orderBackupC = c->order;
int dimsBackupX[MAX_TENSOR_DIM_NUM];
int dimsBackupC[MAX_TENSOR_DIM_NUM];
memcpy(dimsBackupX, x->dimSize, sizeof(int) * x->order);
memcpy(dimsBackupC, c->dimSize, sizeof(int) * c->order);
x->Reshape(x->unitNum / x->GetDim(-1), x->GetDim(-1));
c->Reshape(c->unitNum / c->GetDim(-1), c->GetDim(-1));
if (!isEfficient || x->isGrad)
dedx->Reshape(dedx->unitNum / dedx->GetDim(-1), dedx->GetDim(-1));
dedc->Reshape(dedc->unitNum / dedc->GetDim(-1), dedc->GetDim(-1));
GradMatrixMul(x, dedx, transX, w, dedw, transW, dedc, 1.0F, isEfficient);
x->Reshape(orderBackupX, dimsBackupX);
c->Reshape(orderBackupC, dimsBackupC);
if (!isEfficient || x->isGrad)
dedx->Reshape(orderBackupX, dimsBackupX);
dedc->Reshape(orderBackupC, dimsBackupC);
}
node->visitMark = NODE_FINISHED;
}
}
......@@ -168,6 +168,10 @@ private:
/* gradient for reduceVariance */
static
void GradReduceVariance(XTensor * node, bool isEfficient);
/* gradient for operation */
static
void GradMulAndShift(XTensor * node, bool isEfficient);
};
}
......
......@@ -61,6 +61,7 @@ public:
XTensor wa;
XTensor wbig;
/* size of transformed Q and K */
int dk;
......
......@@ -80,7 +80,6 @@ void AttDecoder::InitModel(int argc, char ** argv,
attentionsEnde = new T2TAttention[nlayer];
attEndeLayerNorms = new T2TLN[nlayer];
/* initialize the stacked layers */
for (int i = 0; i < nlayer; i++) {
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
......@@ -89,9 +88,7 @@ void AttDecoder::InitModel(int argc, char ** argv,
fnnLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
attentionsEnde[i].InitModel(argc, argv, true, myIgnored, myDevID, myMem);
attEndeLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
}
}
/*
......
......@@ -103,8 +103,6 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, boo
x = embedder.Make(input);
//x.Dump(tmpFILE, "embedding: ");
/* dropout */
if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
......@@ -160,4 +158,3 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool isTraining)
}
}
......@@ -89,13 +89,15 @@ XTensor T2TFNN::Make(XTensor &input, bool isTraining)
XTensor t1;
/* t1 = max(0, x * w1 + b1) */
t1 = Rectify(MMul(input, w1) + b1);
//t1 = Rectify(MMul(input, w1) + b1);
t1 = Rectify(MulAndShift(input, w1, b1));
if(isTraining && dropoutP > 0)
t1 = Dropout(t1, dropoutP);
/* result = t1 * w2 + b2 */
return MMul(t1, w2) + b2;
//return MMul(t1, w2) + b2;
return MulAndShift(t1, w2, b2);
}
......
......@@ -219,7 +219,7 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
dims[i + 1] = inputDec.GetDim(i);
dims[0] = nhead;
dims[inputDec.order + 1] = len;
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingEnc.devID, paddingEnc.mem);
InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, 1.0F, paddingDec.devID, paddingDec.mem);
/* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
this matrix can be used to prevent the attention to current or following words in
......@@ -236,10 +236,10 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID, paddingEnc.mem);
_Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
_Unsqueeze(&paddingDec, maskEncDecTMPDec, paddingEnc.order, paddingEnc.GetDim(-1));
_Multiply(maskEncDecTMPDec, maskEncDecTMPEnc, maskEncDecTMPDec);
_ScaleAndShiftMe(maskEncDecTMPDec, 1e9F, -1e9F);
_Unsqueeze(maskEncDecTMPDec, &maskEncDec, 0, dims[0]);
//_Unsqueeze(&paddingDec, maskEncDecTMPDec, paddingEnc.order, paddingEnc.GetDim(-1));
//_Multiply(maskEncDecTMPDec, maskEncDecTMPEnc, maskEncDecTMPDec);
_ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
_Unsqueeze(maskEncDecTMPEnc, &maskEncDec, 0, dims[0]);
DelTensorBuf(maskEncDecTMPDec);
DelTensorBuf(maskEncDecTMPEnc);
......@@ -274,10 +274,9 @@ void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTe
_Sum(&maskEnc, padding3, &maskEnc);
encoding = MakeEncoder(inputEnc, maskEnc, isTraining);
//encoding.Dump(stderr, "encoding",10);
decoding = MakeDecoder(inputDec, encoding, maskDec, maskEncDec, isTraining);
//decoding.Dump(stderr, "decoding", 10);
outputLayer->Make(decoding, output);
delete[] dims;
......
......@@ -122,6 +122,7 @@ void T2TTrainer::Init(int argc, char ** argv)
LoadParamBool(argc, argv, "bigbatch", &isBigBatch, false);
LoadParamBool(argc, argv, "debug", &isDebugged, false);
LoadParamBool(argc, argv, "randbatch", &isRandomBatch, false);
LoadParamInt(argc, argv, "bucketsize", &bucketSize, 0);
buf = new int[bufSize];
buf2 = new int[bufSize];
......@@ -147,8 +148,11 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
{
int step = 0;
int wc = 0;
int ws =0;
int wordCount = 0;
int totalW;
int wordCountTotal = 0;
int wordCountBatch = 0;
bool isEnd = false;
float loss = 0;
float lr = 0;
......@@ -178,9 +182,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
double startT = GetClockSec();
FILE * fileen = fopen("enc.txt", "w");
FILE * filede = fopen("dec.txt", "w");
for(epoch = 1; epoch <= nepoch; epoch++){
#ifndef WIN32
if(isShuffled)
......@@ -197,6 +198,9 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
XTensor batchEnc;
XTensor batchDec;
/* labels */
XTensor label;
/* padding */
XTensor paddingEnc;
XTensor paddingDec;
......@@ -207,17 +211,13 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
/* label smoothed gold standard (if needed) */
XTensor goldSmoothed;
while (LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold,
while (LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &batchDec, &paddingDec, &gold, &label,
NULL, vSize, vSizeTgt,
sBatchSize, wBatchSize, isLenSorted, wc, devID, mem, true))
sBatchSize, wBatchSize, isLenSorted, ws, wc, devID, mem, true))
{
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
//batchEnc.Dump(stderr, "enc",1);
//batchDec.Dump(stderr, "dec",1);
//paddingDec.Dump(stderr, "paddec");
/* output probabilities */
XTensor output;
......@@ -231,35 +231,41 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
}
/* back-propagation for obtaining gradients */
if (labelSmoothingP > 0)
LabelSmooth(&gold, &goldSmoothed, labelSmoothingP);
//if (labelSmoothingP > 0)
// LabelSmooth(&gold, &goldSmoothed, labelSmoothingP);
XTensor labelOnehot;
labelOnehot = IndexToOnehot(label, vSizeTgt, labelSmoothingP);
/* make paddings for the output */
if (output.GetDim(0) > 0)
PadOutput(&output, &gold, &paddingDec);
PadOutput(&output, &labelOnehot, &paddingDec);
/* get probabilities */
float prob = GetProb(&output, &gold, NULL);
//printf("%f\n", prob);
//float prob = 0;
float prob = GetProb(&output, &labelOnehot, NULL);
DTYPE lossLocal = -prob / wc;
bool doUpdate = (!IsNAN(lossLocal) && !IsINF(lossLocal) && lossLocal < 1e3F);
XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold;
//doUpdate = false;
//XTensor &g = labelSmoothingP > 0 ? goldSmoothed : gold;
if (doUpdate) {
/* recale the output for normalized loss */
RescaleOutput(&output, &g, &paddingDec);
RescaleOutput(&output, &labelOnehot, &paddingDec);
/* back-propagation */
net.Backward(output, g, paddingDec, CROSSENTROPY);
net.Backward(output, labelOnehot, paddingDec, CROSSENTROPY);
//net.Backward(output, label, labelSmoothingP, CROSSENTROPY);
gradStep += 1;
loss += -prob;
wordCount += wc;
wordCountTotal += wc;
//totalW = wc + ws;
wordCountBatch += ws;
/* update the parameters */
if(gradStep == updateStep){
......@@ -283,8 +289,8 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
if (step % 100 == 0) {
double elapsed = GetClockSec() - startT;
XPRINT8(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f",
lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount), exp(-prob/wc));
XPRINT8(0, stderr, "[INFO] elapsed=%.1fs, step=%d, epoch=%d, tword=%d, sword=%d, loss=%.3f, ppl=%.3f, sppl=%.3f",
elapsed, step, epoch, wordCountTotal, wordCountBatch, loss/wordCount, exp(loss/wordCount), exp(-prob/wc));
if (!doUpdate)
XPRINT(0, stderr, " (no update)");
XPRINT(0, stderr, "\n");
......@@ -306,9 +312,6 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
}
fclose(fileen);
fclose(filede);
double elapsed = GetClockSec() - startT;
epoch = MIN(epoch, nepoch);
......@@ -330,6 +333,7 @@ test the model
void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
{
int wc = 0;
int ws = 0;
int wordCount = 0;
int wordCountTotal = 0;
int sentCount = 0;
......@@ -354,6 +358,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
XTensor batchEnc;
XTensor batchDec;
/* label */
XTensor label;
/* padding */
XTensor paddingEnc;
XTensor paddingDec;
......@@ -366,9 +373,9 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
ClearBuf();
while(LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold,
while(LoadBatch(file, model->isLM, &batchEnc, &paddingEnc, &paddingDec, &paddingDec, &gold, &label,
seqs, vSize, vSizeTgt,
1, 1, false, wc, devID, mem, false))
1, 1, false, ws, wc, devID, mem, false))
{
CheckNTErrors(batchEnc.order == 2, "wrong tensor order of the sequence batch");
......@@ -470,6 +477,7 @@ struct SampleNode
int * p;
int size;
int value;
int key;
};
int CompareSampleNode(const void * a, const void * b)
......@@ -477,6 +485,11 @@ int CompareSampleNode(const void * a, const void * b)
return ((SampleNode*)b)->value - ((SampleNode*)a)->value;
}
int CompareSampleNodeV2(const void * a, const void * b)
{
return ((SampleNode*)b)->key - ((SampleNode*)a)->key;
}
/*
load data to buffer
>> file - where to load data
......@@ -490,8 +503,7 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
int wordCount = 0;
while(fgets(line, MAX_SEQUENCE_LENGTH - 1, file)){
int len = (int)strlen(line);
if(line[0]=='b')
break;
while(line[len - 1] == '\r' || line[len - 1] == '\n'){
line[len - 1] = 0;
len--;
......@@ -563,19 +575,41 @@ int T2TTrainer::LoadBuf(FILE * file, bool isSorted, int step)
node.p = buf + offset;
node.size = 0;
int max = 0;
for(int j = 0; j < step; j++){
for (int j = 0; j < step; j++) {
node.size += seqLen[i + j];
max = MAX(max, seqLen[i + j]);
}
//node.value = seqLen[i+1]+seqLen[i];
//node.value = MAX(seqLen[i+1],seqLen[i]);
node.value = max;
node.key = rand();
count++;
offset += node.size;
}
qsort(nodes, count, sizeof(SampleNode), CompareSampleNode);
/* distribute samples into buckets. In each bucket, sequences have
similar a length */
if (bucketSize > 0) {
int bucketCount = 0;
int low = 0;
int high = low + bucketSize;
int n = count - 1;
int m = n;
int num = 0;
while (num < count) {
for (m = n; m >= 0; m--) {
if (nodes[m].value > high)
break;
}
qsort(nodes + m + 1, n - m, sizeof(SampleNode), CompareSampleNodeV2);
num += (n - m);
n = m;
low += bucketSize;
high = low + bucketSize;
}
}
count = 0;
offset = 0;
for(int i = 0; i < seqCount; i += step){
......@@ -633,22 +667,22 @@ load a batch of sequences
int T2TTrainer::LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
XTensor * gold, XTensor * label,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining)
{
if(isLM){
return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, gold,
return LoadBatchLM(file, batchEnc, paddingEnc, batchDec, paddingDec, gold, label,
seqs, vsEnc, sBatch, wBatch,
isSorted, wCount, devID, mem, isTraining);
}
else{
return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, gold,
return LoadBatchMT(file, batchEnc, paddingEnc, batchDec, paddingDec, gold, label,
seqs, vsEnc, vsDec, sBatch, wBatch,
isSorted, wCount, devID, mem, isTraining);
isSorted, ws, wCount, devID, mem, isTraining);
}
}
......@@ -674,7 +708,7 @@ load a batch of sequences (for LM)
int T2TTrainer::LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
XTensor * gold, XTensor * label,
int * seqs,
int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
......@@ -716,11 +750,13 @@ int T2TTrainer::LoadBatchLM(FILE * file,
dims[2] = vs;
InitTensor2D(batchEnc, sc, max, X_INT, devID, mem);
InitTensor2D(label, sc, max, X_INT, devID, mem);
InitTensor(gold, 3, dims, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(paddingEnc, sc, max, X_FLOAT, devID, mem);
InitTensor2D(paddingDec, sc, max, X_FLOAT, devID, mem);
batchEnc->SetZeroAll();
label->SetZeroAll();
gold->SetZeroAll();
paddingEnc->SetZeroAll();
paddingDec->SetZeroAll();
......@@ -728,13 +764,15 @@ int T2TTrainer::LoadBatchLM(FILE * file,
int seqSize = 0;
int * batchEncValues = new int[batchEnc->unitNum];
int * labelValues = new int[label->unitNum];
MTYPE * goldOffsets = new MTYPE[gold->unitNum];
//MTYPE * paddingEncOffsets = new MTYPE[paddingEnc->unitNum];
//MTYPE * paddingDecOffsets = new MTYPE[paddingDec->unitNum];
MTYPE * paddingEncOffsets = new MTYPE[paddingEnc->unitNum];
MTYPE * paddingDecOffsets = new MTYPE[paddingDec->unitNum];
int wGold = 0;
memset(batchEncValues, 0, sizeof(int) * batchEnc->unitNum);
memset(labelValues, 0, sizeof(int) * label->unitNum);
for(int s = seq; s < seq + sc; s++){
int len = isDoubledEnd ? seqLen[s] : seqLen[s] - 1;
......@@ -742,16 +780,23 @@ int T2TTrainer::LoadBatchLM(FILE * file,
for(int w = 0; w < len; w++){
int num = buf[seqOffset[s] + w];
batchEncValues[(int)batchEnc->GetOffset2D(s - seq, w)] = num;
//paddingEncOffsets[wCount] = paddingEnc->GetOffset2D(s - seq, w);
//paddingDecOffsets[wCount] = paddingDec->GetOffset2D(s - seq, w);
if (w > 0)
paddingEncOffsets[wCount] = paddingEnc->GetOffset2D(s - seq, w);
paddingDecOffsets[wCount] = paddingDec->GetOffset2D(s - seq, w);
if (w > 0) {
goldOffsets[wGold++] = gold->GetOffset3D(s - seq, w - 1, num);
labelValues[(int)label->GetOffset2D(s - seq, w - 1)] = buf[seqOffset[s] + w];
}
if (w == len - 1) {
if (isDoubledEnd)
if (isDoubledEnd) {
goldOffsets[wGold++] = gold->GetOffset3D(s - seq, w, num);
else
labelValues[(int)label->GetOffset2D(s - seq, w)] = buf[seqOffset[s] + w];
}
else {
goldOffsets[wGold++] = gold->GetOffset3D(s - seq, w, buf[seqOffset[s] + w + 1]);
labelValues[(int)label->GetOffset2D(s - seq, w)] = buf[seqOffset[s] + w + 1];
}
}
wCount++;
......@@ -767,11 +812,12 @@ int T2TTrainer::LoadBatchLM(FILE * file,
}
batchEnc->SetData(batchEncValues, batchEnc->unitNum);
label->SetData(labelValues, label->unitNum);
gold->SetDataBatched(goldOffsets, 1.0F, wGold);
//paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCount);
//paddingDec->SetDataBatched(paddingDecOffsets, 1.0F, wCount);
paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCount);
paddingDec->SetDataBatched(paddingDecOffsets, 1.0F, wCount);
XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
/*XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
_ConvertDataType(batchEnc, tmp);
_NotEqual(tmp, paddingEnc, 0);
DelTensorBuf(tmp);
......@@ -779,12 +825,13 @@ int T2TTrainer::LoadBatchLM(FILE * file,
XTensor * tmp2 = NewTensorBuf(paddingDec, devID, mem);
_ConvertDataType(batchEnc, tmp2);
_NotEqual(tmp2, paddingDec, 0);
DelTensorBuf(tmp2);
DelTensorBuf(tmp2);*/
delete[] batchEncValues;
delete[] labelValues;
delete[] goldOffsets;
//delete[] paddingEncOffsets;
//delete[] paddingDecOffsets;
delete[] paddingEncOffsets;
delete[] paddingDecOffsets;
fflush(tf);
......@@ -819,15 +866,13 @@ load a batch of sequences (for MT)
int T2TTrainer::LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
XTensor * gold, XTensor * label,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining)
{
//if (nextSeq < 0 || nextSeq >= nseqBuf)
// LoadBuf(file, isSorted, 2);
if (nextBatch < 0 || nextBatch >= bufBatchSize) {
LoadBuf(file, isSorted, 2);
......@@ -855,22 +900,22 @@ int T2TTrainer::LoadBatchMT(FILE * file,
/* target-side sequence */
wnDec = isDoubledEnd ? seqLen[seq + sc + 1] : seqLen[seq + sc + 1] - 1;
int tcEnc = isBigBatch ? (wcEnc + wnEnc): MAX(maxEnc, wnEnc) * (sc + 2) / 2;
int tcDec = isBigBatch ? (wcDec + wnDec): MAX(maxDec, wnDec) * (sc + 2) / 2;
int tcEnc = isBigBatch ? (wcEnc + wnEnc) : MAX(maxEnc, wnEnc) * (sc + 2) / 2;
int tcDec = isBigBatch ? (wcDec + wnDec) : MAX(maxDec, wnDec) * (sc + 2) / 2;
if(sc != 0 && sc > sBatch * 2 && (tcEnc > wBatch || tcDec > wBatch))
if (sc != 0 && sc > sBatch * 2 && (tcEnc > wBatch || tcDec > wBatch))
break;
wcEnc += wnEnc;
sc += 1;
if(maxEnc < wnEnc)
if (maxEnc < wnEnc)
maxEnc = wnEnc;
wcDec += wnDec;
sc += 1;
if(maxDec < wnDec)
if (maxDec < wnDec)
maxDec = wnDec;
}
......@@ -889,46 +934,6 @@ int T2TTrainer::LoadBatchMT(FILE * file,
qsort(bufBatch, bufBatchSize, sizeof(BatchNode), CompareBatchNode);
}
/*int seq = MAX(nextSeq, 0);
int wcEnc = 0;
int wcDec = 0;
int wnEnc = 0;
int wnDec = 0;
int maxEnc = 0;
int maxDec = 0;
int sc = 0;
CheckNTErrors((nseqBuf - seq) % 2 == 0, "Input sequence must be paired!");
while(seq + sc < nseqBuf){
wnEnc = seqLen[seq + sc];
wnDec = isDoubledEnd ? seqLen[seq + sc + 1] : seqLen[seq + sc + 1] - 1;
int tcEnc = isBigBatch ? (wcEnc + wnEnc): MAX(maxEnc, wnEnc) * (sc + 2) / 2;
int tcDec = isBigBatch ? (wcDec + wnDec): MAX(maxDec, wnDec) * (sc + 2) / 2;
if(sc != 0 && sc > sBatch * 2 && (tcEnc > wBatch || tcDec > wBatch))
break;
wcEnc += wnEnc;
sc += 1;
if(maxEnc < wnEnc)
maxEnc = wnEnc;
wcDec += wnDec;
sc += 1;
if(maxDec < wnDec)
maxDec = wnDec;
}
nextSeq = seq + sc;
if(sc <= 0)
return 0;*/
if(bufBatchSize <= 0)
return 0;
......@@ -948,13 +953,15 @@ int T2TTrainer::LoadBatchMT(FILE * file,
InitTensor2D(paddingEnc, sCount, maxEnc, X_FLOAT, devID, mem);
InitTensor2D(batchDec, sCount, maxDec, X_INT, devID, mem);
InitTensor2D(paddingDec, sCount, maxDec, X_FLOAT, devID, mem);
InitTensor(gold, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
InitTensor2D(label, sCount, maxDec, X_INT, devID, mem);
//InitTensor(gold, 3, dimsDec, X_FLOAT, 1.0F, devID, mem);
batchEnc->SetZeroAll();
paddingEnc->SetZeroAll();
batchDec->SetZeroAll();
paddingDec->SetZeroAll();
gold->SetZeroAll();
label->SetZeroAll();
//gold->SetZeroAll();
int wCountEnc = 0;
int wCountDec = 0;
......@@ -964,12 +971,14 @@ int T2TTrainer::LoadBatchMT(FILE * file,
int * batchEncValues = new int[batchEnc->unitNum];
int * batchDecValues = new int[batchDec->unitNum];
int * labelValues = new int[label->unitNum];
//MTYPE * paddingEncOffsets = new MTYPE[sc * maxEnc / 2];
MTYPE * paddingDecOffsets = new MTYPE[sc * maxDec / 2];
MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2];
//MTYPE * goldOffsets = new MTYPE[sc * maxDec / 2];
memset(batchEncValues, 0, sizeof(int) * batchEnc->unitNum);
memset(batchDecValues, 0, sizeof(int) * batchDec->unitNum);
memset(labelValues, 0, sizeof(int) * batchDec->unitNum);
/* batch of the source-side sequences */
for(int s = seq; s < seq + sc; s += 2){
......@@ -982,7 +991,7 @@ int T2TTrainer::LoadBatchMT(FILE * file,
wCountEnc++;
}
}
ws = wCountEnc;
batchEnc->SetData(batchEncValues, batchEnc->unitNum);
//paddingEnc->SetDataBatched(paddingEncOffsets, 1.0F, wCountEnc);
XTensor * tmp = NewTensorBuf(paddingEnc, devID, mem);
......@@ -1003,14 +1012,19 @@ int T2TTrainer::LoadBatchMT(FILE * file,
paddingDecOffsets[wCountPad++] = paddingDec->GetOffset2D(sent, w);
wCount++;
}
if (w > 0)
goldOffsets[wGold++] = gold->GetOffset3D(sent, w - 1, buf[seqOffset[s] + w]);
if (w > 0) {
//goldOffsets[wGold++] = gold->GetOffset3D(sent, w - 1, buf[seqOffset[s] + w]);
labelValues[label->GetOffset2D(sent, w - 1)] = buf[seqOffset[s] + w];
}
if (w == len - 1) {
if (isDoubledEnd)
goldOffsets[wGold++] = gold->GetOffset3D(sent, w, buf[seqOffset[s] + w]);
else
goldOffsets[wGold++] = gold->GetOffset3D(sent, w, buf[seqOffset[s] + w + 1]);
if (isDoubledEnd) {
//goldOffsets[wGold++] = gold->GetOffset3D(sent, w, buf[seqOffset[s] + w]);
labelValues[label->GetOffset2D(sent, w)] = buf[seqOffset[s] + w];
}
else {
//goldOffsets[wGold++] = gold->GetOffset3D(sent, w, buf[seqOffset[s] + w + 1]);
labelValues[label->GetOffset2D(sent, w)] = buf[seqOffset[s] + w + 1];
}
}
//wCount++;
wCountDec++;
......@@ -1025,6 +1039,7 @@ int T2TTrainer::LoadBatchMT(FILE * file,
}
batchDec->SetData(batchDecValues, batchDec->unitNum);
label->SetData(labelValues, label->unitNum);
paddingDec->SetDataBatched(paddingDecOffsets, 1.0F, wCountPad);
//XTensor * tmp2 = NewTensorBuf(paddingDec, devID, mem);
......@@ -1032,13 +1047,14 @@ int T2TTrainer::LoadBatchMT(FILE * file,
//_NotEqual(tmp2, paddingDec, 0);
//DelTensorBuf(tmp2);
gold->SetDataBatched(goldOffsets, 1.0F, wGold);
//gold->SetDataBatched(goldOffsets, 1.0F, wGold);
delete[] batchEncValues;
delete[] batchDecValues;
delete[] labelValues;
//delete[] paddingEncOffsets;
delete[] paddingDecOffsets;
delete[] goldOffsets;
//delete[] goldOffsets;
return sc;
}
......@@ -1071,12 +1087,6 @@ float T2TTrainer::GetProb(XTensor * output, XTensor * gold, XTensor * wordProbs)
XTensor probs;
InitTensor(&probs, output);
//XTensor logOutput;
//InitTensor(&logOutput, output);
//_Log(output, &logOutput);
/* probs[i,j] = output[i,j] * gold[i,j] */
//_Multiply(&logOutput, gold, &probs);
_Multiply(output, gold, &probs);
/* probability of each word */
......
......@@ -176,6 +176,9 @@ public:
/* indicates whether we intend to debug the net */
bool isDebugged;
/* bucket size */
int bucketSize;
public:
/* constructor */
T2TTrainer();
......@@ -205,10 +208,10 @@ public:
int LoadBatch(FILE * file, bool isLM,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
XTensor * gold, XTensor * label,
int * seqs,
int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining);
......@@ -216,7 +219,7 @@ public:
int LoadBatchLM(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
XTensor * gold, XTensor * label,
int * seqs, int vs, int sBatch, int wBatch,
bool isSorted, int &wCount,
int devID, XMem * mem,
......@@ -226,9 +229,9 @@ public:
int LoadBatchMT(FILE * file,
XTensor * batchEnc, XTensor * paddingEnc,
XTensor * batchDec, XTensor * paddingDec,
XTensor * gold,
XTensor * gold, XTensor * label,
int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch,
bool isSorted, int &wCount,
bool isSorted, int &ws, int &wCount,
int devID, XMem * mem,
bool isTraining);
......
......@@ -37,8 +37,6 @@ int TransformerMain(int argc, const char ** argv)
if(argc == 0)
return 1;
fprintf(stderr, "%e\n", log(1e-8F));
char ** args = new char*[argc];
for(int i = 0; i < argc; i++){
args[i] = new char[strlen(argv[i]) + 1];
......@@ -67,9 +65,6 @@ int TransformerMain(int argc, const char ** argv)
T2TModel model;
model.InitModel(argc, args);
//if(strcmp(modelFN, ""))
//model.Read(modelFN);
/* learn model parameters */
if(strcmp(trainFN, ""))
trainer.Train(trainFN, testFN, strcmp(modelFN, "") ? modelFN : "checkpoint.model", &model);
......
......@@ -308,6 +308,27 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id
}
/*
create a hyperedge with two input tensors and a output tensor
>> t1 - a tail tensor
>> t2 - the second tail tensor
>> t3 - the third tail tensor
>> h - head tensor
>> id - id of the edge type
*/
void XLink::MakeLink(const XTensor * t1, const XTensor * t2, const XTensor * t3,XTensor * h, int id)
{
if (h == NULL)
return;
XList list(3);
list.Add(t1);
list.Add(t2);
list.Add(t3);
MakeLink(&list, h, id);
}
/*
create a hyper edge with a list of tensors and a output tensor
>> list - a list of input tensors
>> h - head tensor
......
......@@ -138,6 +138,10 @@ struct XLink
static
void MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id);
/* create a hyper edge with three input tensors and a output tensor */
static
void MakeLink(const XTensor * t1, const XTensor * t2, const XTensor * t3, XTensor * h, int id);
/* create a hyper edge with a list of input tensors and a output tensor */
static
void MakeLink(const XList * list, XTensor * h, int id);
......
......@@ -77,6 +77,8 @@ const char * GetOPName(int type)
return "M_POWER";
else if (type == MATH_SCALEANDSHIFT)
return "M_SCALEANDSHIFT";
else if (type == MATH_MULANDSHIFT)
return "M_OPERATION";
else if (type == MATH_SIGN)
return "M_SIGN";
else if (type == MATH_SUB)
......
......@@ -57,7 +57,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_NORMALIZE MATH_NEGATE + 1
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
#define MATH_SIGN MATH_SCALEANDSHIFT + 1
#define MATH_MULANDSHIFT MATH_SCALEANDSHIFT + 1
#define MATH_SIGN MATH_MULANDSHIFT + 1
#define MATH_SUB MATH_SIGN + 1
#define MATH_SUBDIM MATH_SUB + 1
#define MATH_SUM MATH_SUBDIM + 1
......
......@@ -1614,17 +1614,11 @@ void XTensor::Dump(FILE * file, const char * label, const int n, const int beg,
else if (dataType == X_INT) {
int end = MIN(n > 0 ? beg + n : beg + unitNum, unitNum);
for(int i = beg; i < end; i++){
if((i%(dimSize[1]) == 0)&&(i!=0)) {
fprintf(file, " \n");
}
int f = ((int*)d)[i];
if(i == beg)
fprintf(file, "%d", f);
else
fprintf(file, " %d", f);
//if((i%(dimSize[1]-1) == 0)&&(i!=0)) {
//fprintf(file, " \n");
//}
}
}
else
......
......@@ -44,6 +44,7 @@
#include "arithmetic/SumByColumnVT.h"
#include "arithmetic/SumDim.h"
#include "arithmetic/XTensorBLAS.h"
#include "arithmetic/MulAndShift.h"
#include "getandset/ConvertDataType.h"
#include "getandset/OnehotAndIndex.h"
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2019-02-27
*/
#include "../../XTensor.h"
#include "../../XDevice.h"
#include "../../XName.h"
#include "MulAndShift.h"
#include "MatrixMul.h"
#include "Sum.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
return a dimension if the sum is performed as SumDim (in more details in SumDim.h)
>> a - a tensor
>> b - another tensor for sum
*/
int GetSumIndex(const XTensor &a, const XTensor &b)
{
if (a.order < b.order)
return -1;
if (XTensor::IsSameShaped(&a, &b))
return -1;
int hitCount = 0;
int hitDim = -1;
for (int i = 0; i < b.order; i++) {
if (b.dimSize[b.order - 1 - i] == 1)
continue;
else if (b.dimSize[b.order - 1 - i] == a.dimSize[a.order - 1 - i]) {
hitCount++;
hitDim = a.order - b.order + i;
}
}
if (hitCount == 1)
return hitDim;
else
return -1;
}
/*
operation c = x * w + b MulAndShift
>> x - tensor x
>> w - tensor w
>> b - tensor b
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
*/
XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DTYPE alpha, XPRunner * parallelRunner)
{
CheckNTErrors(x.dataType == w.dataType, "Input tensors should have the same data type!");
CheckNTErrors(x.order >= 2 && w.order >= 2, "Input tensors must have a order >= 2!");
int xn = x.dimSizeRDI[1];
int xm = x.dimSizeRDI[0];
int wn = w.dimSizeRDI[1];
int wm = w.dimSizeRDI[0];
CheckNTErrors(xm == wn, "Unmatched tensors in multiplication!");
int order = x.order + w.order - 2;
int sub = 0;
int * dimSize = new int[order];
for (int i = 2; i < x.order; i++)
dimSize[sub++] = x.dimSizeRDI[x.order + 1 - i];
for (int i = 2; i < w.order; i++)
dimSize[sub++] = w.dimSizeRDI[w.order + 1 - i];
dimSize[sub++] = xn;
dimSize[sub++] = wm;
float dr = (!x.isSparse || !w.isSparse) ? 1.0F : MAX(x.denseRatio, w.denseRatio);
XTensor * tmp = NewTensorBuf(order, dimSize, x.dataType, dr, x.devID, x.mem);
/* call _MatrixMul function */
_MatrixMul(&x, X_NOTRANS, &w, X_NOTRANS, tmp, alpha, 0, parallelRunner);
XTensor c(tmp);
c.SetTMPFlag();
int n = GetSumIndex(tmp, b);
if (n == -1) {
/* call _Sum function */
_Sum(tmp, &b, &c);
// TODO!!
ShowNTErrors("TODO!");
}
else if (n >= 0 && n < tmp->order) {
/* call _SumDim function */
_SumDim(tmp, &b, &c, n);
}
else {
ShowNTErrors("Something is wrong!");
}
/* tensor connections */
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
//XLink::AddParamToHead(&c, beta);
/* destroy variables */
delete[] dimSize;
DelTensorBuf(tmp);
return c;
}
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2019-02-27
*/
#ifndef __MULANDSHIFT_H__
#define __MULANDSHIFT_H__
#include "../../XTensor.h"
#include "../CHeader.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor)
#endif // __OPERATION_H__
......@@ -99,11 +99,11 @@ convert index tensor to onehot tensor
>> onehot - onehot tensor, which value is 0 or 1
>> size - the last dimension size of the onehot tensor
*/
void _IndexToOnehot(XTensor * index, XTensor * onehot, int size)
void _IndexToOnehot(XTensor * index, XTensor * onehot, int size, float labelSmoothingP)
{
CheckNTErrors(onehot->GetDim(-1) == size, "Illegal tensor dimension!");
CheckNTErrors(onehot->order == index->order + 1, "Illegal tensor order!");
CheckNTErrors(onehot->dataType == X_INT, "The onehot tensor must be in X_INT!")
//CheckNTErrors(onehot->dataType == X_INT, "The onehot tensor must be in X_INT!")
CheckNTErrors(index->dataType == X_INT, "The index tensor must be in X_INT!")
for (int i = 0; i < index->order; i++)
......@@ -111,9 +111,12 @@ void _IndexToOnehot(XTensor * index, XTensor * onehot, int size)
onehot->SetZeroAll();
float confidence = 1 - labelSmoothingP;
float lowconfidence = labelSmoothingP / size;
#ifdef USE_CUDA
if(onehot->devID >= 0 && index->devID >= 0) {
_CudaIndexToOnehot(index, onehot, size);
_CudaIndexToOnehot(index, onehot, size, confidence, lowconfidence);
return;
}
#endif
......@@ -122,12 +125,13 @@ void _IndexToOnehot(XTensor * index, XTensor * onehot, int size)
int stride = size;
int * indexData = (int *)index->data;
int * onehotData = (int *)onehot->data;
DTYPE * onehotData = (DTYPE *)onehot->data;
for (int i = 0; i < blockNum; i++) {
int id = indexData[i];
int * od = onehotData + i * stride;
od[id] = 1;
DTYPE * od = onehotData + i * stride;
od[id] = 2;
//onehotData[i * stride + id] = 1;
}
}
......@@ -138,9 +142,10 @@ make a new tensor to keep the result and return it
>> index - index tensor, which value is an integer num
>> size - the last dimension size of the onehot tensor
>> confidence - labelsmoothing
<< return - the onehot tensor
*/
XTensor IndexToOnehot(XTensor & index, int size)
XTensor IndexToOnehot(XTensor & index, int size, float labelSmoothingP)
{
CheckNTErrors(index.dataType == X_INT, "The onehot tensor must be in X_INT!")
......@@ -151,9 +156,9 @@ XTensor IndexToOnehot(XTensor & index, int size)
int * dim = new int[order + 1];
memcpy(dim, index.dimSize, order * sizeof(int));
dim[order] = size;
InitTensor(&onehot, index.order + 1, dim, X_INT, 1.0F, index.devID, index.mem);
InitTensor(&onehot, index.order + 1, dim, X_FLOAT, 1.0F, index.devID, index.mem);
_IndexToOnehot(&index, &onehot, size);
_IndexToOnehot(&index, &onehot, size, labelSmoothingP);
delete[] dim;
......
......@@ -96,7 +96,7 @@ convert index tensor to onehot tensor (kernel version)
>> stride - stride of a data block
*/
__global__
void KernelIndexToOnehot(int * onehotData, int * indexData, int blockNum, int stride)
void KernelIndexToOnehot(DTYPE * onehotData, int * indexData, int blockNum, int stride, float confidence, float lowconfidence)
{
/* block id */
int i = blockDim.x * blockIdx.x + threadIdx.x;
......@@ -107,10 +107,17 @@ void KernelIndexToOnehot(int * onehotData, int * indexData, int blockNum, int st
if (i >= blockNum || offset >= stride)
return;
int * od = onehotData + i * stride;
DTYPE * od = onehotData + i * stride;
int id = indexData[i];
od[id] = 1;
//od[id] = 2.0;
//onehotData[i * stride + id] = 0.1;
if (offset == id)
od[offset] = confidence;
else{
od[offset] = lowconfidence;
}
}
/*
......@@ -120,7 +127,7 @@ convert index tensor to onehot tensor (cuda version)
>> onehot - onehot tensor, which value is 0 or 1
>> size - the last dimension size of the onehot tensor
*/
void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size)
void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size, float confidence, float lowconfidence)
{
int devID = onehot->devID;
......@@ -138,10 +145,10 @@ void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size)
dim3 blocks(cudaGrids[0], cudaGrids[1]);
dim3 threads(cudaBlocks[0], cudaBlocks[1]);
int * onehotData = (int *)onehot->data;
DTYPE * onehotData = (DTYPE *)onehot->data;
int * indexData = (int *)index->data;
KernelIndexToOnehot<<<blocks, threads >>>(onehotData, indexData, blockNum, stride);
KernelIndexToOnehot<<<blocks, threads >>>(onehotData, indexData, blockNum, stride, confidence, lowconfidence);
BacktoCudaDev(devID, devIDBackup);
}
......
......@@ -30,7 +30,7 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
void _CudaOnehotToIndex(XTensor * onehot, XTensor * index, int size);
/* convert index tensor to onehot tensor (cuda version) */
void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size);
void _CudaIndexToOnehot(XTensor * index, XTensor * onehot, int size, float confidence, float lowconfidence);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -34,11 +34,11 @@ make a new tensor to keep the result and return it */
XTensor OnehotToIndex(XTensor & onehot, int num);
/* convert index tensor to onehot tensor */
void _IndexToOnehot(XTensor * index, XTensor * onehot, int size);
void _IndexToOnehot(XTensor * index, XTensor * onehot, int size, float labelSmoothingP);
/* convert index tensor to onehot tensor (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor IndexToOnehot(XTensor & index, int num);
XTensor IndexToOnehot(XTensor & index, int num, float labelSmoothingP);
} // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论