Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
Emmay
NiuTrans.Tensor
Commits
7a7dc4c6
Commit
7a7dc4c6
authored
Jul 08, 2018
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
new tensor connections
parent
90c12836
隐藏空白字符变更
内嵌
并排
正在显示
13 个修改的文件
包含
153 行增加
和
85 行删除
+153
-85
source/XLink.cpp
+13
-11
source/XLink.h
+9
-7
source/XName.cpp
+46
-0
source/XName.h
+8
-18
source/core/arithmetic/MatrixMul2DParallel.cpp
+3
-3
source/core/arithmetic/Sum.cpp
+39
-5
source/core/arithmetic/Sum.cu
+3
-4
source/core/arithmetic/Sum.cuh
+3
-4
source/core/arithmetic/Sum.h
+25
-21
source/core/movement/CopyValues.cpp
+0
-3
source/core/shape/ConcatenateSolely.cpp
+1
-6
source/core/utilities/XMatrixSegment.cpp
+1
-1
source/test/TSum.cpp
+2
-2
没有找到文件。
source/XLink.cpp
查看文件 @
7a7dc4c6
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include <stdio.h>
#include <stdio.h>
#include "XLink.h"
#include "XLink.h"
#include "XName.h"
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
...
@@ -35,6 +36,7 @@ XLink::XLink()
...
@@ -35,6 +36,7 @@ XLink::XLink()
tailNum
=
0
;
tailNum
=
0
;
paramNum
=
0
;
paramNum
=
0
;
type
[
0
]
=
0
;
type
[
0
]
=
0
;
typeID
=
0
;
}
}
/* deconstructor */
/* deconstructor */
...
@@ -59,14 +61,14 @@ void XLink::Reset()
...
@@ -59,14 +61,14 @@ void XLink::Reset()
/*
/*
set edge type name
set edge type name
>>
typeName - type name in string
>>
id - id of the type
*/
*/
void
XLink
::
SetType
(
const
char
*
typeName
)
void
XLink
::
SetType
(
int
id
)
{
{
type
[
0
]
=
0
;
type
[
0
]
=
0
;
if
(
typeName
==
NULL
)
strcpy
(
type
,
GetOPName
(
id
));
return
;
typeID
=
id
;
strcpy
(
type
,
typeName
);
CheckNTErrors
(
!
strcmp
(
type
,
"NULL"
),
"illegal edge type name!"
);
}
}
/*
/*
...
@@ -141,9 +143,9 @@ create a hyperedge with two input tensors and a output tensor
...
@@ -141,9 +143,9 @@ create a hyperedge with two input tensors and a output tensor
>> t1 - a tail tensor
>> t1 - a tail tensor
>> t2 - another tail tensor
>> t2 - another tail tensor
>> h - head tensor
>> h - head tensor
>>
typeName - name of
edge type
>>
id - id of the
edge type
*/
*/
void
XLink
::
MakeLink
(
XTensor
*
t1
,
XTensor
*
t2
,
XTensor
*
h
,
const
char
*
typeName
)
void
XLink
::
MakeLink
(
XTensor
*
t1
,
XTensor
*
t2
,
XTensor
*
h
,
int
id
)
{
{
if
(
h
!=
NULL
)
if
(
h
!=
NULL
)
return
;
return
;
...
@@ -159,7 +161,7 @@ void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeN
...
@@ -159,7 +161,7 @@ void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeN
else
{
else
{
ShowNTErrors
(
"TODO!"
);
ShowNTErrors
(
"TODO!"
);
}
}
income
.
SetType
(
typeName
);
income
.
SetType
(
id
);
/* backward for t1 */
/* backward for t1 */
if
(
t1
!=
NULL
){
if
(
t1
!=
NULL
){
...
@@ -180,15 +182,15 @@ void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeN
...
@@ -180,15 +182,15 @@ void XLink::MakeLink(XTensor * t1, XTensor * t2, XTensor * h, const char * typeN
create a hyper edge with a list of tensors and a output tensor
create a hyper edge with a list of tensors and a output tensor
>> list - a list of input tensors
>> list - a list of input tensors
>> h - head tensor
>> h - head tensor
>>
typeName - name of
edge type
>>
id - id of the
edge type
*/
*/
void
XLink
::
MakeLink
(
XList
*
list
,
XTensor
*
h
,
const
char
*
typeName
)
void
XLink
::
MakeLink
(
XList
*
list
,
XTensor
*
h
,
int
id
)
{
{
/* forward */
/* forward */
XLink
&
income
=
h
->
income
;
XLink
&
income
=
h
->
income
;
income
.
Reset
();
income
.
Reset
();
income
.
SetHead
(
h
);
income
.
SetHead
(
h
);
income
.
SetType
(
typeName
);
income
.
SetType
(
id
);
for
(
int
i
=
0
;
i
<
list
->
count
;
i
++
){
for
(
int
i
=
0
;
i
<
list
->
count
;
i
++
){
XTensor
*
t
=
(
XTensor
*
)
list
->
GetItem
(
i
);
XTensor
*
t
=
(
XTensor
*
)
list
->
GetItem
(
i
);
...
...
source/XLink.h
查看文件 @
7a7dc4c6
...
@@ -74,6 +74,9 @@ struct XLink
...
@@ -74,6 +74,9 @@ struct XLink
/* name of the hyperedge type. e.g., sum, mul ... */
/* name of the hyperedge type. e.g., sum, mul ... */
char
type
[
MAX_OP_NAME_LENGTH
];
char
type
[
MAX_OP_NAME_LENGTH
];
/* type id */
int
typeID
;
/* constuctor */
/* constuctor */
XLink
();
XLink
();
...
@@ -83,8 +86,8 @@ struct XLink
...
@@ -83,8 +86,8 @@ struct XLink
/* reset it */
/* reset it */
void
Reset
();
void
Reset
();
/* set edge type name */
/* set edge type
id and
name */
void
SetType
(
const
char
*
typeName
);
void
SetType
(
int
id
);
/* set head */
/* set head */
void
SetHead
(
XTensor
*
h
);
void
SetHead
(
XTensor
*
h
);
...
@@ -103,11 +106,11 @@ struct XLink
...
@@ -103,11 +106,11 @@ struct XLink
/* create a hyper edge with two input tensors and a output tensor */
/* create a hyper edge with two input tensors and a output tensor */
static
static
void
MakeLink
(
XTensor
*
t1
,
XTensor
*
t2
,
XTensor
*
h
,
const
char
*
typeName
);
void
MakeLink
(
XTensor
*
t1
,
XTensor
*
t2
,
XTensor
*
h
,
int
id
);
/* create a hyper edge with a list of tensors and a output tensor */
/* create a hyper edge with a list of
input
tensors and a output tensor */
static
static
void
MakeLink
(
XList
*
list
,
XTensor
*
h
,
const
char
*
typeName
);
void
MakeLink
(
XList
*
list
,
XTensor
*
h
,
int
id
);
/* add a parameter */
/* add a parameter */
static
static
...
@@ -120,4 +123,4 @@ struct XLink
...
@@ -120,4 +123,4 @@ struct XLink
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
#endif // __XLINK_H__
#endif // __XLINK_H__
\ No newline at end of file
source/XName.cpp
0 → 100644
查看文件 @
7a7dc4c6
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05
*/
#ifndef __XNAME_H__
#define __XNAME_H__
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
#define MATH_ARITHMETIC 0x00001000
#define MATH_SUM MATH_ARITHMETIC + 1
#define MATH_MULTIPLY MATH_SUM + 1
/* get operator name */
const
char
*
GetOPName
(
int
type
)
{
if
((
type
&
MATH_ARITHMETIC
)
!=
0
){
if
(
type
==
MATH_SUM
)
return
"M_SUM"
;
else
if
(
type
==
MATH_MULTIPLY
)
return
"M_MULTIPLY"
;
}
return
"NULL"
;
}
}
// namespace nts(NiuTrans.Tensor)
#endif // __XNAME_H__
source/XName.h
查看文件 @
7a7dc4c6
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
* We define various names here
* We define various names here
*
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05
* It was really HOT these days. I can't imagine
what a hot day
here in Shenyang!
* It was really HOT these days. I can't imagine
it is SO hot
here in Shenyang!
*/
*/
#ifndef __XNAME_H__
#ifndef __XNAME_H__
...
@@ -28,22 +28,13 @@
...
@@ -28,22 +28,13 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
#define MATH_MATMUL "M_MATMUL"
#define MATH_ARITHMETIC 10000
#define MATH_CONCATENATESOLY "M_CONCATENATESOLY"
#define MATH_SUM MATH_ARITHMETIC + 1
#define MATH_COPYVALUES "M_COPYVALUES"
#define MATH_MULTIPLY MATH_SUM + 1
#define MATH_MATRIXMUL "M_MATRIXMUL"
#define MATH_MATRIXMUL2D "M_MATRIXMUL2D"
/* get operator name */
#define MATH_MATRIXMULBATCHED "M_MATRIXMULBATCHED"
const
char
*
GetOPName
(
int
type
);
#define MATH_MERGE "M_MERGE"
#define MATH_MULTIPLY "M_MULTIPLY"
#define MATH_REDUCEMAX "M_REDUCEMAX"
#define MATH_REDUCESUM "M_REDUCESUM"
#define MATH_SELECTRANGE "M_SELECTRANGE"
#define MATH_SORT "M_SORT"
#define MATH_SUM "M_SUM"
#define MATH_TOPK "M_TOPK"
#define MATH_UNSQUEEZE "M_UNSQUEEZE"
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
#endif // __XNAME_H__
#endif // __XNAME_H__
\ No newline at end of file
source/core/arithmetic/MatrixMul2DParallel.cpp
查看文件 @
7a7dc4c6
...
@@ -40,9 +40,9 @@ where trans() return the transposed matrix if the flag is fired
...
@@ -40,9 +40,9 @@ where trans() return the transposed matrix if the flag is fired
>> parallelRunner - parallel processing module
>> parallelRunner - parallel processing module
*/
*/
void
MatrixMul2DParallel
(
XTensor
*
a
,
MATRIX_TRANS_TYPE
transposedA
,
void
MatrixMul2DParallel
(
XTensor
*
a
,
MATRIX_TRANS_TYPE
transposedA
,
XTensor
*
b
,
MATRIX_TRANS_TYPE
transposedB
,
XTensor
*
b
,
MATRIX_TRANS_TYPE
transposedB
,
XTensor
*
c
,
DTYPE
alpha
,
DTYPE
beta
,
XTensor
*
c
,
DTYPE
alpha
,
DTYPE
beta
,
XPRunner
*
parallelRunner
)
XPRunner
*
parallelRunner
)
{
{
CheckNTErrors
((
a
&&
b
&&
c
),
"Empty input tensors!"
);
CheckNTErrors
((
a
&&
b
&&
c
),
"Empty input tensors!"
);
CheckNTErrors
((
a
->
order
==
2
&&
b
->
order
==
2
&&
c
->
order
==
2
),
CheckNTErrors
((
a
->
order
==
2
&&
b
->
order
==
2
&&
c
->
order
==
2
),
...
...
source/core/arithmetic/Sum.cpp
查看文件 @
7a7dc4c6
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "../../XTensor.h"
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "Sum.h"
#include "Sum.h"
#include "Sum.cuh"
#include "Sum.cuh"
...
@@ -28,12 +29,13 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
...
@@ -28,12 +29,13 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
/*
tensor summation c = a + b * \beta
tensor summation c = a + b * \beta
return a pointer
>> a - a tensor
>> a - a tensor
>> b - another tensor
>> b - another tensor
>> c - where we put a+b*\beta. we save it in a if c is NULL
>> c - where we put a+b*\beta. we save it in a if c is NULL
>> beta - the scaling factor
>> beta - the scaling factor
*/
*/
void
Sum
(
XTensor
*
a
,
XTensor
*
b
,
XTensor
*
c
,
DTYPE
beta
)
void
_
Sum
(
XTensor
*
a
,
XTensor
*
b
,
XTensor
*
c
,
DTYPE
beta
)
{
{
if
(
c
==
NULL
)
if
(
c
==
NULL
)
c
=
a
;
c
=
a
;
...
@@ -59,17 +61,16 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
...
@@ -59,17 +61,16 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
ShowNTErrors
(
"Cannot run this method on multiple devices simultaneously!"
);
ShowNTErrors
(
"Cannot run this method on multiple devices simultaneously!"
);
}
}
else
else
CudaSum
(
a
,
b
,
c
,
beta
);
_
CudaSum
(
a
,
b
,
c
,
beta
);
}
}
else
else
CudaSum
(
a
,
b
,
c
,
beta
);
_
CudaSum
(
a
,
b
,
c
,
beta
);
#endif
#endif
}
}
else
{
else
{
if
(
!
a
->
isSparse
&&
!
b
->
isSparse
)
{
if
(
!
a
->
isSparse
&&
!
b
->
isSparse
)
{
CheckNTErrors
(
!
c
->
isSparse
,
CheckNTErrors
(
!
c
->
isSparse
,
"Illegal use of sparse matrix in addition!"
);
"Illegal use of sparse matrix in addition!"
);
if
(
a
->
dataType
==
DEFAULT_DTYPE
&&
if
(
a
->
dataType
==
DEFAULT_DTYPE
&&
b
->
dataType
==
DEFAULT_DTYPE
&&
b
->
dataType
==
DEFAULT_DTYPE
&&
...
@@ -112,5 +113,38 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
...
@@ -112,5 +113,38 @@ void Sum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
}
}
}
}
}
}
/*
tensor summation a = a + b * \beta
do it on site
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
*/
void
_SumMe
(
XTensor
*
a
,
XTensor
*
b
,
DTYPE
beta
)
{
_Sum
(
a
,
b
,
a
,
beta
);
}
/*
tensor summation a = a + b * \beta
return a XTensor structure
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
*/
XTensor
Sum
(
XTensor
&
a
,
XTensor
&
b
,
DTYPE
beta
)
{
XTensor
c
(
&
a
);
/* computation */
_Sum
(
&
a
,
&
b
,
&
c
,
beta
);
/* tensor connections */
//XLink::MakeLink(&a, &b, &c, MATH_SUM);
//XLink::AddParamToHead(&c, beta);
return
c
;
}
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
source/core/arithmetic/Sum.cu
查看文件 @
7a7dc4c6
...
@@ -51,7 +51,7 @@ tensor summation c = a + b * \beta (cuda version)
...
@@ -51,7 +51,7 @@ tensor summation c = a + b * \beta (cuda version)
>> c - where we put a+b*\beta. we save it in a if c is NULL
>> c - where we put a+b*\beta. we save it in a if c is NULL
>> beta - the scaling factor
>> beta - the scaling factor
*/
*/
void CudaSum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
void
_
CudaSum(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
{
{
if (c == NULL)
if (c == NULL)
c = a;
c = a;
...
@@ -124,7 +124,7 @@ tensor summation c = a + b * \beta (cuda version) with an input handle
...
@@ -124,7 +124,7 @@ tensor summation c = a + b * \beta (cuda version) with an input handle
>> size - size of the array
>> size - size of the array
>> beta - the coefficient
>> beta - the coefficient
*/
*/
void CudaSumWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta)
void
_
CudaSumWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta)
{
{
if (size == 0)
if (size == 0)
return;
return;
...
@@ -160,4 +160,4 @@ void CudaSumWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b,
...
@@ -160,4 +160,4 @@ void CudaSumWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b,
#endif // USE_CUDA
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
source/core/arithmetic/Sum.cuh
查看文件 @
7a7dc4c6
...
@@ -34,14 +34,14 @@ void KernelADD(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1.
...
@@ -34,14 +34,14 @@ void KernelADD(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1.
/* tensor summation c = a + b * \beta (cuda version) */
/* tensor summation c = a + b * \beta (cuda version) */
extern "C"
extern "C"
void CudaSum(XTensor * a, XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
void
_
CudaSum(XTensor * a, XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
/* tensor summation c = a + b * \beta (cuda version) with an input handle */
/* tensor summation c = a + b * \beta (cuda version) with an input handle */
extern "C"
extern "C"
void CudaSumWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1.0);
void
_
CudaSumWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1.0);
#endif // USE_CUDA
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
} // namespace nts(NiuTrans.Tensor)
#endif // __SUM_CUH__
#endif // __SUM_CUH__
\ No newline at end of file
source/core/arithmetic/Sum.h
查看文件 @
7a7dc4c6
/* NiuTrans.Tensor - an open-source tensor library
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017
, Natural Language Processing Lab, Northestern University.
* Copyright (C) 2018
, Natural Language Processing Lab, Northestern University.
* All rights reserved.
* All rights reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing, software
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License.
* limitations under the License.
*/
*/
/*
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
*/
#ifndef __SUM_H__
#ifndef __SUM_H__
#define __SUM_H__
#define __SUM_H__
...
@@ -27,9 +27,14 @@
...
@@ -27,9 +27,14 @@
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
namespace
nts
{
// namespace nts(NiuTrans.Tensor)
/* tensor summation c = a + b * \beta */
/* tensor summation c = a + b * \beta */
extern
"C"
void
_Sum
(
XTensor
*
a
,
XTensor
*
b
,
XTensor
*
c
=
NULL
,
DTYPE
beta
=
(
DTYPE
)
1
.
0
);
void
Sum
(
XTensor
*
a
,
XTensor
*
b
,
XTensor
*
c
=
NULL
,
DTYPE
beta
=
(
DTYPE
)
1
.
0
);
/* tensor summation a = a + b * \beta (return a pointer) */
void
_SumMe
(
XTensor
*
a
,
XTensor
*
b
,
DTYPE
beta
=
(
DTYPE
)
1
.
0
);
/* tensor summation c = a + b * \beta (return a structure) */
XTensor
Sum
(
XTensor
&
a
,
XTensor
&
b
,
DTYPE
beta
=
(
DTYPE
)
1
.
0
);
}
// namespace nts(NiuTrans.Tensor)
}
// namespace nts(NiuTrans.Tensor)
#endif // __SUM_H__
#endif // __SUM_H__
\ No newline at end of file
source/core/movement/CopyValues.cpp
查看文件 @
7a7dc4c6
...
@@ -42,9 +42,6 @@ bool CopyValues(XTensor * s, XTensor * t, XStream * stream)
...
@@ -42,9 +42,6 @@ bool CopyValues(XTensor * s, XTensor * t, XStream * stream)
CheckNTErrors
((
t
->
data
!=
NULL
),
"Cannot copy to an empty data array!"
);
CheckNTErrors
((
t
->
data
!=
NULL
),
"Cannot copy to an empty data array!"
);
CheckNTErrors
((
s
->
unitNum
==
t
->
unitNum
),
"Unmatched data item number!"
);
CheckNTErrors
((
s
->
unitNum
==
t
->
unitNum
),
"Unmatched data item number!"
);
/* make tensor connections */
XLink
::
MakeLink
(
s
,
NULL
,
t
,
MATH_COPYVALUES
);
if
((
s
->
dataType
==
X_FLOAT16
&&
t
->
dataType
==
X_FLOAT
)
||
if
((
s
->
dataType
==
X_FLOAT16
&&
t
->
dataType
==
X_FLOAT
)
||
(
s
->
dataType
==
X_FLOAT
&&
t
->
dataType
==
X_FLOAT16
))
{
(
s
->
dataType
==
X_FLOAT
&&
t
->
dataType
==
X_FLOAT16
))
{
CheckNTErrors
(((
s
->
devID
<
0
&&
t
->
devID
<
0
)
||
s
->
devID
==
t
->
devID
),
CheckNTErrors
(((
s
->
devID
<
0
&&
t
->
devID
<
0
)
||
s
->
devID
==
t
->
devID
),
...
...
source/core/shape/ConcatenateSolely.cpp
查看文件 @
7a7dc4c6
...
@@ -37,10 +37,6 @@ void ConcatenateSolely(XList * smalls, XTensor * big, int dim)
...
@@ -37,10 +37,6 @@ void ConcatenateSolely(XList * smalls, XTensor * big, int dim)
{
{
CheckNTErrors
((
big
->
order
>
dim
&&
dim
>=
0
),
"Illegal dimension to concatenate!"
);
CheckNTErrors
((
big
->
order
>
dim
&&
dim
>=
0
),
"Illegal dimension to concatenate!"
);
/* make tensor connections */
XLink
::
MakeLink
(
smalls
,
big
,
MATH_CONCATENATESOLY
);
XLink
::
AddParamToHeadInt
(
big
,
dim
);
int
catDimSize
=
0
;
int
catDimSize
=
0
;
int
dimRDI
=
big
->
order
-
dim
-
1
;
int
dimRDI
=
big
->
order
-
dim
-
1
;
...
@@ -102,4 +98,4 @@ void ConcatenateSolely(XList * smalls, XTensor * big, int dim)
...
@@ -102,4 +98,4 @@ void ConcatenateSolely(XList * smalls, XTensor * big, int dim)
delete
sourceArrays
;
delete
sourceArrays
;
}
}
}
}
}
//
namespace
nts
(
NiuTrans
.
Tensor
)
}
// namespace nts(NiuTrans.Tensor)
\ No newline at end of file
source/core/utilities/XMatrixSegment.cpp
查看文件 @
7a7dc4c6
...
@@ -36,7 +36,7 @@ segment a 2d tensor (i.e., matrix) into blocks and run jobs in parallel
...
@@ -36,7 +36,7 @@ segment a 2d tensor (i.e., matrix) into blocks and run jobs in parallel
>> ... - arguments of the jobs
>> ... - arguments of the jobs
*/
*/
void
RunParallel2D
(
XPRunner
*
parallelRunner
,
void
*
job
,
void
RunParallel2D
(
XPRunner
*
parallelRunner
,
void
*
job
,
int
opNum
,
int
rowNum
,
int
colNum
,
int
argNum
,
...)
int
opNum
,
int
rowNum
,
int
colNum
,
int
argNum
,
...)
{
{
if
(
rowNum
==
0
||
colNum
==
0
)
if
(
rowNum
==
0
||
colNum
==
0
)
return
;
return
;
...
...
source/test/TSum.cpp
查看文件 @
7a7dc4c6
...
@@ -55,7 +55,7 @@ bool TestSum1()
...
@@ -55,7 +55,7 @@ bool TestSum1()
b
->
SetData
(
bData
,
unitNum
);
b
->
SetData
(
bData
,
unitNum
);
/* call sum function */
/* call sum function */
Sum
(
a
,
b
);
_
Sum
(
a
,
b
);
/* check results */
/* check results */
cpuTest
=
a
->
CheckData
(
answer
,
unitNum
);
cpuTest
=
a
->
CheckData
(
answer
,
unitNum
);
...
@@ -131,7 +131,7 @@ bool TestSum2()
...
@@ -131,7 +131,7 @@ bool TestSum2()
c
->
SetZeroAll
();
c
->
SetZeroAll
();
/* call Sum function */
/* call Sum function */
Sum
(
a
,
b
,
c
,
beta
);
_
Sum
(
a
,
b
,
c
,
beta
);
/* check results */
/* check results */
cpuTest
=
c
->
CheckData
(
answer
,
unitNum
);
cpuTest
=
c
->
CheckData
(
answer
,
unitNum
);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论