Commit 005739bc by xiaotong

add device id and mem pool to each module

parent 2058ca57
...@@ -43,11 +43,14 @@ T2TAttention::~T2TAttention() ...@@ -43,11 +43,14 @@ T2TAttention::~T2TAttention()
initialize the model initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list pf pointers to the arguments >> argv - list pf pointers to the arguments
>> devID - device id >> myDevID - device id
>> mem - the memory pool >> myMem - the memory pool
*/ */
void T2TAttention::InitModel(int argc, const char ** argv, int devID, XMem * mem) void T2TAttention::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{ {
devID = myDevID;
mem = myMem;
float minmax = 0; float minmax = 0;
LoadParamInt(argc, argv, "nhead", &nhead, 8); LoadParamInt(argc, argv, "nhead", &nhead, 8);
...@@ -61,4 +64,4 @@ void T2TAttention::InitModel(int argc, const char ** argv, int devID, XMem * mem ...@@ -61,4 +64,4 @@ void T2TAttention::InitModel(int argc, const char ** argv, int devID, XMem * mem
w.SetDataRand(-minmax, minmax); w.SetDataRand(-minmax, minmax);
} }
} }
\ No newline at end of file
...@@ -39,6 +39,12 @@ where head_i = Attention(Q * w_i^Q, K * w_i^K, V * w_i^V) ...@@ -39,6 +39,12 @@ where head_i = Attention(Q * w_i^Q, K * w_i^K, V * w_i^V)
class T2TAttention class T2TAttention
{ {
public: public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* head number */ /* head number */
int nhead; int nhead;
...@@ -62,9 +68,9 @@ public: ...@@ -62,9 +68,9 @@ public:
~T2TAttention(); ~T2TAttention();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, int devID = -1, XMem * mem = NULL); void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
}; };
} }
#endif #endif
\ No newline at end of file
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/ */
#include <math.h>
#include "T2TEncoder.h" #include "T2TEncoder.h"
#include "T2TUtility.h" #include "T2TUtility.h"
...@@ -39,11 +40,14 @@ AttEncoder::~AttEncoder() ...@@ -39,11 +40,14 @@ AttEncoder::~AttEncoder()
initialize the model initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list pf pointers to the arguments >> argv - list pf pointers to the arguments
>> devID - device id >> myDevID - device id
>> mem - the memory pool >> myMem - the memory pool
*/ */
void AttEncoder::InitModel(int argc, const char ** argv, int devID, XMem * mem) void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{ {
devID = myDevID;
mem = myMem;
LoadParamInt(argc, argv, "nstack", &nlayer, 6); LoadParamInt(argc, argv, "nstack", &nlayer, 6);
LoadParamInt(argc, argv, "hsize", &hSize, 512); LoadParamInt(argc, argv, "hsize", &hSize, 512);
LoadParamInt(argc, argv, "esize", &eSize, 512); LoadParamInt(argc, argv, "esize", &eSize, 512);
......
...@@ -98,7 +98,7 @@ public: ...@@ -98,7 +98,7 @@ public:
~AttEncoder(); ~AttEncoder();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, int devID = -1, XMem * mem = NULL); void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */ /* make the encoding network */
XTensor * Make(XTensor * input); XTensor * Make(XTensor * input);
...@@ -107,4 +107,4 @@ public: ...@@ -107,4 +107,4 @@ public:
} }
#endif #endif
\ No newline at end of file
...@@ -42,11 +42,14 @@ T2TFNN::~T2TFNN() ...@@ -42,11 +42,14 @@ T2TFNN::~T2TFNN()
initialize the model initialize the model
>> argc - number of arguments >> argc - number of arguments
>> argv - list pf pointers to the arguments >> argv - list pf pointers to the arguments
>> devID - device id >> myDevID - device id
>> mem - the memory pool >> myMem - the memory pool
*/ */
void T2TFNN::InitModel(int argc, const char ** argv, int devID, XMem * mem) void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{ {
devID = myDevID;
mem = myMem;
float minmax = 0; float minmax = 0;
LoadParamInt(argc, argv, "d", &inSize, 512); LoadParamInt(argc, argv, "d", &inSize, 512);
......
...@@ -69,10 +69,10 @@ public: ...@@ -69,10 +69,10 @@ public:
~T2TFNN(); ~T2TFNN();
/* initialize the model */ /* initialize the model */
void InitModel(int argc, const char ** argv, int devID = -1, XMem * mem = NULL); void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
}; };
} }
#endif #endif
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include "T2TLayerNormal.h"
namespace transformer
{
/* constructor */
T2TLN::T2TLN()
{
devID = -1;
mem = NULL;
}
/* de-constructor */
T2TLN::~T2TLN()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list pf pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TLN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
}
/* make the network */
XTensor * T2TLN::Make(XTensor * input)
{
return NULL;
}
}
...@@ -22,14 +22,36 @@ ...@@ -22,14 +22,36 @@
#ifndef __T2TFNN_H__ #ifndef __T2TFNN_H__
#define __T2TFNN_H__ #define __T2TFNN_H__
#include "../../network/XNet.h"
using namespace nts;
namespace transformer namespace transformer
{ {
class T2TLN class T2TLN
{ {
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
public:
/* constructor */
T2TLN();
/* de-constructor */
~T2TLN();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor * Make(XTensor * input);
}; };
} }
#endif #endif
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论