Commit d221ef9d by xuchen

merge with zhangyuhao branch

parent e0d86e5b
......@@ -300,6 +300,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id
if(h == NULL)
return;
if (!t1->enableGrad)
return;
TensorList list(2);
list.Add((XTensor*)t1);
list.Add((XTensor*)t2);
......@@ -320,6 +323,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, const XTensor * t3,
if (h == NULL)
return;
if (!t1->enableGrad || !t2->enableGrad)
return;
TensorList list(3);
list.Add((XTensor*)t1);
list.Add((XTensor*)t2);
......@@ -370,6 +376,9 @@ create a hyper edge with a input tensors and a list of output tensors
*/
void XLink::MakeLink(XTensor * t, TensorList * list, int id)
{
if (!t->enableGrad)
return;
/* forward */
for(int i = 0; i < list->count; i++){
XTensor * h = (XTensor*)list->GetItem(i);
......
......@@ -23,15 +23,11 @@
*
*/
#include "XList.h"
#include "time.h"
#include "XMem.h"
#include "XList.h"
#include "XGlobal.h"
#include <ctime>
#include <utility>
#include <algorithm>
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
......@@ -78,7 +74,8 @@ TensorListBase<T>::TensorListBase(int myMaxNum, XMem* myMem)
template <typename T>
TensorListBase<T>::~TensorListBase()
{
delete[] items;
if(items && mem)
delete[] items;
}
......@@ -103,6 +100,13 @@ void TensorListBase<T>::Add(T&& item)
items[count++] = item;
}
/* return number of elements */
template<typename T>
size_t TensorListBase<T>::Size()
{
return count;
}
/*
add an item into the list
>> item - a const reference to the item
......@@ -130,7 +134,7 @@ add a number of items into the list
>> inputItemCount - number of input items
*/
template <typename T>
void TensorListBase<T>::Add(T* inputItems, int inputItemCount)
void TensorListBase<T>::Add(const T* inputItems, int inputItemCount)
{
if (count + inputItemCount >= maxNum) {
int newMaxNum = (count + inputItemCount) * 2 + 1;
......@@ -206,10 +210,10 @@ void TensorListBase<T>::Insert(int pos, T&& item)
template <typename T>
T& TensorListBase<T>::GetItem(int i) const
{
CheckNTErrors(i >= -1 && i < count, "Index of a list item is out of scope!");
CheckNTErrors(i >= -count && i < count, "Index of a list item is out of scope!");
CheckNTErrors(count > 0, "Cannt index the item in an empty list!");
if (i == -1)
return items[count - 1];
if (i < 0)
return items[count + i];
else
return items[i];
}
......@@ -226,7 +230,7 @@ template<typename T>
inline void TensorListBase<T>::SetItem(int i, T&& item)
{
if (i >= 0 && i < count)
items[i] = std::move(item);
items[i] = item;
}
/*
......@@ -245,6 +249,26 @@ inline int TensorListBase<T>::FindFirst(const T& item)
return -1;
}
template <>
inline int TensorListBase<Example>::FindFirst(const Example& item)
{
for (int i = 0; i < count; i++) {
if (item.id == items[i].id)
return i;
}
return -1;
}
template <>
inline int TensorListBase<Result>::FindFirst(const Result& item)
{
for (int i = 0; i < count; i++) {
if (item.id == items[i].id)
return i;
}
return -1;
}
/* clear the data array */
template <typename T>
void TensorListBase<T>::Clear()
......@@ -294,6 +318,17 @@ void TensorListBase<T>::Remove(int i)
count--;
}
template<typename T>
void TensorListBase<T>::Reserve(int n)
{
if (items) {
/* reserve failed */
return;
}
items = new T[n];
}
/*
copy the list
>> myMem - memory pool used for allocating the data in the new list
......@@ -348,6 +383,8 @@ template struct TensorListBase<long>;
template struct TensorListBase<float>;
template struct TensorListBase<short>;
template struct TensorListBase<XTensor*>;
template struct TensorListBase<Result>;
template struct TensorListBase<Example>;
template struct TensorListBase<void*>;
} /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
......@@ -66,11 +66,14 @@ public:
/* add an item into the list */
void Add(T&& item);
/* return number of elements */
size_t Size();
/* add an item into the list */
void Add(const T& item);
/* add a number of items into the list */
void Add(T* inputItems, int inputItemCount);
void Add(const T* inputItems, int inputItemCount);
/* append a list to the current list */
void AddList(TensorListBase* l);
......@@ -105,6 +108,9 @@ public:
/* remove the item at position i */
void Remove(int i);
/* reserve space for data entry */
void Reserve(int n);
/* copy the list */
TensorListBase* Copy(XMem* myMem);
......@@ -112,22 +118,33 @@ public:
void Shuffle(int nround = 10, int beg = -1, int len = 0);
/* short */
T& operator[] (int i) {
return GetItem(i);
};
T& operator[] (int i) { return GetItem(i); };
T& Get(int i) { return GetItem(i); };
void Set(int i, T item) { SetItem(i, item); };
};
struct XTensor;
typedef TensorListBase<void*> XList;
typedef TensorListBase<int> IntList;
typedef TensorListBase<char> CharList;
typedef TensorListBase<char*> StrList;
typedef TensorListBase<long> LongList;
typedef TensorListBase<float> FloatList;
typedef TensorListBase<short> ShortList;
typedef TensorListBase<void*> XList;
struct Example {
int id;
IntList data;
};
struct Result {
int id;
IntList data;
};
typedef TensorListBase<Result> ResultList;
typedef TensorListBase<Example> ExampleList;
typedef TensorListBase<XTensor*> TensorList;
} /* end of the nts (NiuTrans.Tensor) namespace */
......
......@@ -1916,6 +1916,26 @@ void XTensor::Dump(const XTensor * tensor, FILE * file, const char * label, cons
}
/*
dump data to a binary file
>> file - where to dump the data
*/
void XTensor::BinaryDump(FILE* file)
{
XTensor tmp;
InitTensorOnCPU(&tmp, this);
_CopyValues(this, &tmp);
switch (dataType) {
case X_INT: {
fwrite(tmp.data, sizeof(int), unitNum, file);
}
default: {
fwrite(tmp.data, sizeof(float), unitNum, file);
}
}
}
/*
read data from a file
>> file - where to load the data
>> label - label of the tensor
......@@ -2027,6 +2047,30 @@ void XTensor::Read(FILE * file, const char * label)
delete[](char*)dataBuf;
}
/*
read data from a binary file
>>> file - the file stream pointer
>>> offset - the distance from the start to this tensor
*/
void XTensor::BinaryRead(FILE* file, size_t offset)
{
fseek(file, offset, 0);
switch (dataType) {
case X_INT: {
int * d = new int[unitNum];
fread(d, sizeof(int), unitNum, file);
SetData(d, unitNum);
delete[] d;
}
default: {
float * d = new float[unitNum];
fread(d, sizeof(float), unitNum, file);
SetData(d, unitNum);
delete[] d;
}
}
}
/*
flush the data to the target device
>> targetMem - memory pool on the target device
......
......@@ -433,9 +433,15 @@ public:
static
void Dump(const XTensor * tensor, FILE * file, const char * label = NULL, const int n = -1, const int beg = 0, const int verbose = 0);
/* dump data to a binary file */
void BinaryDump(FILE * file);
/* read data from a file */
void Read(FILE * file, const char * label = NULL);
/* read data from a binary file */
void BinaryRead(FILE * file, size_t offset);
/* flush the data to the target device */
void FlushToMem(XMem * targetMem);
......
......@@ -215,18 +215,22 @@ XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
_Div(&a, &b, &c, alpha, leadingDim);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
}
else if(n >= 0 && n < a.order){
/* call _DivDim function */
_DivDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
}
else{
ShowNTErrors("Something is wrong!");
......@@ -261,7 +265,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
/* call _Div function */
_Div(&a, &b, &c, 0, leadingDim);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha);
......@@ -272,7 +276,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
/* call _DivDim function */
_DivDim(&a, &b, &c, n, alpha);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
......
......@@ -164,10 +164,12 @@ XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
_DivDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
return c;
}
......@@ -193,7 +195,7 @@ void DivDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE alpha)
/* call _Div function */
_DivDim(&a, &b, &c, n, alpha);
if (c.enableGrad == true) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n);
......
......@@ -155,8 +155,10 @@ XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha)
_Mask(&a, &mask, &c, alpha);
/* tensor connections */
XLink::MakeLink(&a, &mask, &c, MATH_MASK);
XLink::AddParamToHead(&c, alpha);
if (a.enableGrad) {
XLink::MakeLink(&a, &mask, &c, MATH_MASK);
XLink::AddParamToHead(&c, alpha);
}
return c;
}
......@@ -176,7 +178,7 @@ void Mask(const XTensor &a, const XTensor &mask, XTensor &c, DTYPE alpha)
/* call _Mask function */
_Mask(&a, &mask, &c, alpha);
if (c.enableGrad) {
if (a.enableGrad) {
XLink::MakeLink(&a, &mask, &c, MATH_MASK);
XLink::AddParamToHead(&c, alpha);
}
......
......@@ -296,10 +296,12 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
_MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHeadTrans(&c, transposedB);
XLink::AddParamToHead(&c, alpha);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHeadTrans(&c, transposedB);
XLink::AddParamToHead(&c, alpha);
}
/* destroy variables */
delete[] dimSize;
......@@ -344,7 +346,7 @@ void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
/* call _MatrixMul function */
_MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, beta, parallelRunner);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, transposedA);
......@@ -393,10 +395,12 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b,
_MatrixMul(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha);
}
/* destroy variables */
delete[] dimSize;
......@@ -440,7 +444,7 @@ void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c,
/* call _MatrixMul function */
_MatrixMul(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
......
......@@ -314,10 +314,12 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const
_MatrixMulBatched(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED);
XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHeadTrans(&c, transposedB);
XLink::AddParamToHead(&c, alpha);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED);
XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHeadTrans(&c, transposedB);
XLink::AddParamToHead(&c, alpha);
}
/* destroy variables */
delete[] dimSize;
......@@ -370,10 +372,12 @@ XTensor MatrixMulBatched(const XTensor &a, const XTensor &b,
_MatrixMulBatched(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha);
}
/* destroy variables */
delete[] dimSize;
......
......@@ -118,11 +118,87 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
}
/* tensor connections */
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
//XLink::AddParamToHead(&c, beta);
if (w.enableGrad && b.enableGrad) {
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHeadTrans(&c, X_NOTRANS);
}
/* destroy variables */
delete[] dimSize;
DelTensorBuf(tmp);
return c;
}
/*
operation c = x * w + b MulAndShift
>> x - tensor x
>> w - tensor w
>> b - tensor b
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
*/
XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedA,
const XTensor& w, MATRIX_TRANS_TYPE transposedB,
const XTensor& b, DTYPE alpha, XPRunner* parallelRunner)
{
CheckNTErrors(x.dataType == w.dataType, "Input tensors should have the same data type!");
CheckNTErrors(x.order >= 2 && w.order >= 2, "Input tensors must have a order >= 2!");
int xn = transposedA == X_TRANS ? x.dimSizeRDI[0] : x.dimSizeRDI[1];
int xm = transposedA == X_TRANS ? x.dimSizeRDI[1] : x.dimSizeRDI[0];
int wn = transposedB == X_TRANS ? w.dimSizeRDI[0] : w.dimSizeRDI[1];
int wm = transposedB == X_TRANS ? w.dimSizeRDI[1] : w.dimSizeRDI[0];
int order = x.order + w.order - 2;
int sub = 0;
int * dimSize = new int[order];
for (int i = 2; i < x.order; i++)
dimSize[sub++] = x.dimSizeRDI[x.order + 1 - i];
for (int i = 2; i < w.order; i++)
dimSize[sub++] = w.dimSizeRDI[w.order + 1 - i];
dimSize[sub++] = xn;
dimSize[sub++] = wm;
float dr = (!x.isSparse || !w.isSparse) ? 1.0F : MAX(x.denseRatio, w.denseRatio);
XTensor * tmp = NewTensorBuf(order, dimSize, x.dataType, dr, x.devID, x.mem);
/* call _MatrixMul function */
_MatrixMul(&x, transposedA, &w, transposedB, tmp, alpha, 0, parallelRunner);
XTensor c(tmp);
c.SetTMPFlag();
int n = GetSumIndex(tmp, b);
if (n == -1) {
/* call _Sum function */
_Sum(tmp, &b, &c);
// TODO!!
ShowNTErrors("TODO!");
}
else if (n >= 0 && n < tmp->order) {
/* call _SumDim function */
_SumDim(tmp, &b, &c, n);
}
else {
ShowNTErrors("Something is wrong!");
}
/* tensor connections */
if (w.enableGrad && b.enableGrad) {
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHeadTrans(&c, transposedB);
}
/* destroy variables */
delete[] dimSize;
......
......@@ -31,6 +31,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
XTensor MulAndShift(const XTensor &x, MATRIX_TRANS_TYPE transposedA,
const XTensor &w, MATRIX_TRANS_TYPE transposedB,
const XTensor &b, DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -216,18 +216,22 @@ XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim
_Multiply(&a, &b, &c, 0, leadingDim);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
}
else if(n >= 0 && n < a.order){
/* call _MultiplyDim function */
_MultiplyDim(&a, &b, &c, n, alpha);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
}
else{
ShowNTErrors("Something is wrong!");
......@@ -262,7 +266,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
/* call _Multiply function */
_Multiply(&a, &b, &c, 0, leadingDim);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha);
......@@ -273,7 +277,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
/* call _MultiplyDim function */
_MultiplyDim(&a, &b, &c, n, alpha);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
......
......@@ -180,9 +180,11 @@ XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n)
_MultiplyDim(&a, &b, &c, n, 0);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, 0);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, 0);
}
return c;
}
......@@ -208,7 +210,7 @@ void MultiplyDim(const XTensor &a, const XTensor &b, XTensor &c, int n)
/* call _Multiply function */
_MultiplyDim(&a, &b, &c, n, 0);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n);
......@@ -350,8 +352,10 @@ XTensor MultiplyBroadcast(const XTensor &a, const XTensor &b)
_MultiplyBroadcast(&a, &b, &c, 0);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYBROADCAST);
XLink::AddParamToHead(&c, 0);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYBROADCAST);
XLink::AddParamToHead(&c, 0);
}
return c;
}
......@@ -374,7 +378,7 @@ void MultiplyBroadcast(const XTensor &a, const XTensor &b, XTensor &c)
/* call _SumBroadcast function */
_MultiplyBroadcast(&a, &b, &c, 0);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYBROADCAST);
XLink::AddParamToHead(&c, 0);
......
......@@ -190,17 +190,21 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
_Sub(&a, &b, &c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
}
}
else if(n >= 0 && n < a.order){
/* call _SubDim function */
_SubDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
}
else{
ShowNTErrors("Something is wrong!");
......@@ -229,7 +233,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
/* call _Sub function */
_Sub(&a, &b, &c, beta);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
......@@ -239,7 +243,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
/* call _SubDim function */
_SubDim(&a, &b, &c, n, beta);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n);
......
......@@ -164,9 +164,11 @@ XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
_SubDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
return c;
}
......@@ -193,7 +195,7 @@ void SubDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta)
/* call _Sub function */
_SubDim(&a, &b, &c, n, beta);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n);
......
......@@ -224,17 +224,21 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
_Sum(&a, &b, &c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta);
}
}
else if(n >= 0 && n < a.order){
/* call _SumDim function */
_SumDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
}
else{
ShowNTErrors("Something is wrong!");
......@@ -261,9 +265,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
if (n == -1) {
/* call _Sum function */
_Sum(&a, &b, &c, beta);
if (c.enableGrad) {
/* tensor connections */
/* tensor connections */
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta);
}
......@@ -271,9 +275,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
else if (n >= 0 && n < a.order) {
/* call _SumDim function */
_SumDim(&a, &b, &c, n, beta);
if (c.enableGrad) {
/* tensor connections */
/* tensor connections */
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
......
......@@ -181,9 +181,11 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
_SumDim(&a, &b, &c, n, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
return c;
}
......@@ -210,7 +212,7 @@ void SumDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta)
/* call _SumDim function */
_SumDim(&a, &b, &c, n, beta);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n);
......@@ -353,9 +355,11 @@ XTensor SumBroadcast(const XTensor &a, const XTensor &b, DTYPE beta)
_SumBroadcast(&a, &b, &c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMBROADCAST);
XLink::AddParamToHead(&c, beta);
if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUMBROADCAST);
XLink::AddParamToHead(&c, beta);
}
return c;
}
......@@ -377,7 +381,7 @@ void SumBroadcast(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
/* call _SumBroadcast function */
_SumBroadcast(&a, &b, &c, beta);
if (c.enableGrad) {
if (a.enableGrad && b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMBROADCAST);
XLink::AddParamToHead(&c, beta);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include "../../XTensor.h"
#include "../../XName.h"
......@@ -121,7 +121,8 @@ XTensor ConvertDataType(const XTensor & input, TENSOR_DATA_TYPE dataType)
_ConvertDataType(&input, &output);
/* tensor connection */
XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE);
if(input.enableGrad)
XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE);
return output;
}
......@@ -136,7 +137,7 @@ void ConvertDataType(const XTensor & input, XTensor & output, TENSOR_DATA_TYPE d
_ConvertDataType(&input, &output);
/* tensor connection */
if (output.enableGrad)
if (input.enableGrad)
XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE);
}
......
......@@ -117,10 +117,12 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high)
_SelectRange(&a, &c, dim, low, high);
/* tensor connection */
XLink::MakeLink(&a, NULL, &c, GETANDSET_SELECT);
XLink::AddParamToHeadInt(&c, dim);
XLink::AddParamToHeadInt(&c, low);
XLink::AddParamToHeadInt(&c, high);
if (a.enableGrad) {
XLink::MakeLink(&a, NULL, &c, GETANDSET_SELECT);
XLink::AddParamToHeadInt(&c, dim);
XLink::AddParamToHeadInt(&c, low);
XLink::AddParamToHeadInt(&c, high);
}
/* destroy variables */
delete[] dimSize;
......
......@@ -167,7 +167,9 @@ XTensor funcName(const XTensor &a, T num)
XTensor b(&a); \
b.SetTMPFlag(); \
_funcName(&a, &b, num); \
XLink::MakeLink(&a, NULL, &b, operationId); \
if(a.enableGrad){ \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
XLink::AddParamToHead(&b, num); \
return b; \
} \
......@@ -183,7 +185,7 @@ void funcName(const XTensor &a, XTensor &b, T num)
InitTensor(&b, &a); \
} \
_funcName(&a, &b, num); \
if (b.enableGrad) { \
if (a.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \
XLink::AddParamToHead(&b, num); \
} \
......
......@@ -67,7 +67,7 @@ keep the result in the input tensor a and return nothing
*/
void _ClipMe(XTensor * a, DTYPE lower, DTYPE upper)
{
_Clip(a, a, lower, upper);
_Clip(a, a, lower, upper);
}
/*
......@@ -92,18 +92,20 @@ make a new tensor to keep the result and return it
*/
XTensor Clip(const XTensor & a, DTYPE lower, DTYPE upper)
{
XTensor b(&a);
b.SetTMPFlag();
XTensor b(&a);
b.SetTMPFlag();
/* call _Clip function */
_Clip(&a, &b, lower, upper);
/* call _Clip function */
_Clip(&a, &b, lower, upper);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_CLIP);
XLink::AddParamToHead(&b, lower);
XLink::AddParamToHead(&b, upper);
/* tensor connections */
if (a.enableGrad) {
XLink::MakeLink(&a, NULL, &b, MATH_CLIP);
XLink::AddParamToHead(&b, lower);
XLink::AddParamToHead(&b, upper);
}
return b;
return b;
}
void Clip(const XTensor & a, XTensor & b, DTYPE lower, DTYPE upper)
......@@ -115,8 +117,8 @@ void Clip(const XTensor & a, XTensor & b, DTYPE lower, DTYPE upper)
/* call _Clip function */
_Clip(&a, &b, lower, upper);
if (b.enableGrad) {
/* tensor connections */
/* tensor connections */
if (a.enableGrad) {
XLink::MakeLink(&a, NULL, &b, MATH_CLIP);
XLink::AddParamToHead(&b, lower);
XLink::AddParamToHead(&b, upper);
......
......@@ -173,9 +173,11 @@ XTensor Normalize(const XTensor &input, int dim,
list.Add((XTensor*)&var);
list.Add((XTensor*)&a);
list.Add((XTensor*)&b);
XLink::MakeLink(&list, &output, MATH_NORMALIZE);
XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHead(&output, epsilon);
if (input.enableGrad) {
XLink::MakeLink(&list, &output, MATH_NORMALIZE);
XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHead(&output, epsilon);
}
return output;
}
......@@ -208,7 +210,7 @@ void Normalize(const XTensor &input, XTensor &output, int dim,
/* call _Normalize function */
_Normalize(&input, &output, dim, &mean, &var, &a, &b, epsilon);
if (output.enableGrad == true) {
if (input.enableGrad == true) {
/* tensor connections */
TensorList list(5);
list.Add((XTensor*)&input);
......
......@@ -126,9 +126,11 @@ XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift)
_ScaleAndShift(&a, &b, scale, shift);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT);
XLink::AddParamToHead(&b, scale);
XLink::AddParamToHead(&b, shift);
if (a.enableGrad) {
XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT);
XLink::AddParamToHead(&b, scale);
XLink::AddParamToHead(&b, shift);
}
return b;
}
......@@ -152,7 +154,7 @@ void ScaleAndShift(const XTensor & a, XTensor & b, DTYPE scale, DTYPE shift)
/* call _ScaleAndShift function */
_ScaleAndShift(&a, &b, scale, shift);
if (b.enableGrad) {
if (a.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT);
XLink::AddParamToHead(&b, scale);
......
......@@ -151,7 +151,9 @@ XTensor funcName(const XTensor & a)
XTensor b(&a); \
b.SetTMPFlag(); \
_funcName(&a, &b); \
XLink::MakeLink(&a, NULL, &b, operationId); \
if(a.enableGrad){ \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
return b; \
}
......@@ -162,7 +164,7 @@ void funcName(const XTensor & a, XTensor & b)
InitTensor(&b, &a); \
} \
_funcName(&a, &b); \
if (b.enableGrad) { \
if (a.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
}
......
......@@ -258,10 +258,12 @@ XTensor CopyIndexed(const XTensor & s, int dim,
list.Add((XTensor*)&tgtIndex);
/* tensor connection */
XLink::MakeLink(&list, &t, MOVEMENT_COPYINDEXED);
XLink::AddParamToHeadInt(&t, dim);
XLink::AddParamToHeadInt(&t, copyNum);
if (s.enableGrad) {
XLink::MakeLink(&list, &t, MOVEMENT_COPYINDEXED);
XLink::AddParamToHeadInt(&t, dim);
XLink::AddParamToHeadInt(&t, copyNum);
}
/* destroy variables */
delete[] dimSize;
......@@ -314,13 +316,15 @@ XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, in
memcpy(saveTgtIndex, tgtIndex, indexSize * sizeof(int));
/* tensor connection */
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYINDEXED);
XLink::AddParamToHeadInt(&t, dim);
XLink::AddParamToHeadPointer(&t, saveSrcIndex);
XLink::AddParamToHeadInt(&t, indexSize);
XLink::AddParamToHeadPointer(&t, saveTgtIndex);
XLink::AddParamToHeadInt(&t, copyNum);
if (s.enableGrad) {
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYINDEXED);
XLink::AddParamToHeadInt(&t, dim);
XLink::AddParamToHeadPointer(&t, saveSrcIndex);
XLink::AddParamToHeadInt(&t, indexSize);
XLink::AddParamToHeadPointer(&t, saveTgtIndex);
XLink::AddParamToHeadInt(&t, copyNum);
}
/* destroy variables */
delete[] dimSize;
......
......@@ -134,7 +134,9 @@ XTensor CopyValues(const XTensor &s, XStream * stream)
_CopyValues(&s, &t, stream);
/* tensor connection */
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYVALUES);
if (s.enableGrad) {
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYVALUES);
}
return t;
}
......
......@@ -93,9 +93,11 @@ XTensor Gather(XTensor &s, XTensor &index)
_Gather(&s, &t, &index);
/* tensor connection */
XLink::MakeLink(&s, &index, &t, MOVEMENT_GATHER);
if (s.enableGrad) {
XLink::MakeLink(&s, &index, &t, MOVEMENT_GATHER);
}
return t;
}
} // namespace nts(NiuTrans.Tensor)
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -181,8 +181,10 @@ XTensor ReduceMax(const XTensor &input, int dim)
_ReduceMax(&input, &output, dim);
/* tensor connection */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX);
XLink::AddParamToHeadInt(&output, dim);
if (input.enableGrad) {
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX);
XLink::AddParamToHeadInt(&output, dim);
}
/* destroy variables */
delete[] dimSize;
......@@ -221,7 +223,7 @@ void ReduceMax(const XTensor &input, XTensor &output, int dim)
/* call _ReduceMax function */
_ReduceMax(&input, &output, dim);
if (output.enableGrad) {
if (input.enableGrad) {
/* tensor connections */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX);
XLink::AddParamToHeadInt(&output, dim);
......
......@@ -77,8 +77,10 @@ XTensor ReduceMean(const XTensor &input, int dim)
_ReduceMean(&input, &output, dim);
/* tensor connection */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMEAN);
XLink::AddParamToHeadInt(&output, dim);
if (input.enableGrad) {
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMEAN);
XLink::AddParamToHeadInt(&output, dim);
}
/* destroy variables */
delete[] dimSize;
......@@ -119,7 +121,7 @@ void ReduceMean(const XTensor &input, XTensor &output, int dim)
/* call _ReduceMean function */
_ReduceMean(&input, &output, dim);
if (output.enableGrad) {
if (input.enableGrad) {
/* tensor connections */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMEAN);
XLink::AddParamToHeadInt(&output, dim);
......
......@@ -306,10 +306,12 @@ XTensor ReduceSum(const XTensor &input, int dim, const XTensor &shift, DTYPE pow
_ReduceSum(&input, &output, dim, &shift, power, isExp);
/* tensor connection */
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUM);
XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHead(&output, power);
XLink::AddParamToHeadBool(&output, isExp);
if (input.enableGrad) {
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUM);
XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHead(&output, power);
XLink::AddParamToHeadBool(&output, isExp);
}
/* destroy variables */
delete[] dimSize;
......@@ -341,7 +343,7 @@ void ReduceSum(const XTensor &input, XTensor &output, int dim, const XTensor &sh
/* call _ReduceSum function */
_ReduceSum(&input, &output, dim, &shift, power, isExp);
if (output.enableGrad) {
if (input.enableGrad) {
/* tensor connections */
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUM);
XLink::AddParamToHeadInt(&output, dim);
......@@ -385,10 +387,12 @@ XTensor ReduceSum(const XTensor &input, int dim, DTYPE power, bool isExp)
_ReduceSum(&input, &output, dim, NULL, power, isExp);
/* tensor connection */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCESUM);
XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHead(&output, power);
XLink::AddParamToHeadBool(&output, isExp);
if (input.enableGrad) {
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCESUM);
XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHead(&output, power);
XLink::AddParamToHeadBool(&output, isExp);
}
/* destroy variables */
delete[] dimSize;
......@@ -434,7 +438,7 @@ void ReduceSum(const XTensor &input, XTensor &output, int dim, DTYPE power, bool
/* call _ReduceSum function */
_ReduceSum(&input, &output, dim, NULL, power, isExp);
if (output.enableGrad) {
if (input.enableGrad) {
/* tensor connections */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCESUM);
XLink::AddParamToHeadInt(&output, dim);
......
......@@ -73,8 +73,10 @@ XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift)
_ReduceSumSquared(&input, &output, dim, &shift);
/* tensor connection */
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUMSQUARED);
XLink::AddParamToHeadInt(&output, dim);
if (input.enableGrad) {
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUMSQUARED);
XLink::AddParamToHeadInt(&output, dim);
}
/* destroy variables */
delete[] dimSize;
......@@ -116,7 +118,7 @@ void ReduceSumSquared(const XTensor &input, XTensor &output, int dim, const XTen
/* call _ReduceSumSquared function */
_ReduceSumSquared(&input, &output, dim, &shift);
if (output.enableGrad) {
if (input.enableGrad) {
/* tensor connections */
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUMSQUARED);
XLink::AddParamToHeadInt(&output, dim);
......
......@@ -76,8 +76,10 @@ XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean)
_ReduceVariance(&input, &output, dim, &mean);
/* tensor connection */
XLink::MakeLink(&input, &mean, &output, REDUCE_REDUCEVARIANCE);
XLink::AddParamToHeadInt(&output, dim);
if (input.enableGrad) {
XLink::MakeLink(&input, &mean, &output, REDUCE_REDUCEVARIANCE);
XLink::AddParamToHeadInt(&output, dim);
}
/* destroy variables */
delete[] dimSize;
......@@ -119,7 +121,7 @@ void ReduceVariance(const XTensor &input, XTensor &output, int dim, const XTenso
/* call _ReduceVariance function */
_ReduceVariance(&input, &output, dim, &mean);
if (output.enableGrad) {
if (input.enableGrad) {
/* tensor connection */
XLink::MakeLink(&input, &mean, &output, REDUCE_REDUCEVARIANCE);
XLink::AddParamToHeadInt(&output, dim);
......
......@@ -99,9 +99,11 @@ XTensor Concatenate(const TensorList &smalls, int dim)
_Merge(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
if (tensor->enableGrad) {
XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
}
/* destroy variables */
delete[] dimSize;
......@@ -127,8 +129,10 @@ XTensor Concatenate(const TensorList &smalls, int dim)
_ConcatenateSolely(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
if (tensor->enableGrad) {
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
}
/* destroy variables */
delete[] dimSize;
......@@ -309,9 +313,11 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
_Merge(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
if (tensor->enableGrad) {
XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
}
/* destroy variables */
delete[] dimSize;
......@@ -337,8 +343,10 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
_ConcatenateSolely(&smalls, &big, dim);
/* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
if (tensor->enableGrad) {
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
}
/* destroy variables */
delete[] dimSize;
......
......@@ -222,9 +222,11 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
_Merge(&s, &t, whereToMerge, leadingDim);
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
XLink::AddParamToHeadInt(&t, whereToMerge);
XLink::AddParamToHeadInt(&t, leadingDim);
if (s.enableGrad) {
XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
XLink::AddParamToHeadInt(&t, whereToMerge);
XLink::AddParamToHeadInt(&t, leadingDim);
}
/* destroy variables */
delete[] dimSize;
......@@ -261,7 +263,7 @@ void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim)
/* call _Merge function */
_Merge(&s, &t, whereToMerge, leadingDim);
if (t.enableGrad) {
if (s.enableGrad) {
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
XLink::AddParamToHeadInt(&t, whereToMerge);
......@@ -412,8 +414,10 @@ XTensor Merge(const TensorList &smalls, int whereToMerge)
_Merge(&smalls, &big, whereToMerge);
/* tensor connections */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
XLink::AddParamToHeadInt(&big, whereToMerge);
if (tensor->enableGrad) {
XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
XLink::AddParamToHeadInt(&big, whereToMerge);
}
/* destroy variables */
delete[] dimSize;
......@@ -453,8 +457,10 @@ XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge)
_Merge(&smalls, &big, whereToMerge);
/* tensor connections */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
XLink::AddParamToHeadInt(&big, whereToMerge);
if (smallA.enableGrad) {
XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
XLink::AddParamToHeadInt(&big, whereToMerge);
}
/* destroy variables */
delete[] dimSize;
......
......@@ -43,7 +43,9 @@ XTensor Reshape(XTensor &s, int order, int * dimSize)
t.Reshape(order, dimSize);
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE);
if (s.enableGrad) {
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE);
}
return t;
}
......@@ -57,7 +59,7 @@ void Reshape(XTensor &s, XTensor &t, int order, int * dimSize)
/* call Reshape function */
t.Reshape(order, dimSize);
if (t.enableGrad) {
if (s.enableGrad) {
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE);
}
......
......@@ -217,9 +217,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
_Split(&s, &t, whereToSplit, splitNum);
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT);
XLink::AddParamToHeadInt(&t, whereToSplit);
XLink::AddParamToHeadInt(&t, splitNum);
if (s.enableGrad) {
XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT);
XLink::AddParamToHeadInt(&t, whereToSplit);
XLink::AddParamToHeadInt(&t, splitNum);
}
/* destroy variables */
delete[] dimSize;
......@@ -251,7 +253,7 @@ void Split(const XTensor &s, XTensor &t, int whereToSplit, int splitNum)
/* call _Split function */
_Split(&s, &t, whereToSplit, splitNum);
if (t.enableGrad) {
if (s.enableGrad) {
/* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT);
XLink::AddParamToHeadInt(&t, whereToSplit);
......@@ -409,12 +411,15 @@ void Split(const XTensor &big, TensorList &smalls, int whereToSplit, int splitNu
/* tensor connections */
for(int i = 0; i < smalls.count; i++){
XTensor * s = (XTensor*)smalls.Get(i);
XLink::MakeLink(&big, NULL, s, SHAPE_SPLIT_LIST);
XLink::AddParamToHeadInt(s, whereToSplit);
/* it is tricky here that we keep the id of each
block, rather than the total number of the splits */
XLink::AddParamToHeadInt(s, i);
if (s->enableGrad) {
XLink::MakeLink(&big, NULL, s, SHAPE_SPLIT_LIST);
XLink::AddParamToHeadInt(s, whereToSplit);
/* it is tricky here that we keep the id of each
block, rather than the total number of the splits */
XLink::AddParamToHeadInt(s, i);
}
}
}
......
......@@ -121,7 +121,9 @@ XTensor Squeeze(XTensor & source, int leadingDim)
_Squeeze(&source, &target, leadingDim);
/* tensor connections */
XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE);
if (source.enableGrad) {
XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE);
}
return target;
}
......@@ -135,7 +137,7 @@ void Squeeze(XTensor & source, XTensor & target, int leadingDim)
/* call _Squeeze function */
_Squeeze(&source, &target, leadingDim);
if (target.enableGrad) {
if (source.enableGrad) {
/* tensor connections */
XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE);
}
......
......@@ -144,9 +144,11 @@ XTensor Transpose(const XTensor &a, const int i, const int j)
_Transpose(&a, &b, i, j);
/* tensor connection */
XLink::MakeLink(&a, NULL, &b, SHAPE_TRANSPOSE);
XLink::AddParamToHeadInt(&b, i);
XLink::AddParamToHeadInt(&b, j);
if (a.enableGrad) {
XLink::MakeLink(&a, NULL, &b, SHAPE_TRANSPOSE);
XLink::AddParamToHeadInt(&b, i);
XLink::AddParamToHeadInt(&b, j);
}
/* destroy variables */
delete[] dimSize;
......
......@@ -156,9 +156,11 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize)
_Unsqueeze(&a, &b, dim, dSize);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE);
XLink::AddParamToHeadInt(&b, dim);
XLink::AddParamToHeadInt(&b, dSize);
if (a.enableGrad) {
XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE);
XLink::AddParamToHeadInt(&b, dim);
XLink::AddParamToHeadInt(&b, dSize);
}
/* destroy variables */
delete[] dimSize;
......@@ -191,7 +193,7 @@ void Unsqueeze(const XTensor &a, XTensor &b, int dim, int dSize)
/* call _Unsqueeze function */
_Unsqueeze(&a, &b, dim, dSize);
if (b.enableGrad) {
if (a.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE);
XLink::AddParamToHeadInt(&b, dim);
......
......@@ -81,8 +81,10 @@ XTensor DropoutWithIndex(const XTensor &x, XTensor &maskIndex, DTYPE scale)
_ScaleAndShiftMe(&c, scale);
/* tensor connections */
XLink::MakeLink(&x, &maskIndex, &c, MOVEMENT_DROPOUTWITHINDEX);
XLink::AddParamToHead(&c, scale);
if (x.enableGrad) {
XLink::MakeLink(&x, &maskIndex, &c, MOVEMENT_DROPOUTWITHINDEX);
XLink::AddParamToHead(&c, scale);
}
return c;
}
......
......@@ -78,7 +78,9 @@ XTensor HardTanH(const XTensor &x)
_HardTanH(&x, &y);
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_HARDTANH);
if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_HARDTANH);
}
return y;
}
......@@ -92,7 +94,7 @@ void HardTanH(const XTensor &x, XTensor &y)
/* call _HardTanH function */
_HardTanH(&x, &y);
if (y.enableGrad) {
if (x.enableGrad) {
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_HARDTANH);
}
......
......@@ -54,7 +54,9 @@ XTensor Identity(const XTensor &x)
_Identity(&x, &y);
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY);
if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY);
}
return y;
}
......@@ -68,7 +70,7 @@ void Identity(const XTensor &x, XTensor &y)
/* call _Identity function */
_Identity(&x, &y);
if (y.enableGrad) {
if (x.enableGrad) {
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY);
}
......
......@@ -188,8 +188,10 @@ XTensor LogSoftmax(const XTensor &x, int leadDim)
_LogSoftmax(&x, &y, ld);
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
}
return y;
}
......@@ -215,7 +217,7 @@ void LogSoftmax(const XTensor &x, XTensor &y, int leadDim)
/* call _LogSoftmax function */
_LogSoftmax(&x, &y, ld);
if (y.enableGrad) {
if (x.enableGrad) {
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
......
......@@ -70,7 +70,9 @@ XTensor Rectify(const XTensor &x)
_Rectify(&x, &y);
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_RECTIFY);
if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_RECTIFY);
}
return y;
}
......@@ -84,7 +86,7 @@ void Rectify(const XTensor &x, XTensor &y)
/* call _Rectify function */
_Rectify(&x, &y);
if (y.enableGrad) {
if (x.enableGrad) {
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_RECTIFY);
}
......
......@@ -73,7 +73,9 @@ XTensor Sigmoid(const XTensor &x)
_Sigmoid(&x, &y);
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SIGMOID);
if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_SIGMOID);
}
return y;
}
......@@ -87,7 +89,7 @@ void Sigmoid(const XTensor &x, XTensor &y)
/* call _Sigmoid function */
_Sigmoid(&x, &y);
if (y.enableGrad) {
if (x.enableGrad) {
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SIGMOID);
}
......
......@@ -142,8 +142,10 @@ XTensor Softmax(const XTensor &x, int leadDim)
_Softmax(&x, &y, ld);
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
}
return y;
}
......@@ -161,7 +163,7 @@ void Softmax(const XTensor &x, XTensor &y, int leadDim)
/* call _Softmax function */
_Softmax(&x, &y, ld);
if (y.enableGrad) {
if (x.enableGrad) {
/* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
......
......@@ -277,8 +277,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
tails.Add((XTensor*)&gold);
tails.Add(weight);
tails.Add(padding);
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim);
if (output.enableGrad) {
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim);
}
return loss;
}
......@@ -302,8 +305,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
tails.Add((XTensor*)&gold);
tails.Add(weight);
tails.Add((XTensor*)&padding);
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim);
if (output.enableGrad) {
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim);
}
return loss;
}
......
......@@ -421,7 +421,7 @@ bool TestSetData6()
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE answer[5] = { 5.2, 3.2, 1.2, -0.8, -2.8 };
DTYPE answer[5] = {5.2F, 3.2F, 1.2F, -0.8F, -2.8F};
/* CPU test */
bool cpuTest = true;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论