Commit 005739bc by xiaotong

add device id and mem pool to each module

parent 2058ca57
......@@ -43,11 +43,14 @@ T2TAttention::~T2TAttention()
initialize the model
>> argc - number of arguments
>> argv - list pf pointers to the arguments
>> devID - device id
>> mem - the memory pool
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TAttention::InitModel(int argc, const char ** argv, int devID, XMem * mem)
void T2TAttention::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
float minmax = 0;
LoadParamInt(argc, argv, "nhead", &nhead, 8);
......@@ -61,4 +64,4 @@ void T2TAttention::InitModel(int argc, const char ** argv, int devID, XMem * mem
w.SetDataRand(-minmax, minmax);
}
}
\ No newline at end of file
}
......@@ -39,6 +39,12 @@ where head_i = Attention(Q * w_i^Q, K * w_i^K, V * w_i^V)
class T2TAttention
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
/* head number */
int nhead;
......@@ -62,9 +68,9 @@ public:
~T2TAttention();
/* initialize the model */
void InitModel(int argc, const char ** argv, int devID = -1, XMem * mem = NULL);
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
};
}
#endif
\ No newline at end of file
#endif
......@@ -19,6 +19,7 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include <math.h>
#include "T2TEncoder.h"
#include "T2TUtility.h"
......@@ -39,11 +40,14 @@ AttEncoder::~AttEncoder()
initialize the model
>> argc - number of arguments
>> argv - list pf pointers to the arguments
>> devID - device id
>> mem - the memory pool
>> myDevID - device id
>> myMem - the memory pool
*/
void AttEncoder::InitModel(int argc, const char ** argv, int devID, XMem * mem)
void AttEncoder::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
LoadParamInt(argc, argv, "nstack", &nlayer, 6);
LoadParamInt(argc, argv, "hsize", &hSize, 512);
LoadParamInt(argc, argv, "esize", &eSize, 512);
......
......@@ -98,7 +98,7 @@ public:
~AttEncoder();
/* initialize the model */
void InitModel(int argc, const char ** argv, int devID = -1, XMem * mem = NULL);
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the encoding network */
XTensor * Make(XTensor * input);
......@@ -107,4 +107,4 @@ public:
}
#endif
\ No newline at end of file
#endif
......@@ -42,11 +42,14 @@ T2TFNN::~T2TFNN()
initialize the model
>> argc - number of arguments
>> argv - list pf pointers to the arguments
>> devID - device id
>> mem - the memory pool
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TFNN::InitModel(int argc, const char ** argv, int devID, XMem * mem)
void T2TFNN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
float minmax = 0;
LoadParamInt(argc, argv, "d", &inSize, 512);
......
......@@ -69,10 +69,10 @@ public:
~T2TFNN();
/* initialize the model */
void InitModel(int argc, const char ** argv, int devID = -1, XMem * mem = NULL);
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
};
}
#endif
\ No newline at end of file
#endif
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
*/
#include "T2TLayerNormal.h"
namespace transformer
{
/* constructor */
T2TLN::T2TLN()
{
devID = -1;
mem = NULL;
}
/* de-constructor */
T2TLN::~T2TLN()
{
}
/*
initialize the model
>> argc - number of arguments
>> argv - list pf pointers to the arguments
>> myDevID - device id
>> myMem - the memory pool
*/
void T2TLN::InitModel(int argc, const char ** argv, int myDevID, XMem * myMem)
{
devID = myDevID;
mem = myMem;
}
/* make the network */
XTensor * T2TLN::Make(XTensor * input)
{
return NULL;
}
}
......@@ -22,14 +22,36 @@
#ifndef __T2TFNN_H__
#define __T2TFNN_H__
#include "../../network/XNet.h"
using namespace nts;
namespace transformer
{
class T2TLN
{
public:
/* device id */
int devID;
/* memory pool */
XMem * mem;
public:
/* constructor */
T2TLN();
/* de-constructor */
~T2TLN();
/* initialize the model */
void InitModel(int argc, const char ** argv, int myDevID = -1, XMem * myMem = NULL);
/* make the network */
XTensor * Make(XTensor * input);
};
}
#endif
\ No newline at end of file
#endif
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论