Commit 99097e41 by huchi

add support for greedy search

parent bfa6fc90
...@@ -19,6 +19,10 @@ ...@@ -19,6 +19,10 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-10 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-10
*/ */
//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>
//#include <crtdbg.h>
#include <stdio.h> #include <stdio.h>
#include "./network/XNet.h" #include "./network/XNet.h"
#include "./tensor/XUtility.h" #include "./tensor/XUtility.h"
...@@ -27,9 +31,7 @@ ...@@ -27,9 +31,7 @@
#include "./sample/fnnlm/FNNLM.h" #include "./sample/fnnlm/FNNLM.h"
#include "./sample/transformer/Transformer.h" #include "./sample/transformer/Transformer.h"
//#define CRTDBG_MAP_ALLOC
//#include <stdlib.h>
//#include <crtdbg.h>
using namespace nts; using namespace nts;
using namespace fnnlm; using namespace fnnlm;
...@@ -37,19 +39,10 @@ using namespace transformer; ...@@ -37,19 +39,10 @@ using namespace transformer;
int main( int argc, const char ** argv ) int main( int argc, const char ** argv )
{ {
//_CrtSetDbgFlag(_CrtSetDbgFlag(_CRTDBG_REPORT_FLAG) | _CRTDBG_LEAK_CHECK_DF); /*_CrtSetDbgFlag(_CrtSetDbgFlag(_CRTDBG_REPORT_FLAG) | _CRTDBG_LEAK_CHECK_DF);
//_CrtSetBreakAlloc(2708); _CrtSetBreakAlloc(2708);*/
TransformerMain(argc - 1, argv + 1); TransformerMain(argc - 1, argv + 1);
/*XTensor x;
InitTensor2D(&x, 2, 2);
float d[]{ 1,2,3,4 };
x.SetData(d, 4);
XTensor y;
y = ReduceSum(x, 0);
y.Dump(stderr);*/
//_CrtDumpMemoryLeaks(); //_CrtDumpMemoryLeaks();
return 0; return 0;
......
...@@ -34,7 +34,7 @@ T2TAttention::T2TAttention() ...@@ -34,7 +34,7 @@ T2TAttention::T2TAttention()
nhead = -1; nhead = -1;
dk = -1; dk = -1;
dv = -1; dv = -1;
d = -1; d = -1;
isMasked = false; isMasked = false;
ignored = 0; ignored = 0;
} }
...@@ -62,7 +62,7 @@ void T2TAttention::InitModel(int argc, char** argv, ...@@ -62,7 +62,7 @@ void T2TAttention::InitModel(int argc, char** argv,
float minmax = 0; float minmax = 0;
LoadParamInt(argc, argv, "nhead", &nhead, 8); LoadParamInt(argc, argv, "nhead", &nhead, 4);
LoadParamInt(argc, argv, "d", &dk, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &dk, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &dv, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &dv, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
...@@ -70,15 +70,15 @@ void T2TAttention::InitModel(int argc, char** argv, ...@@ -70,15 +70,15 @@ void T2TAttention::InitModel(int argc, char** argv,
LoadParamFloat(argc, argv, "attminmax", &minmax, 0.1F); LoadParamFloat(argc, argv, "attminmax", &minmax, 0.1F);
LoadParamFloat(argc, argv, "dropoutatt", &dropoutP, 0); LoadParamFloat(argc, argv, "dropoutatt", &dropoutP, 0);
InitTensor2D(&wq, d, d, X_FLOAT, devID); InitTensor2DV2(&wq, d, d, X_FLOAT, devID);
InitTensor1D(&bq, d, X_FLOAT, devID); InitTensor1DV2(&bq, d, X_FLOAT, devID);
InitTensor2D(&wk, d, d, X_FLOAT, devID); InitTensor2DV2(&wk, d, d, X_FLOAT, devID);
InitTensor1D(&bk, d, X_FLOAT, devID); InitTensor1DV2(&bk, d, X_FLOAT, devID);
InitTensor2D(&wv, d, d, X_FLOAT, devID); InitTensor2DV2(&wv, d, d, X_FLOAT, devID);
InitTensor1D(&bv, d, X_FLOAT, devID); InitTensor1DV2(&bv, d, X_FLOAT, devID);
InitTensor2D(&rp_embedding_k, max_relative_position * 2 + 1, d/nhead, X_FLOAT, devID); InitTensor2DV2(&rp_embedding_k, max_relative_position * 2 + 1, d/nhead, X_FLOAT, devID);
InitTensor2D(&wa, d, d, X_FLOAT, devID); InitTensor2DV2(&wo, d, d, X_FLOAT, devID);
InitTensor1D(&ba, d, X_FLOAT, devID); InitTensor1DV2(&bo, d, X_FLOAT, devID);
} }
/* /*
...@@ -94,24 +94,27 @@ make the network ...@@ -94,24 +94,27 @@ make the network
>> cacheType - which type that cache is >> cacheType - which type that cache is
<< return - multi-attention result << return - multi-attention result
*/ */
XTensor T2TAttention::Make( XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, Cache* cache, int cacheType) XTensor T2TAttention::Make(XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, Cache* cache, int cacheType)
{ {
const bool isEnc = (!cache) ? true : false; const bool isEnc = (!cache) ? true : false;
/* linear transformation before self-attention */ /* linear transformation before self-attention */
XTensor q2, k2, v2; XTensor q2, k2, v2;
q2 = MatrixMul(q, X_NOTRANS, wq, X_TRANS) + bq;
q2 = MatrixMul(q, wq) + bq;
if (!cache) { if (!cache) {
/* self attention for encoder layers */ /* self attention for encoder layers */
k2 = MatrixMul(k, X_NOTRANS, wk, X_TRANS) + bk; k2 = MatrixMul(k, wk) + bk;
v2 = MatrixMul(v, X_NOTRANS, wv, X_TRANS) + bv; v2 = MatrixMul(v, wv) + bv;
return MakeRPRAttention(k2, q2, v2, mask, isTraining, isEnc); return MakeRPRAttention(k2, q2, v2, mask, isTraining, isEnc);
} }
else { else {
if (cacheType == SELF_ATT) { if (cacheType == SELF_ATT) {
k2 = MatrixMul(k, X_NOTRANS, wk, X_TRANS) + bk; k2 = MatrixMul(k, wk) + bk;
v2 = MatrixMul(v, X_NOTRANS, wv, X_TRANS) + bv; v2 = MatrixMul(v, wv) + bv;
/* if hit, we only concat the cache with the new token */ /* if hit, we only concat the cache with the new token */
if (!cache->miss) { if (!cache->miss) {
...@@ -121,12 +124,13 @@ XTensor T2TAttention::Make( XTensor& k, XTensor& q, XTensor& v, XTensor* mask, ...@@ -121,12 +124,13 @@ XTensor T2TAttention::Make( XTensor& k, XTensor& q, XTensor& v, XTensor* mask,
cache->key = k2; cache->key = k2;
cache->value = v2; cache->value = v2;
cache->miss = false; cache->miss = false;
return MakeRPRAttention(cache->key, q2, cache->value, mask, isTraining, isEnc); return MakeRPRAttention(cache->key, q2, cache->value, mask, isTraining, isEnc);
} }
else if (cacheType == EN_DE_ATT) { else if (cacheType == EN_DE_ATT) {
if (cache->miss) { if (cache->miss) {
cache->key = MatrixMul(k, X_NOTRANS, wk, X_TRANS) + bk; cache->key = MatrixMul(k, wk) + bk;
cache->value = MatrixMul(v, X_NOTRANS, wv, X_TRANS) + bv; cache->value = MatrixMul(v, wv) + bv;
cache->miss = false; cache->miss = false;
} }
return MakeAttention(cache->key, q2, cache->value, mask, isTraining, isEnc); return MakeAttention(cache->key, q2, cache->value, mask, isTraining, isEnc);
...@@ -134,50 +138,49 @@ XTensor T2TAttention::Make( XTensor& k, XTensor& q, XTensor& v, XTensor* mask, ...@@ -134,50 +138,49 @@ XTensor T2TAttention::Make( XTensor& k, XTensor& q, XTensor& v, XTensor* mask,
CheckNTErrors(0, "invalid cache type"); CheckNTErrors(0, "invalid cache type");
} }
} }
/* /*
make the attention network given keys, queries and values (after linear transformation) make the attention network given keys, queries and values (after linear transformation)
>> k - keys. It might be of size B * L * H >> k - keys. It might be of size B * L * H
where B = batch size, L = sequence length, where B = batch size, L = sequence length,
and H = vector size of each position and H = vector size of each position
>> q - queries >> q - queries
>> v - values >> v - values
>> mask - as it is >> mask - as it is
>> isTraining - indicates whether the model is used for training >> isTraining - indicates whether the model is used for training
*/ */
XTensor T2TAttention::MakeAttention(XTensor &k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, bool is_encoder) XTensor T2TAttention::MakeAttention(XTensor& k, XTensor& q, XTensor& v, XTensor* mask, bool isTraining, bool is_encoder)
{ {
XTensor kheads; XTensor kheads;
XTensor qheads; XTensor qheads;
XTensor vheads; XTensor vheads;
/* multi head */ /* multi head */
kheads = Split(k, k.order - 1, nhead); kheads = Split(k, k.order - 1, nhead);
qheads = Split(q, q.order - 1, nhead); qheads = Split(q, q.order - 1, nhead);
vheads = Split(v, v.order - 1, nhead); vheads = Split(v, v.order - 1, nhead);
XTensor att; XTensor att;
XTensor dot; XTensor dot;
XTensor scalar; XTensor scalar;
/* scalar = softmax(Q * K^T / sqrt(dk)) * V */ /* scalar = softmax(Q * K^T / sqrt(dk)) * V */
dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS); dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
/*if (isMasked && mask) { /*if (isMasked && mask)
_SumMe(&dot, mask); _SumMe(&dot, mask);*/
}*/
dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead)); dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead));
scalar = Softmax(dot, -1); scalar = Softmax(dot, -1);
/*if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
scalar = Dropout(scalar, dropoutP);*/ scalar = Dropout(scalar, dropoutP);
att = BMMul(scalar, vheads); att = BMMul(scalar, vheads);
/* concatenate the heads */ /* concatenate the heads */
return MulAndShift(Merge(att, att.order - 1), X_NOTRANS, wa, X_TRANS, ba); return MulAndShift(Merge(att, att.order - 1), wo, bo);
} }
/* /*
...@@ -215,34 +218,32 @@ XTensor T2TAttention::MakeRPRAttention(XTensor& k, XTensor& q, XTensor& v, XTens ...@@ -215,34 +218,32 @@ XTensor T2TAttention::MakeRPRAttention(XTensor& k, XTensor& q, XTensor& v, XTens
InitTensor4DV2(&dot, nhead, batch_size, len_q, len_kv, X_FLOAT, q.devID); InitTensor4DV2(&dot, nhead, batch_size, len_q, len_kv, X_FLOAT, q.devID);
/* generate the relative emb index (L_q, L_kv) */ /* generate the relative emb index (L_q, L_kv) */
GetRPEmbedding(&emb_matrix, len_q, len_kv, max_relative_position, q.devID,is_encoder); GetRPEmbedding(&emb_matrix, len_q, len_kv, max_relative_position, q.devID, is_encoder);
/* generate the relative key from the rp_embedding_k (L_q, L_kv, H/K) */ /* generate the relative key from the rp_embedding_k (L_q, L_kv, H/K) */
_Gather(&rp_embedding_k, &relative_key, &emb_matrix); _Gather(&rp_embedding_k, &relative_key, &emb_matrix);
/* RPR dot product (K, B, L_q, L_kv)*/ /* RPR dot product (K, B, L_q, L_kv)*/
qheads = qheads / float(nhead);
RPDotProduct(&qheads, &kheads, &relative_key, &dot, true); RPDotProduct(&qheads, &kheads, &relative_key, &dot, true);
/*if (isMasked && mask) /*if (isMasked && mask)
_SumMe(&dot, mask);*/ _SumMe(&dot, mask);*/
/* scale the dot result */ /* scale the dot result */
//dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead)); dot = Linear(dot, 1.0F / (float)sqrt((float)dk / nhead));
/* softmax */ /* softmax */
scalar = Softmax(dot, -1); scalar = Softmax(dot, -1);
/*if (isTraining && dropoutP > 0) if (isTraining && dropoutP > 0)
scalar = Dropout(scalar, dropoutP);*/ scalar = Dropout(scalar, dropoutP);
/* generate the relative attention output (K, B, L_q, H/K) */ /* generate the relative attention output (K, B, L_q, H/K) */
att = BMMul(scalar, vheads); att = BMMul(scalar, vheads);
/* concatenate the heads */ /* concatenate the heads */
return MulAndShift(Merge(att, att.order - 1), X_NOTRANS, wa, X_TRANS, ba); return MulAndShift(Merge(att, att.order - 1), wo, bo);
} }
void T2TAttention::GetRPEmbedding(XTensor* emb_matrix, const int len_q, const int len_kv, const int max_relative_length, const int devID, const bool is_encoder) void T2TAttention::GetRPEmbedding(XTensor* emb_matrix, const int len_q, const int len_kv, const int max_relative_length, const int devID, const bool is_encoder)
...@@ -251,10 +252,11 @@ void T2TAttention::GetRPEmbedding(XTensor* emb_matrix, const int len_q, const in ...@@ -251,10 +252,11 @@ void T2TAttention::GetRPEmbedding(XTensor* emb_matrix, const int len_q, const in
XTensor range; XTensor range;
InitTensor1DV2(&range, len_kv, X_INT, devID); InitTensor1DV2(&range, len_kv, X_INT, devID);
int* index = new int[len_kv]; int* index = new int[len_kv];
// for encoder self-attention which the L_q = L_kv // for encoder self-attention which the L_q = L_kv
if (is_encoder) if (is_encoder)
{ {
for (int i = 0; i <len_kv; i++) for (int i = 0; i < len_kv; i++)
index[i] = i; index[i] = i;
range.SetData(index, len_kv); range.SetData(index, len_kv);
XTensor range_2D, range_2D_t; XTensor range_2D, range_2D_t;
...@@ -267,7 +269,7 @@ void T2TAttention::GetRPEmbedding(XTensor* emb_matrix, const int len_q, const in ...@@ -267,7 +269,7 @@ void T2TAttention::GetRPEmbedding(XTensor* emb_matrix, const int len_q, const in
// for decoder self-attention which the L_q != L_kv, and L_q is 1 // for decoder self-attention which the L_q != L_kv, and L_q is 1
else else
{ {
for (int i = 0; i <len_kv; i++) for (int i = 0; i < len_kv; i++)
index[i] = -len_kv + i + 1; index[i] = -len_kv + i + 1;
range.SetData(index, len_kv); range.SetData(index, len_kv);
_Unsqueeze(&range, emb_matrix, 0, len_q); _Unsqueeze(&range, emb_matrix, 0, len_q);
...@@ -299,7 +301,6 @@ void T2TAttention::RPDotProduct(XTensor* x, XTensor* y, XTensor* z, XTensor* att ...@@ -299,7 +301,6 @@ void T2TAttention::RPDotProduct(XTensor* x, XTensor* y, XTensor* z, XTensor* att
XTensor context; XTensor context;
InitTensor4DV2(&context, head_num, batch_size, len_q, last_dim, X_FLOAT, x->devID); InitTensor4DV2(&context, head_num, batch_size, len_q, last_dim, X_FLOAT, x->devID);
_MatrixMulBatched(x, X_NOTRANS, y, transpose_flag, &context); _MatrixMulBatched(x, X_NOTRANS, y, transpose_flag, &context);
//if (profiler_) profiler_->FinishTimer("RPDotPro-BMM");
// reshape and transpose x to (L_q, K*B, H/K or L_kv) // reshape and transpose x to (L_q, K*B, H/K or L_kv)
int merge_dims[] = { head_num * batch_size, len_q, x->dimSize[3] }; int merge_dims[] = { head_num * batch_size, len_q, x->dimSize[3] };
...@@ -323,5 +324,6 @@ void T2TAttention::RPDotProduct(XTensor* x, XTensor* y, XTensor* z, XTensor* att ...@@ -323,5 +324,6 @@ void T2TAttention::RPDotProduct(XTensor* x, XTensor* y, XTensor* z, XTensor* att
relative_t.Reshape(4, split_dims); relative_t.Reshape(4, split_dims);
_Sum(&context, &relative_t, attention); _Sum(&context, &relative_t, attention);
} }
} }
...@@ -90,14 +90,18 @@ public: ...@@ -90,14 +90,18 @@ public:
/* bias for V */ /* bias for V */
XTensor bv; XTensor bv;
XTensor wBig;
XTensor bBig;
/* RPR emb */ /* RPR emb */
XTensor rp_embedding_k; XTensor rp_embedding_k;
/* transformation after dot-product attention */ /* transformation after dot-product attention */
XTensor wa; XTensor wo;
/* bias after dot-product attention */ /* bias after dot-product attention */
XTensor ba; XTensor bo;
/* size of transformed Q and K */ /* size of transformed Q and K */
int dk; int dk;
......
...@@ -31,27 +31,27 @@ namespace transformer ...@@ -31,27 +31,27 @@ namespace transformer
/* constructor */ /* constructor */
AttDecoder::AttDecoder() AttDecoder::AttDecoder()
{ {
attentions = NULL; selfAtt = NULL;
fnns = NULL; fnns = NULL;
attLayerNorms = NULL; selfAttLayerNorms = NULL;
attentionsEnde = NULL; enDeAtt = NULL;
attEndeLayerNorms = NULL; enDeAttLayerNorms = NULL;
decodeLayerNorm = NULL; decoderLayerNorm = NULL;
selfCache = NULL; selfAttCache = NULL;
contextCache = NULL; enDeAttCache = NULL;
} }
/* de-constructor */ /* de-constructor */
AttDecoder::~AttDecoder() AttDecoder::~AttDecoder()
{ {
delete[] selfCache; delete[] selfAttCache;
delete[] contextCache; delete[] enDeAttCache;
delete[] attentions; delete[] selfAtt;
delete[] fnns; delete[] fnns;
delete[] attLayerNorms; delete[] selfAttLayerNorms;
delete[] attentionsEnde; delete[] enDeAtt;
delete[] attEndeLayerNorms; delete[] enDeAttLayerNorms;
delete decodeLayerNorm; delete decoderLayerNorm;
} }
/* /*
...@@ -71,7 +71,7 @@ void AttDecoder::InitModel(int argc, char ** argv, ...@@ -71,7 +71,7 @@ void AttDecoder::InitModel(int argc, char ** argv,
devID = myDevID; devID = myDevID;
ignored = myIgnored; ignored = myIgnored;
LoadParamInt(argc, argv, "nlayer", &nlayer, 3); LoadParamInt(argc, argv, "nlayer", &nlayer, 4);
LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "esize", &eSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "esize", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "vsizetgt", &vSize, 34040); LoadParamInt(argc, argv, "vsizetgt", &vSize, 34040);
...@@ -83,24 +83,24 @@ void AttDecoder::InitModel(int argc, char ** argv, ...@@ -83,24 +83,24 @@ void AttDecoder::InitModel(int argc, char ** argv,
/* embedding model */ /* embedding model */
embedder.InitModel(argc, argv, devID, false); embedder.InitModel(argc, argv, devID, false);
attentions = new T2TAttention[nlayer]; selfAtt = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer]; fnns = new T2TFNN[nlayer];
attLayerNorms = new T2TLN[nlayer]; selfAttLayerNorms = new T2TLN[nlayer];
attentionsEnde = new T2TAttention[nlayer]; enDeAtt = new T2TAttention[nlayer];
attEndeLayerNorms = new T2TLN[nlayer]; enDeAttLayerNorms = new T2TLN[nlayer];
decodeLayerNorm = new T2TLN; decoderLayerNorm = new T2TLN;
selfCache = new Cache[nlayer]; selfAttCache = new Cache[nlayer];
contextCache = new Cache[nlayer]; enDeAttCache = new Cache[nlayer];
/* initialize the stacked layers */ /* initialize the stacked layers */
for (int i = 0; i < nlayer; i++) { for (int i = 0; i < nlayer; i++) {
attentions[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID); selfAtt[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID);
fnns[i].InitModel(argc, argv, myDevID); fnns[i].InitModel(argc, argv, myDevID);
attLayerNorms[i].InitModel(argc, argv, myDevID); selfAttLayerNorms[i].InitModel(argc, argv, myDevID);
attentionsEnde[i].InitModel(argc, argv, true, myIgnored, myDevID); enDeAtt[i].InitModel(argc, argv, true, myIgnored, myDevID);
attEndeLayerNorms[i].InitModel(argc, argv, myDevID); enDeAttLayerNorms[i].InitModel(argc, argv, myDevID);
} }
decodeLayerNorm->InitModel(argc, argv, myDevID); decoderLayerNorm->InitModel(argc, argv, myDevID);
} }
/* /*
...@@ -131,48 +131,38 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X ...@@ -131,48 +131,38 @@ XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor *mask, X
XTensor attNorm; XTensor attNorm;
/* layer normalization */ /* layer normalization */
inputNorm = attLayerNorms[i].Make(x); inputNorm = selfAttLayerNorms[i].Make(x);
//inputNorm.Dump(stderr, "inputNorm", 10);
/******************/ /******************/
/* self attention */ /* self attention */
att = attentions[i].Make(inputNorm, inputNorm, inputNorm, NULL, isTraining, &selfCache[i], SELF_ATT); att = selfAtt[i].Make(inputNorm, inputNorm, inputNorm, NULL, isTraining, &selfAttCache[i], SELF_ATT);
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP); att = Dropout(att, dropoutP);
/* residual connection */ /* residual connection */
_SumMe(&att, &x); att = att + x;
//att.Dump(stderr, "Sum(att, x)", 10);
/* layer normalization */ /* layer normalization */
attNorm = attEndeLayerNorms[i].Make(att); attNorm = enDeAttLayerNorms[i].Make(att);
//attNorm.Dump(stderr, "attNorm", 10);
/* encoder-decoder attention */ /* encoder-decoder attention */
ende = attentionsEnde[i].Make(outputEnc, attNorm, outputEnc, &maskEncDec, isTraining, &contextCache[i], EN_DE_ATT); ende = enDeAtt[i].Make(outputEnc, attNorm, outputEnc, &maskEncDec, isTraining, &enDeAttCache[i], EN_DE_ATT);
//ende.Dump(stderr, "ende atten", 10);
/* dropout */ /* dropout */
if(isTraining && dropoutP > 0) if(isTraining && dropoutP > 0)
ende = Dropout(ende, dropoutP); ende = Dropout(ende, dropoutP);
/* residual connection */ /* residual connection */
_SumMe(&ende, &att); ende = ende + att;
//res.Dump(stderr, "Sum(ende, att)", 10);
/* fnn */ /* fnn */
x = fnns[i].Make(ende, isTraining); x = fnns[i].Make(ende, isTraining);
//x.Dump(stderr, "fnns[i]", 10);
} }
x = decodeLayerNorm->Make(x); x = decoderLayerNorm->Make(x);
//x.Dump(stderr, "decodeLayerNorm", 10);
x.SetName(DECODING_NAME);
return x; return x;
} }
......
...@@ -63,13 +63,13 @@ public: ...@@ -63,13 +63,13 @@ public:
T2TFNN * fnns; T2TFNN * fnns;
/* attention model of each layer */ /* attention model of each layer */
T2TAttention * attentions; T2TAttention * selfAtt;
/* layer normalization for attention */ /* layer normalization for attention */
T2TLN * attLayerNorms; T2TLN * selfAttLayerNorms;
/* layer normalization for decoder */ /* layer normalization for decoder */
T2TLN * decodeLayerNorm; T2TLN * decoderLayerNorm;
/* input tensor of the encoder */ /* input tensor of the encoder */
XTensor * input; XTensor * input;
...@@ -78,16 +78,16 @@ public: ...@@ -78,16 +78,16 @@ public:
XTensor * output; XTensor * output;
/* encoder-decoder attention model of each layer */ /* encoder-decoder attention model of each layer */
T2TAttention * attentionsEnde; T2TAttention * enDeAtt;
/* layer normalization for encoder-decoder attention */ /* layer normalization for encoder-decoder attention */
T2TLN * attEndeLayerNorms; T2TLN * enDeAttLayerNorms;
/* layer cache list */ /* layer cache list */
Cache* selfCache; Cache* selfAttCache;
/* layer cache list */ /* layer cache list */
Cache* contextCache; Cache* enDeAttCache;
public: public:
/* constructor */ /* constructor */
......
...@@ -62,7 +62,7 @@ void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, bool isEnc) ...@@ -62,7 +62,7 @@ void T2TEmbedder::InitModel(int argc, char ** argv, int myDevID, bool isEnc)
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "pad", &padIdx, 1); LoadParamInt(argc, argv, "pad", &padIdx, 1);
InitTensor2D(&w, vSize, eSize, X_FLOAT, devID); InitTensor2DV2(&w, vSize, eSize, X_FLOAT, devID);
maxLength = maxLength + 1 + 1; maxLength = maxLength + 1 + 1;
DTYPE v = 1.0F/(float)sqrt((float)eSize); DTYPE v = 1.0F/(float)sqrt((float)eSize);
...@@ -80,7 +80,7 @@ make positional embeddings (of size eSize * length) ...@@ -80,7 +80,7 @@ make positional embeddings (of size eSize * length)
*/ */
void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length, int padIdx) void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length, int padIdx)
{ {
InitTensor2D(&posEmbeddingBase, length, eSize, X_FLOAT, devID); InitTensor2DV2(&posEmbeddingBase, length, eSize, X_FLOAT, devID);
float * data = new float[posEmbeddingBase.unitNum]; float * data = new float[posEmbeddingBase.unitNum];
...@@ -113,47 +113,47 @@ make the network ...@@ -113,47 +113,47 @@ make the network
*/ */
XTensor T2TEmbedder::Make(XTensor &input, int prevLen) XTensor T2TEmbedder::Make(XTensor &input, int prevLen)
{ {
/* assert padding index is 1 */ ///* assert padding index is 1 */
CheckNTErrors(input.order > 1, "Wrong input tensor size!"); //CheckNTErrors(input.order > 1, "Wrong input tensor size!");
CheckNTErrors(input.dimSize[input.order - 1] < maxLength, "The sequence is too long!"); //CheckNTErrors(input.dimSize[input.order - 1] < maxLength, "The sequence is too long!");
CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\""); //CheckNTErrors(vSize > 0, "set vocabulary size by \"-vsize\"");
CheckNTErrors(eSize > 0, "set embedding size by \"-esize\""); //CheckNTErrors(eSize > 0, "set embedding size by \"-esize\"");
//
XTensor wordEmbedding, position, posEmbedding; //XTensor wordEmbedding, position, posEmbedding;
InitTensor(&position, &input); //InitTensor(&position, &input);
int* posData = new int[input.unitNum];
XTensor inputCPU;
InitTensorOnCPU(&inputCPU, &input);
_CopyValues(&input, &inputCPU);
for (int i = 0; i < inputCPU.GetDim(0); i++) {
int startNoPad = 2 + prevLen - 1;
int* p = ((int*)inputCPU.data) + i * inputCPU.GetDim(1);
for (int j = 0; j < inputCPU.GetDim(1); j++) {
if (p[j] == 1) {
posData[i * inputCPU.GetDim(1) + j] = 1;
}
else {
posData[i * inputCPU.GetDim(1) + j] = startNoPad++;
}
}
}
position.SetData(posData, position.unitNum); //int* posData = new int[input.unitNum];
delete[] posData;
/* we make positional embeddings first */ //XTensor inputCPU;
if(true){ //InitTensorOnCPU(&inputCPU, &input);
posEmbedding = Gather(posEmbeddingBase, position); //_CopyValues(&input, &inputCPU);
}
/* then we make word embeddings */
//for (int i = 0; i < inputCPU.GetDim(0); i++) {
// int startNoPad = 2 + prevLen - 1;
// int* p = ((int*)inputCPU.data) + i * inputCPU.GetDim(1);
// for (int j = 0; j < inputCPU.GetDim(1); j++) {
// if (p[j] == 1) {
// posData[i * inputCPU.GetDim(1) + j] = 1;
// }
// else {
// posData[i * inputCPU.GetDim(1) + j] = startNoPad++;
// }
// }
//}
//position.SetData(posData, position.unitNum);
//delete[] posData;
///* we make positional embeddings first */
//if(true){
// posEmbedding = Gather(posEmbeddingBase, position);
//}
/* then we make word embeddings */
XTensor wordEmbedding;
wordEmbedding = Gather(w, input); wordEmbedding = Gather(w, input);
wordEmbedding = Linear(wordEmbedding, (float)sqrt((float)eSize)); wordEmbedding = Linear(wordEmbedding, (float)sqrt((float)eSize));
......
...@@ -29,7 +29,7 @@ using namespace nts; ...@@ -29,7 +29,7 @@ using namespace nts;
namespace transformer namespace transformer
{ {
#define DEFAULT_EMBEDDING_SIZE 512 #define DEFAULT_EMBEDDING_SIZE 128
/* /*
embedding (of word at position i): embedding (of word at position i):
......
...@@ -34,7 +34,7 @@ AttEncoder::AttEncoder() ...@@ -34,7 +34,7 @@ AttEncoder::AttEncoder()
attentions = NULL; attentions = NULL;
fnns = NULL; fnns = NULL;
attLayerNorms = NULL; attLayerNorms = NULL;
encodeLayerNorm = NULL; encoderLayerNorm = NULL;
} }
/* de-constructor */ /* de-constructor */
...@@ -43,7 +43,7 @@ AttEncoder::~AttEncoder() ...@@ -43,7 +43,7 @@ AttEncoder::~AttEncoder()
delete[] attentions; delete[] attentions;
delete[] fnns; delete[] fnns;
delete[] attLayerNorms; delete[] attLayerNorms;
delete encodeLayerNorm; delete encoderLayerNorm;
} }
/* /*
...@@ -61,7 +61,7 @@ void AttEncoder::InitModel(int argc, char ** argv, ...@@ -61,7 +61,7 @@ void AttEncoder::InitModel(int argc, char ** argv,
devID = myDevID; devID = myDevID;
ignored = myIgnored; ignored = myIgnored;
LoadParamInt(argc, argv, "nlayer", &nlayer, 35); LoadParamInt(argc, argv, "nlayer", &nlayer, 20);
LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "hsize", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "esize", &eSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "esize", &eSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "vsize", &vSize, 34040); LoadParamInt(argc, argv, "vsize", &vSize, 34040);
...@@ -76,7 +76,7 @@ void AttEncoder::InitModel(int argc, char ** argv, ...@@ -76,7 +76,7 @@ void AttEncoder::InitModel(int argc, char ** argv,
attentions = new T2TAttention[nlayer]; attentions = new T2TAttention[nlayer];
fnns = new T2TFNN[nlayer]; fnns = new T2TFNN[nlayer];
attLayerNorms = new T2TLN[nlayer]; attLayerNorms = new T2TLN[nlayer];
encodeLayerNorm = new T2TLN; encoderLayerNorm = new T2TLN;
/* initialize the stacked layers */ /* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){ for(int i = 0; i < nlayer; i++){
...@@ -84,7 +84,7 @@ void AttEncoder::InitModel(int argc, char ** argv, ...@@ -84,7 +84,7 @@ void AttEncoder::InitModel(int argc, char ** argv,
fnns[i].InitModel(argc, argv, myDevID); fnns[i].InitModel(argc, argv, myDevID);
attLayerNorms[i].InitModel(argc, argv, myDevID); attLayerNorms[i].InitModel(argc, argv, myDevID);
} }
encodeLayerNorm->InitModel(argc, argv, myDevID); encoderLayerNorm->InitModel(argc, argv, myDevID);
} }
/* /*
...@@ -123,13 +123,9 @@ XTensor AttEncoder::Make(XTensor &input, XTensor *mask, XTensor &maskEncDec, boo ...@@ -123,13 +123,9 @@ XTensor AttEncoder::Make(XTensor &input, XTensor *mask, XTensor &maskEncDec, boo
/* fnn */ /* fnn */
x = fnns[i].Make(res, isTraining); x = fnns[i].Make(res, isTraining);
} }
x = encodeLayerNorm->Make(x); x = encoderLayerNorm->Make(x);
x.SetName(ENCODING_NAME);
input.SetName(ENCODING_INPUT_NAME);
return x; return x;
} }
......
...@@ -93,11 +93,11 @@ public: ...@@ -93,11 +93,11 @@ public:
/* attention model of each layer */ /* attention model of each layer */
T2TAttention * attentions; T2TAttention * attentions;
/* layer normalization for attention */ /* layer normalizations for attention */
T2TLN * attLayerNorms; T2TLN * attLayerNorms;
/* layer normalization for encoder */ /* layer normalization for encoder */
T2TLN * encodeLayerNorm; T2TLN * encoderLayerNorm;
/* input tensor of the encoder */ /* input tensor of the encoder */
XTensor * input; XTensor * input;
......
/* NiuTrans.Tensor - an open-source tensor library /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University. * Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved. * All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/ */
#include <math.h> #include <math.h>
#include "T2TFNN.h" #include "T2TFNN.h"
...@@ -32,9 +32,9 @@ namespace transformer ...@@ -32,9 +32,9 @@ namespace transformer
/* constructor */ /* constructor */
T2TFNN::T2TFNN() T2TFNN::T2TFNN()
{ {
inSize = -1; inSize = -1;
outSize = -1; outSize = -1;
hSize = -1; hSize = -1;
} }
/* deconstructor */ /* deconstructor */
...@@ -42,28 +42,28 @@ T2TFNN::~T2TFNN() ...@@ -42,28 +42,28 @@ T2TFNN::~T2TFNN()
{ {
} }
/* /*
initialize the model initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list of pointers to the arguments >> argv - list of pointers to the arguments
>> myDevID - device id >> myDevID - device id
*/ */
void T2TFNN::InitModel(int argc, char ** argv, int myDevID) void T2TFNN::InitModel(int argc, char** argv, int myDevID)
{ {
devID = myDevID; devID = myDevID;
float minmax = 0; float minmax = 0;
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &outSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &outSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "fnnh", &hSize, outSize * 4); LoadParamInt(argc, argv, "fnnh", &hSize, outSize * 8);
LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.1F); LoadParamFloat(argc, argv, "fnnminmax", &minmax, 0.1F);
LoadParamFloat(argc, argv, "dropoutfnn", &dropoutP, 0); LoadParamFloat(argc, argv, "dropoutfnn", &dropoutP, 0);
InitTensor2DV2(&w1, hSize, inSize, X_FLOAT, devID); InitTensor2DV2(&w1, inSize, hSize, X_FLOAT, devID);
InitTensor1DV2(&b1, hSize, X_FLOAT, devID); InitTensor1DV2(&b1, hSize, X_FLOAT, devID);
InitTensor2DV2(&w2, outSize, hSize, X_FLOAT, devID); InitTensor2DV2(&w2, hSize, outSize, X_FLOAT, devID);
InitTensor1DV2(&b2, outSize, X_FLOAT, devID); InitTensor1DV2(&b2, outSize, X_FLOAT, devID);
fnnLayerNorm.InitModel(argc, argv, myDevID); fnnLayerNorm.InitModel(argc, argv, myDevID);
...@@ -78,25 +78,25 @@ void T2TFNN::InitModel(int argc, char ** argv, int myDevID) ...@@ -78,25 +78,25 @@ void T2TFNN::InitModel(int argc, char ** argv, int myDevID)
//b2.SetZeroAll(); //b2.SetZeroAll();
} }
/* /*
make the network make the network
y = max(0, x * w1 + b1) * w2 + b2 y = max(0, x * w1 + b1) * w2 + b2
>> input - the input tensor >> input - the input tensor
>> return - the output tensor >> return - the output tensor
*/ */
XTensor T2TFNN::Make(XTensor &input, bool isTraining) XTensor T2TFNN::Make(XTensor& input, bool isTraining)
{ {
XTensor t1; XTensor t1;
/* t1 = max(0, x * w1 + b1) */ /* t1 = max(0, x * w1 + b1) */
t1 = Rectify(MulAndShift(fnnLayerNorm.Make(input), X_NOTRANS, w1, X_TRANS, b1)); t1 = Rectify(MulAndShift(fnnLayerNorm.Make(input), w1, b1));
if(isTraining && dropoutP > 0) if (isTraining && dropoutP > 0)
t1 = Dropout(t1, dropoutP); t1 = Dropout(t1, dropoutP);
/* result = t1 * w2 + b2 */ /* result = t1 * w2 + b2 */
XTensor res; XTensor res;
res = MulAndShift(t1, X_NOTRANS, w2, X_TRANS, b2); res = MulAndShift(t1, w2, b2);
_SumMe(&res, &input); _SumMe(&res, &input);
return res; return res;
} }
......
...@@ -53,8 +53,8 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID) ...@@ -53,8 +53,8 @@ void T2TLN::InitModel(int argc, char ** argv, int myDevID)
d = 0; d = 0;
LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &d, DEFAULT_EMBEDDING_SIZE);
InitTensor1D(&w, d, X_FLOAT, devID); InitTensor1DV2(&w, d, X_FLOAT, devID);
InitTensor1D(&b, d, X_FLOAT, devID); InitTensor1DV2(&b, d, X_FLOAT, devID);
} }
/* /*
...@@ -78,7 +78,7 @@ XTensor T2TLN::Make(XTensor &input) ...@@ -78,7 +78,7 @@ XTensor T2TLN::Make(XTensor &input)
mean = ReduceMean(x, x.order - 1); mean = ReduceMean(x, x.order - 1);
/* \sigma = (sum_i (x_i - \mu)^2)/m */ /* \sigma = (sum_i (x_i - \mu)^2)/m */
variance = ReduceVariance(x, x.order - 1, mean); variance = ReduceVariance(x, x.order - 1, mean) + 1e-5F;
/* standard = sqrt(variance) */ /* standard = sqrt(variance) */
standard = Power(variance, 0.5F); standard = Power(variance, 0.5F);
...@@ -92,7 +92,7 @@ XTensor T2TLN::Make(XTensor &input) ...@@ -92,7 +92,7 @@ XTensor T2TLN::Make(XTensor &input)
xn = (x - meanFilled) / standardFilled; xn = (x - meanFilled) / standardFilled;
/* result = x' * w + b */ /* result = x' * w + b */
return xn * w + b; return xn * w + b;
} }
} }
...@@ -103,7 +103,7 @@ public: ...@@ -103,7 +103,7 @@ public:
/* read the parameters */ /* read the parameters */
void Read(const char * fn); void Read(const char * fn);
}; };
void FastRead(XTensor* x, FILE* f);
} }
#endif #endif
...@@ -56,13 +56,11 @@ void T2TOutput::InitModel(int argc, char ** argv, int myDevID) ...@@ -56,13 +56,11 @@ void T2TOutput::InitModel(int argc, char ** argv, int myDevID)
LoadParamInt(argc, argv, "vsizetgt", &vSize, -1); LoadParamInt(argc, argv, "vsizetgt", &vSize, -1);
LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &inSize, DEFAULT_EMBEDDING_SIZE);
LoadParamInt(argc, argv, "d", &hSize, DEFAULT_EMBEDDING_SIZE); LoadParamInt(argc, argv, "d", &hSize, DEFAULT_EMBEDDING_SIZE);
LoadParamFloat(argc, argv, "outputminmax", &minmax, 0.08F);
InitTensor2D(&w, hSize, vSize, X_FLOAT, devID); InitTensor2DV2(&w, vSize, hSize, X_FLOAT, devID);
} }
/* /*
make the network (redefined output tensor) make the network (redefined output tensor)
>> input - input tensor >> input - input tensor
...@@ -72,9 +70,7 @@ void T2TOutput::Make(XTensor &input, XTensor &output) ...@@ -72,9 +70,7 @@ void T2TOutput::Make(XTensor &input, XTensor &output)
{ {
XTensor &x = input; XTensor &x = input;
output = LogSoftmax(MMul(x, X_NOTRANS, w, X_NOTRANS), -1); output = LogSoftmax(MMul(x, X_NOTRANS, w, X_TRANS), -1);
output.SetName(OUTPUT_NAME);
} }
} }
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13
*/ */
#include "T2TPredictor.h" #include "T2TPredictor.h"
#include "../../tensor/core/CHeader.h" #include "../../tensor/core/CHeader.h"
...@@ -38,24 +38,24 @@ T2TStateBundle::T2TStateBundle() ...@@ -38,24 +38,24 @@ T2TStateBundle::T2TStateBundle()
/* de-constructor */ /* de-constructor */
T2TStateBundle::~T2TStateBundle() T2TStateBundle::~T2TStateBundle()
{ {
if(states != NULL) if (states != NULL)
delete[] states; delete[] states;
} }
/* /*
create states create states
>> num - number of states >> num - number of states
*/ */
void T2TStateBundle::MakeStates(int num) void T2TStateBundle::MakeStates(int num)
{ {
CheckNTErrors(num > 0, "invalid number"); CheckNTErrors(num > 0, "invalid number");
if(states != NULL) if (states != NULL)
delete[] states; delete[] states;
states = new T2TState[num]; states = new T2TState[num];
for(int i = 0; i < num; i++){ for (int i = 0; i < num; i++) {
states[i].prediction = -1; states[i].prediction = -1;
states[i].pid = T2T_PID_EMPTY; states[i].pid = T2T_PID_EMPTY;
states[i].isEnd = false; states[i].isEnd = false;
...@@ -74,7 +74,7 @@ void T2TStateBundle::MakeStates(int num) ...@@ -74,7 +74,7 @@ void T2TStateBundle::MakeStates(int num)
/* constructor */ /* constructor */
T2TPredictor::T2TPredictor() T2TPredictor::T2TPredictor()
{ {
startSymbol = -1; startSymbol = 2;
} }
/* de-constructor */ /* de-constructor */
...@@ -82,37 +82,44 @@ T2TPredictor::~T2TPredictor() ...@@ -82,37 +82,44 @@ T2TPredictor::~T2TPredictor()
{ {
} }
/* /*
create an initial state create an initial state
>> model - the t2t model >> model - the t2t model
>> top - the top-most layer of the network >> top - the top-most layer of the network
>> input - input of the network >> input - input of the network
>> beamSize - beam size >> beamSize - beam size
>> state - the state to be initialized >> state - the state to be initialized
*/ */
void T2TPredictor::Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state) void T2TPredictor::Create(T2TModel* model, XTensor* top, const XTensor* input, int beamSize, T2TStateBundle* state)
{ {
int dims[MAX_TENSOR_DIM_NUM]; int dims[MAX_TENSOR_DIM_NUM];
for (int i = 0; i < input->order - 1; i++) for (int i = 0; i < input->order - 1; i++)
dims[i] = input->GetDim(i); dims[i] = input->GetDim(i);
dims[input->order - 1] = beamSize; dims[input->order - 1] = beamSize;
InitTensor(&state->probPath, input->order, dims, X_FLOAT, input->devID); InitTensorV2(&state->probPath, input->order, dims, X_FLOAT, 1.0F, input->devID);
InitTensor(&state->nstep, input->order, dims, X_FLOAT, input->devID); InitTensorV2(&state->nstep, input->order, dims, X_FLOAT, 1.0F, input->devID);
InitTensor(&state->endMark, input->order, dims, X_INT, input->devID); InitTensorV2(&state->endMark, input->order, dims, X_INT, 1.0F, input->devID);
float* data = new float[state->probPath.unitNum]; /*float* data = new float[state->probPath.unitNum];
for (int i = 0; i < state->probPath.unitNum; ++i) { for (int i = 0; i < state->probPath.unitNum; ++i) {
data[i] = -1e20F; data[i] = -1e20F;
if (i % beamSize == 0) if (i % beamSize == 0)
data[i] = 0; data[i] = 0;
} }
state->probPath.SetData(data, state->probPath.unitNum); state->probPath.SetData(data, state->probPath.unitNum);
delete[] data;*/
SetDataFixed(state->probPath, -1e9F);
for (int i = 0; i < state->probPath.unitNum; ++i) {
if (i % beamSize == 0)
state->probPath.Set(0.0F, i);
}
state->nstep.SetZeroAll(); state->nstep.SetZeroAll();
state->endMark.SetZeroAll(); state->endMark.SetZeroAll();
delete[] data;
state->stateNum = 0; state->stateNum = 0;
} }
...@@ -125,15 +132,15 @@ void T2TPredictor::SetStartSymbol(int symbol) ...@@ -125,15 +132,15 @@ void T2TPredictor::SetStartSymbol(int symbol)
startSymbol = symbol; startSymbol = symbol;
} }
/* /*
read a state read a state
>> model - the t2t model that keeps the network created so far >> model - the t2t model that keeps the network created so far
>> state - a set of states. It keeps >> state - a set of states. It keeps
1) hypotheses (states) 1) hypotheses (states)
2) probablities of hypotheses 2) probablities of hypotheses
3) parts of the network for expanding toward the next state 3) parts of the network for expanding toward the next state
*/ */
void T2TPredictor::Read(T2TModel * model, T2TStateBundle * state) void T2TPredictor::Read(T2TModel* model, T2TStateBundle* state)
{ {
m = model; m = model;
s = state; s = state;
...@@ -147,8 +154,7 @@ predict the next state ...@@ -147,8 +154,7 @@ predict the next state
>> paddingEnc - padding of the encoder >> paddingEnc - padding of the encoder
>>> isStart - is the start or not >>> isStart - is the start or not
*/ */
void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, void T2TPredictor::Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inputEnc, XTensor* paddingEnc, bool isStart)
XTensor * inputEnc, XTensor * paddingEnc, bool isStart)
{ {
int dims[MAX_TENSOR_DIM_NUM]; int dims[MAX_TENSOR_DIM_NUM];
...@@ -157,42 +163,43 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, ...@@ -157,42 +163,43 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
/* the first token */ /* the first token */
XTensor first; XTensor first;
CheckNTErrors(inputEnc->order >= 2, "Wrong order of the tensor!"); CheckNTErrors(inputEnc->order >= 2, "Wrong order of the tensor!");
for(int i = 0; i < inputEnc->order - 1; i++) for (int i = 0; i < inputEnc->order - 1; i++)
dims[i] = inputEnc->GetDim(i); dims[i] = inputEnc->GetDim(i);
dims[inputEnc->order - 1] = 1; dims[inputEnc->order - 1] = 1;
InitTensor(&first, inputEnc->order, dims, X_INT, inputEnc->devID); InitTensorV2(&first, inputEnc->order, dims, X_INT, 1.0F, inputEnc->devID);
SetDataFixedInt(first, startSymbol); SetDataFixedInt(first, startSymbol);
/* add a new word into the input sequence of the decoder side */ /* add a new word into the input sequence of the decoder side */
if (isStart) { if (isStart) {
inputDec = Identity(first); inputDec = Identity(first);
} }
else{ else {
/* only pass one step to the decoder */ /* only pass one step to the decoder */
inputDec = GetLastPrediction(s); inputDec = GetLastPrediction(s);
inputDec.SetDevice(inputEnc->devID); inputDec.SetDevice(inputEnc->devID);
} }
/* prediction probabilities */ /* prediction probabilities */
XTensor &output = next->prob; XTensor& output = next->prob;
XTensor decoding; XTensor decoding;
for(int i = 0; i < inputDec.order - 1; i++) for (int i = 0; i < inputDec.order - 1; i++)
dims[i] = inputDec.GetDim(i); dims[i] = inputDec.GetDim(i);
dims[inputDec.order - 1] = inputDec.GetDim(-1); dims[inputDec.order - 1] = inputDec.GetDim(-1);
XTensor paddingDec; XTensor paddingDec;
InitTensor(&paddingDec, inputDec.order, dims, X_INT, paddingEnc->devID); InitTensorV2(&paddingDec, inputDec.order, dims, X_INT, 1.0F, paddingEnc->devID);
SetDataFixedInt(paddingDec, 1); SetDataFixedInt(paddingDec, 1);
XTensor maskDec; XTensor maskDec;
XTensor maskEncDec; XTensor maskEncDec;
/* decoder mask */ /* decoder mask */
m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec, 0); //m->MakeMTMaskDec(*inputEnc, inputDec, *paddingEnc, paddingDec, maskDec, maskEncDec, 0);
/* make the decoding network */ /* make the decoding network */
decoding = m->decoder->Make(inputDec, *encoding, NULL, maskEncDec, false); decoding = m->decoder->Make(inputDec, *encoding, NULL, maskEncDec, false);
...@@ -203,38 +210,38 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding, ...@@ -203,38 +210,38 @@ void T2TPredictor::Predict(T2TStateBundle * next, XTensor * encoding,
m->outputLayer->Make(decoding, output); m->outputLayer->Make(decoding, output);
} }
/* /*
generate paths up to the states of the current step generate paths up to the states of the current step
>> state - state bundle of the current step >> state - state bundle of the current step
*/ */
XTensor T2TPredictor::GeneratePaths(T2TStateBundle * state) XTensor T2TPredictor::GeneratePaths(T2TStateBundle* state)
{ {
CheckNTErrors(state->stateNum >= 0, "Illegal state!"); CheckNTErrors(state->stateNum >= 0, "Illegal state!");
int distance = -1; int distance = -1;
for(int i = 0; i < state->stateNum; i++){ for (int i = 0; i < state->stateNum; i++) {
T2TState * cur = state->states + i; T2TState* cur = state->states + i;
int nsteps = 0; int nsteps = 0;
while(cur != NULL){ while (cur != NULL) {
nsteps++; nsteps++;
cur = cur->last; cur = cur->last;
} }
if(nsteps > distance) if (nsteps > distance)
distance = nsteps; distance = nsteps;
} }
XTensor path; XTensor path;
InitTensor2D(&path, state->stateNum, distance, X_INT); InitTensor2DV2(&path, state->stateNum, distance, X_INT);
path.SetZeroAll(); path.SetZeroAll();
for(int i = 0; i < state->stateNum; i++){ for (int i = 0; i < state->stateNum; i++) {
T2TState * cur = state->states + i; T2TState* cur = state->states + i;
int nsteps = 0; int nsteps = 0;
while(cur != NULL){ while (cur != NULL) {
nsteps++; nsteps++;
path.Set2DInt(cur->prediction, i, distance - nsteps); path.Set2DInt(cur->prediction, i, distance - nsteps);
cur = cur->last; cur = cur->last;
...@@ -253,7 +260,7 @@ XTensor T2TPredictor::GetLastPrediction(T2TStateBundle* state) ...@@ -253,7 +260,7 @@ XTensor T2TPredictor::GetLastPrediction(T2TStateBundle* state)
CheckNTErrors(state->stateNum >= 0, "Illegal state!"); CheckNTErrors(state->stateNum >= 0, "Illegal state!");
XTensor lastPred; XTensor lastPred;
InitTensor2D(&lastPred, state->stateNum, 1, X_INT); InitTensor2DV2(&lastPred, state->stateNum, 1, X_INT);
for (int i = 0; i < state->stateNum; i++) { for (int i = 0; i < state->stateNum; i++) {
T2TState* cur = state->states + i; T2TState* cur = state->states + i;
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13
* This is the first source file I create in 2019 - new start! * This is the first source file I create in 2019 - new start!
*/ */
#ifndef __T2TPREDICTOR_H__ #ifndef __T2TPREDICTOR_H__
#define __T2TPREDICTOR_H__ #define __T2TPREDICTOR_H__
...@@ -39,8 +39,8 @@ public: ...@@ -39,8 +39,8 @@ public:
/* we assume that the prediction is an integer */ /* we assume that the prediction is an integer */
int prediction; int prediction;
/* id of the problem. One can regard it as the sentence id when we /* id of the problem. One can regard it as the sentence id when we
translate a number of sentences in the batched manner. The hypothesis translate a number of sentences in the batched manner. The hypothesis
is empty if id = -1 */ is empty if id = -1 */
int pid; int pid;
...@@ -66,7 +66,7 @@ public: ...@@ -66,7 +66,7 @@ public:
int nstep; int nstep;
/* pointer to the previous state */ /* pointer to the previous state */
T2TState * last; T2TState* last;
}; };
/* a bundle of states */ /* a bundle of states */
...@@ -75,7 +75,7 @@ class T2TStateBundle ...@@ -75,7 +75,7 @@ class T2TStateBundle
public: public:
/* predictions */ /* predictions */
XTensor prediction; XTensor prediction;
/* id of the previous state that generates the current one */ /* id of the previous state that generates the current one */
XTensor preID; XTensor preID;
...@@ -95,7 +95,7 @@ public: ...@@ -95,7 +95,7 @@ public:
XTensor nstep; XTensor nstep;
/* list of states */ /* list of states */
T2TState * states; T2TState* states;
/* number of states */ /* number of states */
int stateNum; int stateNum;
...@@ -114,19 +114,19 @@ public: ...@@ -114,19 +114,19 @@ public:
void MakeStates(int num); void MakeStates(int num);
}; };
/* The predictor reads the current state and then predicts the next. /* The predictor reads the current state and then predicts the next.
It is exactly the same procedure of MT inference - It is exactly the same procedure of MT inference -
we get the state of previous words and then generate the next word. we get the state of previous words and then generate the next word.
Here, a state can be regared as the representation of words (word Here, a state can be regared as the representation of words (word
indices, hidden states, embeddings and etc.). */ indices, hidden states, embeddings and etc.). */
class T2TPredictor class T2TPredictor
{ {
private: private:
/* pointer to the transformer model */ /* pointer to the transformer model */
T2TModel * m; T2TModel* m;
/* current state */ /* current state */
T2TStateBundle * s; T2TStateBundle* s;
/* start symbol */ /* start symbol */
int startSymbol; int startSymbol;
...@@ -139,19 +139,19 @@ public: ...@@ -139,19 +139,19 @@ public:
~T2TPredictor(); ~T2TPredictor();
/* create an initial state */ /* create an initial state */
void Create(T2TModel * model, XTensor * top, const XTensor * input, int beamSize, T2TStateBundle * state); void Create(T2TModel* model, XTensor* top, const XTensor* input, int beamSize, T2TStateBundle* state);
/* set the start symbol */ /* set the start symbol */
void SetStartSymbol(int symbol); void SetStartSymbol(int symbol);
/* read a state */ /* read a state */
void Read(T2TModel * model, T2TStateBundle * state); void Read(T2TModel* model, T2TStateBundle* state);
/* predict the next state */ /* predict the next state */
void Predict(T2TStateBundle * next, XTensor * encoding, XTensor * inputEnc, XTensor * paddingEnc, bool isStart); void Predict(T2TStateBundle* next, XTensor* encoding, XTensor* inputEnc, XTensor* paddingEnc, bool isStart);
/* generate paths up to the states of the current step */ /* generate paths up to the states of the current step */
XTensor GeneratePaths(T2TStateBundle * state); XTensor GeneratePaths(T2TStateBundle* state);
/* get the predictions of the previous step */ /* get the predictions of the previous step */
XTensor GetLastPrediction(T2TStateBundle* state); XTensor GetLastPrediction(T2TStateBundle* state);
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27
*/ */
#ifndef __T2TSEARCH_H__ #ifndef __T2TSEARCH_H__
#define __T2TSEARCH_H__ #define __T2TSEARCH_H__
...@@ -40,10 +40,10 @@ private: ...@@ -40,10 +40,10 @@ private:
/* predictor */ /* predictor */
T2TPredictor predictor; T2TPredictor predictor;
/* max length of the generated sequence */ /* max length of the generated sequence */
int maxLength; int maxLength;
/* beam size */ /* beam size */
int beamSize; int beamSize;
...@@ -51,10 +51,10 @@ private: ...@@ -51,10 +51,10 @@ private:
int batchSize; int batchSize;
/* we keep the final hypotheses in a heap for each sentence in the batch. */ /* we keep the final hypotheses in a heap for each sentence in the batch. */
XHeap<MIN_HEAP, float> * fullHypos; XHeap<MIN_HEAP, float>* fullHypos;
/* array of the end symbols */ /* array of the end symbols */
int * endSymbols; int* endSymbols;
/* number of the end symbols */ /* number of the end symbols */
int endSymbolNum; int endSymbolNum;
...@@ -68,42 +68,42 @@ public: ...@@ -68,42 +68,42 @@ public:
/* de-constructor */ /* de-constructor */
~T2TSearch(); ~T2TSearch();
/* initialize the model */ /* initialize the model */
void Init(int argc, char ** argv); void Init(int argc, char** argv);
/* search for the most promising states */ /* search for the most promising states */
void Search(T2TModel * model, XTensor * input, XTensor * padding, XTensor * output); void Search(T2TModel* model, XTensor* input, XTensor* padding, XTensor* output);
/* preparation */ /* preparation */
void Prepare(int myBatchSize,int myBeamSize); void Prepare(int myBatchSize, int myBeamSize);
/* compute the model score for each hypothesis */ /* compute the model score for each hypothesis */
void Score(T2TStateBundle * prev, T2TStateBundle * beam); void Score(T2TStateBundle* prev, T2TStateBundle* beam);
/* generate token indices via beam pruning */ /* generate token indices via beam pruning */
void Generate(T2TStateBundle * beam); void Generate(T2TStateBundle* beam);
/* expand the search graph */ /* expand the search graph */
void Expand(T2TStateBundle * prev, T2TStateBundle * beam); void Expand(T2TStateBundle* prev, T2TStateBundle* beam);
/* collect hypotheses with ending symbol */ /* collect hypotheses with ending symbol */
void Collect(T2TStateBundle * beam); void Collect(T2TStateBundle* beam);
/* fill the hypotheis heap with incomplete hypothses */ /* fill the hypotheis heap with incomplete hypothses */
void FillHeap(T2TStateBundle * beam); void FillHeap(T2TStateBundle* beam);
/* save the output sequences in a tensor */ /* save the output sequences in a tensor */
void Dump(XTensor * output); void Dump(XTensor* output);
/* check if the token is an end symbol */ /* check if the token is an end symbol */
bool IsEnd(int token); bool IsEnd(int token);
/* set end symbols for search */ /* set end symbols for search */
void SetEnd(const int * tokens, const int tokenNum); void SetEnd(const int* tokens, const int tokenNum);
/* make a mask to prevent duplicated entries in beam expansion for the first position */ /* make a mask to prevent duplicated entries in beam expansion for the first position */
XTensor MakeFirstMask(T2TStateBundle * beam); XTensor MakeFirstMask(T2TStateBundle* beam);
}; };
} }
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27
*/ */
#include <math.h> #include <math.h>
#include "T2TUtility.h" #include "T2TUtility.h"
...@@ -44,23 +44,23 @@ T2TTester::~T2TTester() ...@@ -44,23 +44,23 @@ T2TTester::~T2TTester()
} }
/* initialize the model */ /* initialize the model */
void T2TTester::Init(int argc, char ** argv) void T2TTester::Init(int argc, char** argv)
{ {
LoadParamInt(argc, argv, "vsize", &vSize, 34040); LoadParamInt(argc, argv, "vsize", &vSize, 34040);
LoadParamInt(argc, argv, "vsizetgt", &vSizeTgt, vSize); LoadParamInt(argc, argv, "vsizetgt", &vSizeTgt, vSize);
LoadParamInt(argc, argv, "sentbatch", &sentBatch, 1); LoadParamInt(argc, argv, "sentbatch", &sentBatch, 1);
LoadParamBool(argc, argv, "sort", &batchLoader.sortBuffer, true); LoadParamBool(argc, argv, "sort", &batchLoader.sortBuffer, true);
seacher.Init(argc, argv); seacher.Init(argc, argv);
} }
/* /*
test the model test the model
>> fn - test data file >> fn - test data file
>> ofn - output data file >> ofn - output data file
>> model - model that is trained >> model - model that is trained
*/ */
void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) void T2TTester::Test(const char* fn, const char* ofn, T2TModel* model)
{ {
int wc = 0; int wc = 0;
int wordCount = 0; int wordCount = 0;
...@@ -86,7 +86,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -86,7 +86,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
int* seqs = new int[MILLION]; int* seqs = new int[MILLION];
batchLoader.Init(fn); batchLoader.Init(fn);
int count = 0; int count = 0;
while (!batchLoader.IsEmpty()) while (!batchLoader.IsEmpty())
...@@ -94,23 +94,23 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -94,23 +94,23 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
count++; count++;
wordCount = 0; wordCount = 0;
for (int i = 0; i < model->decoder->nlayer; ++i) { for (int i = 0; i < model->decoder->nlayer; ++i) {
model->decoder->selfCache[i].miss = true; model->decoder->selfAttCache[i].miss = true;
model->decoder->contextCache[i].miss = true; model->decoder->enDeAttCache[i].miss = true;
} }
vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID); vector<int> indices = batchLoader.LoadBatch(&batchEnc, &paddingEnc, sentBatch, devID);
XTensor output; XTensor output;
seacher.Search(model, &batchEnc, &paddingEnc, &output); seacher.Search(model, &batchEnc, &paddingEnc, &output);
output.Dump(stderr);
for (int i = 0; i < indices.size(); ++i) { for (int i = 0; i < indices.size(); ++i) {
Result res; Result res;
XTensor sent, srcIdx, tgtIdx; XTensor sent, srcIdx, tgtIdx;
InitTensor1D(&srcIdx, 1, X_INT, output.devID); InitTensor1DV2(&srcIdx, 1, X_INT, output.devID);
int idx[]{i}; int idx[]{ i };
srcIdx.SetData(idx, 1); srcIdx.SetData(idx, 1);
InitTensor(&tgtIdx, &srcIdx); InitTensorV2(&tgtIdx, &srcIdx);
SetAscendingOrder(tgtIdx, 0); SetAscendingOrder(tgtIdx, 0);
sent = CopyIndexed(output, 0, srcIdx, tgtIdx); sent = CopyIndexed(output, 0, srcIdx, tgtIdx);
...@@ -127,9 +127,9 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -127,9 +127,9 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
if (batchCount % 1 == 0) { if (batchCount % 1 == 0) {
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, XPRINT3(0, stderr,
"[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n", "[INFO] elapsed=%.1fs, sentence=%d, sword=%d\n",
elapsed, sentCount, wordCount); elapsed, sentCount, wordCount);
} }
} }
...@@ -138,11 +138,11 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model) ...@@ -138,11 +138,11 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
for (auto res : batchLoader.resBuffer) { for (auto res : batchLoader.resBuffer) {
Dump(ofile, &res.values); Dump(ofile, &res.values);
} }
fclose(ofile); fclose(ofile);
delete[] seqs; delete[] seqs;
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, sent=%d)\n", elapsed, wordCountTotal, sentCount); XPRINT3(0, stderr, "[INFO] test finished (took %.1fs, word=%d, sent=%d)\n", elapsed, wordCountTotal, sentCount);
...@@ -153,7 +153,7 @@ dump the result into the file ...@@ -153,7 +153,7 @@ dump the result into the file
>> file - data file >> file - data file
>> output - output tensor >> output - output tensor
*/ */
void T2TTester::Dump(FILE * file, XTensor * output) void T2TTester::Dump(FILE* file, XTensor* output)
{ {
int seqLength = output->GetDim(-1); int seqLength = output->GetDim(-1);
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27
* A week with no trips :) * A week with no trips :)
*/ */
#ifndef __T2TTESTER_H__ #ifndef __T2TTESTER_H__
#define __T2TTESTER_H__ #define __T2TTESTER_H__
...@@ -41,7 +41,7 @@ public: ...@@ -41,7 +41,7 @@ public:
/* batch size for sentences */ /* batch size for sentences */
int sentBatch; int sentBatch;
/* for batching */ /* for batching */
DataSet batchLoader; DataSet batchLoader;
...@@ -56,13 +56,13 @@ public: ...@@ -56,13 +56,13 @@ public:
~T2TTester(); ~T2TTester();
/* initialize the model */ /* initialize the model */
void Init(int argc, char ** argv); void Init(int argc, char** argv);
/* test the model */ /* test the model */
void Test(const char * fn, const char * ofn, T2TModel * model); void Test(const char* fn, const char* ofn, T2TModel* model);
/* dump the result into the file */ /* dump the result into the file */
void Dump(FILE * file, XTensor * output); void Dump(FILE* file, XTensor* output);
}; };
} }
......
...@@ -38,7 +38,7 @@ namespace transformer ...@@ -38,7 +38,7 @@ namespace transformer
{ {
/* entrance of the program */ /* entrance of the program */
int TransformerMain(int argc, const char ** argv); int TransformerMain(int argc, const char** argv);
} }
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "XList.h" #include "XList.h"
#include "XGlobal.h" #include "XGlobal.h"
/* the nts (NiuTrans.Tensor) namespace */ /* the nts (NiuTrans.Tensor) namespace */
namespace nts { namespace nts {
...@@ -363,6 +364,8 @@ template struct TensorListBase<long>; ...@@ -363,6 +364,8 @@ template struct TensorListBase<long>;
template struct TensorListBase<float>; template struct TensorListBase<float>;
template struct TensorListBase<short>; template struct TensorListBase<short>;
template struct TensorListBase<XTensor*>; template struct TensorListBase<XTensor*>;
template struct TensorListBase<uint64_t>;
template struct TensorListBase<void*>; template struct TensorListBase<void*>;
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include "XMem.h" #include "XMem.h"
#include "XGlobal.h" #include "XGlobal.h"
#include <cstdint>
#ifndef __TensorList_H__ #ifndef __TensorList_H__
#define __TensorList_H__ #define __TensorList_H__
...@@ -118,7 +120,14 @@ public: ...@@ -118,7 +120,14 @@ public:
void Shuffle(int nround = 10, int beg = -1, int len = 0); void Shuffle(int nround = 10, int beg = -1, int len = 0);
/* short */ /* short */
T& operator[] (int i) { return GetItem(i); }; T& operator[] (int i) {
CheckNTErrors(i >= -count && i < count, "Index of a list item is out of scope!");
CheckNTErrors(count > 0, "Cannt index the item in an empty list!");
if (i < 0)
return items[count + i];
else
return items[i];
};
T& Get(int i) { return GetItem(i); }; T& Get(int i) { return GetItem(i); };
void Set(int i, T item) { SetItem(i, item); }; void Set(int i, T item) { SetItem(i, item); };
}; };
...@@ -132,7 +141,7 @@ typedef TensorListBase<char*> StrList; ...@@ -132,7 +141,7 @@ typedef TensorListBase<char*> StrList;
typedef TensorListBase<long> LongList; typedef TensorListBase<long> LongList;
typedef TensorListBase<float> FloatList; typedef TensorListBase<float> FloatList;
typedef TensorListBase<short> ShortList; typedef TensorListBase<short> ShortList;
typedef TensorListBase<uint64_t> UInt64List;
typedef TensorListBase<XTensor*> TensorList; typedef TensorListBase<XTensor*> TensorList;
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
......
...@@ -86,7 +86,7 @@ void _funcCPUName(const XTensor * input, XTensor * output, int dim) ...@@ -86,7 +86,7 @@ void _funcCPUName(const XTensor * input, XTensor * output, int dim)
vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip)+j * vecBufLength); \ vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip)+j * vecBufLength); \
} \ } \
for (int j = 1; j < strideNum / 32; j++) { \ for (int j = 1; j < strideNum / 32; j++) { \
const DTYPE* ptr = (DTYPE*)(ip + j * vecBufLength); \ const DTYPE* ptr = (DTYPE*)(ip + j * 4 * vecBufLength); \
vecBuf[0] = vecBuf[0]._vectorOp(VectorBuffer::loadu(ptr + 0 * vecBufLength)); \ vecBuf[0] = vecBuf[0]._vectorOp(VectorBuffer::loadu(ptr + 0 * vecBufLength)); \
vecBuf[1] = vecBuf[1]._vectorOp(VectorBuffer::loadu(ptr + 1 * vecBufLength)); \ vecBuf[1] = vecBuf[1]._vectorOp(VectorBuffer::loadu(ptr + 1 * vecBufLength)); \
vecBuf[2] = vecBuf[2]._vectorOp(VectorBuffer::loadu(ptr + 2 * vecBufLength)); \ vecBuf[2] = vecBuf[2]._vectorOp(VectorBuffer::loadu(ptr + 2 * vecBufLength)); \
...@@ -106,7 +106,7 @@ void _funcCPUName(const XTensor * input, XTensor * output, int dim) ...@@ -106,7 +106,7 @@ void _funcCPUName(const XTensor * input, XTensor * output, int dim)
else { \ else { \
/* data is separated */ \ /* data is separated */ \
for(int i = 0; i < blockNum; i++){ \ for(int i = 0; i < blockNum; i++){ \
for(int j = 0; j < input->dimSize[input->order - 1] / 32; j++){ \ for(int j = 0; j < stride / 32; j++){ \
DTYPE * ip = (DTYPE*)input->data + blockSize * i; \ DTYPE * ip = (DTYPE*)input->data + blockSize * i; \
DTYPE * op = (DTYPE*)output->data + stride * i; \ DTYPE * op = (DTYPE*)output->data + stride * i; \
VectorBuffer vecBuf[4]; \ VectorBuffer vecBuf[4]; \
......
...@@ -42,7 +42,7 @@ void _ReduceMean(const XTensor * input, XTensor * output, int dim) ...@@ -42,7 +42,7 @@ void _ReduceMean(const XTensor * input, XTensor * output, int dim)
int num = input->dimSize[dim]; int num = input->dimSize[dim];
_ReduceSum(input, output, dim); _ReduceSum(input, output, dim);
_ScaleAndShiftMe(output, (DTYPE)1/num, 0); _ScaleAndShiftMe(output, 1.0F/(DTYPE)(num), 0);
} }
/* /*
......
...@@ -105,7 +105,7 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor ...@@ -105,7 +105,7 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor
vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip) + j * vecBufLength, isExp, power, bias); vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip) + j * vecBufLength, isExp, power, bias);
} }
for(int j = 1; j < strideNum / 32; j++){ for(int j = 1; j < strideNum / 32; j++){
const DTYPE* ptr = (DTYPE*)(ip + j * vecBufLength); const DTYPE* ptr = (DTYPE*)(ip + (j * 4) * vecBufLength);
vecBuf[0] = vecBuf[0] + VectorBuffer::loadu(ptr + 0 * vecBufLength, isExp, power, bias); vecBuf[0] = vecBuf[0] + VectorBuffer::loadu(ptr + 0 * vecBufLength, isExp, power, bias);
vecBuf[1] = vecBuf[1] + VectorBuffer::loadu(ptr + 1 * vecBufLength, isExp, power, bias); vecBuf[1] = vecBuf[1] + VectorBuffer::loadu(ptr + 1 * vecBufLength, isExp, power, bias);
vecBuf[2] = vecBuf[2] + VectorBuffer::loadu(ptr + 2 * vecBufLength, isExp, power, bias); vecBuf[2] = vecBuf[2] + VectorBuffer::loadu(ptr + 2 * vecBufLength, isExp, power, bias);
...@@ -122,7 +122,7 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor ...@@ -122,7 +122,7 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor
} else{ } else{
//data is separated //data is separated
for(int i = 0; i < blockNum; i++){ for(int i = 0; i < blockNum; i++){
for(int j = 0; j < input->dimSize[input->order - 1] / 32; j++){ for(int j = 0; j < stride / 32; j++){
DTYPE * ip = (DTYPE*)input->data + blockSize * i; DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE * op = (DTYPE*)output->data + stride * i; DTYPE * op = (DTYPE*)output->data + stride * i;
DTYPE * sp = shift != NULL ? (DTYPE*)shift->data + stride * i : NULL; DTYPE * sp = shift != NULL ? (DTYPE*)shift->data + stride * i : NULL;
...@@ -133,8 +133,7 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor ...@@ -133,8 +133,7 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor
} }
VectorBuffer vecBuf[4]; VectorBuffer vecBuf[4];
for(int k = 0; k < 4; k++){ for(int k = 0; k < 4; k++){
vecBuf[k] = VectorBuffer::loadu((DTYPE*)(ip) + (j * 4 + k) * 32 / sizeof(DTYPE), isExp, power, bias + j * 32 / sizeof(DTYPE)); vecBuf[k] = VectorBuffer::loadu((DTYPE*)(ip) + (j * 4 + k) * 32 / sizeof(DTYPE), isExp, power, bias + k * 32 / sizeof(DTYPE));
} }
for(int k = 1; k < strideNum; k++){ for(int k = 1; k < strideNum; k++){
DTYPE * ptr = ip + k * stride + (j * 4) * vecBufLength; DTYPE * ptr = ip + k * stride + (j * 4) * vecBufLength;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论