Commit bfa6fc90 by huchi

replace cache type with XTensor

parent e925cfd9
......@@ -41,6 +41,14 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(2708);
TransformerMain(argc - 1, argv + 1);
/*XTensor x;
InitTensor2D(&x, 2, 2);
float d[]{ 1,2,3,4 };
x.SetData(d, 4);
XTensor y;
y = ReduceSum(x, 0);
y.Dump(stderr);*/
//_CrtDumpMemoryLeaks();
......
......@@ -28,8 +28,6 @@
namespace transformer
{
enum { NONE, SELF, CONTEXT };
/* constructor */
T2TAttention::T2TAttention()
{
......@@ -86,88 +84,55 @@ void T2TAttention::InitModel(int argc, char** argv,
/*
make the network
>> k - keys. It might be of size B * L * H
where B = batch size, L = sequence length,
and H = vector size of each position
where B = batch size, L = sequence length,
and H = vector size of each position
>> q - queries
>> v - values
>> mask - as it is
>> isTraining - indicates whether the model is used for training
>> cache - the cache list
>> cacheType - type of the cache
>> cache - layer cache list
>> cacheType - which type that cache is
<< return - multi-attention result
*/
XTensor T2TAttention::Make(XTensor& k, XTensor& q, XTensor& v, XTensor *mask,
bool isTraining, Cache * cache, int cacheType)
XTensor T2TAttention::Make( XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, Cache* cache, int cacheType)
{
bool is_encoder = (!cache) ? true : false;
int k2Dim[]{ k.GetDim(0),k.GetDim(1), wk.GetDim(1) };
int v2Dim[]{ v.GetDim(0),v.GetDim(1), wv.GetDim(1) };
XTensor* q2 = NewTensor3DV2(q.GetDim(0),q.GetDim(1), wq.GetDim(1), X_FLOAT, q.devID);
XTensor* k2 = NULL;
XTensor* v2 = NULL;
XTensor* kNewCache = NULL;
XTensor* vNewCache = NULL;
const bool isEnc = (!cache) ? true : false;
/* linear transformation before self-attention */
/* notice that all weights are transposed!!! */
_MatrixMul(&q, X_NOTRANS, &wq, X_TRANS, q2);
_SumDim(q2, &bq, 2);
XTensor q2, k2, v2;
q2 = MatrixMul(q, X_NOTRANS, wq, X_TRANS) + bq;
if (!cache) {
k2 = NewTensor3DV2(k.GetDim(0),k.GetDim(1), wk.GetDim(1), X_FLOAT, k.devID);
v2 = NewTensor3DV2(v.GetDim(0),v.GetDim(1), wv.GetDim(1), X_FLOAT, v.devID);
_MatrixMul(&k, X_NOTRANS, &wk, X_TRANS, k2);
_SumDim(k2, &bk, 2);
_MatrixMul(&v, X_NOTRANS, &wv, X_TRANS, v2);
_SumDim(v2, &bv, 2);
/* self attention for encoder layers */
k2 = MatrixMul(k, X_NOTRANS, wk, X_TRANS) + bk;
v2 = MatrixMul(v, X_NOTRANS, wv, X_TRANS) + bv;
return MakeRPRAttention(k2, q2, v2, mask, isTraining, isEnc);
}
else {
if (cacheType == SELF) {
k2 = NewTensor3DV2(q.GetDim(0), q.GetDim(1), wk.GetDim(1), X_FLOAT, q.devID);
v2 = NewTensor3DV2(q.GetDim(0), q.GetDim(1), wv.GetDim(1), X_FLOAT, q.devID);
_MatrixMul(&q, X_NOTRANS, &wk, X_TRANS, k2);
_SumDim(k2, &bk, 2);
_MatrixMul(&q, X_NOTRANS, &wv, X_TRANS, v2);
_SumDim(v2, &bv, 2);
if (!cache->IsEmpty()) {
XTensor* kOldCache = cache->GetK();
XTensor* vOldCache = cache->GetV();
kNewCache = NewTensor3DV2(kOldCache->GetDim(0), kOldCache->GetDim(1) + k2->GetDim(1), kOldCache->GetDim(2), X_FLOAT, k2->devID);
vNewCache = NewTensor3DV2(vOldCache->GetDim(0), vOldCache->GetDim(1) + v2->GetDim(1), vOldCache->GetDim(2), X_FLOAT, v2->devID);
_Concatenate(kOldCache, k2, kNewCache, 1);
_Concatenate(vOldCache, v2, vNewCache, 1);
DelTensor(k2);
DelTensor(v2);
k2 = kNewCache;
v2 = vNewCache;
if (cacheType == SELF_ATT) {
k2 = MatrixMul(k, X_NOTRANS, wk, X_TRANS) + bk;
v2 = MatrixMul(v, X_NOTRANS, wv, X_TRANS) + bv;
/* if hit, we only concat the cache with the new token */
if (!cache->miss) {
k2 = Concatenate(cache->key, k2, 1);
v2 = Concatenate(cache->value, v2, 1);
}
cache->Update(k2, v2);
cache->key = k2;
cache->value = v2;
cache->miss = false;
return MakeRPRAttention(cache->key, q2, cache->value, mask, isTraining, isEnc);
}
else if (cacheType == CONTEXT) {
if (cache->IsEmpty()) {
k2 = NewTensor3DV2(k.GetDim(0), k.GetDim(1), wk.GetDim(1), X_FLOAT, k.devID);
v2 = NewTensor3DV2(v.GetDim(0), v.GetDim(1), wv.GetDim(1), X_FLOAT, v.devID);
_MatrixMul(&k, X_NOTRANS, &wk, X_TRANS, k2);
_SumDim(k2, &bk, 2);
_MatrixMul(&v, X_NOTRANS, &wv, X_TRANS, v2);
_SumDim(v2, &bv, 2);
cache->Update(k2, v2);
}
else {
k2 = cache->GetK();
v2 = cache->GetV();
else if (cacheType == EN_DE_ATT) {
if (cache->miss) {
cache->key = MatrixMul(k, X_NOTRANS, wk, X_TRANS) + bk;
cache->value = MatrixMul(v, X_NOTRANS, wv, X_TRANS) + bv;
cache->miss = false;
}
return MakeAttention(cache->key, q2, cache->value, mask, isTraining, isEnc);
}
else {
CheckNTErrors(0, "invalid cache type");
}
CheckNTErrors(0, "invalid cache type");
}
if(cacheType == CONTEXT)
return MakeAttention(k2, q2, v2, mask, isTraining, is_encoder);
return MakeRPRAttention(k2, q2, v2, mask, isTraining, is_encoder);
}
/*
......@@ -180,16 +145,16 @@ make the attention network given keys, queries and values (after linear transfor
>> mask - as it is
>> isTraining - indicates whether the model is used for training
*/
XTensor T2TAttention::MakeAttention(XTensor *k, XTensor *q, XTensor *v, const XTensor *mask, bool isTraining, bool is_encoder)
XTensor T2TAttention::MakeAttention(XTensor &k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, bool is_encoder)
{
XTensor kheads;
XTensor qheads;
XTensor vheads;
/* multi head */
kheads = Split(*k, k->order - 1, nhead);
qheads = Split(*q, q->order - 1, nhead);
vheads = Split(*v, v->order - 1, nhead);
kheads = Split(k, k.order - 1, nhead);
qheads = Split(q, q.order - 1, nhead);
vheads = Split(v, v.order - 1, nhead);
XTensor att;
XTensor dot;
......@@ -198,16 +163,16 @@ XTensor T2TAttention::MakeAttention(XTensor *k, XTensor *q, XTensor *v, const XT
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */
dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
if (isMasked && mask) {
/*if (isMasked && mask) {
_SumMe(&dot, mask);
}
}*/
dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead));
scalar = Softmax(dot, -1);
if(isTraining && dropoutP > 0)
scalar = Dropout(scalar, dropoutP);
/*if(isTraining && dropoutP > 0)
scalar = Dropout(scalar, dropoutP);*/
att = BMMul(scalar, vheads);
......@@ -225,32 +190,32 @@ make the attention network by incorporating the relative position representation
>> mask - as it is
>> isTraining - indicates whether the model is used for training
*/
XTensor T2TAttention::MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTensor *mask, bool isTraining, bool is_encoder)
XTensor T2TAttention::MakeRPRAttention(XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, bool is_encoder)
{
XTensor kheads;
XTensor qheads;
XTensor vheads;
const int batch_size = q->GetDim(0);
const int len_q = q->GetDim(1);
const int len_kv = k->GetDim(1);
const int batch_size = q.GetDim(0);
const int len_q = q.GetDim(1);
const int len_kv = k.GetDim(1);
/* multi head */
kheads = Split(*k, k->order - 1, nhead);
qheads = Split(*q, q->order - 1, nhead);
vheads = Split(*v, v->order - 1, nhead);
kheads = Split(k, k.order - 1, nhead);
qheads = Split(q, q.order - 1, nhead);
vheads = Split(v, v.order - 1, nhead);
XTensor att;
XTensor dot;
XTensor scalar;
XTensor emb_matrix, relative_key;
InitTensor2DV2(&emb_matrix, len_q, len_kv, X_INT, q->devID);
InitTensor3DV2(&relative_key, len_q, len_kv, kheads.GetDim(-1), X_FLOAT, q->devID);
InitTensor4DV2(&dot, nhead, batch_size, len_q, len_kv, X_FLOAT, q->devID);
InitTensor2DV2(&emb_matrix, len_q, len_kv, X_INT, q.devID);
InitTensor3DV2(&relative_key, len_q, len_kv, kheads.GetDim(-1), X_FLOAT, q.devID);
InitTensor4DV2(&dot, nhead, batch_size, len_q, len_kv, X_FLOAT, q.devID);
/* generate the relative emb index (L_q, L_kv) */
GetRPEmbedding(&emb_matrix, len_q, len_kv, max_relative_position, q->devID,is_encoder);
GetRPEmbedding(&emb_matrix, len_q, len_kv, max_relative_position, q.devID,is_encoder);
/* generate the relative key from the rp_embedding_k (L_q, L_kv, H/K) */
......@@ -261,8 +226,8 @@ XTensor T2TAttention::MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTens
RPDotProduct(&qheads, &kheads, &relative_key, &dot, true);
if (isMasked && mask)
_SumMe(&dot, mask);
/*if (isMasked && mask)
_SumMe(&dot, mask);*/
/* scale the dot result */
//dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead));
......@@ -270,18 +235,12 @@ XTensor T2TAttention::MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTens
/* softmax */
scalar = Softmax(dot, -1);
if (isTraining && dropoutP > 0)
scalar = Dropout(scalar, dropoutP);
/*if (isTraining && dropoutP > 0)
scalar = Dropout(scalar, dropoutP);*/
/* generate the relative attention output (K, B, L_q, H/K) */
att = BMMul(scalar, vheads);
if (is_encoder) {
DelTensor(k);
DelTensor(q);
DelTensor(v);
}
/* concatenate the heads */
return MulAndShift(Merge(att, att.order - 1), X_NOTRANS, wa, X_TRANS, ba);
}
......
......@@ -28,46 +28,31 @@ using namespace nts;
namespace transformer
{
/* attention type */
enum { NONE, SELF_ATT, EN_DE_ATT };
/* layer cache for key and value */
class Cache {
/* layer cache for keys and values */
class Cache
{
public:
/* cache for keys */
XTensor key;
/* cache for key */
XTensor* k{ NULL };
/* cache for value */
XTensor* v{ NULL };
/* cache for values */
XTensor value;
public:
bool IsEmpty() {
return (k == NULL) && (v == NULL);
}
void Clear() {
if (k && v && k->id > 0 && v->id > 0) {
DelTensor(k);
DelTensor(v);
}
k = NULL;
v = NULL;
}
void Update(XTensor* newK, XTensor* newV) {
if (!newK || (k == newK) || !newV || (v == newV))
return;
Clear();
k = newK;
v = newV;
}
bool miss;
XTensor* GetK() {
return k;
Cache() {
miss = true;
}
XTensor* GetV() {
return v;
void Update(XTensor&& k, XTensor&& v) {
key = k;
value = v;
miss = false;
}
};
......@@ -153,14 +138,14 @@ public:
int myDevID = -1);
/* make the network */
XTensor Make(XTensor& k, XTensor& q, XTensor& v, XTensor* mask,
bool isTraining, Cache* cache, int cacheType);
XTensor Make( XTensor& k, XTensor& q, XTensor& v,
XTensor* mask, bool isTraining, Cache* cache, int cacheType);
/* make the attention network given keys, queries and values (after linear transformation) */
XTensor MakeAttention(XTensor* k, XTensor* q, XTensor* v, const XTensor* mask, bool isTraining, bool is_encoder);
XTensor MakeAttention(XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, bool is_encoder);
/* make the attention network given keys, queries and values (after linear transformation) */
XTensor MakeRPRAttention(XTensor* k, XTensor* q, XTensor* v, XTensor* mask, bool isTraining, bool is_encoder);
XTensor MakeRPRAttention(XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, bool is_encoder);
void GetRPEmbedding(XTensor* emb_matrix, const int len_q, const int len_kv, const int max_relative_length, const int device_id, const bool is_encoder);
......
......@@ -136,7 +136,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X
/******************/
/* self attention */
att = attentions[i].Make(inputNorm, inputNorm, inputNorm, NULL, isTraining, &selfCache[i], 1);
att = attentions[i].Make(inputNorm, inputNorm, inputNorm, NULL, isTraining, &selfCache[i], SELF_ATT);
/* dropout */
if(isTraining && dropoutP > 0)
......@@ -151,7 +151,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X
//attNorm.Dump(stderr, "attNorm", 10);
/* encoder-decoder attention */
ende = attentionsEnde[i].Make(outputEnc, attNorm, outputEnc, &maskEncDec, isTraining, &contextCache[i], 2);
ende = attentionsEnde[i].Make(outputEnc, attNorm, outputEnc, &maskEncDec, isTraining, &contextCache[i], EN_DE_ATT);
//ende.Dump(stderr, "ende atten", 10);
......
......@@ -123,6 +123,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor *mask, XTensor &maskEncDec, boo
/* fnn */
x = fnns[i].Make(res, isTraining);
}
x = encodeLayerNorm->Make(x);
......
......@@ -55,9 +55,6 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID)
InitTensor1D(&w, d, X_FLOAT, devID);
InitTensor1D(&b, d, X_FLOAT, devID);
w.SetDataRand(1.0F, 1.0F);
b.SetZeroAll();
}
/*
......
......@@ -491,11 +491,10 @@ void T2TModel::Read(const char * fn)
GetParams(params);
size_t offset = 0;
for(int i = 0; i < params.count; i++){
XTensor * p = (XTensor*)params.Get(i);
p->BinaryRead(file, offset);
offset += p->unitNum;
FastRead(p, file);
// p->Read(file, "");
}
fclose(file);
......@@ -505,4 +504,14 @@ void T2TModel::Read(const char * fn)
XPRINT1(0, stderr, "[INFO] model loaded (took %.1fs)\n", elapsed);
}
void FastRead(XTensor* x, FILE* f) {
float * dataBuf = new float[x->unitNum];
fread(dataBuf, sizeof(char), sizeof(float) * x->unitNum, f);
x->SetData(dataBuf, x->unitNum);
delete[] dataBuf;
}
}
\ No newline at end of file
......@@ -103,7 +103,7 @@ public:
/* read the parameters */
void Read(const char * fn);
};
void FastRead(XTensor* x, FILE* f);
}
#endif
......@@ -61,18 +61,7 @@ void T2TOutput::InitModel(int argc, char ** argv, int myDevID)
InitTensor2D(&w, hSize, vSize, X_FLOAT, devID);
}
/*
make the network
y = softmax(x * w)
>> input - input tensor
<< return - output tensor
*/
XTensor T2TOutput::Make(XTensor &input)
{
XTensor &x = input;
return Softmax(MMul(x, X_NOTRANS, w, X_TRANS), -1);
}
/*
make the network (redefined output tensor)
......
......@@ -21,6 +21,7 @@
#include "T2TPredictor.h"
#include "../../tensor/core/CHeader.h"
#include <iostream>
using namespace nts;
......@@ -91,15 +92,6 @@ create an initial state
*/
void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state)
{
state->layersEnc.Clear();
state->layersDec.Clear();
XTensor * encoding = XLink::SearchNode(top, ENCODING_NAME);
CheckNTErrors(encoding != NULL, "No encoding layers found!");
state->layersEnc.Add(encoding);
state->layersDec.Add(NULL);
int dims[MAX_TENSOR_DIM_NUM];
for (int i = 0; i < input->order - 1; i++)
dims[i] = input->GetDim(i);
......@@ -109,10 +101,18 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input
InitTensor(&state->nstep, input->order, dims, X_FLOAT, input->devID);
InitTensor(&state->endMark, input->order, dims, X_INT, input->devID);
state->probPath.SetZeroAll();
float* data = new float[state->probPath.unitNum];
for (int i = 0; i < state->probPath.unitNum; ++i) {
data[i] = -1e20F;
if (i % beamSize == 0)
data[i] = 0;
}
state->probPath.SetData(data, state->probPath.unitNum);
state->nstep.SetZeroAll();
state->endMark.SetZeroAll();
delete[] data;
state->stateNum = 0;
}
......@@ -145,20 +145,13 @@ predict the next state
>> encoding - encoder output
>> inputEnc - input of the encoder
>> paddingEnc - padding of the encoder
>>> isStart - is the start or not
*/
void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
XTensor * inputEnc, XTensor * paddingEnc)
XTensor * inputEnc, XTensor * paddingEnc, bool isStart)
{
int dims[MAX_TENSOR_DIM_NUM];
next->layersEnc.Clear();
next->layersDec.Clear();
AttDecoder &decoder = *m->decoder;
/* word indices of previous positions */
XTensor * inputLast = (XTensor*)s->layersDec.GetItem(0);
/* word indices of positions up to next state */
XTensor inputDec;
......@@ -171,10 +164,10 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
dims[inputEnc->order - 1] = 1;
InitTensor(&first, inputEnc->order, dims, X_INT, inputEnc->devID);
_SetDataFixedInt(&first, startSymbol);
SetDataFixedInt(first, startSymbol);
/* add a new word into the input sequence of the decoder side */
if (inputLast == NULL) {
if (isStart) {
inputDec = Identity(first);
}
else{
......@@ -186,7 +179,6 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
/* prediction probabilities */
XTensor &output = next->prob;
XTensor decoding;
XTensor decodingStep;
for(int i = 0; i < inputDec.order - 1; i++)
dims[i] = inputDec.GetDim(i);
......@@ -203,38 +195,12 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec, 0);
/* make the decoding network */
decoding = decoder.Make(inputDec, *encoding, &maskDec, maskEncDec, false);
XTensor selectSrc;
XTensor selectTgt;
decoding = m->decoder->Make(inputDec, *encoding, NULL, maskEncDec, false);
CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!");
int stride = decoding.GetDim(decoding.order - 2);
//InitTensor1DV2(&selectSrc, 1, X_INT);
//InitTensor1DV2(&selectTgt, 1, X_INT);
//selectSrc.SetInt(stride - 1, 0);
//selectTgt.SetInt(0, 0);
//XTensor srcGPU;
//InitTensor1DV2(&srcGPU, 1, X_INT, decoding.devID);
//_CopyValues(&selectSrc, &srcGPU);
//XTensor tgtGPU;
//InitTensor1DV2(&tgtGPU, 1, X_INT, decoding.devID);
//_CopyValues(&selectTgt, &tgtGPU);
///* the decoder output of the last position */
//decodingStep = CopyIndexed(decoding, decoding.order - 2, srcGPU, tgtGPU);
/* generate the output probabilities */
m->outputLayer->Make(decoding, output);
next->layersEnc.AddList(&s->layersEnc);
next->layersDec.Add(&inputDec);
next->layersDec.Add(&output);
}
/*
......@@ -288,14 +254,10 @@ XTensor T2TPredictor::GetLastPrediction(T2TStateBundle* state)
XTensor lastPred;
InitTensor2D(&lastPred, state->stateNum, 1, X_INT);
lastPred.SetZeroAll();
for (int i = 0; i < state->stateNum; i++) {
T2TState* cur = state->states + i;
while (cur->last != NULL)
cur = cur->last;
lastPred.Set2DInt(cur->prediction, i, 0);
}
......
......@@ -94,13 +94,6 @@ public:
/* step number of each hypothesis */
XTensor nstep;
/* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */
TensorList layersEnc;
/* layers on the decoder side */
TensorList layersDec;
/* list of states */
T2TState * states;
......@@ -155,7 +148,7 @@ public:
void Read(T2TModel * model, T2TStateBundle * state);
/* predict the next state */
void Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc);
void Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc, bool isStart);
/* generate paths up to the states of the current step */
XTensor GeneratePaths(T2TStateBundle * state);
......
......@@ -59,9 +59,9 @@ void T2TSearch::Init(int argc, char ** argv)
{
LoadParamInt(argc, argv, "beamsize", &beamSize, 1);
LoadParamInt(argc, argv, "batchsize", &batchSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F);
LoadParamInt(argc, argv, "endid", endSymbols, -1);
LoadParamInt(argc, argv, "startid", &startSymbol, -1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 1.0F);
LoadParamInt(argc, argv, "endid", endSymbols, 2);
LoadParamInt(argc, argv, "startid", &startSymbol, 2);
if(endSymbols[0] >= 0)
endSymbolNum = 1;
......@@ -90,13 +90,9 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* encoder mask */
model->MakeMTMaskEnc(*input, *padding, maskEnc);
//input->Dump(stderr, "input:");
//maskEnc.Dump(stderr, "maskenc:");
/* make the encoding network */
encoding = model->MakeEncoder(*input, &maskEnc, false);
encoding.SetName(ENCODING_NAME);
encodingBeam = Unsqueeze(encoding, encoding.order - 2, beamSize);
inputBeam = Unsqueeze(*input, input->order - 1, beamSize);
......@@ -110,9 +106,11 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
maxLength = input->GetDim(-1) * 2;
CheckNTErrors(maxLength > 0, "no max length specified!");
T2TStateBundle * states = new T2TStateBundle[maxLength + 1];
T2TStateBundle * first = states;
T2TStateBundle* states = new T2TStateBundle[maxLength + 1];
T2TStateBundle* first = states;
T2TStateBundle* cur;
T2TStateBundle* next;
/* create the first state */
predictor.Create(model, &encodingBeam, input, beamSize, first);
predictor.SetStartSymbol(startSymbol);
......@@ -121,14 +119,14 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* generate the sequence from left to right */
for(int i = 0 ; i < maxLength; i++){
T2TStateBundle * cur = states + i;
T2TStateBundle * next = states + i + 1;
cur = states + i;
next = states + i + 1;
/* read the current state */
predictor.Read(model, cur);
/* predict the next state */
predictor.Predict(next, &encodingBeam, &inputBeam, &paddingBeam);
predictor.Predict(next, &encodingBeam, &inputBeam, &paddingBeam, i==0);
/* compute the model score (given the prediction probability) */
Score(cur, next);
......@@ -144,7 +142,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
}
/* fill the heap with imcomplete hypotheses if neccesary */
FillHeap(&states[maxLength]);
FillHeap(next);
Dump(output);
......@@ -210,24 +208,25 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
_ScaleAndShift(&lenPrev, &len, 1.0F, 1.0F);
/* the GNMT-like length penalty */
lp = T2TLengthPenalizer::GNMT(len, alpha);
//lp = T2TLengthPenalizer::GNMT(len, alpha);
lp.Reshape(lp.unitNum);
//lp.Reshape(lp.unitNum);
/* score = log-prob/lp */
_DivDim(&probPath, &lp, &score, 0);
//_DivDim(&probPath, &lp, &score, 0);
if (prev->isStart) {
XTensor firstMask = MakeFirstMask(beam);
firstMask.Reshape(firstMask.unitNum);
/* mask the hypotheses in the beam expect the first one */
/* mask the hypotheses in the beam except the first one */
_SumDim(&score, &firstMask, &score, 0);
}
InitTensor(&mask,
prev->endMark.order, prev->endMark.dimSize, X_FLOAT,
prev->endMark.devID);
mask.SetZeroAll();
_SetDataFixedCond(&mask, &prev->endMark, -1e9F);
mask.Reshape(mask.unitNum);
......@@ -279,17 +278,26 @@ void T2TSearch::Generate(T2TStateBundle * beam)
dimsTopK[order - 3] = dimsBeam[order - 3];
dimsTopK[order - 1] = beamSize;
InitTensor(&scoreTopK, order, dimsTopK, score.dataType,
score.devID);
InitTensor(&index, order, dimsTopK, X_INT,
score.devID);
InitTensor(&scoreTopK, order, dimsTopK, score.dataType, score.devID);
InitTensor(&index, order, dimsTopK, X_INT, score.devID);
InitTensor(&preID, order, dimsTopK, X_INT, -1);
/* mask the first and the padding id */
int dimMask[]{ score.GetDim(-1) };
XTensor mask;
InitTensor(&mask, 1, dimMask, X_FLOAT, -1);
mask.SetZeroAll();
mask.Set1D(-1e20F, 0);
mask.Set1D(-1e20F, 1);
mask.SetDevice(score.devID, score.mem);
//_SumDim(&score, &mask, 2);
score.Reshape(order, dimsBeam);
/* keep the most promissing candidates in the beam */
/* TODO: check this line */
TopK(score, scoreTopK, index, -1, beamSize);
CopyValues(index, preID);
/* "preID" represents the id (or the offset) of the previous state used to make the current
......@@ -313,11 +321,14 @@ void T2TSearch::Generate(T2TStateBundle * beam)
/* CPU data (TODO: remove GPU->CPU data copy!!!) */
XTensor indexGPU;
indexGPU = CopyValues(index);
//InitTensorV2(&indexCPU, index.order, index.dimSize, index.dataType, index.denseRatio, -1);
//CopyValues(index, indexCPU);
for (int i = 0; i < indexGPU.unitNum; i++)
indexGPU.SetInt(i * stride + indexGPU.GetInt(i), i);
for (int i = 0; i < indexGPU.unitNum; i += beamSize) {
for (int j = 0; j < beamSize; j++)
indexGPU.SetInt(i * stride + indexGPU.GetInt(i + j), i + j);
}
/*for (int i = 0; i < indexGPU.unitNum; i++) {
indexGPU.SetInt(i + indexGPU.GetInt(i), i);
}*/
CheckNTErrors(IsSameShaped(prob, probPath), "Wrong tensor shape!");
......@@ -460,6 +471,10 @@ void T2TSearch::Collect(T2TStateBundle * beam)
CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!");
/* check if this is the first end symbol. It is false
if there have been end symbols in previously generated words. */
bool isCompleted = state.isCompleted && (state.last == NULL || !state.last->isCompleted);
/* we push the hypothesis into the heap when it is completed */
if(state.isEnd != 0)
......
......@@ -93,17 +93,13 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
{
count++;
wordCount = 0;
/*if (count % 10 == 0 && sentBatch < 128)
sentBatch *= 2;*/
/* reset cache for decoder */
for (int i = 0; i < model->decoder->nlayer; ++i) {
model->decoder->selfCache[i].Clear();
model->decoder->contextCache[i].Clear();
model->decoder->selfCache[i].miss = true;
model->decoder->contextCache[i].miss = true;
}
vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID);
XTensor output;
seacher.Search(model, &batchEnc, &paddingEnc, &output);
......@@ -122,7 +118,6 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
res.id = indices[i];
batchLoader.resBuffer.emplace_back(res);
}
wc = batchEnc.GetDim(-1);
wordCount += wc;
......
......@@ -54,7 +54,7 @@ int TransformerMain(int argc, const char ** argv)
char * rawModel = new char[MAX_LINE_LENGTH];
LoadParamString(argc, args, "model", modelFN, "");
LoadParamString(argc, args, "rawModel", rawModel, "");
LoadParamString(argc, args, "rawmodel", rawModel, "");
LoadParamString(argc, args, "test", testFN, "");
LoadParamString(argc, args, "output", outputFN, "");
LoadParamBool(argc, args, "beamsearch", &isBeamSearch, false);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论