Commit bfa6fc90 by huchi

replace cache type with XTensor

parent e925cfd9
...@@ -41,6 +41,14 @@ int main( int argc, const char ** argv ) ...@@ -41,6 +41,14 @@ int main( int argc, const char ** argv )
//_CrtSetBreakAlloc(2708); //_CrtSetBreakAlloc(2708);
TransformerMain(argc - 1, argv + 1); TransformerMain(argc - 1, argv + 1);
/*XTensor x;
InitTensor2D(&x, 2, 2);
float d[]{ 1,2,3,4 };
x.SetData(d, 4);
XTensor y;
y = ReduceSum(x, 0);
y.Dump(stderr);*/
//_CrtDumpMemoryLeaks(); //_CrtDumpMemoryLeaks();
......
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
namespace transformer namespace transformer
{ {
enum { NONE, SELF, CONTEXT };
/* constructor */ /* constructor */
T2TAttention::T2TAttention() T2TAttention::T2TAttention()
{ {
...@@ -86,88 +84,55 @@ void T2TAttention::InitModel(int argc, char** argv, ...@@ -86,88 +84,55 @@ void T2TAttention::InitModel(int argc, char** argv,
/* /*
make the network make the network
>> k - keys. It might be of size B * L * H >> k - keys. It might be of size B * L * H
where B = batch size, L = sequence length, where B = batch size, L = sequence length,
and H = vector size of each position and H = vector size of each position
>> q - queries >> q - queries
>> v - values >> v - values
>> mask - as it is >> mask - as it is
>> isTraining - indicates whether the model is used for training >> isTraining - indicates whether the model is used for training
>> cache - the cache list >> cache - layer cache list
>> cacheType - type of the cache >> cacheType - which type that cache is
<< return - multi-attention result << return - multi-attention result
*/ */
XTensor T2TAttention::Make(XTensor& k, XTensor& q, XTensor& v, XTensor *mask, XTensor T2TAttention::Make( XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, Cache* cache, int cacheType)
bool isTraining, Cache * cache, int cacheType)
{ {
bool is_encoder = (!cache) ? true : false; const bool isEnc = (!cache) ? true : false;
int k2Dim[]{ k.GetDim(0),k.GetDim(1), wk.GetDim(1) };
int v2Dim[]{ v.GetDim(0),v.GetDim(1), wv.GetDim(1) };
XTensor* q2 = NewTensor3DV2(q.GetDim(0),q.GetDim(1), wq.GetDim(1), X_FLOAT, q.devID);
XTensor* k2 = NULL;
XTensor* v2 = NULL;
XTensor* kNewCache = NULL;
XTensor* vNewCache = NULL;
/* linear transformation before self-attention */ /* linear transformation before self-attention */
/* notice that all weights are transposed!!! */ XTensor q2, k2, v2;
q2 = MatrixMul(q, X_NOTRANS, wq, X_TRANS) + bq;
_MatrixMul(&q, X_NOTRANS, &wq, X_TRANS, q2);
_SumDim(q2, &bq, 2);
if (!cache) { if (!cache) {
k2 = NewTensor3DV2(k.GetDim(0),k.GetDim(1), wk.GetDim(1), X_FLOAT, k.devID); /* self attention for encoder layers */
v2 = NewTensor3DV2(v.GetDim(0),v.GetDim(1), wv.GetDim(1), X_FLOAT, v.devID); k2 = MatrixMul(k, X_NOTRANS, wk, X_TRANS) + bk;
_MatrixMul(&k, X_NOTRANS, &wk, X_TRANS, k2); v2 = MatrixMul(v, X_NOTRANS, wv, X_TRANS) + bv;
_SumDim(k2, &bk, 2); return MakeRPRAttention(k2, q2, v2, mask, isTraining, isEnc);
_MatrixMul(&v, X_NOTRANS, &wv, X_TRANS, v2);
_SumDim(v2, &bv, 2);
} }
else { else {
if (cacheType == SELF) { if (cacheType == SELF_ATT) {
k2 = NewTensor3DV2(q.GetDim(0), q.GetDim(1), wk.GetDim(1), X_FLOAT, q.devID); k2 = MatrixMul(k, X_NOTRANS, wk, X_TRANS) + bk;
v2 = NewTensor3DV2(q.GetDim(0), q.GetDim(1), wv.GetDim(1), X_FLOAT, q.devID); v2 = MatrixMul(v, X_NOTRANS, wv, X_TRANS) + bv;
_MatrixMul(&q, X_NOTRANS, &wk, X_TRANS, k2);
_SumDim(k2, &bk, 2); /* if hit, we only concat the cache with the new token */
_MatrixMul(&q, X_NOTRANS, &wv, X_TRANS, v2); if (!cache->miss) {
_SumDim(v2, &bv, 2); k2 = Concatenate(cache->key, k2, 1);
if (!cache->IsEmpty()) { v2 = Concatenate(cache->value, v2, 1);
XTensor* kOldCache = cache->GetK();
XTensor* vOldCache = cache->GetV();
kNewCache = NewTensor3DV2(kOldCache->GetDim(0), kOldCache->GetDim(1) + k2->GetDim(1), kOldCache->GetDim(2), X_FLOAT, k2->devID);
vNewCache = NewTensor3DV2(vOldCache->GetDim(0), vOldCache->GetDim(1) + v2->GetDim(1), vOldCache->GetDim(2), X_FLOAT, v2->devID);
_Concatenate(kOldCache, k2, kNewCache, 1);
_Concatenate(vOldCache, v2, vNewCache, 1);
DelTensor(k2);
DelTensor(v2);
k2 = kNewCache;
v2 = vNewCache;
} }
cache->Update(k2, v2); cache->key = k2;
cache->value = v2;
cache->miss = false;
return MakeRPRAttention(cache->key, q2, cache->value, mask, isTraining, isEnc);
} }
else if (cacheType == CONTEXT) { else if (cacheType == EN_DE_ATT) {
if (cache->IsEmpty()) { if (cache->miss) {
k2 = NewTensor3DV2(k.GetDim(0), k.GetDim(1), wk.GetDim(1), X_FLOAT, k.devID); cache->key = MatrixMul(k, X_NOTRANS, wk, X_TRANS) + bk;
v2 = NewTensor3DV2(v.GetDim(0), v.GetDim(1), wv.GetDim(1), X_FLOAT, v.devID); cache->value = MatrixMul(v, X_NOTRANS, wv, X_TRANS) + bv;
_MatrixMul(&k, X_NOTRANS, &wk, X_TRANS, k2); cache->miss = false;
_SumDim(k2, &bk, 2);
_MatrixMul(&v, X_NOTRANS, &wv, X_TRANS, v2);
_SumDim(v2, &bv, 2);
cache->Update(k2, v2);
}
else {
k2 = cache->GetK();
v2 = cache->GetV();
} }
return MakeAttention(cache->key, q2, cache->value, mask, isTraining, isEnc);
} }
else { CheckNTErrors(0, "invalid cache type");
CheckNTErrors(0, "invalid cache type");
}
} }
if(cacheType == CONTEXT)
return MakeAttention(k2, q2, v2, mask, isTraining, is_encoder);
return MakeRPRAttention(k2, q2, v2, mask, isTraining, is_encoder);
} }
/* /*
...@@ -180,16 +145,16 @@ make the attention network given keys, queries and values (after linear transfor ...@@ -180,16 +145,16 @@ make the attention network given keys, queries and values (after linear transfor
>> mask - as it is >> mask - as it is
>> isTraining - indicates whether the model is used for training >> isTraining - indicates whether the model is used for training
*/ */
XTensor T2TAttention::MakeAttention(XTensor *k, XTensor *q, XTensor *v, const XTensor *mask, bool isTraining, bool is_encoder) XTensor T2TAttention::MakeAttention(XTensor &k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, bool is_encoder)
{ {
XTensor kheads; XTensor kheads;
XTensor qheads; XTensor qheads;
XTensor vheads; XTensor vheads;
/* multi head */ /* multi head */
kheads = Split(*k, k->order - 1, nhead); kheads = Split(k, k.order - 1, nhead);
qheads = Split(*q, q->order - 1, nhead); qheads = Split(q, q.order - 1, nhead);
vheads = Split(*v, v->order - 1, nhead); vheads = Split(v, v.order - 1, nhead);
XTensor att; XTensor att;
XTensor dot; XTensor dot;
...@@ -198,16 +163,16 @@ XTensor T2TAttention::MakeAttention(XTensor *k, XTensor *q, XTensor *v, const XT ...@@ -198,16 +163,16 @@ XTensor T2TAttention::MakeAttention(XTensor *k, XTensor *q, XTensor *v, const XT
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */ /* scalar = softmax(Q * K^T / sqrt(dk)) * V */
dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS); dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
if (isMasked && mask) { /*if (isMasked && mask) {
_SumMe(&dot, mask); _SumMe(&dot, mask);
} }*/
dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead)); dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead));
scalar = Softmax(dot, -1); scalar = Softmax(dot, -1);
if(isTraining && dropoutP > 0) /*if(isTraining && dropoutP > 0)
scalar = Dropout(scalar, dropoutP); scalar = Dropout(scalar, dropoutP);*/
att = BMMul(scalar, vheads); att = BMMul(scalar, vheads);
...@@ -225,32 +190,32 @@ make the attention network by incorporating the relative position representation ...@@ -225,32 +190,32 @@ make the attention network by incorporating the relative position representation
>> mask - as it is >> mask - as it is
>> isTraining - indicates whether the model is used for training >> isTraining - indicates whether the model is used for training
*/ */
XTensor T2TAttention::MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTensor *mask, bool isTraining, bool is_encoder) XTensor T2TAttention::MakeRPRAttention(XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, bool is_encoder)
{ {
XTensor kheads; XTensor kheads;
XTensor qheads; XTensor qheads;
XTensor vheads; XTensor vheads;
const int batch_size = q->GetDim(0); const int batch_size = q.GetDim(0);
const int len_q = q->GetDim(1); const int len_q = q.GetDim(1);
const int len_kv = k->GetDim(1); const int len_kv = k.GetDim(1);
/* multi head */ /* multi head */
kheads = Split(*k, k->order - 1, nhead); kheads = Split(k, k.order - 1, nhead);
qheads = Split(*q, q->order - 1, nhead); qheads = Split(q, q.order - 1, nhead);
vheads = Split(*v, v->order - 1, nhead); vheads = Split(v, v.order - 1, nhead);
XTensor att; XTensor att;
XTensor dot; XTensor dot;
XTensor scalar; XTensor scalar;
XTensor emb_matrix, relative_key; XTensor emb_matrix, relative_key;
InitTensor2DV2(&emb_matrix, len_q, len_kv, X_INT, q->devID); InitTensor2DV2(&emb_matrix, len_q, len_kv, X_INT, q.devID);
InitTensor3DV2(&relative_key, len_q, len_kv, kheads.GetDim(-1), X_FLOAT, q->devID); InitTensor3DV2(&relative_key, len_q, len_kv, kheads.GetDim(-1), X_FLOAT, q.devID);
InitTensor4DV2(&dot, nhead, batch_size, len_q, len_kv, X_FLOAT, q->devID); InitTensor4DV2(&dot, nhead, batch_size, len_q, len_kv, X_FLOAT, q.devID);
/* generate the relative emb index (L_q, L_kv) */ /* generate the relative emb index (L_q, L_kv) */
GetRPEmbedding(&emb_matrix, len_q, len_kv, max_relative_position, q->devID,is_encoder); GetRPEmbedding(&emb_matrix, len_q, len_kv, max_relative_position, q.devID,is_encoder);
/* generate the relative key from the rp_embedding_k (L_q, L_kv, H/K) */ /* generate the relative key from the rp_embedding_k (L_q, L_kv, H/K) */
...@@ -261,8 +226,8 @@ XTensor T2TAttention::MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTens ...@@ -261,8 +226,8 @@ XTensor T2TAttention::MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTens
RPDotProduct(&qheads, &kheads, &relative_key, &dot, true); RPDotProduct(&qheads, &kheads, &relative_key, &dot, true);
if (isMasked && mask) /*if (isMasked && mask)
_SumMe(&dot, mask); _SumMe(&dot, mask);*/
/* scale the dot result */ /* scale the dot result */
//dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead)); //dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead));
...@@ -270,18 +235,12 @@ XTensor T2TAttention::MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTens ...@@ -270,18 +235,12 @@ XTensor T2TAttention::MakeRPRAttention(XTensor *k, XTensor *q, XTensor *v, XTens
/* softmax */ /* softmax */
scalar = Softmax(dot, -1); scalar = Softmax(dot, -1);
if (isTraining && dropoutP > 0) /*if (isTraining && dropoutP > 0)
scalar = Dropout(scalar, dropoutP); scalar = Dropout(scalar, dropoutP);*/
/* generate the relative attention output (K, B, L_q, H/K) */ /* generate the relative attention output (K, B, L_q, H/K) */
att = BMMul(scalar, vheads); att = BMMul(scalar, vheads);
if (is_encoder) {
DelTensor(k);
DelTensor(q);
DelTensor(v);
}
/* concatenate the heads */ /* concatenate the heads */
return MulAndShift(Merge(att, att.order - 1), X_NOTRANS, wa, X_TRANS, ba); return MulAndShift(Merge(att, att.order - 1), X_NOTRANS, wa, X_TRANS, ba);
} }
......
...@@ -28,46 +28,31 @@ using namespace nts; ...@@ -28,46 +28,31 @@ using namespace nts;
namespace transformer namespace transformer
{ {
/* attention type */
enum { NONE, SELF_ATT, EN_DE_ATT };
/* layer cache for key and value */ /* layer cache for keys and values */
class Cache { class Cache
{
public: public:
/* cache for keys */
XTensor key;
/* cache for key */ /* cache for values */
XTensor* k{ NULL }; XTensor value;
/* cache for value */
XTensor* v{ NULL };
public: public:
bool IsEmpty() { bool miss;
return (k == NULL) && (v == NULL);
}
void Clear() {
if (k && v && k->id > 0 && v->id > 0) {
DelTensor(k);
DelTensor(v);
}
k = NULL;
v = NULL;
}
void Update(XTensor* newK, XTensor* newV) {
if (!newK || (k == newK) || !newV || (v == newV))
return;
Clear();
k = newK;
v = newV;
}
XTensor* GetK() { Cache() {
return k; miss = true;
} }
XTensor* GetV() { void Update(XTensor&& k, XTensor&& v) {
return v; key = k;
value = v;
miss = false;
} }
}; };
...@@ -153,14 +138,14 @@ public: ...@@ -153,14 +138,14 @@ public:
int myDevID = -1); int myDevID = -1);
/* make the network */ /* make the network */
XTensor Make(XTensor& k, XTensor& q, XTensor& v, XTensor* mask, XTensor Make( XTensor& k, XTensor& q, XTensor& v,
bool isTraining, Cache* cache, int cacheType); XTensor* mask, bool isTraining, Cache* cache, int cacheType);
/* make the attention network given keys, queries and values (after linear transformation) */ /* make the attention network given keys, queries and values (after linear transformation) */
XTensor MakeAttention(XTensor* k, XTensor* q, XTensor* v, const XTensor* mask, bool isTraining, bool is_encoder); XTensor MakeAttention(XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, bool is_encoder);
/* make the attention network given keys, queries and values (after linear transformation) */ /* make the attention network given keys, queries and values (after linear transformation) */
XTensor MakeRPRAttention(XTensor* k, XTensor* q, XTensor* v, XTensor* mask, bool isTraining, bool is_encoder); XTensor MakeRPRAttention(XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, bool is_encoder);
void GetRPEmbedding(XTensor* emb_matrix, const int len_q, const int len_kv, const int max_relative_length, const int device_id, const bool is_encoder); void GetRPEmbedding(XTensor* emb_matrix, const int len_q, const int len_kv, const int max_relative_length, const int device_id, const bool is_encoder);
......
...@@ -136,7 +136,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X ...@@ -136,7 +136,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X
/******************/ /******************/
/* self attention */ /* self attention */
att = attentions[i].Make(inputNorm, inputNorm, inputNorm, NULL, isTraining, &selfCache[i], 1); att = attentions[i].Make(inputNorm, inputNorm, inputNorm, NULL, isTraining, &selfCache[i], SELF_ATT);
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
...@@ -151,7 +151,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X ...@@ -151,7 +151,7 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X
//attNorm.Dump(stderr, "attNorm", 10); //attNorm.Dump(stderr, "attNorm", 10);
/* encoder-decoder attention */ /* encoder-decoder attention */
ende = attentionsEnde[i].Make(outputEnc, attNorm, outputEnc, &maskEncDec, isTraining, &contextCache[i], 2); ende = attentionsEnde[i].Make(outputEnc, attNorm, outputEnc, &maskEncDec, isTraining, &contextCache[i], EN_DE_ATT);
//ende.Dump(stderr, "ende atten", 10); //ende.Dump(stderr, "ende atten", 10);
......
...@@ -123,6 +123,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor *mask, XTensor &maskEncDec, boo ...@@ -123,6 +123,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor *mask, XTensor &maskEncDec, boo
/* fnn */ /* fnn */
x = fnns[i].Make(res, isTraining); x = fnns[i].Make(res, isTraining);
} }
x = encodeLayerNorm->Make(x); x = encodeLayerNorm->Make(x);
......
...@@ -55,9 +55,6 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID) ...@@ -55,9 +55,6 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID)
InitTensor1D(&w, d, X_FLOAT, devID); InitTensor1D(&w, d, X_FLOAT, devID);
InitTensor1D(&b, d, X_FLOAT, devID); InitTensor1D(&b, d, X_FLOAT, devID);
w.SetDataRand(1.0F, 1.0F);
b.SetZeroAll();
} }
/* /*
......
...@@ -491,11 +491,10 @@ void T2TModel::Read(const char * fn) ...@@ -491,11 +491,10 @@ void T2TModel::Read(const char * fn)
GetParams(params); GetParams(params);
size_t offset = 0;
for(int i = 0; i < params.count; i++){ for(int i = 0; i < params.count; i++){
XTensor * p = (XTensor*)params.Get(i); XTensor * p = (XTensor*)params.Get(i);
p->BinaryRead(file, offset); FastRead(p, file);
offset += p->unitNum; // p->Read(file, "");
} }
fclose(file); fclose(file);
...@@ -505,4 +504,14 @@ void T2TModel::Read(const char * fn) ...@@ -505,4 +504,14 @@ void T2TModel::Read(const char * fn)
XPRINT1(0, stderr, "[INFO] model loaded (took %.1fs)\n", elapsed); XPRINT1(0, stderr, "[INFO] model loaded (took %.1fs)\n", elapsed);
} }
void FastRead(XTensor* x, FILE* f) {
float * dataBuf = new float[x->unitNum];
fread(dataBuf, sizeof(char), sizeof(float) * x->unitNum, f);
x->SetData(dataBuf, x->unitNum);
delete[] dataBuf;
} }
}
\ No newline at end of file
...@@ -103,7 +103,7 @@ public: ...@@ -103,7 +103,7 @@ public:
/* read the parameters */ /* read the parameters */
void Read(const char * fn); void Read(const char * fn);
}; };
void FastRead(XTensor* x, FILE* f);
} }
#endif #endif
...@@ -61,18 +61,7 @@ void T2TOutput::InitModel(int argc, char ** argv, int myDevID) ...@@ -61,18 +61,7 @@ void T2TOutput::InitModel(int argc, char ** argv, int myDevID)
InitTensor2D(&w, hSize, vSize, X_FLOAT, devID); InitTensor2D(&w, hSize, vSize, X_FLOAT, devID);
} }
/*
make the network
y = softmax(x * w)
>> input - input tensor
<< return - output tensor
*/
XTensor T2TOutput::Make(XTensor &input)
{
XTensor &x = input;
return Softmax(MMul(x, X_NOTRANS, w, X_TRANS), -1);
}
/* /*
make the network (redefined output tensor) make the network (redefined output tensor)
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "T2TPredictor.h" #include "T2TPredictor.h"
#include "../../tensor/core/CHeader.h" #include "../../tensor/core/CHeader.h"
#include <iostream>
using namespace nts; using namespace nts;
...@@ -91,15 +92,6 @@ create an initial state ...@@ -91,15 +92,6 @@ create an initial state
*/ */
void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state) void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state)
{ {
state->layersEnc.Clear();
state->layersDec.Clear();
XTensor * encoding = XLink::SearchNode(top, ENCODING_NAME);
CheckNTErrors(encoding != NULL, "No encoding layers found!");
state->layersEnc.Add(encoding);
state->layersDec.Add(NULL);
int dims[MAX_TENSOR_DIM_NUM]; int dims[MAX_TENSOR_DIM_NUM];
for (int i = 0; i < input->order - 1; i++) for (int i = 0; i < input->order - 1; i++)
dims[i] = input->GetDim(i); dims[i] = input->GetDim(i);
...@@ -109,10 +101,18 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input ...@@ -109,10 +101,18 @@ void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input
InitTensor(&state->nstep, input->order, dims, X_FLOAT, input->devID); InitTensor(&state->nstep, input->order, dims, X_FLOAT, input->devID);
InitTensor(&state->endMark, input->order, dims, X_INT, input->devID); InitTensor(&state->endMark, input->order, dims, X_INT, input->devID);
state->probPath.SetZeroAll(); float* data = new float[state->probPath.unitNum];
for (int i = 0; i < state->probPath.unitNum; ++i) {
data[i] = -1e20F;
if (i % beamSize == 0)
data[i] = 0;
}
state->probPath.SetData(data, state->probPath.unitNum);
state->nstep.SetZeroAll(); state->nstep.SetZeroAll();
state->endMark.SetZeroAll(); state->endMark.SetZeroAll();
delete[] data;
state->stateNum = 0; state->stateNum = 0;
} }
...@@ -145,20 +145,13 @@ predict the next state ...@@ -145,20 +145,13 @@ predict the next state
>> encoding - encoder output >> encoding - encoder output
>> inputEnc - input of the encoder >> inputEnc - input of the encoder
>> paddingEnc - padding of the encoder >> paddingEnc - padding of the encoder
>>> isStart - is the start or not
*/ */
void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
XTensor * inputEnc, XTensor * paddingEnc) XTensor * inputEnc, XTensor * paddingEnc, bool isStart)
{ {
int dims[MAX_TENSOR_DIM_NUM]; int dims[MAX_TENSOR_DIM_NUM];
next->layersEnc.Clear();
next->layersDec.Clear();
AttDecoder &decoder = *m->decoder;
/* word indices of previous positions */
XTensor * inputLast = (XTensor*)s->layersDec.GetItem(0);
/* word indices of positions up to next state */ /* word indices of positions up to next state */
XTensor inputDec; XTensor inputDec;
...@@ -171,10 +164,10 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, ...@@ -171,10 +164,10 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
dims[inputEnc->order - 1] = 1; dims[inputEnc->order - 1] = 1;
InitTensor(&first, inputEnc->order, dims, X_INT, inputEnc->devID); InitTensor(&first, inputEnc->order, dims, X_INT, inputEnc->devID);
_SetDataFixedInt(&first, startSymbol); SetDataFixedInt(first, startSymbol);
/* add a new word into the input sequence of the decoder side */ /* add a new word into the input sequence of the decoder side */
if (inputLast == NULL) { if (isStart) {
inputDec = Identity(first); inputDec = Identity(first);
} }
else{ else{
...@@ -186,7 +179,6 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, ...@@ -186,7 +179,6 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
/* prediction probabilities */ /* prediction probabilities */
XTensor &output = next->prob; XTensor &output = next->prob;
XTensor decoding; XTensor decoding;
XTensor decodingStep;
for(int i = 0; i < inputDec.order - 1; i++) for(int i = 0; i < inputDec.order - 1; i++)
dims[i] = inputDec.GetDim(i); dims[i] = inputDec.GetDim(i);
...@@ -203,38 +195,12 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, ...@@ -203,38 +195,12 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec, 0); m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec, 0);
/* make the decoding network */ /* make the decoding network */
decoding = decoder.Make(inputDec, *encoding, &maskDec, maskEncDec, false); decoding = m->decoder->Make(inputDec, *encoding, NULL, maskEncDec, false);
XTensor selectSrc;
XTensor selectTgt;
CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!"); CheckNTErrors(decoding.order >= 2, "The tensor must be of order 2 or larger!");
int stride = decoding.GetDim(decoding.order - 2);
//InitTensor1DV2(&selectSrc, 1, X_INT);
//InitTensor1DV2(&selectTgt, 1, X_INT);
//selectSrc.SetInt(stride - 1, 0);
//selectTgt.SetInt(0, 0);
//XTensor srcGPU;
//InitTensor1DV2(&srcGPU, 1, X_INT, decoding.devID);
//_CopyValues(&selectSrc, &srcGPU);
//XTensor tgtGPU;
//InitTensor1DV2(&tgtGPU, 1, X_INT, decoding.devID);
//_CopyValues(&selectTgt, &tgtGPU);
///* the decoder output of the last position */
//decodingStep = CopyIndexed(decoding, decoding.order - 2, srcGPU, tgtGPU);
/* generate the output probabilities */ /* generate the output probabilities */
m->outputLayer->Make(decoding, output); m->outputLayer->Make(decoding, output);
next->layersEnc.AddList(&s->layersEnc);
next->layersDec.Add(&inputDec);
next->layersDec.Add(&output);
} }
/* /*
...@@ -288,14 +254,10 @@ XTensor T2TPredictor::GetLastPrediction(T2TStateBundle* state) ...@@ -288,14 +254,10 @@ XTensor T2TPredictor::GetLastPrediction(T2TStateBundle* state)
XTensor lastPred; XTensor lastPred;
InitTensor2D(&lastPred, state->stateNum, 1, X_INT); InitTensor2D(&lastPred, state->stateNum, 1, X_INT);
lastPred.SetZeroAll();
for (int i = 0; i < state->stateNum; i++) { for (int i = 0; i < state->stateNum; i++) {
T2TState* cur = state->states + i; T2TState* cur = state->states + i;
while (cur->last != NULL)
cur = cur->last;
lastPred.Set2DInt(cur->prediction, i, 0); lastPred.Set2DInt(cur->prediction, i, 0);
} }
......
...@@ -94,13 +94,6 @@ public: ...@@ -94,13 +94,6 @@ public:
/* step number of each hypothesis */ /* step number of each hypothesis */
XTensor nstep; XTensor nstep;
/* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */
TensorList layersEnc;
/* layers on the decoder side */
TensorList layersDec;
/* list of states */ /* list of states */
T2TState * states; T2TState * states;
...@@ -155,7 +148,7 @@ public: ...@@ -155,7 +148,7 @@ public:
void Read(T2TModel * model, T2TStateBundle * state); void Read(T2TModel * model, T2TStateBundle * state);
/* predict the next state */ /* predict the next state */
void Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc); void Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc, bool isStart);
/* generate paths up to the states of the current step */ /* generate paths up to the states of the current step */
XTensor GeneratePaths(T2TStateBundle * state); XTensor GeneratePaths(T2TStateBundle * state);
......
...@@ -59,9 +59,9 @@ void T2TSearch::Init(int argc, char ** argv) ...@@ -59,9 +59,9 @@ void T2TSearch::Init(int argc, char ** argv)
{ {
LoadParamInt(argc, argv, "beamsize", &beamSize, 1); LoadParamInt(argc, argv, "beamsize", &beamSize, 1);
LoadParamInt(argc, argv, "batchsize", &batchSize, 1); LoadParamInt(argc, argv, "batchsize", &batchSize, 1);
LoadParamFloat(argc, argv, "lenalpha", &alpha, 0.2F); LoadParamFloat(argc, argv, "lenalpha", &alpha, 1.0F);
LoadParamInt(argc, argv, "endid", endSymbols, -1); LoadParamInt(argc, argv, "endid", endSymbols, 2);
LoadParamInt(argc, argv, "startid", &startSymbol, -1); LoadParamInt(argc, argv, "startid", &startSymbol, 2);
if(endSymbols[0] >= 0) if(endSymbols[0] >= 0)
endSymbolNum = 1; endSymbolNum = 1;
...@@ -90,13 +90,9 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -90,13 +90,9 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* encoder mask */ /* encoder mask */
model->MakeMTMaskEnc(*input, *padding, maskEnc); model->MakeMTMaskEnc(*input, *padding, maskEnc);
//input->Dump(stderr, "input:");
//maskEnc.Dump(stderr, "maskenc:");
/* make the encoding network */ /* make the encoding network */
encoding = model->MakeEncoder(*input, &maskEnc, false); encoding = model->MakeEncoder(*input, &maskEnc, false);
encoding.SetName(ENCODING_NAME);
encodingBeam = Unsqueeze(encoding, encoding.order - 2, beamSize); encodingBeam = Unsqueeze(encoding, encoding.order - 2, beamSize);
inputBeam = Unsqueeze(*input, input->order - 1, beamSize); inputBeam = Unsqueeze(*input, input->order - 1, beamSize);
...@@ -110,9 +106,11 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -110,9 +106,11 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
maxLength = input->GetDim(-1) * 2; maxLength = input->GetDim(-1) * 2;
CheckNTErrors(maxLength > 0, "no max length specified!"); CheckNTErrors(maxLength > 0, "no max length specified!");
T2TStateBundle * states = new T2TStateBundle[maxLength + 1]; T2TStateBundle* states = new T2TStateBundle[maxLength + 1];
T2TStateBundle * first = states; T2TStateBundle* first = states;
T2TStateBundle* cur;
T2TStateBundle* next;
/* create the first state */ /* create the first state */
predictor.Create(model, &encodingBeam, input, beamSize, first); predictor.Create(model, &encodingBeam, input, beamSize, first);
predictor.SetStartSymbol(startSymbol); predictor.SetStartSymbol(startSymbol);
...@@ -121,14 +119,14 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -121,14 +119,14 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
/* generate the sequence from left to right */ /* generate the sequence from left to right */
for(int i = 0 ; i < maxLength; i++){ for(int i = 0 ; i < maxLength; i++){
T2TStateBundle * cur = states + i; cur = states + i;
T2TStateBundle * next = states + i + 1; next = states + i + 1;
/* read the current state */ /* read the current state */
predictor.Read(model, cur); predictor.Read(model, cur);
/* predict the next state */ /* predict the next state */
predictor.Predict(next, &encodingBeam, &inputBeam, &paddingBeam); predictor.Predict(next, &encodingBeam, &inputBeam, &paddingBeam, i==0);
/* compute the model score (given the prediction probability) */ /* compute the model score (given the prediction probability) */
Score(cur, next); Score(cur, next);
...@@ -144,7 +142,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe ...@@ -144,7 +142,7 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
} }
/* fill the heap with imcomplete hypotheses if neccesary */ /* fill the heap with imcomplete hypotheses if neccesary */
FillHeap(&states[maxLength]); FillHeap(next);
Dump(output); Dump(output);
...@@ -210,24 +208,25 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam) ...@@ -210,24 +208,25 @@ void T2TSearch::Score(T2TStateBundle * prev, T2TStateBundle * beam)
_ScaleAndShift(&lenPrev, &len, 1.0F, 1.0F); _ScaleAndShift(&lenPrev, &len, 1.0F, 1.0F);
/* the GNMT-like length penalty */ /* the GNMT-like length penalty */
lp = T2TLengthPenalizer::GNMT(len, alpha); //lp = T2TLengthPenalizer::GNMT(len, alpha);
lp.Reshape(lp.unitNum); //lp.Reshape(lp.unitNum);
/* score = log-prob/lp */ /* score = log-prob/lp */
_DivDim(&probPath, &lp, &score, 0); //_DivDim(&probPath, &lp, &score, 0);
if (prev->isStart) { if (prev->isStart) {
XTensor firstMask = MakeFirstMask(beam); XTensor firstMask = MakeFirstMask(beam);
firstMask.Reshape(firstMask.unitNum); firstMask.Reshape(firstMask.unitNum);
/* mask the hypotheses in the beam expect the first one */ /* mask the hypotheses in the beam except the first one */
_SumDim(&score, &firstMask, &score, 0); _SumDim(&score, &firstMask, &score, 0);
} }
InitTensor(&mask, InitTensor(&mask,
prev->endMark.order, prev->endMark.dimSize, X_FLOAT, prev->endMark.order, prev->endMark.dimSize, X_FLOAT,
prev->endMark.devID); prev->endMark.devID);
mask.SetZeroAll();
_SetDataFixedCond(&mask, &prev->endMark, -1e9F); _SetDataFixedCond(&mask, &prev->endMark, -1e9F);
mask.Reshape(mask.unitNum); mask.Reshape(mask.unitNum);
...@@ -279,17 +278,26 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -279,17 +278,26 @@ void T2TSearch::Generate(T2TStateBundle * beam)
dimsTopK[order - 3] = dimsBeam[order - 3]; dimsTopK[order - 3] = dimsBeam[order - 3];
dimsTopK[order - 1] = beamSize; dimsTopK[order - 1] = beamSize;
InitTensor(&scoreTopK, order, dimsTopK, score.dataType, InitTensor(&scoreTopK, order, dimsTopK, score.dataType, score.devID);
score.devID); InitTensor(&index, order, dimsTopK, X_INT, score.devID);
InitTensor(&index, order, dimsTopK, X_INT,
score.devID);
InitTensor(&preID, order, dimsTopK, X_INT, -1); InitTensor(&preID, order, dimsTopK, X_INT, -1);
/* mask the first and the padding id */
int dimMask[]{ score.GetDim(-1) };
XTensor mask;
InitTensor(&mask, 1, dimMask, X_FLOAT, -1);
mask.SetZeroAll();
mask.Set1D(-1e20F, 0);
mask.Set1D(-1e20F, 1);
mask.SetDevice(score.devID, score.mem);
//_SumDim(&score, &mask, 2);
score.Reshape(order, dimsBeam); score.Reshape(order, dimsBeam);
/* keep the most promissing candidates in the beam */ /* keep the most promissing candidates in the beam */
/* TODO: check this line */
TopK(score, scoreTopK, index, -1, beamSize); TopK(score, scoreTopK, index, -1, beamSize);
CopyValues(index, preID); CopyValues(index, preID);
/* "preID" represents the id (or the offset) of the previous state used to make the current /* "preID" represents the id (or the offset) of the previous state used to make the current
...@@ -313,11 +321,14 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -313,11 +321,14 @@ void T2TSearch::Generate(T2TStateBundle * beam)
/* CPU data (TODO: remove GPU->CPU data copy!!!) */ /* CPU data (TODO: remove GPU->CPU data copy!!!) */
XTensor indexGPU; XTensor indexGPU;
indexGPU = CopyValues(index); indexGPU = CopyValues(index);
//InitTensorV2(&indexCPU, index.order, index.dimSize, index.dataType, index.denseRatio, -1);
//CopyValues(index, indexCPU);
for (int i = 0; i < indexGPU.unitNum; i++) for (int i = 0; i < indexGPU.unitNum; i += beamSize) {
indexGPU.SetInt(i * stride + indexGPU.GetInt(i), i); for (int j = 0; j < beamSize; j++)
indexGPU.SetInt(i * stride + indexGPU.GetInt(i + j), i + j);
}
/*for (int i = 0; i < indexGPU.unitNum; i++) {
indexGPU.SetInt(i + indexGPU.GetInt(i), i);
}*/
CheckNTErrors(IsSameShaped(prob, probPath), "Wrong tensor shape!"); CheckNTErrors(IsSameShaped(prob, probPath), "Wrong tensor shape!");
...@@ -460,6 +471,10 @@ void T2TSearch::Collect(T2TStateBundle * beam) ...@@ -460,6 +471,10 @@ void T2TSearch::Collect(T2TStateBundle * beam)
CheckNTErrors(state.pid >= 0 && state.pid < batchSize, CheckNTErrors(state.pid >= 0 && state.pid < batchSize,
"Invalid sample id!"); "Invalid sample id!");
/* check if this is the first end symbol. It is false
if there have been end symbols in previously generated words. */
bool isCompleted = state.isCompleted && (state.last == NULL || !state.last->isCompleted);
/* we push the hypothesis into the heap when it is completed */ /* we push the hypothesis into the heap when it is completed */
if(state.isEnd != 0) if(state.isEnd != 0)
......
...@@ -93,17 +93,13 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -93,17 +93,13 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
{ {
count++; count++;
wordCount = 0; wordCount = 0;
/*if (count % 10 == 0 && sentBatch < 128)
sentBatch *= 2;*/
/* reset cache for decoder */
for (int i = 0; i < model->decoder->nlayer; ++i) { for (int i = 0; i < model->decoder->nlayer; ++i) {
model->decoder->selfCache[i].Clear(); model->decoder->selfCache[i].miss = true;
model->decoder->contextCache[i].Clear(); model->decoder->contextCache[i].miss = true;
} }
vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID); vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID);
XTensor output; XTensor output;
seacher.Search(model, &batchEnc, &paddingEnc, &output); seacher.Search(model, &batchEnc, &paddingEnc, &output);
...@@ -122,7 +118,6 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -122,7 +118,6 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
res.id = indices[i]; res.id = indices[i];
batchLoader.resBuffer.emplace_back(res); batchLoader.resBuffer.emplace_back(res);
} }
wc = batchEnc.GetDim(-1); wc = batchEnc.GetDim(-1);
wordCount += wc; wordCount += wc;
......
...@@ -54,7 +54,7 @@ int TransformerMain(int argc, const char ** argv) ...@@ -54,7 +54,7 @@ int TransformerMain(int argc, const char ** argv)
char * rawModel = new char[MAX_LINE_LENGTH]; char * rawModel = new char[MAX_LINE_LENGTH];
LoadParamString(argc, args, "model", modelFN, ""); LoadParamString(argc, args, "model", modelFN, "");
LoadParamString(argc, args, "rawModel", rawModel, ""); LoadParamString(argc, args, "rawmodel", rawModel, "");
LoadParamString(argc, args, "test", testFN, ""); LoadParamString(argc, args, "test", testFN, "");
LoadParamString(argc, args, "output", outputFN, ""); LoadParamString(argc, args, "output", outputFN, "");
LoadParamBool(argc, args, "beamsearch", &isBeamSearch, false); LoadParamBool(argc, args, "beamsearch", &isBeamSearch, false);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论