Commit d221ef9d by xuchen

merge with zhangyuhao branch

parent e0d86e5b
...@@ -300,6 +300,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id ...@@ -300,6 +300,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id
if(h == NULL) if(h == NULL)
return; return;
if (!t1->enableGrad)
return;
TensorList list(2); TensorList list(2);
list.Add((XTensor*)t1); list.Add((XTensor*)t1);
list.Add((XTensor*)t2); list.Add((XTensor*)t2);
...@@ -320,6 +323,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, const XTensor * t3, ...@@ -320,6 +323,9 @@ void XLink::MakeLink(const XTensor * t1, const XTensor * t2, const XTensor * t3,
if (h == NULL) if (h == NULL)
return; return;
if (!t1->enableGrad || !t2->enableGrad)
return;
TensorList list(3); TensorList list(3);
list.Add((XTensor*)t1); list.Add((XTensor*)t1);
list.Add((XTensor*)t2); list.Add((XTensor*)t2);
...@@ -370,6 +376,9 @@ create a hyper edge with a input tensors and a list of output tensors ...@@ -370,6 +376,9 @@ create a hyper edge with a input tensors and a list of output tensors
*/ */
void XLink::MakeLink(XTensor * t, TensorList * list, int id) void XLink::MakeLink(XTensor * t, TensorList * list, int id)
{ {
if (!t->enableGrad)
return;
/* forward */ /* forward */
for(int i = 0; i < list->count; i++){ for(int i = 0; i < list->count; i++){
XTensor * h = (XTensor*)list->GetItem(i); XTensor * h = (XTensor*)list->GetItem(i);
......
...@@ -23,15 +23,11 @@ ...@@ -23,15 +23,11 @@
* *
*/ */
#include "XList.h" #include "time.h"
#include "XMem.h" #include "XMem.h"
#include "XList.h"
#include "XGlobal.h" #include "XGlobal.h"
#include <ctime>
#include <utility>
#include <algorithm>
/* the nts (NiuTrans.Tensor) namespace */ /* the nts (NiuTrans.Tensor) namespace */
namespace nts { namespace nts {
...@@ -78,7 +74,8 @@ TensorListBase<T>::TensorListBase(int myMaxNum, XMem* myMem) ...@@ -78,7 +74,8 @@ TensorListBase<T>::TensorListBase(int myMaxNum, XMem* myMem)
template <typename T> template <typename T>
TensorListBase<T>::~TensorListBase() TensorListBase<T>::~TensorListBase()
{ {
delete[] items; if(items && mem)
delete[] items;
} }
...@@ -103,6 +100,13 @@ void TensorListBase<T>::Add(T&& item) ...@@ -103,6 +100,13 @@ void TensorListBase<T>::Add(T&& item)
items[count++] = item; items[count++] = item;
} }
/* return number of elements */
template<typename T>
size_t TensorListBase<T>::Size()
{
return count;
}
/* /*
add an item into the list add an item into the list
>> item - a const reference to the item >> item - a const reference to the item
...@@ -130,7 +134,7 @@ add a number of items into the list ...@@ -130,7 +134,7 @@ add a number of items into the list
>> inputItemCount - number of input items >> inputItemCount - number of input items
*/ */
template <typename T> template <typename T>
void TensorListBase<T>::Add(T* inputItems, int inputItemCount) void TensorListBase<T>::Add(const T* inputItems, int inputItemCount)
{ {
if (count + inputItemCount >= maxNum) { if (count + inputItemCount >= maxNum) {
int newMaxNum = (count + inputItemCount) * 2 + 1; int newMaxNum = (count + inputItemCount) * 2 + 1;
...@@ -206,10 +210,10 @@ void TensorListBase<T>::Insert(int pos, T&& item) ...@@ -206,10 +210,10 @@ void TensorListBase<T>::Insert(int pos, T&& item)
template <typename T> template <typename T>
T& TensorListBase<T>::GetItem(int i) const T& TensorListBase<T>::GetItem(int i) const
{ {
CheckNTErrors(i >= -1 && i < count, "Index of a list item is out of scope!"); CheckNTErrors(i >= -count && i < count, "Index of a list item is out of scope!");
CheckNTErrors(count > 0, "Cannt index the item in an empty list!"); CheckNTErrors(count > 0, "Cannt index the item in an empty list!");
if (i == -1) if (i < 0)
return items[count - 1]; return items[count + i];
else else
return items[i]; return items[i];
} }
...@@ -226,7 +230,7 @@ template<typename T> ...@@ -226,7 +230,7 @@ template<typename T>
inline void TensorListBase<T>::SetItem(int i, T&& item) inline void TensorListBase<T>::SetItem(int i, T&& item)
{ {
if (i >= 0 && i < count) if (i >= 0 && i < count)
items[i] = std::move(item); items[i] = item;
} }
/* /*
...@@ -245,6 +249,26 @@ inline int TensorListBase<T>::FindFirst(const T& item) ...@@ -245,6 +249,26 @@ inline int TensorListBase<T>::FindFirst(const T& item)
return -1; return -1;
} }
template <>
inline int TensorListBase<Example>::FindFirst(const Example& item)
{
for (int i = 0; i < count; i++) {
if (item.id == items[i].id)
return i;
}
return -1;
}
template <>
inline int TensorListBase<Result>::FindFirst(const Result& item)
{
for (int i = 0; i < count; i++) {
if (item.id == items[i].id)
return i;
}
return -1;
}
/* clear the data array */ /* clear the data array */
template <typename T> template <typename T>
void TensorListBase<T>::Clear() void TensorListBase<T>::Clear()
...@@ -294,6 +318,17 @@ void TensorListBase<T>::Remove(int i) ...@@ -294,6 +318,17 @@ void TensorListBase<T>::Remove(int i)
count--; count--;
} }
template<typename T>
void TensorListBase<T>::Reserve(int n)
{
if (items) {
/* reserve failed */
return;
}
items = new T[n];
}
/* /*
copy the list copy the list
>> myMem - memory pool used for allocating the data in the new list >> myMem - memory pool used for allocating the data in the new list
...@@ -348,6 +383,8 @@ template struct TensorListBase<long>; ...@@ -348,6 +383,8 @@ template struct TensorListBase<long>;
template struct TensorListBase<float>; template struct TensorListBase<float>;
template struct TensorListBase<short>; template struct TensorListBase<short>;
template struct TensorListBase<XTensor*>; template struct TensorListBase<XTensor*>;
template struct TensorListBase<Result>;
template struct TensorListBase<Example>;
template struct TensorListBase<void*>; template struct TensorListBase<void*>;
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
\ No newline at end of file
...@@ -66,11 +66,14 @@ public: ...@@ -66,11 +66,14 @@ public:
/* add an item into the list */ /* add an item into the list */
void Add(T&& item); void Add(T&& item);
/* return number of elements */
size_t Size();
/* add an item into the list */ /* add an item into the list */
void Add(const T& item); void Add(const T& item);
/* add a number of items into the list */ /* add a number of items into the list */
void Add(T* inputItems, int inputItemCount); void Add(const T* inputItems, int inputItemCount);
/* append a list to the current list */ /* append a list to the current list */
void AddList(TensorListBase* l); void AddList(TensorListBase* l);
...@@ -105,6 +108,9 @@ public: ...@@ -105,6 +108,9 @@ public:
/* remove the item at position i */ /* remove the item at position i */
void Remove(int i); void Remove(int i);
/* reserve space for data entry */
void Reserve(int n);
/* copy the list */ /* copy the list */
TensorListBase* Copy(XMem* myMem); TensorListBase* Copy(XMem* myMem);
...@@ -112,22 +118,33 @@ public: ...@@ -112,22 +118,33 @@ public:
void Shuffle(int nround = 10, int beg = -1, int len = 0); void Shuffle(int nround = 10, int beg = -1, int len = 0);
/* short */ /* short */
T& operator[] (int i) { T& operator[] (int i) { return GetItem(i); };
return GetItem(i);
};
T& Get(int i) { return GetItem(i); }; T& Get(int i) { return GetItem(i); };
void Set(int i, T item) { SetItem(i, item); }; void Set(int i, T item) { SetItem(i, item); };
}; };
struct XTensor; struct XTensor;
typedef TensorListBase<void*> XList;
typedef TensorListBase<int> IntList; typedef TensorListBase<int> IntList;
typedef TensorListBase<char> CharList; typedef TensorListBase<char> CharList;
typedef TensorListBase<char*> StrList; typedef TensorListBase<char*> StrList;
typedef TensorListBase<long> LongList; typedef TensorListBase<long> LongList;
typedef TensorListBase<float> FloatList; typedef TensorListBase<float> FloatList;
typedef TensorListBase<short> ShortList; typedef TensorListBase<short> ShortList;
typedef TensorListBase<void*> XList;
struct Example {
int id;
IntList data;
};
struct Result {
int id;
IntList data;
};
typedef TensorListBase<Result> ResultList;
typedef TensorListBase<Example> ExampleList;
typedef TensorListBase<XTensor*> TensorList; typedef TensorListBase<XTensor*> TensorList;
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
......
...@@ -1916,6 +1916,26 @@ void XTensor::Dump(const XTensor * tensor, FILE * file, const char * label, cons ...@@ -1916,6 +1916,26 @@ void XTensor::Dump(const XTensor * tensor, FILE * file, const char * label, cons
} }
/* /*
dump data to a binary file
>> file - where to dump the data
*/
void XTensor::BinaryDump(FILE* file)
{
XTensor tmp;
InitTensorOnCPU(&tmp, this);
_CopyValues(this, &tmp);
switch (dataType) {
case X_INT: {
fwrite(tmp.data, sizeof(int), unitNum, file);
}
default: {
fwrite(tmp.data, sizeof(float), unitNum, file);
}
}
}
/*
read data from a file read data from a file
>> file - where to load the data >> file - where to load the data
>> label - label of the tensor >> label - label of the tensor
...@@ -2027,6 +2047,30 @@ void XTensor::Read(FILE * file, const char * label) ...@@ -2027,6 +2047,30 @@ void XTensor::Read(FILE * file, const char * label)
delete[](char*)dataBuf; delete[](char*)dataBuf;
} }
/*
read data from a binary file
>>> file - the file stream pointer
>>> offset - the distance from the start to this tensor
*/
void XTensor::BinaryRead(FILE* file, size_t offset)
{
fseek(file, offset, 0);
switch (dataType) {
case X_INT: {
int * d = new int[unitNum];
fread(d, sizeof(int), unitNum, file);
SetData(d, unitNum);
delete[] d;
}
default: {
float * d = new float[unitNum];
fread(d, sizeof(float), unitNum, file);
SetData(d, unitNum);
delete[] d;
}
}
}
/* /*
flush the data to the target device flush the data to the target device
>> targetMem - memory pool on the target device >> targetMem - memory pool on the target device
......
...@@ -433,9 +433,15 @@ public: ...@@ -433,9 +433,15 @@ public:
static static
void Dump(const XTensor * tensor, FILE * file, const char * label = NULL, const int n = -1, const int beg = 0, const int verbose = 0); void Dump(const XTensor * tensor, FILE * file, const char * label = NULL, const int n = -1, const int beg = 0, const int verbose = 0);
/* dump data to a binary file */
void BinaryDump(FILE * file);
/* read data from a file */ /* read data from a file */
void Read(FILE * file, const char * label = NULL); void Read(FILE * file, const char * label = NULL);
/* read data from a binary file */
void BinaryRead(FILE * file, size_t offset);
/* flush the data to the target device */ /* flush the data to the target device */
void FlushToMem(XMem * targetMem); void FlushToMem(XMem * targetMem);
......
...@@ -215,18 +215,22 @@ XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim) ...@@ -215,18 +215,22 @@ XTensor Div(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim)
_Div(&a, &b, &c, alpha, leadingDim); _Div(&a, &b, &c, alpha, leadingDim);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHead(&c, alpha); XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHeadInt(&c, leadingDim); XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
} }
else if(n >= 0 && n < a.order){ else if(n >= 0 && n < a.order){
/* call _DivDim function */ /* call _DivDim function */
_DivDim(&a, &b, &c, n, alpha); _DivDim(&a, &b, &c, n, alpha);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
} }
else{ else{
ShowNTErrors("Something is wrong!"); ShowNTErrors("Something is wrong!");
...@@ -261,7 +265,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin ...@@ -261,7 +265,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
/* call _Div function */ /* call _Div function */
_Div(&a, &b, &c, 0, leadingDim); _Div(&a, &b, &c, 0, leadingDim);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV); XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHead(&c, alpha);
...@@ -272,7 +276,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin ...@@ -272,7 +276,7 @@ void Div(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int leadin
/* call _DivDim function */ /* call _DivDim function */
_DivDim(&a, &b, &c, n, alpha); _DivDim(&a, &b, &c, n, alpha);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM); XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
......
...@@ -164,10 +164,12 @@ XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha) ...@@ -164,10 +164,12 @@ XTensor DivDim(const XTensor &a, const XTensor &b, int n, DTYPE alpha)
_DivDim(&a, &b, &c, n, alpha); _DivDim(&a, &b, &c, n, alpha);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
return c; return c;
} }
...@@ -193,7 +195,7 @@ void DivDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE alpha) ...@@ -193,7 +195,7 @@ void DivDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE alpha)
/* call _Div function */ /* call _Div function */
_DivDim(&a, &b, &c, n, alpha); _DivDim(&a, &b, &c, n, alpha);
if (c.enableGrad == true) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIVDIM); XLink::MakeLink(&a, &b, &c, MATH_DIVDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
......
...@@ -155,8 +155,10 @@ XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha) ...@@ -155,8 +155,10 @@ XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha)
_Mask(&a, &mask, &c, alpha); _Mask(&a, &mask, &c, alpha);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &mask, &c, MATH_MASK); if (a.enableGrad) {
XLink::AddParamToHead(&c, alpha); XLink::MakeLink(&a, &mask, &c, MATH_MASK);
XLink::AddParamToHead(&c, alpha);
}
return c; return c;
} }
...@@ -176,7 +178,7 @@ void Mask(const XTensor &a, const XTensor &mask, XTensor &c, DTYPE alpha) ...@@ -176,7 +178,7 @@ void Mask(const XTensor &a, const XTensor &mask, XTensor &c, DTYPE alpha)
/* call _Mask function */ /* call _Mask function */
_Mask(&a, &mask, &c, alpha); _Mask(&a, &mask, &c, alpha);
if (c.enableGrad) { if (a.enableGrad) {
XLink::MakeLink(&a, &mask, &c, MATH_MASK); XLink::MakeLink(&a, &mask, &c, MATH_MASK);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHead(&c, alpha);
} }
......
...@@ -296,10 +296,12 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, ...@@ -296,10 +296,12 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
_MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner); _MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadTrans(&c, transposedA); XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, transposedB); XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadTrans(&c, transposedB);
XLink::AddParamToHead(&c, alpha);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -344,7 +346,7 @@ void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, ...@@ -344,7 +346,7 @@ void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
/* call _MatrixMul function */ /* call _MatrixMul function */
_MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, beta, parallelRunner); _MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, beta, parallelRunner);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL); XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, transposedA); XLink::AddParamToHeadTrans(&c, transposedA);
...@@ -393,10 +395,12 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b, ...@@ -393,10 +395,12 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b,
_MatrixMul(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner); _MatrixMul(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -440,7 +444,7 @@ void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c, ...@@ -440,7 +444,7 @@ void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c,
/* call _MatrixMul function */ /* call _MatrixMul function */
_MatrixMul(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner); _MatrixMul(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL); XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
......
...@@ -314,10 +314,12 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const ...@@ -314,10 +314,12 @@ XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const
_MatrixMulBatched(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner); _MatrixMulBatched(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadTrans(&c, transposedA); XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED);
XLink::AddParamToHeadTrans(&c, transposedB); XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadTrans(&c, transposedB);
XLink::AddParamToHead(&c, alpha);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -370,10 +372,12 @@ XTensor MatrixMulBatched(const XTensor &a, const XTensor &b, ...@@ -370,10 +372,12 @@ XTensor MatrixMulBatched(const XTensor &a, const XTensor &b,
_MatrixMulBatched(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner); _MatrixMulBatched(&a, X_NOTRANS, &b, X_NOTRANS, &c, alpha, 0, parallelRunner);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
XLink::AddParamToHead(&c, alpha);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -118,11 +118,87 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, ...@@ -118,11 +118,87 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
} }
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT); if (w.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, X_NOTRANS); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
//XLink::AddParamToHead(&c, beta); XLink::AddParamToHeadTrans(&c, X_NOTRANS);
}
/* destroy variables */
delete[] dimSize;
DelTensorBuf(tmp);
return c;
}
/*
operation c = x * w + b MulAndShift
>> x - tensor x
>> w - tensor w
>> b - tensor b
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
*/
XTensor MulAndShift(const XTensor& x, MATRIX_TRANS_TYPE transposedA,
const XTensor& w, MATRIX_TRANS_TYPE transposedB,
const XTensor& b, DTYPE alpha, XPRunner* parallelRunner)
{
CheckNTErrors(x.dataType == w.dataType, "Input tensors should have the same data type!");
CheckNTErrors(x.order >= 2 && w.order >= 2, "Input tensors must have a order >= 2!");
int xn = transposedA == X_TRANS ? x.dimSizeRDI[0] : x.dimSizeRDI[1];
int xm = transposedA == X_TRANS ? x.dimSizeRDI[1] : x.dimSizeRDI[0];
int wn = transposedB == X_TRANS ? w.dimSizeRDI[0] : w.dimSizeRDI[1];
int wm = transposedB == X_TRANS ? w.dimSizeRDI[1] : w.dimSizeRDI[0];
int order = x.order + w.order - 2;
int sub = 0;
int * dimSize = new int[order];
for (int i = 2; i < x.order; i++)
dimSize[sub++] = x.dimSizeRDI[x.order + 1 - i];
for (int i = 2; i < w.order; i++)
dimSize[sub++] = w.dimSizeRDI[w.order + 1 - i];
dimSize[sub++] = xn;
dimSize[sub++] = wm;
float dr = (!x.isSparse || !w.isSparse) ? 1.0F : MAX(x.denseRatio, w.denseRatio);
XTensor * tmp = NewTensorBuf(order, dimSize, x.dataType, dr, x.devID, x.mem);
/* call _MatrixMul function */
_MatrixMul(&x, transposedA, &w, transposedB, tmp, alpha, 0, parallelRunner);
XTensor c(tmp);
c.SetTMPFlag();
int n = GetSumIndex(tmp, b);
if (n == -1) {
/* call _Sum function */
_Sum(tmp, &b, &c);
// TODO!!
ShowNTErrors("TODO!");
}
else if (n >= 0 && n < tmp->order) {
/* call _SumDim function */
_SumDim(tmp, &b, &c, n);
}
else {
ShowNTErrors("Something is wrong!");
}
/* tensor connections */
if (w.enableGrad && b.enableGrad) {
XLink::MakeLink(&x, &w, &b, &c, MATH_MULANDSHIFT);
XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHeadTrans(&c, transposedB);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -31,6 +31,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -31,6 +31,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL); DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
XTensor MulAndShift(const XTensor &x, MATRIX_TRANS_TYPE transposedA,
const XTensor &w, MATRIX_TRANS_TYPE transposedB,
const XTensor &b, DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -216,18 +216,22 @@ XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim ...@@ -216,18 +216,22 @@ XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim
_Multiply(&a, &b, &c, 0, leadingDim); _Multiply(&a, &b, &c, 0, leadingDim);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHead(&c, alpha); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHeadInt(&c, leadingDim); XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHeadInt(&c, leadingDim);
}
} }
else if(n >= 0 && n < a.order){ else if(n >= 0 && n < a.order){
/* call _MultiplyDim function */ /* call _MultiplyDim function */
_MultiplyDim(&a, &b, &c, n, alpha); _MultiplyDim(&a, &b, &c, n, alpha);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, alpha);
}
} }
else{ else{
ShowNTErrors("Something is wrong!"); ShowNTErrors("Something is wrong!");
...@@ -262,7 +266,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l ...@@ -262,7 +266,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
/* call _Multiply function */ /* call _Multiply function */
_Multiply(&a, &b, &c, 0, leadingDim); _Multiply(&a, &b, &c, 0, leadingDim);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLY);
XLink::AddParamToHead(&c, alpha); XLink::AddParamToHead(&c, alpha);
...@@ -273,7 +277,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l ...@@ -273,7 +277,7 @@ void Multiply(const XTensor &a, const XTensor &b, XTensor &c, DTYPE alpha, int l
/* call _MultiplyDim function */ /* call _MultiplyDim function */
_MultiplyDim(&a, &b, &c, n, alpha); _MultiplyDim(&a, &b, &c, n, alpha);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
......
...@@ -180,9 +180,11 @@ XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n) ...@@ -180,9 +180,11 @@ XTensor MultiplyDim(const XTensor &a, const XTensor &b, int n)
_MultiplyDim(&a, &b, &c, n, 0); _MultiplyDim(&a, &b, &c, n, 0);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHead(&c, 0); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, 0);
}
return c; return c;
} }
...@@ -208,7 +210,7 @@ void MultiplyDim(const XTensor &a, const XTensor &b, XTensor &c, int n) ...@@ -208,7 +210,7 @@ void MultiplyDim(const XTensor &a, const XTensor &b, XTensor &c, int n)
/* call _Multiply function */ /* call _Multiply function */
_MultiplyDim(&a, &b, &c, n, 0); _MultiplyDim(&a, &b, &c, n, 0);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
...@@ -350,8 +352,10 @@ XTensor MultiplyBroadcast(const XTensor &a, const XTensor &b) ...@@ -350,8 +352,10 @@ XTensor MultiplyBroadcast(const XTensor &a, const XTensor &b)
_MultiplyBroadcast(&a, &b, &c, 0); _MultiplyBroadcast(&a, &b, &c, 0);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYBROADCAST); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHead(&c, 0); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYBROADCAST);
XLink::AddParamToHead(&c, 0);
}
return c; return c;
} }
...@@ -374,7 +378,7 @@ void MultiplyBroadcast(const XTensor &a, const XTensor &b, XTensor &c) ...@@ -374,7 +378,7 @@ void MultiplyBroadcast(const XTensor &a, const XTensor &b, XTensor &c)
/* call _SumBroadcast function */ /* call _SumBroadcast function */
_MultiplyBroadcast(&a, &b, &c, 0); _MultiplyBroadcast(&a, &b, &c, 0);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYBROADCAST); XLink::MakeLink(&a, &b, &c, MATH_MULTIPLYBROADCAST);
XLink::AddParamToHead(&c, 0); XLink::AddParamToHead(&c, 0);
......
...@@ -190,17 +190,21 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta) ...@@ -190,17 +190,21 @@ XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
_Sub(&a, &b, &c, beta); _Sub(&a, &b, &c, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHead(&c, beta); XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
}
} }
else if(n >= 0 && n < a.order){ else if(n >= 0 && n < a.order){
/* call _SubDim function */ /* call _SubDim function */
_SubDim(&a, &b, &c, n, beta); _SubDim(&a, &b, &c, n, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
} }
else{ else{
ShowNTErrors("Something is wrong!"); ShowNTErrors("Something is wrong!");
...@@ -229,7 +233,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) ...@@ -229,7 +233,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
/* call _Sub function */ /* call _Sub function */
_Sub(&a, &b, &c, beta); _Sub(&a, &b, &c, beta);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB); XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
...@@ -239,7 +243,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) ...@@ -239,7 +243,7 @@ void Sub(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
/* call _SubDim function */ /* call _SubDim function */
_SubDim(&a, &b, &c, n, beta); _SubDim(&a, &b, &c, n, beta);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM); XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
......
...@@ -164,9 +164,11 @@ XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta) ...@@ -164,9 +164,11 @@ XTensor SubDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
_SubDim(&a, &b, &c, n, beta); _SubDim(&a, &b, &c, n, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
return c; return c;
} }
...@@ -193,7 +195,7 @@ void SubDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta) ...@@ -193,7 +195,7 @@ void SubDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta)
/* call _Sub function */ /* call _Sub function */
_SubDim(&a, &b, &c, n, beta); _SubDim(&a, &b, &c, n, beta);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUBDIM); XLink::MakeLink(&a, &b, &c, MATH_SUBDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
......
...@@ -224,17 +224,21 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta) ...@@ -224,17 +224,21 @@ XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
_Sum(&a, &b, &c, beta); _Sum(&a, &b, &c, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHead(&c, beta); XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta);
}
} }
else if(n >= 0 && n < a.order){ else if(n >= 0 && n < a.order){
/* call _SumDim function */ /* call _SumDim function */
_SumDim(&a, &b, &c, n, beta); _SumDim(&a, &b, &c, n, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
} }
else{ else{
ShowNTErrors("Something is wrong!"); ShowNTErrors("Something is wrong!");
...@@ -261,9 +265,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) ...@@ -261,9 +265,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
if (n == -1) { if (n == -1) {
/* call _Sum function */ /* call _Sum function */
_Sum(&a, &b, &c, beta); _Sum(&a, &b, &c, beta);
if (c.enableGrad) { /* tensor connections */
/* tensor connections */ if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUM); XLink::MakeLink(&a, &b, &c, MATH_SUM);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
} }
...@@ -271,9 +275,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) ...@@ -271,9 +275,9 @@ void Sum(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
else if (n >= 0 && n < a.order) { else if (n >= 0 && n < a.order) {
/* call _SumDim function */ /* call _SumDim function */
_SumDim(&a, &b, &c, n, beta); _SumDim(&a, &b, &c, n, beta);
if (c.enableGrad) { /* tensor connections */
/* tensor connections */ if (a.enableGrad && b.enableGrad) {
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM); XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
......
...@@ -181,9 +181,11 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta) ...@@ -181,9 +181,11 @@ XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
_SumDim(&a, &b, &c, n, beta); _SumDim(&a, &b, &c, n, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHeadInt(&c, n); XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHeadInt(&c, n);
XLink::AddParamToHead(&c, beta);
}
return c; return c;
} }
...@@ -210,7 +212,7 @@ void SumDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta) ...@@ -210,7 +212,7 @@ void SumDim(const XTensor &a, const XTensor &b, XTensor &c, int n, DTYPE beta)
/* call _SumDim function */ /* call _SumDim function */
_SumDim(&a, &b, &c, n, beta); _SumDim(&a, &b, &c, n, beta);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMDIM); XLink::MakeLink(&a, &b, &c, MATH_SUMDIM);
XLink::AddParamToHeadInt(&c, n); XLink::AddParamToHeadInt(&c, n);
...@@ -353,9 +355,11 @@ XTensor SumBroadcast(const XTensor &a, const XTensor &b, DTYPE beta) ...@@ -353,9 +355,11 @@ XTensor SumBroadcast(const XTensor &a, const XTensor &b, DTYPE beta)
_SumBroadcast(&a, &b, &c, beta); _SumBroadcast(&a, &b, &c, beta);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMBROADCAST); if (a.enableGrad && b.enableGrad) {
XLink::AddParamToHead(&c, beta); XLink::MakeLink(&a, &b, &c, MATH_SUMBROADCAST);
XLink::AddParamToHead(&c, beta);
}
return c; return c;
} }
...@@ -377,7 +381,7 @@ void SumBroadcast(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta) ...@@ -377,7 +381,7 @@ void SumBroadcast(const XTensor &a, const XTensor &b, XTensor &c, DTYPE beta)
/* call _SumBroadcast function */ /* call _SumBroadcast function */
_SumBroadcast(&a, &b, &c, beta); _SumBroadcast(&a, &b, &c, beta);
if (c.enableGrad) { if (a.enableGrad && b.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUMBROADCAST); XLink::MakeLink(&a, &b, &c, MATH_SUMBROADCAST);
XLink::AddParamToHead(&c, beta); XLink::AddParamToHead(&c, beta);
......
/* NiuTrans.Tensor - an open-source tensor library /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University. * Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved. * All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11 * $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/ */
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XName.h" #include "../../XName.h"
...@@ -121,7 +121,8 @@ XTensor ConvertDataType(const XTensor & input, TENSOR_DATA_TYPE dataType) ...@@ -121,7 +121,8 @@ XTensor ConvertDataType(const XTensor & input, TENSOR_DATA_TYPE dataType)
_ConvertDataType(&input, &output); _ConvertDataType(&input, &output);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE); if(input.enableGrad)
XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE);
return output; return output;
} }
...@@ -136,7 +137,7 @@ void ConvertDataType(const XTensor & input, XTensor & output, TENSOR_DATA_TYPE d ...@@ -136,7 +137,7 @@ void ConvertDataType(const XTensor & input, XTensor & output, TENSOR_DATA_TYPE d
_ConvertDataType(&input, &output); _ConvertDataType(&input, &output);
/* tensor connection */ /* tensor connection */
if (output.enableGrad) if (input.enableGrad)
XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE); XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE);
} }
......
...@@ -117,10 +117,12 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high) ...@@ -117,10 +117,12 @@ XTensor SelectRange(const XTensor &a, int dim, int low, int high)
_SelectRange(&a, &c, dim, low, high); _SelectRange(&a, &c, dim, low, high);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&a, NULL, &c, GETANDSET_SELECT); if (a.enableGrad) {
XLink::AddParamToHeadInt(&c, dim); XLink::MakeLink(&a, NULL, &c, GETANDSET_SELECT);
XLink::AddParamToHeadInt(&c, low); XLink::AddParamToHeadInt(&c, dim);
XLink::AddParamToHeadInt(&c, high); XLink::AddParamToHeadInt(&c, low);
XLink::AddParamToHeadInt(&c, high);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -167,7 +167,9 @@ XTensor funcName(const XTensor &a, T num) ...@@ -167,7 +167,9 @@ XTensor funcName(const XTensor &a, T num)
XTensor b(&a); \ XTensor b(&a); \
b.SetTMPFlag(); \ b.SetTMPFlag(); \
_funcName(&a, &b, num); \ _funcName(&a, &b, num); \
XLink::MakeLink(&a, NULL, &b, operationId); \ if(a.enableGrad){ \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
XLink::AddParamToHead(&b, num); \ XLink::AddParamToHead(&b, num); \
return b; \ return b; \
} \ } \
...@@ -183,7 +185,7 @@ void funcName(const XTensor &a, XTensor &b, T num) ...@@ -183,7 +185,7 @@ void funcName(const XTensor &a, XTensor &b, T num)
InitTensor(&b, &a); \ InitTensor(&b, &a); \
} \ } \
_funcName(&a, &b, num); \ _funcName(&a, &b, num); \
if (b.enableGrad) { \ if (a.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \ XLink::MakeLink(&a, NULL, &b, operationId); \
XLink::AddParamToHead(&b, num); \ XLink::AddParamToHead(&b, num); \
} \ } \
......
...@@ -67,7 +67,7 @@ keep the result in the input tensor a and return nothing ...@@ -67,7 +67,7 @@ keep the result in the input tensor a and return nothing
*/ */
void _ClipMe(XTensor * a, DTYPE lower, DTYPE upper) void _ClipMe(XTensor * a, DTYPE lower, DTYPE upper)
{ {
_Clip(a, a, lower, upper); _Clip(a, a, lower, upper);
} }
/* /*
...@@ -92,18 +92,20 @@ make a new tensor to keep the result and return it ...@@ -92,18 +92,20 @@ make a new tensor to keep the result and return it
*/ */
XTensor Clip(const XTensor & a, DTYPE lower, DTYPE upper) XTensor Clip(const XTensor & a, DTYPE lower, DTYPE upper)
{ {
XTensor b(&a); XTensor b(&a);
b.SetTMPFlag(); b.SetTMPFlag();
/* call _Clip function */ /* call _Clip function */
_Clip(&a, &b, lower, upper); _Clip(&a, &b, lower, upper);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_CLIP); if (a.enableGrad) {
XLink::AddParamToHead(&b, lower); XLink::MakeLink(&a, NULL, &b, MATH_CLIP);
XLink::AddParamToHead(&b, upper); XLink::AddParamToHead(&b, lower);
XLink::AddParamToHead(&b, upper);
}
return b; return b;
} }
void Clip(const XTensor & a, XTensor & b, DTYPE lower, DTYPE upper) void Clip(const XTensor & a, XTensor & b, DTYPE lower, DTYPE upper)
...@@ -115,8 +117,8 @@ void Clip(const XTensor & a, XTensor & b, DTYPE lower, DTYPE upper) ...@@ -115,8 +117,8 @@ void Clip(const XTensor & a, XTensor & b, DTYPE lower, DTYPE upper)
/* call _Clip function */ /* call _Clip function */
_Clip(&a, &b, lower, upper); _Clip(&a, &b, lower, upper);
if (b.enableGrad) { /* tensor connections */
/* tensor connections */ if (a.enableGrad) {
XLink::MakeLink(&a, NULL, &b, MATH_CLIP); XLink::MakeLink(&a, NULL, &b, MATH_CLIP);
XLink::AddParamToHead(&b, lower); XLink::AddParamToHead(&b, lower);
XLink::AddParamToHead(&b, upper); XLink::AddParamToHead(&b, upper);
......
...@@ -173,9 +173,11 @@ XTensor Normalize(const XTensor &input, int dim, ...@@ -173,9 +173,11 @@ XTensor Normalize(const XTensor &input, int dim,
list.Add((XTensor*)&var); list.Add((XTensor*)&var);
list.Add((XTensor*)&a); list.Add((XTensor*)&a);
list.Add((XTensor*)&b); list.Add((XTensor*)&b);
XLink::MakeLink(&list, &output, MATH_NORMALIZE); if (input.enableGrad) {
XLink::AddParamToHeadInt(&output, dim); XLink::MakeLink(&list, &output, MATH_NORMALIZE);
XLink::AddParamToHead(&output, epsilon); XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHead(&output, epsilon);
}
return output; return output;
} }
...@@ -208,7 +210,7 @@ void Normalize(const XTensor &input, XTensor &output, int dim, ...@@ -208,7 +210,7 @@ void Normalize(const XTensor &input, XTensor &output, int dim,
/* call _Normalize function */ /* call _Normalize function */
_Normalize(&input, &output, dim, &mean, &var, &a, &b, epsilon); _Normalize(&input, &output, dim, &mean, &var, &a, &b, epsilon);
if (output.enableGrad == true) { if (input.enableGrad == true) {
/* tensor connections */ /* tensor connections */
TensorList list(5); TensorList list(5);
list.Add((XTensor*)&input); list.Add((XTensor*)&input);
......
...@@ -126,9 +126,11 @@ XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift) ...@@ -126,9 +126,11 @@ XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift)
_ScaleAndShift(&a, &b, scale, shift); _ScaleAndShift(&a, &b, scale, shift);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT); if (a.enableGrad) {
XLink::AddParamToHead(&b, scale); XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT);
XLink::AddParamToHead(&b, shift); XLink::AddParamToHead(&b, scale);
XLink::AddParamToHead(&b, shift);
}
return b; return b;
} }
...@@ -152,7 +154,7 @@ void ScaleAndShift(const XTensor & a, XTensor & b, DTYPE scale, DTYPE shift) ...@@ -152,7 +154,7 @@ void ScaleAndShift(const XTensor & a, XTensor & b, DTYPE scale, DTYPE shift)
/* call _ScaleAndShift function */ /* call _ScaleAndShift function */
_ScaleAndShift(&a, &b, scale, shift); _ScaleAndShift(&a, &b, scale, shift);
if (b.enableGrad) { if (a.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT); XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT);
XLink::AddParamToHead(&b, scale); XLink::AddParamToHead(&b, scale);
......
...@@ -151,7 +151,9 @@ XTensor funcName(const XTensor & a) ...@@ -151,7 +151,9 @@ XTensor funcName(const XTensor & a)
XTensor b(&a); \ XTensor b(&a); \
b.SetTMPFlag(); \ b.SetTMPFlag(); \
_funcName(&a, &b); \ _funcName(&a, &b); \
XLink::MakeLink(&a, NULL, &b, operationId); \ if(a.enableGrad){ \
XLink::MakeLink(&a, NULL, &b, operationId); \
} \
return b; \ return b; \
} }
...@@ -162,7 +164,7 @@ void funcName(const XTensor & a, XTensor & b) ...@@ -162,7 +164,7 @@ void funcName(const XTensor & a, XTensor & b)
InitTensor(&b, &a); \ InitTensor(&b, &a); \
} \ } \
_funcName(&a, &b); \ _funcName(&a, &b); \
if (b.enableGrad) { \ if (a.enableGrad) { \
XLink::MakeLink(&a, NULL, &b, operationId); \ XLink::MakeLink(&a, NULL, &b, operationId); \
} \ } \
} }
......
...@@ -258,10 +258,12 @@ XTensor CopyIndexed(const XTensor & s, int dim, ...@@ -258,10 +258,12 @@ XTensor CopyIndexed(const XTensor & s, int dim,
list.Add((XTensor*)&tgtIndex); list.Add((XTensor*)&tgtIndex);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&list, &t, MOVEMENT_COPYINDEXED); if (s.enableGrad) {
XLink::AddParamToHeadInt(&t, dim); XLink::MakeLink(&list, &t, MOVEMENT_COPYINDEXED);
XLink::AddParamToHeadInt(&t, copyNum); XLink::AddParamToHeadInt(&t, dim);
XLink::AddParamToHeadInt(&t, copyNum);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -314,13 +316,15 @@ XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, in ...@@ -314,13 +316,15 @@ XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, in
memcpy(saveTgtIndex, tgtIndex, indexSize * sizeof(int)); memcpy(saveTgtIndex, tgtIndex, indexSize * sizeof(int));
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYINDEXED); if (s.enableGrad) {
XLink::AddParamToHeadInt(&t, dim); XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYINDEXED);
XLink::AddParamToHeadPointer(&t, saveSrcIndex); XLink::AddParamToHeadInt(&t, dim);
XLink::AddParamToHeadInt(&t, indexSize); XLink::AddParamToHeadPointer(&t, saveSrcIndex);
XLink::AddParamToHeadPointer(&t, saveTgtIndex); XLink::AddParamToHeadInt(&t, indexSize);
XLink::AddParamToHeadInt(&t, copyNum); XLink::AddParamToHeadPointer(&t, saveTgtIndex);
XLink::AddParamToHeadInt(&t, copyNum);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -134,7 +134,9 @@ XTensor CopyValues(const XTensor &s, XStream * stream) ...@@ -134,7 +134,9 @@ XTensor CopyValues(const XTensor &s, XStream * stream)
_CopyValues(&s, &t, stream); _CopyValues(&s, &t, stream);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYVALUES); if (s.enableGrad) {
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYVALUES);
}
return t; return t;
} }
......
...@@ -93,9 +93,11 @@ XTensor Gather(XTensor &s, XTensor &index) ...@@ -93,9 +93,11 @@ XTensor Gather(XTensor &s, XTensor &index)
_Gather(&s, &t, &index); _Gather(&s, &t, &index);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&s, &index, &t, MOVEMENT_GATHER); if (s.enableGrad) {
XLink::MakeLink(&s, &index, &t, MOVEMENT_GATHER);
}
return t; return t;
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
...@@ -181,8 +181,10 @@ XTensor ReduceMax(const XTensor &input, int dim) ...@@ -181,8 +181,10 @@ XTensor ReduceMax(const XTensor &input, int dim)
_ReduceMax(&input, &output, dim); _ReduceMax(&input, &output, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX); if (input.enableGrad) {
XLink::AddParamToHeadInt(&output, dim); XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX);
XLink::AddParamToHeadInt(&output, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -221,7 +223,7 @@ void ReduceMax(const XTensor &input, XTensor &output, int dim) ...@@ -221,7 +223,7 @@ void ReduceMax(const XTensor &input, XTensor &output, int dim)
/* call _ReduceMax function */ /* call _ReduceMax function */
_ReduceMax(&input, &output, dim); _ReduceMax(&input, &output, dim);
if (output.enableGrad) { if (input.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX); XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX);
XLink::AddParamToHeadInt(&output, dim); XLink::AddParamToHeadInt(&output, dim);
......
...@@ -77,8 +77,10 @@ XTensor ReduceMean(const XTensor &input, int dim) ...@@ -77,8 +77,10 @@ XTensor ReduceMean(const XTensor &input, int dim)
_ReduceMean(&input, &output, dim); _ReduceMean(&input, &output, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMEAN); if (input.enableGrad) {
XLink::AddParamToHeadInt(&output, dim); XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMEAN);
XLink::AddParamToHeadInt(&output, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -119,7 +121,7 @@ void ReduceMean(const XTensor &input, XTensor &output, int dim) ...@@ -119,7 +121,7 @@ void ReduceMean(const XTensor &input, XTensor &output, int dim)
/* call _ReduceMean function */ /* call _ReduceMean function */
_ReduceMean(&input, &output, dim); _ReduceMean(&input, &output, dim);
if (output.enableGrad) { if (input.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMEAN); XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMEAN);
XLink::AddParamToHeadInt(&output, dim); XLink::AddParamToHeadInt(&output, dim);
......
...@@ -306,10 +306,12 @@ XTensor ReduceSum(const XTensor &input, int dim, const XTensor &shift, DTYPE pow ...@@ -306,10 +306,12 @@ XTensor ReduceSum(const XTensor &input, int dim, const XTensor &shift, DTYPE pow
_ReduceSum(&input, &output, dim, &shift, power, isExp); _ReduceSum(&input, &output, dim, &shift, power, isExp);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUM); if (input.enableGrad) {
XLink::AddParamToHeadInt(&output, dim); XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUM);
XLink::AddParamToHead(&output, power); XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHeadBool(&output, isExp); XLink::AddParamToHead(&output, power);
XLink::AddParamToHeadBool(&output, isExp);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -341,7 +343,7 @@ void ReduceSum(const XTensor &input, XTensor &output, int dim, const XTensor &sh ...@@ -341,7 +343,7 @@ void ReduceSum(const XTensor &input, XTensor &output, int dim, const XTensor &sh
/* call _ReduceSum function */ /* call _ReduceSum function */
_ReduceSum(&input, &output, dim, &shift, power, isExp); _ReduceSum(&input, &output, dim, &shift, power, isExp);
if (output.enableGrad) { if (input.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUM); XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUM);
XLink::AddParamToHeadInt(&output, dim); XLink::AddParamToHeadInt(&output, dim);
...@@ -385,10 +387,12 @@ XTensor ReduceSum(const XTensor &input, int dim, DTYPE power, bool isExp) ...@@ -385,10 +387,12 @@ XTensor ReduceSum(const XTensor &input, int dim, DTYPE power, bool isExp)
_ReduceSum(&input, &output, dim, NULL, power, isExp); _ReduceSum(&input, &output, dim, NULL, power, isExp);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCESUM); if (input.enableGrad) {
XLink::AddParamToHeadInt(&output, dim); XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCESUM);
XLink::AddParamToHead(&output, power); XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHeadBool(&output, isExp); XLink::AddParamToHead(&output, power);
XLink::AddParamToHeadBool(&output, isExp);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -434,7 +438,7 @@ void ReduceSum(const XTensor &input, XTensor &output, int dim, DTYPE power, bool ...@@ -434,7 +438,7 @@ void ReduceSum(const XTensor &input, XTensor &output, int dim, DTYPE power, bool
/* call _ReduceSum function */ /* call _ReduceSum function */
_ReduceSum(&input, &output, dim, NULL, power, isExp); _ReduceSum(&input, &output, dim, NULL, power, isExp);
if (output.enableGrad) { if (input.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCESUM); XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCESUM);
XLink::AddParamToHeadInt(&output, dim); XLink::AddParamToHeadInt(&output, dim);
......
...@@ -73,8 +73,10 @@ XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift) ...@@ -73,8 +73,10 @@ XTensor ReduceSumSquared(const XTensor &input, int dim, const XTensor &shift)
_ReduceSumSquared(&input, &output, dim, &shift); _ReduceSumSquared(&input, &output, dim, &shift);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUMSQUARED); if (input.enableGrad) {
XLink::AddParamToHeadInt(&output, dim); XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUMSQUARED);
XLink::AddParamToHeadInt(&output, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -116,7 +118,7 @@ void ReduceSumSquared(const XTensor &input, XTensor &output, int dim, const XTen ...@@ -116,7 +118,7 @@ void ReduceSumSquared(const XTensor &input, XTensor &output, int dim, const XTen
/* call _ReduceSumSquared function */ /* call _ReduceSumSquared function */
_ReduceSumSquared(&input, &output, dim, &shift); _ReduceSumSquared(&input, &output, dim, &shift);
if (output.enableGrad) { if (input.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUMSQUARED); XLink::MakeLink(&input, &shift, &output, REDUCE_REDUCESUMSQUARED);
XLink::AddParamToHeadInt(&output, dim); XLink::AddParamToHeadInt(&output, dim);
......
...@@ -76,8 +76,10 @@ XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean) ...@@ -76,8 +76,10 @@ XTensor ReduceVariance(const XTensor &input, int dim, const XTensor &mean)
_ReduceVariance(&input, &output, dim, &mean); _ReduceVariance(&input, &output, dim, &mean);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, &mean, &output, REDUCE_REDUCEVARIANCE); if (input.enableGrad) {
XLink::AddParamToHeadInt(&output, dim); XLink::MakeLink(&input, &mean, &output, REDUCE_REDUCEVARIANCE);
XLink::AddParamToHeadInt(&output, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -119,7 +121,7 @@ void ReduceVariance(const XTensor &input, XTensor &output, int dim, const XTenso ...@@ -119,7 +121,7 @@ void ReduceVariance(const XTensor &input, XTensor &output, int dim, const XTenso
/* call _ReduceVariance function */ /* call _ReduceVariance function */
_ReduceVariance(&input, &output, dim, &mean); _ReduceVariance(&input, &output, dim, &mean);
if (output.enableGrad) { if (input.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&input, &mean, &output, REDUCE_REDUCEVARIANCE); XLink::MakeLink(&input, &mean, &output, REDUCE_REDUCEVARIANCE);
XLink::AddParamToHeadInt(&output, dim); XLink::AddParamToHeadInt(&output, dim);
......
...@@ -99,9 +99,11 @@ XTensor Concatenate(const TensorList &smalls, int dim) ...@@ -99,9 +99,11 @@ XTensor Concatenate(const TensorList &smalls, int dim)
_Merge(&smalls, &big, dim); _Merge(&smalls, &big, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE); if (tensor->enableGrad) {
XLink::AddParamToHeadInt(&big, dim); XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -127,8 +129,10 @@ XTensor Concatenate(const TensorList &smalls, int dim) ...@@ -127,8 +129,10 @@ XTensor Concatenate(const TensorList &smalls, int dim)
_ConcatenateSolely(&smalls, &big, dim); _ConcatenateSolely(&smalls, &big, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE); if (tensor->enableGrad) {
XLink::AddParamToHeadInt(&big, dim); XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -309,9 +313,11 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim) ...@@ -309,9 +313,11 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
_Merge(&smalls, &big, dim); _Merge(&smalls, &big, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE); if (tensor->enableGrad) {
XLink::AddParamToHeadInt(&big, dim); XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -337,8 +343,10 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim) ...@@ -337,8 +343,10 @@ XTensor Concatenate(const XTensor &smallA, const XTensor &smallB, int dim)
_ConcatenateSolely(&smalls, &big, dim); _ConcatenateSolely(&smalls, &big, dim);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE); if (tensor->enableGrad) {
XLink::AddParamToHeadInt(&big, dim); XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -222,9 +222,11 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim) ...@@ -222,9 +222,11 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim)
_Merge(&s, &t, whereToMerge, leadingDim); _Merge(&s, &t, whereToMerge, leadingDim);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE); if (s.enableGrad) {
XLink::AddParamToHeadInt(&t, whereToMerge); XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
XLink::AddParamToHeadInt(&t, leadingDim); XLink::AddParamToHeadInt(&t, whereToMerge);
XLink::AddParamToHeadInt(&t, leadingDim);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -261,7 +263,7 @@ void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim) ...@@ -261,7 +263,7 @@ void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim)
/* call _Merge function */ /* call _Merge function */
_Merge(&s, &t, whereToMerge, leadingDim); _Merge(&s, &t, whereToMerge, leadingDim);
if (t.enableGrad) { if (s.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE); XLink::MakeLink(&s, NULL, &t, SHAPE_MERGE);
XLink::AddParamToHeadInt(&t, whereToMerge); XLink::AddParamToHeadInt(&t, whereToMerge);
...@@ -412,8 +414,10 @@ XTensor Merge(const TensorList &smalls, int whereToMerge) ...@@ -412,8 +414,10 @@ XTensor Merge(const TensorList &smalls, int whereToMerge)
_Merge(&smalls, &big, whereToMerge); _Merge(&smalls, &big, whereToMerge);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST); if (tensor->enableGrad) {
XLink::AddParamToHeadInt(&big, whereToMerge); XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
XLink::AddParamToHeadInt(&big, whereToMerge);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -453,8 +457,10 @@ XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge) ...@@ -453,8 +457,10 @@ XTensor Merge(const XTensor &smallA, const XTensor &smallB, int whereToMerge)
_Merge(&smalls, &big, whereToMerge); _Merge(&smalls, &big, whereToMerge);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST); if (smallA.enableGrad) {
XLink::AddParamToHeadInt(&big, whereToMerge); XLink::MakeLink(&smalls, &big, SHAPE_MERGE_LIST);
XLink::AddParamToHeadInt(&big, whereToMerge);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -43,7 +43,9 @@ XTensor Reshape(XTensor &s, int order, int * dimSize) ...@@ -43,7 +43,9 @@ XTensor Reshape(XTensor &s, int order, int * dimSize)
t.Reshape(order, dimSize); t.Reshape(order, dimSize);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE); if (s.enableGrad) {
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE);
}
return t; return t;
} }
...@@ -57,7 +59,7 @@ void Reshape(XTensor &s, XTensor &t, int order, int * dimSize) ...@@ -57,7 +59,7 @@ void Reshape(XTensor &s, XTensor &t, int order, int * dimSize)
/* call Reshape function */ /* call Reshape function */
t.Reshape(order, dimSize); t.Reshape(order, dimSize);
if (t.enableGrad) { if (s.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE); XLink::MakeLink(&s, NULL, &t, SHAPE_RESHAPE);
} }
......
...@@ -217,9 +217,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum) ...@@ -217,9 +217,11 @@ XTensor Split(const XTensor &s, int whereToSplit, int splitNum)
_Split(&s, &t, whereToSplit, splitNum); _Split(&s, &t, whereToSplit, splitNum);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT); if (s.enableGrad) {
XLink::AddParamToHeadInt(&t, whereToSplit); XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT);
XLink::AddParamToHeadInt(&t, splitNum); XLink::AddParamToHeadInt(&t, whereToSplit);
XLink::AddParamToHeadInt(&t, splitNum);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -251,7 +253,7 @@ void Split(const XTensor &s, XTensor &t, int whereToSplit, int splitNum) ...@@ -251,7 +253,7 @@ void Split(const XTensor &s, XTensor &t, int whereToSplit, int splitNum)
/* call _Split function */ /* call _Split function */
_Split(&s, &t, whereToSplit, splitNum); _Split(&s, &t, whereToSplit, splitNum);
if (t.enableGrad) { if (s.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT); XLink::MakeLink(&s, NULL, &t, SHAPE_SPLIT);
XLink::AddParamToHeadInt(&t, whereToSplit); XLink::AddParamToHeadInt(&t, whereToSplit);
...@@ -409,12 +411,15 @@ void Split(const XTensor &big, TensorList &smalls, int whereToSplit, int splitNu ...@@ -409,12 +411,15 @@ void Split(const XTensor &big, TensorList &smalls, int whereToSplit, int splitNu
/* tensor connections */ /* tensor connections */
for(int i = 0; i < smalls.count; i++){ for(int i = 0; i < smalls.count; i++){
XTensor * s = (XTensor*)smalls.Get(i); XTensor * s = (XTensor*)smalls.Get(i);
XLink::MakeLink(&big, NULL, s, SHAPE_SPLIT_LIST);
XLink::AddParamToHeadInt(s, whereToSplit);
/* it is tricky here that we keep the id of each if (s->enableGrad) {
block, rather than the total number of the splits */ XLink::MakeLink(&big, NULL, s, SHAPE_SPLIT_LIST);
XLink::AddParamToHeadInt(s, i); XLink::AddParamToHeadInt(s, whereToSplit);
/* it is tricky here that we keep the id of each
block, rather than the total number of the splits */
XLink::AddParamToHeadInt(s, i);
}
} }
} }
......
...@@ -121,7 +121,9 @@ XTensor Squeeze(XTensor & source, int leadingDim) ...@@ -121,7 +121,9 @@ XTensor Squeeze(XTensor & source, int leadingDim)
_Squeeze(&source, &target, leadingDim); _Squeeze(&source, &target, leadingDim);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE); if (source.enableGrad) {
XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE);
}
return target; return target;
} }
...@@ -135,7 +137,7 @@ void Squeeze(XTensor & source, XTensor & target, int leadingDim) ...@@ -135,7 +137,7 @@ void Squeeze(XTensor & source, XTensor & target, int leadingDim)
/* call _Squeeze function */ /* call _Squeeze function */
_Squeeze(&source, &target, leadingDim); _Squeeze(&source, &target, leadingDim);
if (target.enableGrad) { if (source.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE); XLink::MakeLink(&source, NULL, &target, SHAPE_SQUEEZE);
} }
......
...@@ -144,9 +144,11 @@ XTensor Transpose(const XTensor &a, const int i, const int j) ...@@ -144,9 +144,11 @@ XTensor Transpose(const XTensor &a, const int i, const int j)
_Transpose(&a, &b, i, j); _Transpose(&a, &b, i, j);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&a, NULL, &b, SHAPE_TRANSPOSE); if (a.enableGrad) {
XLink::AddParamToHeadInt(&b, i); XLink::MakeLink(&a, NULL, &b, SHAPE_TRANSPOSE);
XLink::AddParamToHeadInt(&b, j); XLink::AddParamToHeadInt(&b, i);
XLink::AddParamToHeadInt(&b, j);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
......
...@@ -156,9 +156,11 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize) ...@@ -156,9 +156,11 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize)
_Unsqueeze(&a, &b, dim, dSize); _Unsqueeze(&a, &b, dim, dSize);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE); if (a.enableGrad) {
XLink::AddParamToHeadInt(&b, dim); XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE);
XLink::AddParamToHeadInt(&b, dSize); XLink::AddParamToHeadInt(&b, dim);
XLink::AddParamToHeadInt(&b, dSize);
}
/* destroy variables */ /* destroy variables */
delete[] dimSize; delete[] dimSize;
...@@ -191,7 +193,7 @@ void Unsqueeze(const XTensor &a, XTensor &b, int dim, int dSize) ...@@ -191,7 +193,7 @@ void Unsqueeze(const XTensor &a, XTensor &b, int dim, int dSize)
/* call _Unsqueeze function */ /* call _Unsqueeze function */
_Unsqueeze(&a, &b, dim, dSize); _Unsqueeze(&a, &b, dim, dSize);
if (b.enableGrad) { if (a.enableGrad) {
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE); XLink::MakeLink(&a, NULL, &b, SHAPE_UNSQUEEZE);
XLink::AddParamToHeadInt(&b, dim); XLink::AddParamToHeadInt(&b, dim);
......
...@@ -81,8 +81,10 @@ XTensor DropoutWithIndex(const XTensor &x, XTensor &maskIndex, DTYPE scale) ...@@ -81,8 +81,10 @@ XTensor DropoutWithIndex(const XTensor &x, XTensor &maskIndex, DTYPE scale)
_ScaleAndShiftMe(&c, scale); _ScaleAndShiftMe(&c, scale);
/* tensor connections */ /* tensor connections */
XLink::MakeLink(&x, &maskIndex, &c, MOVEMENT_DROPOUTWITHINDEX); if (x.enableGrad) {
XLink::AddParamToHead(&c, scale); XLink::MakeLink(&x, &maskIndex, &c, MOVEMENT_DROPOUTWITHINDEX);
XLink::AddParamToHead(&c, scale);
}
return c; return c;
} }
......
...@@ -78,7 +78,9 @@ XTensor HardTanH(const XTensor &x) ...@@ -78,7 +78,9 @@ XTensor HardTanH(const XTensor &x)
_HardTanH(&x, &y); _HardTanH(&x, &y);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_HARDTANH); if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_HARDTANH);
}
return y; return y;
} }
...@@ -92,7 +94,7 @@ void HardTanH(const XTensor &x, XTensor &y) ...@@ -92,7 +94,7 @@ void HardTanH(const XTensor &x, XTensor &y)
/* call _HardTanH function */ /* call _HardTanH function */
_HardTanH(&x, &y); _HardTanH(&x, &y);
if (y.enableGrad) { if (x.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_HARDTANH); XLink::MakeLink(&x, NULL, &y, FUNC_HARDTANH);
} }
......
...@@ -54,7 +54,9 @@ XTensor Identity(const XTensor &x) ...@@ -54,7 +54,9 @@ XTensor Identity(const XTensor &x)
_Identity(&x, &y); _Identity(&x, &y);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY); if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY);
}
return y; return y;
} }
...@@ -68,7 +70,7 @@ void Identity(const XTensor &x, XTensor &y) ...@@ -68,7 +70,7 @@ void Identity(const XTensor &x, XTensor &y)
/* call _Identity function */ /* call _Identity function */
_Identity(&x, &y); _Identity(&x, &y);
if (y.enableGrad) { if (x.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY); XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY);
} }
......
...@@ -188,8 +188,10 @@ XTensor LogSoftmax(const XTensor &x, int leadDim) ...@@ -188,8 +188,10 @@ XTensor LogSoftmax(const XTensor &x, int leadDim)
_LogSoftmax(&x, &y, ld); _LogSoftmax(&x, &y, ld);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX); if (x.enableGrad) {
XLink::AddParamToHeadInt(&y, ld); XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
}
return y; return y;
} }
...@@ -215,7 +217,7 @@ void LogSoftmax(const XTensor &x, XTensor &y, int leadDim) ...@@ -215,7 +217,7 @@ void LogSoftmax(const XTensor &x, XTensor &y, int leadDim)
/* call _LogSoftmax function */ /* call _LogSoftmax function */
_LogSoftmax(&x, &y, ld); _LogSoftmax(&x, &y, ld);
if (y.enableGrad) { if (x.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX); XLink::MakeLink(&x, NULL, &y, FUNC_LOGSOFTMAX);
XLink::AddParamToHeadInt(&y, ld); XLink::AddParamToHeadInt(&y, ld);
......
...@@ -70,7 +70,9 @@ XTensor Rectify(const XTensor &x) ...@@ -70,7 +70,9 @@ XTensor Rectify(const XTensor &x)
_Rectify(&x, &y); _Rectify(&x, &y);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_RECTIFY); if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_RECTIFY);
}
return y; return y;
} }
...@@ -84,7 +86,7 @@ void Rectify(const XTensor &x, XTensor &y) ...@@ -84,7 +86,7 @@ void Rectify(const XTensor &x, XTensor &y)
/* call _Rectify function */ /* call _Rectify function */
_Rectify(&x, &y); _Rectify(&x, &y);
if (y.enableGrad) { if (x.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_RECTIFY); XLink::MakeLink(&x, NULL, &y, FUNC_RECTIFY);
} }
......
...@@ -73,7 +73,9 @@ XTensor Sigmoid(const XTensor &x) ...@@ -73,7 +73,9 @@ XTensor Sigmoid(const XTensor &x)
_Sigmoid(&x, &y); _Sigmoid(&x, &y);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SIGMOID); if (x.enableGrad) {
XLink::MakeLink(&x, NULL, &y, FUNC_SIGMOID);
}
return y; return y;
} }
...@@ -87,7 +89,7 @@ void Sigmoid(const XTensor &x, XTensor &y) ...@@ -87,7 +89,7 @@ void Sigmoid(const XTensor &x, XTensor &y)
/* call _Sigmoid function */ /* call _Sigmoid function */
_Sigmoid(&x, &y); _Sigmoid(&x, &y);
if (y.enableGrad) { if (x.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SIGMOID); XLink::MakeLink(&x, NULL, &y, FUNC_SIGMOID);
} }
......
...@@ -142,8 +142,10 @@ XTensor Softmax(const XTensor &x, int leadDim) ...@@ -142,8 +142,10 @@ XTensor Softmax(const XTensor &x, int leadDim)
_Softmax(&x, &y, ld); _Softmax(&x, &y, ld);
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX); if (x.enableGrad) {
XLink::AddParamToHeadInt(&y, ld); XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX);
XLink::AddParamToHeadInt(&y, ld);
}
return y; return y;
} }
...@@ -161,7 +163,7 @@ void Softmax(const XTensor &x, XTensor &y, int leadDim) ...@@ -161,7 +163,7 @@ void Softmax(const XTensor &x, XTensor &y, int leadDim)
/* call _Softmax function */ /* call _Softmax function */
_Softmax(&x, &y, ld); _Softmax(&x, &y, ld);
if (y.enableGrad) { if (x.enableGrad) {
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX); XLink::MakeLink(&x, NULL, &y, FUNC_SOFTMAX);
XLink::AddParamToHeadInt(&y, ld); XLink::AddParamToHeadInt(&y, ld);
......
...@@ -277,8 +277,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold, ...@@ -277,8 +277,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
tails.Add((XTensor*)&gold); tails.Add((XTensor*)&gold);
tails.Add(weight); tails.Add(weight);
tails.Add(padding); tails.Add(padding);
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim); if (output.enableGrad) {
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim);
}
return loss; return loss;
} }
...@@ -302,8 +305,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold, ...@@ -302,8 +305,11 @@ XTensor CrossEntropy(const XTensor & output, const XTensor & gold,
tails.Add((XTensor*)&gold); tails.Add((XTensor*)&gold);
tails.Add(weight); tails.Add(weight);
tails.Add((XTensor*)&padding); tails.Add((XTensor*)&padding);
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim); if (output.enableGrad) {
XLink::MakeLink(&tails, &loss, LOSS_CROSSENTROPY);
XLink::AddParamToHeadInt(&loss, dim);
}
return loss; return loss;
} }
......
...@@ -421,7 +421,7 @@ bool TestSetData6() ...@@ -421,7 +421,7 @@ bool TestSetData6()
for (int i = 0; i < order; i++) for (int i = 0; i < order; i++)
unitNum *= dimSize[i]; unitNum *= dimSize[i];
DTYPE answer[5] = { 5.2, 3.2, 1.2, -0.8, -2.8 }; DTYPE answer[5] = {5.2F, 3.2F, 1.2F, -0.8F, -2.8F};
/* CPU test */ /* CPU test */
bool cpuTest = true; bool cpuTest = true;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论