Commit 1f5627f9 by huchi

refactor XList: now a list has the same basic function as a STL vector

parent 72c9551b
...@@ -53,6 +53,9 @@ void TestDataManager() { ...@@ -53,6 +53,9 @@ void TestDataManager() {
s[9] = 0; s[9] = 0;
str.Add(s); str.Add(s);
cout << str.Get(0); cout << str.Get(0);
vector<int> x;
} }
int main() int main()
......
...@@ -25,21 +25,25 @@ ...@@ -25,21 +25,25 @@
#include "XMem.h" #include "XMem.h"
#include "XGlobal.h" #include "XGlobal.h"
#include <utility>
#ifndef __XLIST_H__ #ifndef __XLIST_H__
#define __XLIST_H__ #define __XLIST_H__
/* the nts (NiuTrans.Tensor) namespace */ /* the nts (NiuTrans.Tensor) namespace */
namespace nts { namespace nts {
/* the XListType class */ /* the XListBase class */
template <typename T> template <typename T>
class XListType { struct XListBase {
public: public:
typedef int (*ListCompare)(const T* item1, const T* item2);
/* compare function */
typedef int (*ListCompare)(const T item1, const T item2);
/* data items */ /* data items */
T** items; T *items;
/* number of items */ /* number of items */
int count; int count;
...@@ -50,42 +54,48 @@ public: ...@@ -50,42 +54,48 @@ public:
/* the memory pool for data array allocation */ /* the memory pool for data array allocation */
XMem* mem; XMem* mem;
/* indicates whether data items are integers */
bool isIntList;
public: public:
/* constructor */ /* constructor */
XListType(); XListBase();
/* constructor */ /* constructor */
XListType(int myMaxNum, bool isIntListOrNot = false); XListBase(int myMaxNum);
/* constructor */ /* constructor */
XListType(int myMaxNum, XMem* myMem, bool isIntListOrNot = false); XListBase(int myMaxNum, XMem* myMem);
/* de-constructor */ /* de-constructor */
~XListType(); ~XListBase();
/* add an item into the list */ /* add an item into the list */
void Add(T* item); void Add(T&& item);
/* add an item into the list */
void Add(const T& item);
/* add a number of items into the list */ /* add a number of items into the list */
void Add(T** inputItems, int inputItemCount); void Add(T* inputItems, int inputItemCount);
/* append a list to the current list */ /* append a list to the current list */
void AddList(XListType* l); void AddList(XListBase* l);
/* insert an item to the given position of the list */
void Insert(int pos, const T& item);
/* insert an item to the given position of the list */ /* insert an item to the given position of the list */
void Insert(int pos, T* item); void Insert(int pos, T&& item);
/* get the item at position i */ /* get the item at position i */
T* GetItem(int i) const; T& GetItem(int i) const;
/* set the item at position i */ /* set the item at position i */
void SetItem(int i, T* item); void SetItem(int i, const T& item);
/* set the item at position i */
void SetItem(int i, T&& item);
/* find the position of the first matched item */ /* find the position of the first matched item */
int FindFirst(T* item); int FindFirst(const T& item);
/* clear the data array */ /* clear the data array */
void Clear(); void Clear();
...@@ -100,25 +110,27 @@ public: ...@@ -100,25 +110,27 @@ public:
void Remove(int i); void Remove(int i);
/* copy the list */ /* copy the list */
XListType* Copy(XMem* myMem); XListBase* Copy(XMem* myMem);
/* shuffle the list */ /* shuffle the list */
void Shuffle(int nround = 10, int beg = -1, int len = 0); void Shuffle(int nround = 10, int beg = -1, int len = 0);
/* short */ /* short */
T* Get(int i) { return GetItem(i); }; T& operator[] (int i) {
void Set(int i, T* item) { SetItem(i, item); }; return GetItem(i);
};
T& Get(int i) { return GetItem(i); };
void Set(int i, T item) { SetItem(i, item); };
}; };
/* constructor */ /* constructor */
template <typename T> template <typename T>
XListType<T>::XListType() XListBase<T>::XListBase()
{ {
mem = NULL; mem = NULL;
maxNum = 0; maxNum = 0;
count = 0; count = 0;
items = NULL; items = NULL;
isIntList = false;
} }
/* /*
...@@ -127,13 +139,12 @@ constructor ...@@ -127,13 +139,12 @@ constructor
>> isIntListOrNot - specify if the list keeps int items >> isIntListOrNot - specify if the list keeps int items
*/ */
template <typename T> template <typename T>
XListType<T>::XListType(int myMaxNum, bool isIntListOrNot) XListBase<T>::XListBase(int myMaxNum)
{ {
mem = NULL; mem = NULL;
maxNum = myMaxNum; maxNum = myMaxNum;
count = 0; count = 0;
items = new T*[myMaxNum]; items = new T[myMaxNum];
isIntList = isIntListOrNot;
} }
/* /*
...@@ -143,44 +154,36 @@ constructor ...@@ -143,44 +154,36 @@ constructor
>> isIntListOrNot - specify if the list keeps int items >> isIntListOrNot - specify if the list keeps int items
*/ */
template <typename T> template <typename T>
XListType<T>::XListType(int myMaxNum, XMem* myMem, bool isIntListOrNot) XListBase<T>::XListBase(int myMaxNum, XMem* myMem)
{ {
mem = myMem; mem = myMem;
maxNum = myMaxNum; maxNum = myMaxNum;
count = 0; count = 0;
items = (T**)mem->Alloc(mem->devID, sizeof(T*) * maxNum); items = (T*)mem->Alloc(mem->devID, sizeof(T) * maxNum);
isIntList = isIntListOrNot;
} }
/* de-constructor */ /* de-constructor */
template <typename T> template <typename T>
XListType<T>::~XListType() XListBase<T>::~XListBase()
{ {
if (isIntList) { delete[] items;
for (int i = 0; i < count; i++) {
int* p = (int*)items[i];
delete[] p;
}
}
if (mem == NULL)
delete[] items;
} }
/* /*
add an item into the list add an item into the list
>> item - pointer to the item >> item - a right value
*/ */
template <typename T> template <typename T>
void XListType<T>::Add(T* item) void XListBase<T>::Add(T&& item)
{ {
if (count == maxNum) { if (count == maxNum) {
T** newItems; T* newItems;
if (mem == NULL) if (mem == NULL)
newItems = new T*[maxNum * 2 + 1]; newItems = new T[maxNum * 2 + 1];
else else
newItems = (T**)mem->Alloc(mem->devID, sizeof(T*) * (maxNum * 2 + 1)); newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * (maxNum * 2 + 1));
memcpy(newItems, items, sizeof(T*) * maxNum); memcpy(newItems, items, sizeof(T) * maxNum);
if (mem == NULL) if (mem == NULL)
delete[] items; delete[] items;
items = newItems; items = newItems;
...@@ -190,28 +193,51 @@ void XListType<T>::Add(T* item) ...@@ -190,28 +193,51 @@ void XListType<T>::Add(T* item)
items[count++] = item; items[count++] = item;
} }
/*
add an item into the list
>> item - a const reference to the item
*/
template <typename T>
void XListBase<T>::Add(const T& item)
{
if (count == maxNum) {
T* newItems;
if (mem == NULL)
newItems = new T[maxNum * 2 + 1];
else
newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * (maxNum * 2 + 1));
memcpy(newItems, items, sizeof(T) * maxNum);
if (mem == NULL)
delete[] items;
items = newItems;
maxNum = maxNum * 2 + 1;
}
items[count++] = item;
}
/* /*
add a number of items into the list add a number of items into the list
>> inputItems - pointer to the array of items >> inputItems - pointer to the array of items
>> inputItemCount - number of input items >> inputItemCount - number of input items
*/ */
template <typename T> template <typename T>
void XListType<T>::Add(T** inputItems, int inputItemCount) void XListBase<T>::Add(T* inputItems, int inputItemCount)
{ {
if (count + inputItemCount >= maxNum) { if (count + inputItemCount >= maxNum) {
int newMaxNum = (count + inputItemCount) * 2 + 1; int newMaxNum = (count + inputItemCount) * 2 + 1;
T** newItems; T* newItems;
if (mem == NULL) if (mem == NULL)
newItems = new T*[newMaxNum]; newItems = new T[newMaxNum];
else else
newItems = (T**)mem->Alloc(mem->devID, sizeof(T*) * newMaxNum); newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * newMaxNum);
memcpy(newItems, items, sizeof(T*) * maxNum); memcpy(newItems, items, sizeof(T) * maxNum);
if (mem == NULL) if (mem == NULL)
delete[] items; delete[] items;
items = newItems; items = newItems;
maxNum = newMaxNum; maxNum = newMaxNum;
} }
memcpy(items + count, inputItems, sizeof(T*) * inputItemCount); memcpy(items + count, inputItems, sizeof(T) * inputItemCount);
count += inputItemCount; count += inputItemCount;
} }
...@@ -220,7 +246,7 @@ append a list to the current list ...@@ -220,7 +246,7 @@ append a list to the current list
>> l - the list we use to append >> l - the list we use to append
*/ */
template <typename T> template <typename T>
void XListType<T>::AddList(XListType* l) void XListBase<T>::AddList(XListBase* l)
{ {
Add(l->items, l->count); Add(l->items, l->count);
} }
...@@ -231,15 +257,15 @@ insert an item to the given position of the list ...@@ -231,15 +257,15 @@ insert an item to the given position of the list
>> item - the item for insertion >> item - the item for insertion
*/ */
template <typename T> template <typename T>
void XListType<T>::Insert(int pos, T* item) void XListBase<T>::Insert(int pos, const T& item)
{ {
if (count == maxNum) { if (count == maxNum) {
T** newItems; T* newItems;
if (mem == NULL) if (mem == NULL)
newItems = new T*[maxNum * 2 + 1]; newItems = new T[maxNum * 2 + 1];
else else
newItems = (T**)mem->Alloc(mem->devID, sizeof(T*) * (maxNum * 2 + 1)); newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * (maxNum * 2 + 1));
memcpy(newItems, items, sizeof(T*) * maxNum); memcpy(newItems, items, sizeof(T) * maxNum);
if (mem == NULL) if (mem == NULL)
delete[] items; delete[] items;
items = newItems; items = newItems;
...@@ -252,9 +278,31 @@ void XListType<T>::Insert(int pos, T* item) ...@@ -252,9 +278,31 @@ void XListType<T>::Insert(int pos, T* item)
count++; count++;
} }
template<typename T>
void XListBase<T>::Insert(int pos, T&& item)
{
if (count == maxNum) {
T* newItems;
if (mem == NULL)
newItems = new T[maxNum * 2 + 1];
else
newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * (maxNum * 2 + 1));
memcpy(newItems, items, sizeof(T) * maxNum);
if (mem == NULL)
delete[] items;
items = newItems;
maxNum = maxNum * 2 + 1;
}
for (int i = count - 1; i >= pos; i--)
items[i + 1] = items[i];
items[pos] = item;
count++;
}
/* get the item at position i */ /* get the item at position i */
template <typename T> template <typename T>
T* XListType<T>::GetItem(int i) const T& XListBase<T>::GetItem(int i) const
{ {
CheckNTErrors(i >= -1 && i < count, "Index of a list item is out of scope!"); CheckNTErrors(i >= -1 && i < count, "Index of a list item is out of scope!");
CheckNTErrors(count > 0, "Cannt index the item in an empty list!"); CheckNTErrors(count > 0, "Cannt index the item in an empty list!");
...@@ -266,12 +314,19 @@ T* XListType<T>::GetItem(int i) const ...@@ -266,12 +314,19 @@ T* XListType<T>::GetItem(int i) const
/* set the item at position i */ /* set the item at position i */
template <typename T> template <typename T>
void XListType<T>::SetItem(int i, T* item) void XListBase<T>::SetItem(int i, const T& item)
{ {
if (i >= 0 && i < count) if (i >= 0 && i < count)
items[i] = item; items[i] = item;
} }
template<typename T>
inline void XListBase<T>::SetItem(int i, T&& item)
{
if (i >= 0 && i < count)
items[i] = std::move(item);
}
/* /*
find the position of the first matched item find the position of the first matched item
>> item - the item for matching >> item - the item for matching
...@@ -279,7 +334,7 @@ find the position of the first matched item ...@@ -279,7 +334,7 @@ find the position of the first matched item
*/ */
template <typename T> template <typename T>
int XListType<T>::FindFirst(T* item) int XListBase<T>::FindFirst(const T& item)
{ {
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
if (item == items[i]) if (item == items[i])
...@@ -290,15 +345,10 @@ int XListType<T>::FindFirst(T* item) ...@@ -290,15 +345,10 @@ int XListType<T>::FindFirst(T* item)
/* clear the data array */ /* clear the data array */
template <typename T> template <typename T>
void XListType<T>::Clear() void XListBase<T>::Clear()
{ {
if (isIntList) { delete[] items;
for (int i = 0; i < count; i++) { count = 0;
delete[](int*) items[i];
}
count = 0;
} else
count = 0;
} }
/* /*
...@@ -307,26 +357,26 @@ sort the list ...@@ -307,26 +357,26 @@ sort the list
>> comp - the comparison function used in sorting >> comp - the comparison function used in sorting
*/ */
template <typename T> template <typename T>
void XListType<T>::Sort(int itemSize, ListCompare comp) void XListBase<T>::Sort(int itemSize, ListCompare comp)
{ {
qsort(items, count, itemSize, comp); qsort(items, count, itemSize, comp);
} }
/* reverse the list */ /* reverse the list */
template <typename T> template <typename T>
void XListType<T>::Reverse() void XListBase<T>::Reverse()
{ {
int half = count / 2; int half = count / 2;
for (int i = 0; i < half; i++) { for (int i = 0; i < half; i++) {
T* tmp = items[i]; T tmp(std::move(items[i]));
items[i] = items[count - i - 1]; items[i] = std::move(items[count - i - 1]);
items[count - i - 1] = tmp; items[count - i - 1] = std::move(tmp);
} }
} }
/* remove the item at position i */ /* remove the item at position i */
template <typename T> template <typename T>
void XListType<T>::Remove(int i) void XListBase<T>::Remove(int i)
{ {
if (i >= count || i < 0) if (i >= count || i < 0)
return; return;
...@@ -342,9 +392,9 @@ copy the list ...@@ -342,9 +392,9 @@ copy the list
<< hard copy of the list << hard copy of the list
*/ */
template <typename T> template <typename T>
XListType<T>* XListType<T>::Copy(XMem* myMem) XListBase<T>* XListBase<T>::Copy(XMem* myMem)
{ {
XListType<T>* newList = new XListType<T>(maxNum, myMem); XListBase<T>* newList = new XListBase<T>(maxNum, myMem);
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
newList->Add(GetItem(i)); newList->Add(GetItem(i));
} }
...@@ -358,7 +408,7 @@ shuffle the list ...@@ -358,7 +408,7 @@ shuffle the list
>> len - how many items are used in shuffling >> len - how many items are used in shuffling
*/ */
template <typename T> template <typename T>
void XListType<T>::Shuffle(int nround, int beg, int len) void XListBase<T>::Shuffle(int nround, int beg, int len)
{ {
if (beg < 0) { if (beg < 0) {
beg = 0; beg = 0;
...@@ -375,7 +425,7 @@ void XListType<T>::Shuffle(int nround, int beg, int len) ...@@ -375,7 +425,7 @@ void XListType<T>::Shuffle(int nround, int beg, int len)
for (int i = 0; i < len; i++) { for (int i = 0; i < len; i++) {
float a = (float)rand() / RAND_MAX; float a = (float)rand() / RAND_MAX;
size_t j = (unsigned int)(a * (i + 1)); size_t j = (unsigned int)(a * (i + 1));
T* t = items[beg + j]; T t = items[beg + j];
items[beg + j] = items[beg + i]; items[beg + j] = items[beg + i];
items[beg + i] = t; items[beg + i] = t;
} }
...@@ -385,12 +435,12 @@ void XListType<T>::Shuffle(int nround, int beg, int len) ...@@ -385,12 +435,12 @@ void XListType<T>::Shuffle(int nround, int beg, int len)
struct XTensor; struct XTensor;
/* typedef for list */ /* typedef for list */
typedef XListType<int> IntList; typedef XListBase<int> IntList;
typedef XListType<char> CharList; typedef XListBase<char*> CharList;
typedef XListType<long> LongList; typedef XListBase<long> LongList;
typedef XListType<float> FloatList; typedef XListBase<float> FloatList;
typedef XListType<short> ShortList; typedef XListBase<short> ShortList;
typedef XListType<XTensor> XList; typedef XListBase<XTensor*> XList;
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
......
...@@ -51,10 +51,10 @@ void _MatrixMul2DMultiTheading(XList * args) ...@@ -51,10 +51,10 @@ void _MatrixMul2DMultiTheading(XList * args)
XTensor * c = matrixArgs->GetItem(2); XTensor * c = matrixArgs->GetItem(2);
DTYPE alpha = *(DTYPE*)(matrixArgs->GetItem(3)); DTYPE alpha = *(DTYPE*)(matrixArgs->GetItem(3));
DTYPE beta = *(DTYPE*)(matrixArgs->GetItem(4)); DTYPE beta = *(DTYPE*)(matrixArgs->GetItem(4));
int x1 = *(indexArgs->GetItem(0)); int x1 = indexArgs->GetItem(0);
int y1 = *(indexArgs->GetItem(1)); int y1 = indexArgs->GetItem(1);
int x2 = *(indexArgs->GetItem(2)); int x2 = indexArgs->GetItem(2);
int y2 = *(indexArgs->GetItem(3)); int y2 = indexArgs->GetItem(3);
#ifdef FAST_MATRIX #ifdef FAST_MATRIX
int am = a->dimSize[1]; int am = a->dimSize[1];
......
...@@ -217,15 +217,15 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge) ...@@ -217,15 +217,15 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
bool uniform = true; bool uniform = true;
int mergeNum = smalls->count; int mergeNum = smalls->count;
XTensor* smallsItem0 = (XTensor*)(smalls->GetItem(0)); XTensor* smallsItem0 = smalls->GetItem(0);
int itemSize = smallsItem0->unitNum * smallsItem0->unitSize; int itemSize = smallsItem0->unitNum * smallsItem0->unitSize;
for (int i = 0; i < smalls->count; i++) { for (int i = 0; i < smalls->count; i++) {
XTensor* smallsItem = (XTensor*)smalls->GetItem(i); XTensor* smallsItem = smalls->GetItem(i);
CheckNTErrors((big->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!"); CheckNTErrors((big->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!");
if (i > 0) { if (i > 0) {
XTensor * preItem = (XTensor*)smalls->GetItem(i - 1); XTensor * preItem = smalls->GetItem(i - 1);
if (smallsItem->unitNum * smallsItem->unitSize != (char*)smallsItem->data - (char*)preItem->data) if (smallsItem->unitNum * smallsItem->unitSize != (char*)smallsItem->data - (char*)preItem->data)
uniform = false; uniform = false;
} }
...@@ -237,7 +237,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge) ...@@ -237,7 +237,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
int gridNum = 1; int gridNum = 1;
int mergedNum = smalls->count; int mergedNum = smalls->count;
XTensor * s0 = (XTensor*)smalls->GetItem(0); XTensor * s0 = smalls->GetItem(0);
int whereToMergeRDI = s0->order - whereToMerge - 1; int whereToMergeRDI = s0->order - whereToMerge - 1;
for (int i = 0; i < s0->order; i++) { for (int i = 0; i < s0->order; i++) {
if (i <= whereToMergeRDI) if (i <= whereToMergeRDI)
...@@ -263,7 +263,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge) ...@@ -263,7 +263,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
for (int g = 0; g < gridNum; g++) { for (int g = 0; g < gridNum; g++) {
char * tData = (char*)big->data + g * blockSize * blockNum * big->unitSize; char * tData = (char*)big->data + g * blockSize * blockNum * big->unitSize;
for (int k = 0; k < mergedNum; k++) { for (int k = 0; k < mergedNum; k++) {
XTensor * s = (XTensor*)smalls->GetItem(k); XTensor * s = smalls->GetItem(k);
char * sData = (char*)s->data + g * blockSize * blockNum * s->unitSize; char * sData = (char*)s->data + g * blockSize * blockNum * s->unitSize;
XMemCopy2D(tData + k * tStep, tPtich, big->devID, XMemCopy2D(tData + k * tStep, tPtich, big->devID,
sData + k * sStep, sPitch, s->devID, sData + k * sStep, sPitch, s->devID,
...@@ -295,7 +295,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge) ...@@ -295,7 +295,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
/* copy from source to tmp */ /* copy from source to tmp */
if (!uniform) { if (!uniform) {
for (int i = 0; i < mergeNum; i++) { for (int i = 0; i < mergeNum; i++) {
XTensor* smallsItem = (XTensor*)smalls->GetItem(i); XTensor* smallsItem = smalls->GetItem(i);
XMemCopy((char*)(tensorTMP->data) + (itemSize * i), tensorTMP->devID, smallsItem->data, smallsItem->devID, itemSize); XMemCopy((char*)(tensorTMP->data) + (itemSize * i), tensorTMP->devID, smallsItem->data, smallsItem->devID, itemSize);
} }
} }
...@@ -324,7 +324,7 @@ make a new tensor to keep the result and return it ...@@ -324,7 +324,7 @@ make a new tensor to keep the result and return it
*/ */
XTensor Merge(const XList &smalls, int whereToMerge) XTensor Merge(const XList &smalls, int whereToMerge)
{ {
XTensor * tensor = (XTensor*)smalls.GetItem(0); XTensor * tensor = smalls.GetItem(0);
int order = tensor->order; int order = tensor->order;
int * dimSize = new int[order]; int * dimSize = new int[order];
for (int i = 0; i < tensor->order; i++) { for (int i = 0; i < tensor->order; i++) {
......
...@@ -81,10 +81,10 @@ void RunParallel2D(XPRunner * parallelRunner, void * job, ...@@ -81,10 +81,10 @@ void RunParallel2D(XPRunner * parallelRunner, void * job,
XList * blockArgs = new XList(argNum); XList * blockArgs = new XList(argNum);
int * blockIndex = indexList + i * 4; int * blockIndex = indexList + i * 4;
indexArgs->Add(blockIndex); indexArgs->Add(blockIndex[0]);
indexArgs->Add(blockIndex + 1); indexArgs->Add(blockIndex[1]);
indexArgs->Add(blockIndex + 2); indexArgs->Add(blockIndex[2]);
indexArgs->Add(blockIndex + 3); indexArgs->Add(blockIndex[3]);
for (int j = 0; j < argNum; j++) for (int j = 0; j < argNum; j++)
blockArgs->Add(jobArgList->GetItem(j)); blockArgs->Add(jobArgList->GetItem(j));
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论