Commit 1f5627f9 by huchi

refactor XList: now a list has the same basic function as a STL vector

parent 72c9551b
......@@ -53,6 +53,9 @@ void TestDataManager() {
s[9] = 0;
str.Add(s);
cout << str.Get(0);
vector<int> x;
}
int main()
......
......@@ -25,6 +25,8 @@
#include "XMem.h"
#include "XGlobal.h"
#include <utility>
#ifndef __XLIST_H__
#define __XLIST_H__
......@@ -32,14 +34,16 @@
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
/* the XListType class */
/* the XListBase class */
template <typename T>
class XListType {
struct XListBase {
public:
typedef int (*ListCompare)(const T* item1, const T* item2);
/* compare function */
typedef int (*ListCompare)(const T item1, const T item2);
/* data items */
T** items;
T *items;
/* number of items */
int count;
......@@ -50,42 +54,48 @@ public:
/* the memory pool for data array allocation */
XMem* mem;
/* indicates whether data items are integers */
bool isIntList;
public:
/* constructor */
XListType();
XListBase();
/* constructor */
XListType(int myMaxNum, bool isIntListOrNot = false);
XListBase(int myMaxNum);
/* constructor */
XListType(int myMaxNum, XMem* myMem, bool isIntListOrNot = false);
XListBase(int myMaxNum, XMem* myMem);
/* de-constructor */
~XListType();
~XListBase();
/* add an item into the list */
void Add(T* item);
void Add(T&& item);
/* add an item into the list */
void Add(const T& item);
/* add a number of items into the list */
void Add(T** inputItems, int inputItemCount);
void Add(T* inputItems, int inputItemCount);
/* append a list to the current list */
void AddList(XListType* l);
void AddList(XListBase* l);
/* insert an item to the given position of the list */
void Insert(int pos, const T& item);
/* insert an item to the given position of the list */
void Insert(int pos, T* item);
void Insert(int pos, T&& item);
/* get the item at position i */
T* GetItem(int i) const;
T& GetItem(int i) const;
/* set the item at position i */
void SetItem(int i, T* item);
void SetItem(int i, const T& item);
/* set the item at position i */
void SetItem(int i, T&& item);
/* find the position of the first matched item */
int FindFirst(T* item);
int FindFirst(const T& item);
/* clear the data array */
void Clear();
......@@ -100,25 +110,27 @@ public:
void Remove(int i);
/* copy the list */
XListType* Copy(XMem* myMem);
XListBase* Copy(XMem* myMem);
/* shuffle the list */
void Shuffle(int nround = 10, int beg = -1, int len = 0);
/* short */
T* Get(int i) { return GetItem(i); };
void Set(int i, T* item) { SetItem(i, item); };
T& operator[] (int i) {
return GetItem(i);
};
T& Get(int i) { return GetItem(i); };
void Set(int i, T item) { SetItem(i, item); };
};
/* constructor */
template <typename T>
XListType<T>::XListType()
XListBase<T>::XListBase()
{
mem = NULL;
maxNum = 0;
count = 0;
items = NULL;
isIntList = false;
}
/*
......@@ -127,13 +139,12 @@ constructor
>> isIntListOrNot - specify if the list keeps int items
*/
template <typename T>
XListType<T>::XListType(int myMaxNum, bool isIntListOrNot)
XListBase<T>::XListBase(int myMaxNum)
{
mem = NULL;
maxNum = myMaxNum;
count = 0;
items = new T*[myMaxNum];
isIntList = isIntListOrNot;
items = new T[myMaxNum];
}
/*
......@@ -143,44 +154,59 @@ constructor
>> isIntListOrNot - specify if the list keeps int items
*/
template <typename T>
XListType<T>::XListType(int myMaxNum, XMem* myMem, bool isIntListOrNot)
XListBase<T>::XListBase(int myMaxNum, XMem* myMem)
{
mem = myMem;
maxNum = myMaxNum;
count = 0;
items = (T**)mem->Alloc(mem->devID, sizeof(T*) * maxNum);
isIntList = isIntListOrNot;
items = (T*)mem->Alloc(mem->devID, sizeof(T) * maxNum);
}
/* de-constructor */
template <typename T>
XListType<T>::~XListType()
XListBase<T>::~XListBase()
{
if (isIntList) {
for (int i = 0; i < count; i++) {
int* p = (int*)items[i];
delete[] p;
}
}
if (mem == NULL)
delete[] items;
}
/*
add an item into the list
>> item - pointer to the item
>> item - a right value
*/
template <typename T>
void XListBase<T>::Add(T&& item)
{
if (count == maxNum) {
T* newItems;
if (mem == NULL)
newItems = new T[maxNum * 2 + 1];
else
newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * (maxNum * 2 + 1));
memcpy(newItems, items, sizeof(T) * maxNum);
if (mem == NULL)
delete[] items;
items = newItems;
maxNum = maxNum * 2 + 1;
}
items[count++] = item;
}
/*
add an item into the list
>> item - a const reference to the item
*/
template <typename T>
void XListType<T>::Add(T* item)
void XListBase<T>::Add(const T& item)
{
if (count == maxNum) {
T** newItems;
T* newItems;
if (mem == NULL)
newItems = new T*[maxNum * 2 + 1];
newItems = new T[maxNum * 2 + 1];
else
newItems = (T**)mem->Alloc(mem->devID, sizeof(T*) * (maxNum * 2 + 1));
memcpy(newItems, items, sizeof(T*) * maxNum);
newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * (maxNum * 2 + 1));
memcpy(newItems, items, sizeof(T) * maxNum);
if (mem == NULL)
delete[] items;
items = newItems;
......@@ -196,22 +222,22 @@ add a number of items into the list
>> inputItemCount - number of input items
*/
template <typename T>
void XListType<T>::Add(T** inputItems, int inputItemCount)
void XListBase<T>::Add(T* inputItems, int inputItemCount)
{
if (count + inputItemCount >= maxNum) {
int newMaxNum = (count + inputItemCount) * 2 + 1;
T** newItems;
T* newItems;
if (mem == NULL)
newItems = new T*[newMaxNum];
newItems = new T[newMaxNum];
else
newItems = (T**)mem->Alloc(mem->devID, sizeof(T*) * newMaxNum);
memcpy(newItems, items, sizeof(T*) * maxNum);
newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * newMaxNum);
memcpy(newItems, items, sizeof(T) * maxNum);
if (mem == NULL)
delete[] items;
items = newItems;
maxNum = newMaxNum;
}
memcpy(items + count, inputItems, sizeof(T*) * inputItemCount);
memcpy(items + count, inputItems, sizeof(T) * inputItemCount);
count += inputItemCount;
}
......@@ -220,7 +246,7 @@ append a list to the current list
>> l - the list we use to append
*/
template <typename T>
void XListType<T>::AddList(XListType* l)
void XListBase<T>::AddList(XListBase* l)
{
Add(l->items, l->count);
}
......@@ -231,15 +257,37 @@ insert an item to the given position of the list
>> item - the item for insertion
*/
template <typename T>
void XListType<T>::Insert(int pos, T* item)
void XListBase<T>::Insert(int pos, const T& item)
{
if (count == maxNum) {
T** newItems;
T* newItems;
if (mem == NULL)
newItems = new T*[maxNum * 2 + 1];
newItems = new T[maxNum * 2 + 1];
else
newItems = (T**)mem->Alloc(mem->devID, sizeof(T*) * (maxNum * 2 + 1));
memcpy(newItems, items, sizeof(T*) * maxNum);
newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * (maxNum * 2 + 1));
memcpy(newItems, items, sizeof(T) * maxNum);
if (mem == NULL)
delete[] items;
items = newItems;
maxNum = maxNum * 2 + 1;
}
for (int i = count - 1; i >= pos; i--)
items[i + 1] = items[i];
items[pos] = item;
count++;
}
template<typename T>
void XListBase<T>::Insert(int pos, T&& item)
{
if (count == maxNum) {
T* newItems;
if (mem == NULL)
newItems = new T[maxNum * 2 + 1];
else
newItems = (T*)mem->Alloc(mem->devID, sizeof(T) * (maxNum * 2 + 1));
memcpy(newItems, items, sizeof(T) * maxNum);
if (mem == NULL)
delete[] items;
items = newItems;
......@@ -254,7 +302,7 @@ void XListType<T>::Insert(int pos, T* item)
/* get the item at position i */
template <typename T>
T* XListType<T>::GetItem(int i) const
T& XListBase<T>::GetItem(int i) const
{
CheckNTErrors(i >= -1 && i < count, "Index of a list item is out of scope!");
CheckNTErrors(count > 0, "Cannt index the item in an empty list!");
......@@ -266,12 +314,19 @@ T* XListType<T>::GetItem(int i) const
/* set the item at position i */
template <typename T>
void XListType<T>::SetItem(int i, T* item)
void XListBase<T>::SetItem(int i, const T& item)
{
if (i >= 0 && i < count)
items[i] = item;
}
template<typename T>
inline void XListBase<T>::SetItem(int i, T&& item)
{
if (i >= 0 && i < count)
items[i] = std::move(item);
}
/*
find the position of the first matched item
>> item - the item for matching
......@@ -279,7 +334,7 @@ find the position of the first matched item
*/
template <typename T>
int XListType<T>::FindFirst(T* item)
int XListBase<T>::FindFirst(const T& item)
{
for (int i = 0; i < count; i++) {
if (item == items[i])
......@@ -290,14 +345,9 @@ int XListType<T>::FindFirst(T* item)
/* clear the data array */
template <typename T>
void XListType<T>::Clear()
void XListBase<T>::Clear()
{
if (isIntList) {
for (int i = 0; i < count; i++) {
delete[](int*) items[i];
}
count = 0;
} else
delete[] items;
count = 0;
}
......@@ -307,26 +357,26 @@ sort the list
>> comp - the comparison function used in sorting
*/
template <typename T>
void XListType<T>::Sort(int itemSize, ListCompare comp)
void XListBase<T>::Sort(int itemSize, ListCompare comp)
{
qsort(items, count, itemSize, comp);
}
/* reverse the list */
template <typename T>
void XListType<T>::Reverse()
void XListBase<T>::Reverse()
{
int half = count / 2;
for (int i = 0; i < half; i++) {
T* tmp = items[i];
items[i] = items[count - i - 1];
items[count - i - 1] = tmp;
T tmp(std::move(items[i]));
items[i] = std::move(items[count - i - 1]);
items[count - i - 1] = std::move(tmp);
}
}
/* remove the item at position i */
template <typename T>
void XListType<T>::Remove(int i)
void XListBase<T>::Remove(int i)
{
if (i >= count || i < 0)
return;
......@@ -342,9 +392,9 @@ copy the list
<< hard copy of the list
*/
template <typename T>
XListType<T>* XListType<T>::Copy(XMem* myMem)
XListBase<T>* XListBase<T>::Copy(XMem* myMem)
{
XListType<T>* newList = new XListType<T>(maxNum, myMem);
XListBase<T>* newList = new XListBase<T>(maxNum, myMem);
for (int i = 0; i < count; i++) {
newList->Add(GetItem(i));
}
......@@ -358,7 +408,7 @@ shuffle the list
>> len - how many items are used in shuffling
*/
template <typename T>
void XListType<T>::Shuffle(int nround, int beg, int len)
void XListBase<T>::Shuffle(int nround, int beg, int len)
{
if (beg < 0) {
beg = 0;
......@@ -375,7 +425,7 @@ void XListType<T>::Shuffle(int nround, int beg, int len)
for (int i = 0; i < len; i++) {
float a = (float)rand() / RAND_MAX;
size_t j = (unsigned int)(a * (i + 1));
T* t = items[beg + j];
T t = items[beg + j];
items[beg + j] = items[beg + i];
items[beg + i] = t;
}
......@@ -385,12 +435,12 @@ void XListType<T>::Shuffle(int nround, int beg, int len)
struct XTensor;
/* typedef for list */
typedef XListType<int> IntList;
typedef XListType<char> CharList;
typedef XListType<long> LongList;
typedef XListType<float> FloatList;
typedef XListType<short> ShortList;
typedef XListType<XTensor> XList;
typedef XListBase<int> IntList;
typedef XListBase<char*> CharList;
typedef XListBase<long> LongList;
typedef XListBase<float> FloatList;
typedef XListBase<short> ShortList;
typedef XListBase<XTensor*> XList;
} /* end of the nts (NiuTrans.Tensor) namespace */
......
......@@ -51,10 +51,10 @@ void _MatrixMul2DMultiTheading(XList * args)
XTensor * c = matrixArgs->GetItem(2);
DTYPE alpha = *(DTYPE*)(matrixArgs->GetItem(3));
DTYPE beta = *(DTYPE*)(matrixArgs->GetItem(4));
int x1 = *(indexArgs->GetItem(0));
int y1 = *(indexArgs->GetItem(1));
int x2 = *(indexArgs->GetItem(2));
int y2 = *(indexArgs->GetItem(3));
int x1 = indexArgs->GetItem(0);
int y1 = indexArgs->GetItem(1);
int x2 = indexArgs->GetItem(2);
int y2 = indexArgs->GetItem(3);
#ifdef FAST_MATRIX
int am = a->dimSize[1];
......
......@@ -217,15 +217,15 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
bool uniform = true;
int mergeNum = smalls->count;
XTensor* smallsItem0 = (XTensor*)(smalls->GetItem(0));
XTensor* smallsItem0 = smalls->GetItem(0);
int itemSize = smallsItem0->unitNum * smallsItem0->unitSize;
for (int i = 0; i < smalls->count; i++) {
XTensor* smallsItem = (XTensor*)smalls->GetItem(i);
XTensor* smallsItem = smalls->GetItem(i);
CheckNTErrors((big->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!");
if (i > 0) {
XTensor * preItem = (XTensor*)smalls->GetItem(i - 1);
XTensor * preItem = smalls->GetItem(i - 1);
if (smallsItem->unitNum * smallsItem->unitSize != (char*)smallsItem->data - (char*)preItem->data)
uniform = false;
}
......@@ -237,7 +237,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
int gridNum = 1;
int mergedNum = smalls->count;
XTensor * s0 = (XTensor*)smalls->GetItem(0);
XTensor * s0 = smalls->GetItem(0);
int whereToMergeRDI = s0->order - whereToMerge - 1;
for (int i = 0; i < s0->order; i++) {
if (i <= whereToMergeRDI)
......@@ -263,7 +263,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
for (int g = 0; g < gridNum; g++) {
char * tData = (char*)big->data + g * blockSize * blockNum * big->unitSize;
for (int k = 0; k < mergedNum; k++) {
XTensor * s = (XTensor*)smalls->GetItem(k);
XTensor * s = smalls->GetItem(k);
char * sData = (char*)s->data + g * blockSize * blockNum * s->unitSize;
XMemCopy2D(tData + k * tStep, tPtich, big->devID,
sData + k * sStep, sPitch, s->devID,
......@@ -295,7 +295,7 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
/* copy from source to tmp */
if (!uniform) {
for (int i = 0; i < mergeNum; i++) {
XTensor* smallsItem = (XTensor*)smalls->GetItem(i);
XTensor* smallsItem = smalls->GetItem(i);
XMemCopy((char*)(tensorTMP->data) + (itemSize * i), tensorTMP->devID, smallsItem->data, smallsItem->devID, itemSize);
}
}
......@@ -324,7 +324,7 @@ make a new tensor to keep the result and return it
*/
XTensor Merge(const XList &smalls, int whereToMerge)
{
XTensor * tensor = (XTensor*)smalls.GetItem(0);
XTensor * tensor = smalls.GetItem(0);
int order = tensor->order;
int * dimSize = new int[order];
for (int i = 0; i < tensor->order; i++) {
......
......@@ -81,10 +81,10 @@ void RunParallel2D(XPRunner * parallelRunner, void * job,
XList * blockArgs = new XList(argNum);
int * blockIndex = indexList + i * 4;
indexArgs->Add(blockIndex);
indexArgs->Add(blockIndex + 1);
indexArgs->Add(blockIndex + 2);
indexArgs->Add(blockIndex + 3);
indexArgs->Add(blockIndex[0]);
indexArgs->Add(blockIndex[1]);
indexArgs->Add(blockIndex[2]);
indexArgs->Add(blockIndex[3]);
for (int j = 0; j < argNum; j++)
blockArgs->Add(jobArgList->GetItem(j));
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论