Commit 7f9c0b47 by xiaotong

improve XMem for less memory footprint

parent c90e83fe
......@@ -101,7 +101,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool
/* dropout */
if(isTraining && dropoutP > 0)
x = Dropout(x);
x = Dropout(x, dropoutP);
for(int i = 0; i < nlayer; i++){
XTensor att;
......@@ -119,7 +119,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att);
att = Dropout(att, dropoutP);
/* layer normalization */
x = attLayerNorms[i].Make(att);
......@@ -131,7 +131,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att);
att = Dropout(att, dropoutP);
/* residual connection */
res = Sum(att, x);
......@@ -145,7 +145,7 @@ XTensor AttEncoder::Make(XTensor &input, XTensor &mask, bool skipInputRes, bool
/* dropout */
if(isTraining && dropoutP > 0)
fnn = Dropout(fnn);
fnn = Dropout(fnn, dropoutP);
/* residual connection */
res = Sum(fnn, x);
......
......@@ -747,6 +747,64 @@ void * XMem::AllocStandard(int myDevID, MTYPE mySize, bool myIsRebuiltIndex)
CheckNTErrors(nodeNumUsed < nodeNum, "No enough index nodes for the memory pool!");
}
/*if(testxmemid == 30){
recordp = result;
}
if(curBlockID >= 25){
MHeader * head = blocks[25].head;
while(head != NULL){
fprintf(stderr, "head: %ld %ld\n", head->indexNode->pReal, head->indexNode->size);
head = head->next;
}
}
if(testxmemid == 32){
int nnn = 0;
}
if(recordp != NULL){
MTYPE size = mySize;
if(size <= minSizeIndex[0])
size = minSizeIndex[0];
MPieceNode * entry = NULL;
MPieceNode * node = NULL;
MPieceNode * hit = NULL;
MPieceNode * last = NULL;
entry = memIndex + indexEntryNum + FindIndexEntry(size);
last = entry;
node = entry->next;
while(node != NULL){
CheckNTErrors(node->pre == last, "Something is wrong!");
CheckNTErrors(last->next == node, "Something is wrong!");
CheckNTErrors(node->head.state == 2, "Something is wrong!");
last = node;
if(node->size == 0){
MPieceNode * next = node->next;
RemoveFreeIndexNode(node, entry);
node = next;
ShowNTErrors("Something is wrong!");
}
else{
CheckNTErrors(node->pReal != NULL, "Illegal pointer!");
if(node->pReal == recordp){
hit = node;
break;
}
node = node->next;
}
}
if(hit == NULL){
int nnn = 0;
}
}*/
return result;
}
......@@ -918,6 +976,8 @@ void XMem::ReleaseStandard(int myDevID, void * p, MTYPE size)
hit->head.state = 1;
RemoveAllocIndexNode(hit);
hit->size = (char*)hit->p + hit->head.size - (char*)GetPitchedAddress((char*)hit->p, MY_PITCH);
AddFreeIndexNode(hit);
blocks[hit->head.blockID].used -= hit->head.size;
......@@ -981,8 +1041,9 @@ void XMem::RebuildIndex()
/* make a new index node */
MPieceNode * newNode = memIndex2 + nodeNumUsed2++;
newNode->p = p;
newNode->size = (char*)p + head->size -
( head->state == 1 ? (char*)GetPitchedAddress((char*)p, MY_PITCH) : (char*)head->indexNode->pReal);
newNode->size = node->size;
//newNode->size = (char*)p + head->size -
// ( head->state == 1 ? (char*)GetPitchedAddress((char*)p, MY_PITCH) : (char*)head->indexNode->pReal);
newNode->pre = NULL;
newNode->next = NULL;
......
......@@ -65,9 +65,10 @@ bool TestXMemCase1()
for (int i = 0; i < testNum * scalar; i++) {
testxmemid++;
//fprintf(stderr, "%d %d\n", testxmemid, ok);
int j = rand() % caseNum;
//fprintf(stderr, "%d %d %d\n", testxmemid, j, ok);
if (p[j] == NULL) {
p[j] = (int*)mem.AllocStandard(mem.devID, size[j] * sizeof(int));
for (int k = 0; k < size[j]; k++)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论