Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
80ab70a2
Commit
80ab70a2
authored
Nov 19, 2018
by
xuchen
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'xuchen' into xiaotong-working
parents
411cff4c
b83a6798
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
239 行增加
和
139 行删除
+239
-139
source/network/XBackwardShape.cpp
+1
-1
source/network/XNet.cpp
+4
-4
source/sample/transformer/T2TTrainer.cpp
+35
-40
source/sample/transformer/T2TTrainer.h
+1
-5
source/sample/transformer/Transformer.cpp
+10
-65
source/tensor/Main.cpp
+165
-1
source/tensor/XTensor.cpp
+1
-1
source/tensor/core/shape/Merge.cpp
+0
-1
source/tensor/core/shape/Split.cpp
+18
-0
source/tensor/test/TSplit.cpp
+4
-21
没有找到文件。
source/network/XBackwardShape.cpp
查看文件 @
80ab70a2
...
...
@@ -375,7 +375,7 @@ void XShapeGrad::GradSplitList(XTensor * node, bool isEfficient)
XTensor
*
input
=
income
.
tails
[
0
];
CheckNTErrors
(
income
.
tailNum
==
1
,
"Wrong input tensor number for SPLIT!"
);
CheckNTErrors
(
node
->
order
==
input
->
order
+
1
,
"Wrong tensor orders!"
);
//
CheckNTErrors(node->order == input->order + 1, "Wrong tensor orders!");
node
->
visitMark
=
NODE_DOING
;
}
...
...
source/network/XNet.cpp
查看文件 @
80ab70a2
...
...
@@ -96,7 +96,7 @@ void XNet::Backward(XTensor &root, XTensor &gold, LOSS_FUNCTION_NAME loss)
backward propagation to obtain gradient wrt. the loss/error function
>> root - root node (output) of the network
>> gold - gold standard for the output
>> padding - specify a target value that is ignored and does not contribute to the
loss
computation
>> padding - specify a target value that is ignored and does not contribute to the
gradient
computation
>> loss - name of loss function
*/
void
XNet
::
Backward
(
XTensor
&
root
,
XTensor
&
gold
,
XTensor
&
padding
,
LOSS_FUNCTION_NAME
loss
)
...
...
@@ -135,9 +135,9 @@ void XNet::Backward(XTensor &root, LOSS_FUNCTION_NAME loss)
/*
backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes
>> root - a list of root nodes (output) of the network
>> gold - a list of gold standard for the output
>> padding - specify a target value that is ignored
>> root
s
- a list of root nodes (output) of the network
>> gold
s
- a list of gold standard for the output
>> padding
s
- specify a target value that is ignored
>> loss - name of loss function
*/
void
XNet
::
Backward
(
XList
&
roots
,
XList
&
golds
,
XList
&
paddings
,
LOSS_FUNCTION_NAME
loss
)
...
...
source/sample/transformer/T2TTrainer.cpp
查看文件 @
80ab70a2
...
...
@@ -125,9 +125,6 @@ void T2TTrainer::Init(int argc, char ** argv)
adamBeta1T
=
1.0
F
;
adamBeta2T
=
1.0
F
;
validStep
=
0
;
curEpoch
=
0
;
}
int
tc
=
0
;
...
...
@@ -139,10 +136,8 @@ train the model
>> modelFN - where we keep the model
>> model - model to train
*/
bool
T2TTrainer
::
Train
(
const
char
*
fn
,
const
char
*
validFN
,
const
char
*
modelFN
,
T2TModel
*
model
)
void
T2TTrainer
::
Train
(
const
char
*
fn
,
const
char
*
validFN
,
const
char
*
modelFN
,
T2TModel
*
model
)
{
curEpoch
+=
1
;
int
step
=
0
;
int
wc
=
0
;
int
wordCount
=
0
;
...
...
@@ -154,7 +149,8 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
int
nCheckpoint
=
0
;
int
nSkipped
=
0
;
int
gradStep
=
0
;
//int validStep = 0;
int
validStep
=
0
;
int
epoch
=
0
;
char
*
trainFN
=
new
char
[(
int
)
strlen
(
fn
)
+
10
];
strcpy
(
trainFN
,
fn
);
...
...
@@ -172,10 +168,10 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
double
startT
=
GetClockSec
();
//
for(epoch = 1; epoch <= nepoch; epoch++){
for
(
epoch
=
1
;
epoch
<=
nepoch
;
epoch
++
){
#ifndef WIN32
if
(
isShuffled
)
Shuffle
(
fn
,
trainFN
);
if
(
isShuffled
)
Shuffle
(
fn
,
trainFN
);
#endif
FILE
*
file
=
fopen
(
trainFN
,
"rb"
);
...
...
@@ -204,7 +200,6 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
{
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
//CheckNTErrors(batchEnc.order == 3, "wrong tensor order of the sequence batch");
/* output probabilities */
XTensor
output
;
...
...
@@ -271,25 +266,27 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
break
;
}
if
(
step
%
1
==
0
)
{
if
(
step
%
1
00
==
0
)
{
double
elapsed
=
GetClockSec
()
-
startT
;
XPRINT8
(
0
,
stderr
,
"[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f, sppl=%.3f"
,
lr
,
elapsed
,
step
,
curE
poch
,
wordCountTotal
,
loss
/
wordCount
,
exp
(
loss
/
wordCount
),
exp
(
-
prob
/
wc
));
lr
,
elapsed
,
step
,
e
poch
,
wordCountTotal
,
loss
/
wordCount
,
exp
(
loss
/
wordCount
),
exp
(
-
prob
/
wc
));
if
(
!
doUpdate
)
XPRINT
(
0
,
stderr
,
" (no update)"
);
XPRINT
(
0
,
stderr
,
"
\n
"
);
}
XMem
*
mem
=
model
->
mem
;
MTYPE
used
=
0
;
MTYPE
total
=
0
;
for
(
int
i
=
0
;
i
<
mem
->
blockNum
;
i
++
){
if
(
mem
->
blocks
[
i
].
mem
!=
NULL
){
used
+=
mem
->
blocks
[
i
].
used
;
total
+=
mem
->
blocks
[
i
].
size
;
}
}
fprintf
(
stderr
,
"%d %d %d %d mem: %lld %lld
\n
"
,
paddingEnc
.
GetDim
(
0
),
paddingEnc
.
GetDim
(
1
),
paddingDec
.
GetDim
(
0
),
paddingDec
.
GetDim
(
1
),
used
,
total
);
//XMem * mem = model->mem;
//MTYPE used = 0;
//MTYPE total = 0;
//for(int i = 0; i < mem->blockNum; i++){
// if(mem->blocks[i].mem != NULL){
// used += mem->blocks[i].used;
// total += mem->blocks[i].size;
// }
//}
//fprintf(stderr, "%d %d %d %d mem: %lld %lld\n", paddingEnc.GetDim(0), paddingEnc.GetDim(1),
// paddingDec.GetDim(0), paddingDec.GetDim(1), used, total);
if
(
nStepCheckpoint
>
0
&&
++
nStepCheck
>=
nStepCheckpoint
){
MakeCheckpoint
(
model
,
validFN
,
modelFN
,
"step"
,
step
);
...
...
@@ -299,22 +296,22 @@ bool T2TTrainer::Train(const char * fn, const char * validFN, const char * model
}
fclose
(
file
);
if
(
isEnd
)
return
false
;
return
true
;
//
if(useEpochCheckpoint)
//
MakeCheckpoint(model, validFN, modelFN, "epoch", epoch);
//
}
//
double elapsed = GetClockSec() - startT;
//
//
epoch = MIN(epoch, nepoch);
//
//
XPRINT7(0, stderr, "[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f\n",
//
lr, elapsed, step, epoch, wordCountTotal, loss/wordCount, exp(loss/wordCount));
//
XPRINT4(0, stderr, "[INFO] training finished (took %.1fs, step=%d, skipped=%d and epoch=%d)\n",
//
elapsed, step, nSkipped, epoch);
break
;
if
(
useEpochCheckpoint
)
MakeCheckpoint
(
model
,
validFN
,
modelFN
,
"epoch"
,
epoch
);
}
double
elapsed
=
GetClockSec
()
-
startT
;
epoch
=
MIN
(
epoch
,
nepoch
);
XPRINT7
(
0
,
stderr
,
"[INFO] lr=%.2e, elapsed=%.1fs, step=%d, epoch=%d, word=%d, loss=%.3f, ppl=%.3f
\n
"
,
lr
,
elapsed
,
step
,
epoch
,
wordCountTotal
,
loss
/
wordCount
,
exp
(
loss
/
wordCount
));
XPRINT4
(
0
,
stderr
,
"[INFO] training finished (took %.1fs, step=%d, skipped=%d and epoch=%d)
\n
"
,
elapsed
,
step
,
nSkipped
,
epoch
);
delete
[]
trainFN
;
}
...
...
@@ -368,8 +365,6 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
seqs
,
vSize
,
vSizeTgt
,
1
,
1
,
false
,
wc
,
devID
,
mem
,
false
))
{
//CheckNTErrors(batchEnc.order == 3, "wrong tensor order of the sequence batch");
CheckNTErrors
(
batchEnc
.
order
==
2
,
"wrong tensor order of the sequence batch"
);
/* output probabilities */
...
...
source/sample/transformer/T2TTrainer.h
查看文件 @
80ab70a2
...
...
@@ -103,10 +103,6 @@ public:
/* indicates whether we use adam */
bool
useAdam
;
int
validStep
;
int
curEpoch
;
/* hyper parameters of adam*/
float
adamBeta1
;
float
adamBeta2
;
...
...
@@ -157,7 +153,7 @@ public:
void
Init
(
int
argc
,
char
**
argv
);
/* train the model */
bool
Train
(
const
char
*
fn
,
const
char
*
validFN
,
const
char
*
modelFN
,
T2TModel
*
model
);
void
Train
(
const
char
*
fn
,
const
char
*
validFN
,
const
char
*
modelFN
,
T2TModel
*
model
);
/* test the model */
void
Test
(
const
char
*
fn
,
const
char
*
ofn
,
T2TModel
*
model
);
...
...
source/sample/transformer/Transformer.cpp
查看文件 @
80ab70a2
...
...
@@ -58,75 +58,20 @@ int TransformerMain(int argc, const char ** argv)
LoadParamString
(
argc
,
args
,
"test"
,
testFN
,
""
);
LoadParamString
(
argc
,
args
,
"output"
,
outputFN
,
""
);
T2TTrainer
trainer
;
trainer
.
Init
(
argc
,
args
);
T2TModel
model
;
model
.
InitModel
(
argc
,
args
);
/* learn model parameters */
if
(
strcmp
(
trainFN
,
""
))
{
double
startT
=
GetClockSec
();
T2TTrainer
trainer
;
trainer
.
Init
(
argc
,
args
);
char
*
fn
=
new
char
[
MAX_LINE_LENGTH
];
char
*
fn1
=
new
char
[
MAX_LINE_LENGTH
];
char
*
fn2
=
new
char
[
MAX_LINE_LENGTH
];
//modelFN = strcmp(modelFN, "") ? modelFN : (char *)"checkpoint.model";
int
epoch
;
bool
isTrain
;
for
(
epoch
=
1
;
epoch
<=
trainer
.
nepoch
;
epoch
++
)
{
sprintf
(
fn
,
"%s.%s.%03d"
,
modelFN
,
"epoch"
,
epoch
-
1
);
sprintf
(
fn1
,
"%s.%s.%03d"
,
modelFN
,
"epoch"
,
epoch
);
sprintf
(
fn2
,
"%s.%s.%03d.output"
,
modelFN
,
"epoch"
,
epoch
);
if
(
epoch
==
1
)
{
T2TModel
model
;
model
.
InitModel
(
argc
,
args
);
isTrain
=
trainer
.
Train
(
trainFN
,
testFN
,
modelFN
,
&
model
);
//model.Dump(fn1);
}
else
{
T2TModel
model
;
model
.
InitModel
(
argc
,
args
);
model
.
Read
(
fn
);
isTrain
=
trainer
.
Train
(
trainFN
,
testFN
,
modelFN
,
&
model
);
//model.Dump(fn1);
}
if
(
trainer
.
useEpochCheckpoint
&&
strcmp
(
testFN
,
""
))
{
T2TTrainer
tester
;
tester
.
Init
(
argc
,
args
);
T2TModel
model
;
model
.
InitModel
(
argc
,
args
);
//model.Read(fn1);
//tester.Test(testFN, fn2, &model);
}
if
(
!
isTrain
)
break
;
}
double
elapsed
=
GetClockSec
()
-
startT
;
epoch
=
MIN
(
epoch
,
trainer
.
nepoch
);
if
(
strcmp
(
trainFN
,
""
))
trainer
.
Train
(
trainFN
,
testFN
,
modelFN
,
&
model
);
XPRINT2
(
0
,
stderr
,
"[INFO] training finished (took %.1fs and epoch=%d)
\n
"
,
elapsed
,
epoch
);
delete
[]
fn
;
delete
[]
fn1
;
delete
[]
fn2
;
}
/* don't dump the final model */
/* save the final model */
//
if(strcmp(modelFN, "") && strcmp(trainFN, ""))
//
model.Dump(modelFN);
if
(
strcmp
(
modelFN
,
""
)
&&
strcmp
(
trainFN
,
""
))
model
.
Dump
(
modelFN
);
T2TModel
model
;
model
.
InitModel
(
argc
,
args
);
/* load the model if neccessary */
if
(
strcmp
(
modelFN
,
""
))
model
.
Read
(
modelFN
);
...
...
source/tensor/Main.cpp
查看文件 @
80ab70a2
...
...
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
*
* This is the entrance of the low-level tensor library : NiuTrans.Tensor
...
...
@@ -39,9 +39,20 @@ using namespace nts;
void
SmallTest
();
void
TransposeTest
();
void
LittleTest
();
void
T2TTest
();
void
T2TTest2
();
void
PowerTest
();
int
main
(
int
argc
,
const
char
**
argv
)
{
//PowerTest();
//LittleTest();
//T2TTest();
//T2TTest2();
//return 0;
//_CrtSetBreakAlloc(123);
/* a tiny test */
...
...
@@ -63,6 +74,34 @@ int main( int argc, const char ** argv )
return
0
;
}
void
myRead
(
XTensor
*
tensor
,
const
char
*
filename
,
const
char
*
label
)
{
FILE
*
file
=
fopen
(
filename
,
"rb"
);
if
(
file
==
NULL
)
printf
(
"%s
\n
"
,
filename
);
tensor
->
Read
(
file
,
label
);
}
void
myDump
(
XTensor
*
tensor
,
const
char
*
filename
,
const
char
*
label
)
{
FILE
*
file
=
fopen
(
filename
,
"wb"
);
if
(
file
==
NULL
)
printf
(
"%s
\n
"
,
filename
);
tensor
->
Dump
(
file
,
label
);
}
void
PowerTest
()
{
XTensor
input
;
XTensor
output
;
InitTensor2D
(
&
input
,
256
,
10000
,
X_FLOAT
,
0
);
InitTensor2D
(
&
output
,
256
,
10000
,
X_FLOAT
,
0
);
myRead
(
&
input
,
"1.txt"
,
""
);
_Power
(
&
input
,
&
output
,
2
);
output
.
Dump
(
stderr
,
""
,
200
);
}
void
SmallTest
()
{
XTensor
a
;
...
...
@@ -126,3 +165,128 @@ void TransposeTest()
delete
[]
data
;
}
void
LittleTest
()
{
int
a
=
5000
;
int
b
=
100000
;
int
c
=
a
*
b
;
printf
(
"%d
\n
"
,
c
);
exit
(
1
);
}
void
T2TTest
()
{
XTensor
*
input
;
XTensor
*
weight
;
XTensor
*
output
;
XTensor
*
gold
;
XTensor
*
dedy
;
XTensor
*
dedx
;
XTensor
*
dedxTmp
;
XTensor
*
dedw
;
XTensor
*
padding
;
DTYPE
loss
;
int
*
dimSize
=
new
int
[
2
];
dimSize
[
0
]
=
256
;
dimSize
[
1
]
=
10001
;
int
*
dimSize2
=
new
int
[
3
];
dimSize2
[
0
]
=
2
;
dimSize2
[
1
]
=
31
;
dimSize2
[
2
]
=
256
;
int
*
dimSize3
=
new
int
[
3
];
dimSize3
[
0
]
=
2
;
dimSize3
[
1
]
=
31
;
dimSize3
[
2
]
=
10001
;
int
*
dimSize4
=
new
int
[
2
];
dimSize4
[
0
]
=
2
;
dimSize4
[
1
]
=
31
;
input
=
NewTensor
(
3
,
dimSize2
,
X_FLOAT
,
1.0
F
,
0
);
weight
=
NewTensor
(
2
,
dimSize
,
X_FLOAT
,
1.0
F
,
0
);
dedw
=
NewTensor
(
2
,
dimSize
,
X_FLOAT
,
1.0
F
,
0
);
gold
=
NewTensor
(
3
,
dimSize3
,
X_FLOAT
,
1.0
F
,
0
);
output
=
NewTensor
(
3
,
dimSize3
,
X_FLOAT
,
1.0
F
,
0
);
dedy
=
NewTensor
(
3
,
dimSize3
,
X_FLOAT
,
1.0
F
,
0
);
dedx
=
NewTensor
(
3
,
dimSize3
,
X_FLOAT
,
1.0
F
,
0
);
dedxTmp
=
NewTensor
(
3
,
dimSize3
,
X_FLOAT
,
1.0
F
,
0
);
padding
=
NewTensor
(
2
,
dimSize4
,
X_FLOAT
,
1.0
F
,
0
);
//weight = NewTensor(2, dimSize);
//dedw = NewTensor(2, dimSize);
//input = NewTensor(3, dimSize2);
//gold = NewTensor(3, dimSize3);
//output = NewTensor(3, dimSize3);
//dedy = NewTensor(3, dimSize3);
//dedx = NewTensor(3, dimSize3);
//dedxTmp = NewTensor(3, dimSize3);
//padding = NewTensor(2, dimSize4);
myRead
(
input
,
"x.txt"
,
"x"
);
myRead
(
weight
,
"w.txt"
,
"w"
);
myRead
(
gold
,
"gold.txt"
,
"gold"
);
myRead
(
padding
,
"padding.txt"
,
"padding"
);
XTensor
inter
;
inter
=
MMul
(
*
input
,
*
weight
);
_Softmax
(
&
inter
,
output
,
2
);
//_LogMe(output);
loss
=
_CrossEntropyFast
(
output
,
gold
,
REDUCE_MEAN
,
NULL
,
padding
);
printf
(
"loss: %f
\n
"
,
loss
);
_CrossEntropyBackward
(
dedy
,
output
,
gold
,
NULL
);
//_CrossEntropyBackward(dedy, output, gold, NULL, padding);
myDump
(
dedy
,
"dedy.txt"
,
"dedy"
);
_SoftmaxBackward
(
NULL
,
output
,
input
,
dedy
,
dedx
,
NULL
,
-
1
,
NOLOSS
);
_Sub
(
output
,
gold
,
dedxTmp
);
myDump
(
dedx
,
"dedx.txt"
,
"dedx"
);
dedx
->
Dump
(
stderr
,
"dedx"
,
200
);
dedxTmp
->
Dump
(
stderr
,
"dedxTmp"
,
200
);
input
->
Reshape
(
input
->
unitNum
/
input
->
GetDim
(
-
1
),
input
->
GetDim
(
-
1
));
dedx
->
Reshape
(
dedx
->
unitNum
/
dedx
->
GetDim
(
-
1
),
dedx
->
GetDim
(
-
1
));
_MatrixMulBatched
(
input
,
X_TRANS
,
dedx
,
X_NOTRANS
,
dedw
);
myDump
(
dedw
,
"dedw.txt"
,
"dedw"
);
}
void
T2TTest2
()
{
int
dimSize
[
3
];
dimSize
[
0
]
=
161
;
dimSize
[
1
]
=
47
;
dimSize
[
2
]
=
10001
;
XTensor
*
probs
=
NewTensor
(
3
,
dimSize
,
X_FLOAT
,
1.0
F
,
0
);
//XTensor * probs = NewTensor(3, dimSize, X_FLOAT, 1.0F, -1);
//myRead(probs, "probs.txt", " ");
_SetDataFixedFloat
(
probs
,
1.0
F
);
probs
->
Reshape
(
1
,
probs
->
unitNum
);
DTYPE
sum
=
_ReduceSumAll
(
probs
);
printf
(
"%e
\n
"
,
sum
);
//XTensor tmp;
//tmp = IsNonZero(*probs);
//DTYPE nonZeroNum = ReduceSumAll(tmp);
//printf("%f\n", nonZeroNum);
//
//DTYPE gpu = ReduceSum(*probs, 1).Get2D(0, 0);
//printf("%e\n", gpu);
}
source/tensor/XTensor.cpp
查看文件 @
80ab70a2
...
...
@@ -1121,7 +1121,7 @@ bool XTensor::Set3D(DTYPE value, int d0, int d1, int d2)
CheckNTErrors
(
order
==
3
,
"Cannot get a 2d cell for a tensor whose order is not 2!"
);
CheckNTErrors
(
d0
>=
0
&&
d0
<
dimSize
[
0
],
"dimension 0 is out of range!"
);
CheckNTErrors
(
d1
>=
0
&&
d1
<
dimSize
[
1
],
"dimension 1 is out of range!"
);
CheckNTErrors
(
d2
>=
0
&&
d2
<
dimSize
[
2
],
"dimension
1
is out of range!"
);
CheckNTErrors
(
d2
>=
0
&&
d2
<
dimSize
[
2
],
"dimension
2
is out of range!"
);
CheckNTErrors
(
dataType
==
DEFAULT_DTYPE
,
"The tensor is not in default type."
);
int
dims
[
3
]
=
{
d0
,
d1
,
d2
};
...
...
source/tensor/core/shape/Merge.cpp
查看文件 @
80ab70a2
...
...
@@ -217,7 +217,6 @@ void _Merge(const XList * smalls, XTensor * big, int whereToMerge)
XTensor
*
smallsItem0
=
(
XTensor
*
)(
smalls
->
GetItem
(
0
));
int
itemSize
=
smallsItem0
->
unitNum
*
smallsItem0
->
unitSize
;
for
(
int
i
=
0
;
i
<
smalls
->
count
;
i
++
)
{
XTensor
*
smallsItem
=
(
XTensor
*
)
smalls
->
GetItem
(
i
);
CheckNTErrors
((
big
->
unitNum
==
smallsItem
->
unitNum
*
mergeNum
),
"Unmatched tensors!"
);
...
...
source/tensor/core/shape/Split.cpp
查看文件 @
80ab70a2
...
...
@@ -342,6 +342,24 @@ split a big tensor into small tensors
*/
void
Split
(
const
XTensor
&
big
,
XList
&
smalls
,
int
whereToSplit
,
int
splitNum
)
{
CheckNTErrors
(
big
.
GetDim
(
whereToSplit
)
%
splitNum
==
0
,
"Wrong splitNum!"
);
int
order
=
big
.
order
;
int
*
dimSize
=
new
int
[
order
];
for
(
int
i
=
0
;
i
<
big
.
order
;
i
++
)
{
if
(
i
!=
whereToSplit
)
dimSize
[
i
]
=
big
.
dimSize
[
i
];
else
dimSize
[
i
]
=
big
.
dimSize
[
whereToSplit
]
/
splitNum
;
}
float
dr
=
(
!
big
.
isSparse
)
?
1.0
F
:
big
.
denseRatio
;
for
(
int
i
=
0
;
i
<
splitNum
;
i
++
)
{
XTensor
*
item
=
NewTensor
(
order
,
dimSize
,
big
.
dataType
,
dr
,
big
.
devID
,
big
.
mem
);
smalls
.
Add
(
item
);
}
delete
[]
dimSize
;
/* call _Split function */
_Split
(
&
big
,
&
smalls
,
whereToSplit
,
splitNum
);
...
...
source/tensor/test/TSplit.cpp
查看文件 @
80ab70a2
...
...
@@ -272,29 +272,23 @@ bool TestSplit3()
XTensor
*
s
=
NewTensor
(
sOrder
,
sDimSize
);
XTensor
*
t1
=
NewTensor
(
tOrder1
,
tDimSize1
);
XTensor
*
t2
=
NewTensor
(
tOrder2
,
tDimSize2
);
XTensor
*
tUser1
=
NewTensor
(
tOrder1
,
tDimSize1
);
XTensor
*
tUser2
=
NewTensor
(
tOrder2
,
tDimSize2
);
/* initialize variables */
s
->
SetData
(
sData
,
sUnitNum
);
t1
->
SetZeroAll
();
t2
->
SetZeroAll
();
tUser1
->
SetZeroAll
();
tUser2
->
SetZeroAll
();
/* add tensors to list */
tList
->
Add
(
t1
);
tList
->
Add
(
t2
);
tUserList
.
Add
(
tUser1
);
tUserList
.
Add
(
tUser2
);
/* call split function */
_Split
(
s
,
tList
,
1
,
2
);
Split
(
*
s
,
tUserList
,
1
,
2
);
/* check results */
cpuTest
=
t1
->
CheckData
(
answer1
,
tUnitNum1
)
&&
tUser1
->
CheckData
(
answer1
,
tUnitNum1
)
&&
t2
->
CheckData
(
answer2
,
tUnitNum2
)
&&
tUser2
->
CheckData
(
answer2
,
tUnitNum2
);
cpuTest
=
t1
->
CheckData
(
answer1
,
tUnitNum1
)
&&
((
XTensor
*
)
tUserList
.
Get
(
0
))
->
CheckData
(
answer1
,
tUnitNum1
)
&&
t2
->
CheckData
(
answer2
,
tUnitNum2
)
&&
((
XTensor
*
)
tUserList
.
Get
(
1
))
->
CheckData
(
answer2
,
tUnitNum2
);
#ifdef USE_CUDA
/* GPU test */
...
...
@@ -308,42 +302,31 @@ bool TestSplit3()
XTensor
*
sGPU
=
NewTensor
(
sOrder
,
sDimSize
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
tGPU1
=
NewTensor
(
tOrder1
,
tDimSize1
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
tGPU2
=
NewTensor
(
tOrder2
,
tDimSize2
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
tUserGPU1
=
NewTensor
(
tOrder1
,
tDimSize1
,
X_FLOAT
,
1.0
F
,
0
);
XTensor
*
tUserGPU2
=
NewTensor
(
tOrder2
,
tDimSize2
,
X_FLOAT
,
1.0
F
,
0
);
/* Initialize variables */
sGPU
->
SetData
(
sData
,
sUnitNum
);
tGPU1
->
SetZeroAll
();
tGPU2
->
SetZeroAll
();
tUserGPU1
->
SetZeroAll
();
tUserGPU2
->
SetZeroAll
();
/* add tensors to list */
tList
->
Add
(
tGPU1
);
tList
->
Add
(
tGPU2
);
tUserList
.
Add
(
tUserGPU1
);
tUserList
.
Add
(
tUserGPU2
);
/* call Split function */
_Split
(
sGPU
,
tList
,
1
,
2
);
Split
(
*
sGPU
,
tUserList
,
1
,
2
);
/* check results */
gpuTest
=
tGPU1
->
CheckData
(
answer1
,
tUnitNum1
)
&&
tUserGPU1
->
CheckData
(
answer1
,
tUnitNum1
)
&&
tGPU2
->
CheckData
(
answer2
,
tUnitNum2
)
&&
tUserGPU2
->
CheckData
(
answer2
,
tUnitNum2
);
gpuTest
=
tGPU1
->
CheckData
(
answer1
,
tUnitNum1
)
&&
((
XTensor
*
)
tUserList
.
Get
(
0
))
->
CheckData
(
answer1
,
tUnitNum1
)
&&
tGPU2
->
CheckData
(
answer2
,
tUnitNum2
)
&&
((
XTensor
*
)
tUserList
.
Get
(
1
))
->
CheckData
(
answer2
,
tUnitNum2
);
/* destroy variables */
delete
s
;
delete
t1
;
delete
t2
;
delete
tUser1
;
delete
tUser2
;
delete
sGPU
;
delete
tGPU1
;
delete
tGPU2
;
delete
tUserGPU1
;
delete
tUserGPU2
;
delete
[]
sDimSize
;
delete
[]
tDimSize1
;
delete
[]
tDimSize2
;
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论