Commit 394e8340 by xuchen

1. redefine the inferences 2. update the test 3. update the manual 4. merged with xiao

parent 9b11391e
......@@ -39,6 +39,10 @@ int main( int argc, const char ** argv )
fprintf(stderr, "neural networks in an easy way. \n\n");
fprintf(stderr, "Run this program with \"-test\" for unit test!\n");
}
XNet net;
XTensor a;
net.Backward(a);
//_CrtDumpMemoryLeaks();
......
......@@ -23,4 +23,126 @@
namespace nts{
unsigned int netIDGlobal = 0;
MUTEX_HANDLE netMutex;
/* generate a network id */
unsigned int MakeNetID()
{
if(tensorIDGlobal == 0)
MUTEX_INIT(netMutex);
MUTEX_LOCK(netMutex);
netIDGlobal += 3;
unsigned int id = netIDGlobal;
MUTEX_UNLOCK(netMutex);
return id;
}
/* constructor */
XNet::XNet()
{
nodes.Clear();
}
/* de-constructor */
XNet::~XNet()
{
}
/* clear the network */
void XNet::Clear()
{
nodes.Clear();
gradNodes.Clear();
outputs.Clear();
inputs.Clear();
}
/*
backward propagation to obtain gradient wrt. the loss/error function
>> root - root node (output) of the network
>> gold - gold standard for the output
>> loss - name of loss function
*/
void XNet::Backward(XTensor &root, XTensor &gold, LOSS_FUNCTION_NAME loss)
{
XList roots(1);
roots.Add(&root);
XList golds(1);
golds.Add(&gold);
Backward(roots, golds, loss);
}
/*
backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes
>> root - a list of root nodes (output) of the network
>> gold - a list of gold standard for the output
>> loss - name of loss function
*/
void XNet::Backward(XList &roots, XList &golds, LOSS_FUNCTION_NAME loss)
{
Traverse(roots);
}
/*
traverse the net and find the topological order by
depth-first search (Tarjan's algorithm)
>> root - root node (or output of the net)
*/
void XNet::Traverse(XTensor &root)
{
XList roots(1);
roots.Add(&root);
Traverse(roots);
}
/*
traverse the net and find the topological order by
depth-first search (Tarjan's algorithm)
>> roots - a list of roots (or output nodes)
*/
void XNet::Traverse(XList &roots)
{
id = MakeNetID();
nodes.Clear();
for (int i = 0; i < roots.count; i++)
TarjanVisit((XTensor*)roots.Get(i), nodes, id);
}
/*
depth-first search given a node (Tarjan's algorithm for topological ordering)
>> node - the node to visit (mark 0:unvisited, 1:visiting, 2:done)
>> orders - topological order of the nodes
>> code - code of the network
*/
void XNet::TarjanVisit(XTensor * node, XList &orders, const unsigned int code)
{
if(node == NULL)
return;
if(node->visitMark == code + 1){
ShowNTErrors("There is a circle in the network\n");
}
else if(node->visitMark <= code || node->visitMark >= code + 2){
node->visitMark = code + 1;
XLink &income = node->income;
for(int i = 0; i < income.tailNum; i++){
XTensor * child = income.tails[i];
if(child == NULL)
continue;
TarjanVisit(child, orders, code);
}
node->visitMark = code + 2;
orders.Add(node);
}
}
}
\ No newline at end of file
......@@ -30,9 +30,31 @@
namespace nts{
/* management of tensor net (or graph) */
class XNet
struct XNet
{
public:
/* id of the network */
unsigned int id;
/* tensor nodes of the network (in order) */
XList nodes;
/* tensor nodes to keep gradient for output (e.g., SGD)*/
XList gradNodes;
/* output nodes of the network */
XList outputs;
/* input nodes of the network */
XList inputs;
/* constructor */
XNet();
/* de-constructor */
~XNet();
/* clear the network */
void Clear();
/* backward propagation to obtain gradient wrt. the loss/error function */
void Backward(XTensor &root, XTensor &gold = NULLTensor, LOSS_FUNCTION_NAME loss = NOLOSS);
......@@ -40,8 +62,24 @@ public:
/* backward propagation to obtain gradient wrt. the loss/error function
with a number of root nodes */
void Backward(XList &roots, XList &golds = NULLList, LOSS_FUNCTION_NAME loss = NOLOSS);
/* traverse the net and find the topological order by
depth-first search (Tarjan's algorithm) */
void Traverse(XTensor &root);
/* traverse the net and find the topological order by
depth-first search (Tarjan's algorithm) */
void Traverse(XList &roots);
/* depth-first search given a node (Tarjan's algorithm for topological ordering) */
void TarjanVisit(XTensor * node, XList &orders, const unsigned int code);
};
/* we make a unique id for every tensor */
extern unsigned int netIDGlobal;
extern MUTEX_HANDLE netMutex;
extern unsigned int MakeNetID();
}
#endif
\ No newline at end of file
......@@ -53,8 +53,8 @@ int main( int argc, const char ** argv )
if(argc > 1 && !strcmp(argv[1], "-test"))
Test();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
FNNLMMain(argc - 1, argv + 1);
//else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
// FNNLMMain(argc - 1, argv + 1);
else{
fprintf(stderr, "Thanks for using NiuTrans.Tensor! This is a library that eases the\n");
fprintf(stderr, "use of tensors. All you need is to ... \n\n");
......
......@@ -82,7 +82,7 @@ _XINLINE_ float Float16ToFloat(unsigned short h)
}
/*
data conversion
data type conversion
>> devID - device id
>> s - source data array
>> typeS - source data type
......@@ -92,7 +92,7 @@ data conversion
*/
void ConvertDataType(int devID, void * s, TENSOR_DATA_TYPE typeS, void * t, TENSOR_DATA_TYPE typeT, int size)
{
CheckNTErrors((devID < 0), "This code must be run on GPUs!");
CheckNTErrors((devID < 0), "This code must be run on CPUs!");
if(typeS == typeT)
return;
......
......@@ -37,6 +37,7 @@ XLink::XLink()
paramNum = 0;
type[0] = 0;
typeID = 0;
caculator = NULL;
}
/* deconstructor */
......@@ -59,6 +60,8 @@ void XLink::Reset()
tailNum = 0;
paramNum = 0;
type[0] = 0;
typeID = 0;
caculator = NULL;
}
/* clear it */
......@@ -68,6 +71,8 @@ void XLink::Clear()
tailNum = 0;
paramNum = 0;
type[0] = 0;
typeID = 0;
caculator = NULL;
}
/* reset tails */
......@@ -224,6 +229,7 @@ void XLink::AddParam(void * param, int size)
paramNum++;
delete[] (char*)ps;
}
/*
create a hyperedge with two input tensors and a output tensor
>> t1 - a tail tensor
......@@ -249,7 +255,7 @@ create a hyper edge with a list of tensors and a output tensor
>> h - head tensor
>> id - id of the edge type
*/
void XLink::MakeLink(XList * list, XTensor * h, int id)
void XLink::MakeLink(const XList * list, XTensor * h, int id)
{
/* forward */
XLink &income = h->income;
......@@ -302,6 +308,43 @@ void XLink::AddParamToHeadInt(XTensor * h, int param)
}
/*
add a MATRIX_TRANS_TYPE parameter
>> h - head
>> param - parameter we want introduce
*/
void XLink::AddParamToHeadTrans(XTensor * h, MATRIX_TRANS_TYPE param)
{
if(h != NULL)
return;
h->income.AddParam(&param, sizeof(MATRIX_TRANS_TYPE));
}
/*
add a boolean parameter
>> h - head
>> param - parameter we want introduce
*/
void XLink::AddParamToHeadBool(XTensor * h, bool param)
{
if(h != NULL)
return;
h->income.AddParam(&param, sizeof(bool));
}
/*
add a pointer parameter
>> h - head
>> param - parameter we want introduce
*/
void XLink::AddParamToHeadPointer(XTensor * h, void * param)
{
if(h != NULL)
return;
h->income.AddParam(&param, sizeof(param));
}
/*
replace a node with another, i.e., we redirect the links to the new node
>> oldOne - the node to be replaced
>> newOne - the new node
......
......@@ -76,6 +76,9 @@ struct XLink
/* type id */
int typeID;
/* caculator (pointer to the class for computation) */
void * caculator;
/* constuctor */
XLink();
......@@ -124,7 +127,7 @@ struct XLink
/* create a hyper edge with a list of input tensors and a output tensor */
static
void MakeLink(XList * list, XTensor * h, int id);
void MakeLink(const XList * list, XTensor * h, int id);
/* add a parameter */
static
......@@ -134,6 +137,18 @@ struct XLink
static
void AddParamToHeadInt(XTensor * h, int param);
/* add a MATRIX_TRANS_TYPE parameter */
static
void AddParamToHeadTrans(XTensor * h, MATRIX_TRANS_TYPE param);
/* add a boolean parameter */
static
void AddParamToHeadBool(XTensor * h, bool param);
/* add a pointer parameter */
static
void AddParamToHeadPointer(XTensor * h, void * param);
/* replace a node with another, i.e., we redirect the links to the new node */
static
void Replace(const XTensor * oldOne, XTensor * newOne);
......
......@@ -206,7 +206,7 @@ void XList::Insert(int pos, void * item)
}
/* get the item at position i */
void * XList::GetItem(int i)
void * XList::GetItem(int i) const
{
if( i >= 0 && i < count )
return items[i];
......
......@@ -74,7 +74,7 @@ public:
void AddList(XList * l);
void AddInt(int i);
void Insert(int pos, void * item);
void * GetItem(int i);
void * GetItem(int i) const;
int GetItemInt(int i);
void SetItem(int i, void * item);
void SetItemInt(int i, int item);
......
......@@ -27,12 +27,56 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
const char * GetOPName(int type)
{
if((type & MATH_ARITHMETIC) != 0){
if(type == MATH_SUM)
return "M_SUM";
if(type == MATH_ABSOLUTE)
return "M_ABSOLUTE";
else if(type == MATH_MATRIXMUL)
return "M_MATRIXMUL";
else if(type == MATH_MATRIXMULBATCHED)
return "M_MATRIXMULBATCHED";
else if(type == MATH_MULTIPLY)
return "M_MULTIPLY";
else if(type == MATH_NEGATE)
return "M_NEGATE";
else if(type == MATH_SIGN)
return "M_SIGN";
else if(type == MATH_SUM)
return "M_SUM";
else if(type == MATH_LOG)
return "M_NORMALIZE";
else if(type == MATH_NORMALIZE)
return "M_LOG";
else if(type == MATH_POWER)
return "M_POWER";
else if(type == MATH_SCALEANDSHIFT)
return "M_SCALEANDSHIFT";
else if(type == GETANDSET_SELECT)
return "G_SELECT";
else if(type == MOVEMENT_COPYINDEXED)
return "M_COPYINDEXED";
else if(type == MOVEMENT_COPYVALUES)
return "M_COPYVALUES";
else if(type == REDUCE_REDUCEMAX)
return "R_REDUCEMAX";
else if(type == REDUCE_REDUCEMEAN)
return "R_REDUCEMEAN";
else if(type == REDUCE_REDUCESUM)
return "R_REDUCESUM";
else if(type == REDUCE_REDUCESUMSQUARED)
return "R_REDUCESUMSQUARED";
else if(type == REDUCE_REDUCEVARIANCE)
return "R_REDUCEVARIANCE";
else if(type == SHAPE_CONCATENATE)
return "S_CONCATENATE";
else if(type == SHAPE_MERGE)
return "S_MERGE";
else if(type == SHAPE_PERMUTE)
return "S_PERMUTE";
else if(type == SHAPE_SPLIT)
return "S_SPLIT";
else if(type == SHAPE_TRANSPOSE)
return "S_TRANSPOSE";
else if(type == SHAPE_UNSQUEEZE)
return "S_UNSQUEEZE";
}
return "NULL";
......
......@@ -29,9 +29,40 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_ARITHMETIC 0x00001000
#define MATH_SUM MATH_ARITHMETIC + 1
#define MATH_MULTIPLY MATH_SUM + 1
#define MATH_SCALEANDSHIFT MATH_MULTIPLY + 1
#define MATH_ABSOLUTE MATH_ARITHMETIC + 1
#define MATH_MATRIXMUL MATH_ABSOLUTE + 1
#define MATH_MATRIXMULBATCHED MATH_MATRIXMUL + 1
#define MATH_MULTIPLY MATH_MATRIXMULBATCHED + 1
#define MATH_NEGATE MATH_MULTIPLY + 1
#define MATH_SIGN MATH_NEGATE + 1
#define MATH_SUM MATH_SIGN + 1
#define MATH_LOG MATH_SUM + 1
#define MATH_NORMALIZE MATH_LOG + 1
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
#define GETANDSET MATH_SCALEANDSHIFT + 1
#define GETANDSET_SELECT GETANDSET + 1
#define MOVEMENT GETANDSET_SELECT + 1
#define MOVEMENT_COPYINDEXED MOVEMENT + 1
#define MOVEMENT_COPYVALUES MOVEMENT_COPYINDEXED + 1
#define REDUCE MOVEMENT_COPYVALUES + 1
#define REDUCE_REDUCEMAX REDUCE + 1
#define REDUCE_REDUCEMEAN REDUCE_REDUCEMAX + 1
#define REDUCE_REDUCESUM REDUCE_REDUCEMEAN + 1
#define REDUCE_REDUCESUMSQUARED REDUCE_REDUCESUM + 1
#define REDUCE_REDUCEVARIANCE REDUCE_REDUCESUMSQUARED + 1
#define SHAPE REDUCE_REDUCEVARIANCE + 1
#define SHAPE_CONCATENATE SHAPE + 1
#define SHAPE_MERGE SHAPE_CONCATENATE + 1
#define SHAPE_PERMUTE SHAPE_MERGE + 1
#define SHAPE_SPLIT SHAPE_PERMUTE + 1
#define SHAPE_TRANSPOSE SHAPE_SPLIT + 1
#define SHAPE_UNSQUEEZE SHAPE_TRANSPOSE + 1
/* get operator name */
const char * GetOPName(int type);
......
......@@ -173,7 +173,7 @@ XTensor::XTensor(const XTensor &reference)
devID = reference.devID;
mem = reference.mem;
InitTensor(this, &reference);
CopyValues(&reference, this);
_CopyValues(&reference, this);
}
if(reference.isTmp)
......@@ -237,6 +237,7 @@ void XTensor::Init()
memset(isAllValued, 0, sizeof(bool) * MAX_TENSOR_DIM_NUM);
isInit = false;
isTmp = false;
visitMark = 0;
}
/* delete data arrays */
......@@ -299,7 +300,7 @@ XTensor& XTensor::operator= (const XTensor& tensor)
}
Resize(tensor.order, tensor.dimSize, tensor.dataType, tensor.denseRatio);
CopyValues(&tensor, this);
_CopyValues(&tensor, this);
}
/* copy member variables */
......@@ -344,7 +345,7 @@ judge whether the two matrices are in the same type and size
>> b - anther tensor to compare with
<< return - whether the two input tensors are identical
*/
bool XTensor::IsIdentical(XTensor * a, XTensor * b)
bool XTensor::IsIdentical(const XTensor * a, const XTensor * b)
{
if(a->order != b->order)
return false;
......@@ -426,7 +427,7 @@ void XTensor::Reshape(const int myOrder, const int * myDimSize)
}
/* get the number of items in the data array */
int XTensor::GetSize()
int XTensor::GetSize() const
{
if(isSparse)
return unitNumNonZero;
......@@ -742,7 +743,7 @@ get the pointer to a cell
>> size - size of index
<< return - pointer to the cell
*/
void * XTensor::GetCell(int index[], int size)
void * XTensor::GetCell(int index[], int size) const
{
CheckNTErrors((size == order), "Illegal index!");
......@@ -794,7 +795,7 @@ get the value of a cell in a 2d tensor in default type
>> mi - column index
<< return - value of cell(ni, mi) in float
*/
DTYPE XTensor::Get2D(int ni, int mi)
DTYPE XTensor::Get2D(int ni, int mi) const
{
CheckNTErrors((order == 2), "Cannot get a 2d cell for a tensor whose order is not 2!");
CheckNTErrors((ni >= 0 && ni < dimSize[0]), "dimension 0 is out of range!");
......@@ -1242,7 +1243,7 @@ binary search to find an element in a sparse tensor
it is the previous one if there is no hit
<< return - find it or not?
*/
bool XTensor::BinarySearch(int key, DTYPE &value, void * &position)
bool XTensor::BinarySearch(int key, DTYPE &value, void * &position) const
{
CheckNTErrors((isSparse), "A sparse tensor is required!");
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in the default type.");
......
......@@ -138,6 +138,9 @@ public:
/* indicates whether the tensor is created temporarily */
bool isTmp;
/* mark for traversing the gragh */
unsigned int visitMark;
/*
the link used to form networks. Note that when we compute on tensors, we actually create a
......@@ -198,7 +201,7 @@ public:
/* judge whether the two matrices are in the same type and size */
static
bool IsIdentical(XTensor * a, XTensor * b);
bool IsIdentical(const XTensor * a, const XTensor * b);
/* judge whether the three matrices are in the same type and size */
static
......@@ -214,7 +217,7 @@ public:
void Reshape(const int order, const int * myDimSize);
/* get the number of items in the data array */
int GetSize();
int GetSize() const;
/* get size of the memory used */
int GetDataSizeInChar();
......@@ -250,13 +253,13 @@ public:
DTYPE Get(int index[], int size = -1);
/* get the pointer to a cell */
void * GetCell(int index[], int size = -1);
void * GetCell(int index[], int size = -1) const;
/* get the default type value of a cell in a 1d tensor */
DTYPE Get1D(int i);
/* get the default type value of a cell in a 2d tensor */
DTYPE Get2D(int ni, int mi);
DTYPE Get2D(int ni, int mi) const;
/* get the default type value of a cell in a 3d tensor */
DTYPE Get3D(int d0, int d1, int d2);
......@@ -311,7 +314,7 @@ public:
bool Resize(const XTensor * myTensor);
/* binary search to find an element in a sparse matrix*/
bool BinarySearch(int key, DTYPE &value, void * &position);
bool BinarySearch(int key, DTYPE &value, void * &position) const;
/* dump data to a file */
void Dump(FILE * file, const char * label = NULL, const int n = -1, const int verbose = 0);
......
......@@ -19,6 +19,7 @@
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include <math.h>
#include "../../XTensor.h"
#include "Absolute.h"
#include "Absolute.cuh"
......@@ -29,12 +30,12 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
set every entry to its absolute value
>> a - the tensor we are processing
*/
void Absolute(XTensor * a)
void _Absolute(XTensor * a)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
CudaAbsolute(a);
_CudaAbsolute(a);
return;
}
#endif
......
......@@ -58,7 +58,7 @@ set each entry to its with float16 data type value
>> a - the tensor
*/
extern "C"
void CudaAbsolute(XTensor * a)
void _CudaAbsolute(XTensor * a)
{
CheckNTErrors((a->isSparse == false), "TODO!");
......
......@@ -35,7 +35,7 @@ void KernelAbsolute(__half * d, int size);
/* set each entry to its absolute value */
extern "C"
void CudaAbsolute(XTensor * a);
void _CudaAbsolute(XTensor * a);
#endif // USE_CUDA
......
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its absolute value */
extern "C"
void Absolute(XTensor * a);
void _Absolute(XTensor * a);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -37,8 +37,8 @@ c_i = trans(a_i) * trans(b_i) * \alpha + c_i * \beta for each i in [0,count-1]
>> alpha - scalar
>> beta - scalar
*/
void MatrixMULBatchedCPU(XList * a, MATRIX_TRANS_TYPE transposedA,
XList * b, MATRIX_TRANS_TYPE transposedB,
void _MatrixMULBatchedCPU(const XList * a, MATRIX_TRANS_TYPE transposedA,
const XList * b, MATRIX_TRANS_TYPE transposedB,
XList * c, DTYPE alpha, DTYPE beta)
{
CheckNTErrors((a && b && c), "Empty input lists!");
......@@ -73,11 +73,11 @@ void MatrixMULBatchedCPU(XList * a, MATRIX_TRANS_TYPE transposedA,
CheckNTErrors((ci->order == 2), "2d tensor (i.e., matrix) is required!");
#ifdef USE_BLAS
if (useBLAS)
MatrixMULCPU(ai, transposedA, bi, transposedB, ci, alpha, beta);
_MatrixMULCPU(ai, transposedA, bi, transposedB, ci, alpha, beta);
else
MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#else
MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta);
#endif
}
//}
......
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* matrix multiplication in batch mode (CPU code) */
extern "C"
void MatrixMULBatchedCPU(XList * a, MATRIX_TRANS_TYPE transposedA, XList * b, MATRIX_TRANS_TYPE transposedB, XList * c,
void _MatrixMULBatchedCPU(const XList * a, MATRIX_TRANS_TYPE transposedA, const XList * b, MATRIX_TRANS_TYPE transposedB, XList * c,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -30,34 +30,34 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
matrix multiplication. For the input tensors a and b, we perform matrix multiplication
on the first two dimentsions. E.g., let A be a tensor of size y * z * m and B be
a tensor of size x * y * n. For A * B, we go over each order-2 tensor of A (of size x * y)
and each order-2 tensor B (of size z * x), like this
c_{i,j} = trans(ai) * trans(bj) * alpha + c_{i,j} * beta
where trans() returns the transposed matrix if the flag is fired, ai is the i-th
element tensor of A, bj is the j-th element tensor of B, and c_{i,j} is the (i,j) element
tensor of the result C. C should be a tensor of z * x * n * m. Obviously C = A * B performs
normal matrix multiplication if A = y * z and B = x * y.
matrix multiplication
For the input tensors a and b, we perform matrix multiplication on the first two dimentsions.
E.g., let A be a tensor of size y * z * m and B be a tensor of size x * y * n.
For A * B, we go over each order-2 tensor of A (of size x * y) and each order-2 tensor B (of size z * x),
like this c_{i,j} = trans(ai) * trans(bj) * alpha + c_{i,j} * beta
where trans() returns the transposed matrix if the flag is fired, ai is the i-th element tensor of A,
bj is the j-th element tensor of B, and c_{i,j} is the (i,j) element tensor of the result C.
C should be a tensor of z * x * n * m.
Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x * y.
>> a - tensor a
>> transposedA - indicates whether the matrices in a are transposed
>> b - tensor b
>> transposedB - indicates whether teh matrices in b are transposed
>> c - where we keep a*b
>> alpha - a coefficient
>> beta - another coefficient
>> parallelRunner - parallel processing module
*/
void MatrixMul(XTensor * a, MATRIX_TRANS_TYPE transposedA,
XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta,
XPRunner * parallelRunner)
void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
"Input tensors should have the same data type!");
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2),
"Input tensors must have a order > 2!");
"Input tensors must have a order >= 2!");
int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1];
int am = transposedA == X_TRANS ? a->dimSizeRDI[1] : a->dimSizeRDI[0];
......@@ -132,7 +132,7 @@ void MatrixMul(XTensor * a, MATRIX_TRANS_TYPE transposedA,
XTensor * ai = (XTensor*)aList->GetItem(i);
XTensor * bi = (XTensor*)bList->GetItem(i);
XTensor * ci = (XTensor*)cList->GetItem(i);
MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta, parallelRunner);
_MatrixMul2D(ai, transposedA, bi, transposedB, ci, alpha, beta, parallelRunner);
}
}
else if (a->devID >= 0 && b->devID >= 0 && c->devID >= 0) {
......@@ -144,7 +144,7 @@ void MatrixMul(XTensor * a, MATRIX_TRANS_TYPE transposedA,
ProtectCudaDev(a->devID, devIDBackup);
cublasHandle_t * handle = a->mem != NULL ? a->mem->GetCublasHandle() : GDevs.GetCudaHandle(a->devID);
CudaBLASMatrixMULList(handle,
_CudaBLASMatrixMULList(handle,
aList, transposedA,
bList, transposedB,
cList, aList->count,
......@@ -157,7 +157,7 @@ void MatrixMul(XTensor * a, MATRIX_TRANS_TYPE transposedA,
}
else {
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
MatrixMULBatchedCPU(aList, transposedA,
_MatrixMULBatchedCPU(aList, transposedA,
bList, transposedB,
cList, alpha, beta);
}
......@@ -184,4 +184,74 @@ void MatrixMul(XTensor * a, MATRIX_TRANS_TYPE transposedA,
delete bList;
delete cList;
}
/*
matrix multiplication (return a XTensor structure)
make a new tensor to keep the result and return it
For the input tensors a and b, we perform matrix multiplication on the first two dimentsions.
E.g., let A be a tensor of size y * z * m and B be a tensor of size x * y * n.
For A * B, we go over each order-2 tensor of A (of size x * y) and each order-2 tensor B (of size z * x),
like this c_{i,j} = trans(ai) * trans(bj) * alpha + c_{i,j} * beta
where trans() returns the transposed matrix if the flag is fired, ai is the i-th element tensor of A,
bj is the j-th element tensor of B, and c_{i,j} is the (i,j) element tensor of the result C.
The result C should be a tensor of z * x * n * m.
Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x * y.
>> a - tensor a
>> transposedA - indicates whether the matrices in a are transposed
>> b - tensor b
>> transposedB - indicates whether teh matrices in b are transposed
>> alpha - a coefficient
>> beta - another coefficient
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication
*/
XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{
CheckNTErrors((&a && &b), "Empty input tensors!");
CheckNTErrors((a.dataType == b.dataType), "Input tensors should have the same data type!");
CheckNTErrors((a.order >= 2 && b.order >= 2), "Input tensors must have a order >= 2!");
int an = transposedA == X_TRANS ? a.dimSizeRDI[0] : a.dimSizeRDI[1];
int am = transposedA == X_TRANS ? a.dimSizeRDI[1] : a.dimSizeRDI[0];
int bn = transposedB == X_TRANS ? b.dimSizeRDI[0] : b.dimSizeRDI[1];
int bm = transposedB == X_TRANS ? b.dimSizeRDI[1] : b.dimSizeRDI[0];
CheckNTErrors(am == bn, "Unmatched tensors in multiplication!");
int order = a.order + b.order - 2;
int sub = 0;
int * dimSize = new int[order];
for (int i = 2; i < a.order; i++)
dimSize[sub++] = a.dimSizeRDI[i];
for (int i = 2; i < b.order; i++)
dimSize[sub++] = b.dimSizeRDI[i];
dimSize[sub++] = an;
dimSize[sub++] = bm;
XTensor c = NewTensor(order, dimSize, a.dataType, a.denseRatio, a.devID, a.mem);
c.SetZeroAll();
c.SetTMP();
/* call _MatrixMul function */
_MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, beta, parallelRunner);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMUL);
XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHeadTrans(&c, transposedB);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHead(&c, beta);
/* destroy variables */
delete dimSize;
return c;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -27,20 +27,36 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
matrix multiplication. For the input tensors a and b, we perform matrix multiplication
on the first two dimentsions. E.g., let A be a tensor of size y * z * m and B be
a tensor of size x * y * n. For A * B, we go over each order-2 tensor of A (of size x * y)
and each order-2 tensor B (of size z * x), like this
c_{i,j} = trans(ai) * trans(bj) * alpha + c_{i,j} * beta
where trans() returns the transposed matrix if the flag is fired, ai is the i-th
element tensor of A, bj is the j-th element tensor of B, and c_{i,j} is the (i,j) element
tensor of the result C. C should be a tensor of z * x * n * m. Obviously C = A * B performs
normal matrix multiplication if A = y * z and B = x * y.
matrix multiplication
For the input tensors a and b, we perform matrix multiplicationon the first two dimentsions.
E.g., let A be a tensor of size y * z * m and B bea tensor of size x * y * n.
For A * B, we go over each order-2 tensor of A (of size x * y) and each order-2 tensor B (of size z * x),
like this c_{i,j} = trans(ai) * trans(bj) * alpha + c_{i,j} * beta
where trans() returns the transposed matrix if the flag is fired, ai is the i-th element tensor of A,
bj is the j-th element tensor of B, and c_{i,j} is the (i,j) elementtensor of the result C.
C should be a tensor of z * x * n * m.
Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x * y.
*/
extern "C"
void MatrixMul(XTensor * a, MATRIX_TRANS_TYPE transposedA, XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c,
void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0, XPRunner * parallelRunner = NULL);
/*
matrix multiplication (return a XTensor structure)
make a new tensor c to keep the result and return it
For the input tensors a and b, we perform matrix multiplicationon the first two dimentsions.
E.g., let A be a tensor of size y * z * m and B bea tensor of size x * y * n.
For A * B, we go over each order-2 tensor of A (of size x * y) and each order-2 tensor B (of size z * x),
like this c_{i,j} = trans(ai) * trans(bj) * alpha + c_{i,j} * beta
where trans() returns the transposed matrix if the flag is fired, ai is the i-th element tensor of A,
bj is the j-th element tensor of B, and c_{i,j} is the (i,j) elementtensor of the result C.
C should be a tensor of z * x * n * m.
Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x * y.
*/
XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor)
#endif // __MATRIXMUL_H__
\ No newline at end of file
......@@ -30,8 +30,10 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
matrix multiplication (for 2d tensors)
c = trans(a) * trans(b) * alpha + c * beta
where trans() return the transposed matrix if the flag is fired
>> a - tensor a
>> transposedA - indicates whether the matrices in a are transposed
>> b - tensor b
......@@ -42,8 +44,8 @@ where trans() return the transposed matrix if the flag is fired
>> parallelRunner - parallel processing module
>> stream - the string for creating the job pipeline
*/
void MatrixMul2D(XTensor * a, MATRIX_TRANS_TYPE transposedA,
XTensor * b, MATRIX_TRANS_TYPE transposedB,
void _MatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta,
XPRunner * parallelRunner, XStream * stream)
{
......@@ -67,7 +69,7 @@ void MatrixMul2D(XTensor * a, MATRIX_TRANS_TYPE transposedA,
#ifdef USE_CUDA
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
CudaMatrixMul2D(a, transposedA, b, transposedB, c, alpha, beta, stream);
_CudaMatrixMul2D(a, transposedA, b, transposedB, c, alpha, beta, stream);
return;
}
#endif
......@@ -81,9 +83,9 @@ void MatrixMul2D(XTensor * a, MATRIX_TRANS_TYPE transposedA,
c->dataType == DEFAULT_DTYPE)
{
if (useBLAS)
MatrixMULCPU(a, transposedA, b, transposedB, c, alpha, beta);
_MatrixMULCPU(a, transposedA, b, transposedB, c, alpha, beta);
else
MatrixMul2DParallel(a, transposedA, b, transposedB, c, alpha, beta, parallelRunner);
_MatrixMul2DParallel(a, transposedA, b, transposedB, c, alpha, beta, parallelRunner);
}
else {
// TODO!!
......
......@@ -108,8 +108,10 @@ void KernelMatrixMulDenseMSparseMV2(DTYPE * a, MATRIX_TRANS_TYPE transposedA, in
/*
matrix multiplication (for 2d tensors) (cuda version)
c = trans(a) * trans(b) * alpha + c * beta
where trans() return the transposed matrix if the flag is fired
>> a - tensor a
>> transposedA - indicates whether the matrices in a are transposed
>> b - tensor b
......@@ -119,8 +121,8 @@ where trans() return the transposed matrix if the flag is fired
>> beta - another coefficient
>> stream - the string for creating the job pipeline
*/
void CudaMatrixMul2D(XTensor * a, MATRIX_TRANS_TYPE transposedA,
XTensor * b, MATRIX_TRANS_TYPE transposedB,
void _CudaMatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c,
DTYPE alpha, DTYPE beta, XStream * stream)
{
......@@ -156,7 +158,7 @@ void CudaMatrixMul2D(XTensor * a, MATRIX_TRANS_TYPE transposedA,
cublasSetStream(*handle, stream->stream);
if (a->dataType == X_FLOAT && b->dataType == X_FLOAT && c->dataType == X_FLOAT) {
CudaBLASMatrixMUL(handle, a->data, transposedA, a->dataType, b->data, transposedB, a->dataType, c->data, c->dataType,
_CudaBLASMatrixMUL(handle, a->data, transposedA, a->dataType, b->data, transposedB, a->dataType, c->data, c->dataType,
a->dimSize[0], a->dimSize[1], b->dimSize[0], b->dimSize[1], c->dimSize[0], c->dimSize[1],
alpha, beta);
}
......
......@@ -43,7 +43,7 @@ c = trans(a) * trans(b) * alpha + c * beta
where trans() return the transposed matrix if the flag is fired
*/
extern "C"
void CudaMatrixMul2D(XTensor * a, MATRIX_TRANS_TYPE transposedA, XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c,
void _CudaMatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0, XStream * stream = NULL);
#endif // USE_CUDA
......
......@@ -31,9 +31,8 @@ matrix multiplication (for 2d tensors)
c = trans(a) * trans(b) * alpha + c * beta
where trans() return the transposed matrix if the flag is fired
*/
extern "C"
void MatrixMul2D(XTensor * a, MATRIX_TRANS_TYPE transposedA, XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0, XPRunner * parallelRunner = NULL, XStream * stream = NULL);
void _MatrixMul2D(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0, XPRunner * parallelRunner = NULL, XStream * stream = NULL);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -38,7 +38,7 @@ argument5: matrix a
argument6: matrix b
argument7: matrix c (c=a*b*\alpha + c*beta)
*/
void MatrixMul2DMultiTheading(XList * args)
void _MatrixMul2DMultiTheading(XList * args)
{
int x1 = *(int*)args->GetItem(0);
int y1 = *(int*)args->GetItem(1);
......
......@@ -31,7 +31,7 @@ matrix multiplication for a block (x1,y1) - (x2,y2)
where (x1,y1) is the upper-left corner and (x2,y2) is the bottom-right corner
*/
extern "C"
void MatrixMul2DMultiTheading(XList * args);
void _MatrixMul2DMultiTheading(XList * args);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -30,6 +30,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
matrix multiplication (for 2d tensors) with multi-threading
c = trans(a) * trans(b) * alpha + c * beta
where trans() return the transposed matrix if the flag is fired
>> a - tensor a
>> transposedA - indicates whether the matrices in a are transposed
>> b - tensor b
......@@ -39,10 +40,9 @@ where trans() return the transposed matrix if the flag is fired
>> beta - another coefficient
>> parallelRunner - parallel processing module
*/
void MatrixMul2DParallel(XTensor * a, MATRIX_TRANS_TYPE transposedA,
XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta,
XPRunner * parallelRunner)
void _MatrixMul2DParallel(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->order == 2 && b->order == 2 && c->order == 2),
......@@ -56,7 +56,7 @@ void MatrixMul2DParallel(XTensor * a, MATRIX_TRANS_TYPE transposedA,
/* a * b */
if (transposedA == X_NOTRANS && transposedB == X_NOTRANS) {
RunParallel2D(parallelRunner, (void*)MatrixMul2DMultiTheading, an * am * bm,
RunParallel2D(parallelRunner, (void*)_MatrixMul2DMultiTheading, an * am * bm,
cn, cm, 5,
a, b, c, &alpha, &beta);
}
......
......@@ -27,12 +27,12 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
matrix multiplication (for 2d tensors) with multi-threading
matrix multiplication (for 2d tensors) with multi-threading.
c = trans(a) * trans(b) * alpha + c * beta
where trans() return the transposed matrix if the flag is fired
where trans() return the transposed matrix if the flag is fired.
*/
extern "C"
void MatrixMul2DParallel(XTensor * a, MATRIX_TRANS_TYPE transposedA, XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c,
void _MatrixMul2DParallel(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -30,10 +30,12 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
matrix multiplication of the two tensors
for each 2-dimensional data array in a (denoted as ai) and
each 2-dimensional data array in b (denoted as bi), we have
ci = trans(ai) * trans(bi) * alpha + cm * beta
where trans() returns the transposed matrix if the flag is fired
>> a - tensor a
>> transposedA - indicates whether the matrices in a are transposed
>> b - tensor b
......@@ -43,8 +45,8 @@ where trans() returns the transposed matrix if the flag is fired
>> beta - another coefficient
>> parallelRunner - parallel processing module
*/
void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA,
XTensor * b, MATRIX_TRANS_TYPE transposedB,
void _MatrixMulBatched(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta,
XPRunner * parallelRunner)
{
......@@ -52,7 +54,9 @@ void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA,
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
"Input tensors should have the same data type!");
CheckNTErrors((a->order >= 2 && b->order >= 2 && c->order >= 2),
"Input tensors must have a order > 2!");
"Input tensors must have a order >= 2!");
CheckNTErrors((a->order == b->order && a->order == c->order),
"Input tensor and output tensor must have same order!");
int an = transposedA == X_TRANS ? a->dimSizeRDI[0] : a->dimSizeRDI[1];
int am = transposedA == X_TRANS ? a->dimSizeRDI[1] : a->dimSizeRDI[0];
......@@ -109,7 +113,7 @@ void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA,
ProtectCudaDev(a->devID, devIDBackup);
cublasHandle_t * handle = a->mem != NULL ? a->mem->GetCublasHandle() : GDevs.GetCudaHandle(a->devID);
CudaBLASMatrixMULList(handle,
_CudaBLASMatrixMULList(handle,
aList, transposedA,
bList, transposedB,
cList, aList->count,
......@@ -122,7 +126,7 @@ void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA,
}
else {
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
MatrixMULBatchedCPU(aList, transposedA,
_MatrixMULBatchedCPU(aList, transposedA,
bList, transposedB,
cList, alpha, beta);
}
......@@ -150,4 +154,65 @@ void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA,
delete cList;
}
/*
matrix multiplication of the two tensors (do it on site)
make a new tensor to keep the result and return it
for each 2-dimensional data array in a (denoted as ai) and
each 2-dimensional data array in b (denoted as bi), we have
ci = trans(ai) * trans(bi) * alpha + cm * beta
where trans() returns the transposed matrix if the flag is fired.
>> a - tensor a
>> transposedA - indicates whether the matrices in a are transposed
>> b - tensor b
>> transposedB - indicates whether teh matrices in b are transposed
>> alpha - a coefficient
>> beta - another coefficient
>> parallelRunner - parallel processing module
<< return - the result of matrix multiplication of the two tensors
*/
XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{
CheckNTErrors((&a && &b), "Empty input tensors!");
CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!");
CheckNTErrors((a.order >= 2 && b.order >= 2), "Input tensors must have a order >= 2!");
CheckNTErrors(a.order == b.order, "Input tensor and output tensor must have same order!");
int an = transposedA == X_TRANS ? a.dimSizeRDI[0] : a.dimSizeRDI[1];
int am = transposedA == X_TRANS ? a.dimSizeRDI[1] : a.dimSizeRDI[0];
int bn = transposedB == X_TRANS ? b.dimSizeRDI[0] : b.dimSizeRDI[1];
int bm = transposedB == X_TRANS ? b.dimSizeRDI[1] : b.dimSizeRDI[0];
CheckNTErrors(am == bn, "Unmatched tensors in multiplication!");
int order = a.order;
int sub = 0;
int * dimSize = new int[order];
for (int i = 2; i < a.order; i++)
dimSize[sub++] = a.dimSizeRDI[i];
dimSize[sub++] = an;
dimSize[sub++] = bm;
XTensor c = NewTensor(order, dimSize, a.dataType, a.denseRatio, a.devID, a.mem);
c.SetZeroAll();
c.SetTMP();
/*call _MatrixMulBatched function */
_MatrixMulBatched(&a, transposedA, &b, transposedB, &c, alpha, beta, parallelRunner);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_MATRIXMULBATCHED);
XLink::AddParamToHeadTrans(&c, transposedA);
XLink::AddParamToHeadTrans(&c, transposedB);
XLink::AddParamToHead(&c, alpha);
XLink::AddParamToHead(&c, beta);
/* destroy variables */
delete dimSize;
return c;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -28,13 +28,25 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
matrix multiplication of the two tensors
for each 2-dimensional data array in a (denoted as ai) and
each 2-dimensional data array in b (denoted as bi), we have
ci = trans(ai) * trans(bi) * alpha + cm * beta
where trans() returns the transposed matrix if the flag is fired
*/
void _MatrixMulBatched(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0, XPRunner * parallelRunner = NULL);
/*
matrix multiplication of the two tensors (return a XTensor structure)
make a new tensor to keep the result and return it
for each 2-dimensional data array in a (denoted as ai) and
each 2-dimensional data array in b (denoted as bi), we have
ci = trans(ai) * trans(bi) * alpha + cm * beta
where trans() returns the transposed matrix if the flag is fired
*/
extern "C"
void MatrixMulBatched(XTensor * a, MATRIX_TRANS_TYPE transposedA, XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c,
XTensor MatrixMulBatched(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -28,14 +28,15 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
element-wise product of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the item
>> a - matrix a
>> b - matrix b
>> c - result matrix
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
>>
*/
void _Multiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim)
{
......@@ -121,9 +122,12 @@ void _Multiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, i
}
/*
element-wise product of two tensors and keep the result in the input
element-wise product of two tensors (do it on site)
keep the result in the input tensor a and return nothing
a(i) = a(i)*b(i) + \alpha * a(i)
where i is the index of the item
>> a - tensor a (where keep the result)
>> b - tensor b
>> alpha - the coefficient
......@@ -135,9 +139,12 @@ void _MultiplyMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim)
}
/*
make a tensor of the element-wise product for two input tensors:
element-wise product of two tensors (return a XTensor structure)
make a new tensor c to keep the result and return it
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the item
>> a - tensor a
>> b - tensor b
>> alpha - the coefficient
......@@ -151,7 +158,7 @@ XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha, int leadingDim
XTensor c(&a);
c.SetTMP();
/* computation */
/* call _Multiply function */
_Multiply(&a, &b, &c, alpha, leadingDim);
/* tensor connections */
......
......@@ -26,19 +26,27 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* element-wise product of two tensors:
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the element */
/*
element-wise product of two tensors:
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the element
*/
void _Multiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
/* element-wise product of two tensors and keep the result in the input tensor:
a(i) = a(i)*b(i) + \alpha * a(i)
where i is the index of the element */
/*
element-wise product of two tensors (do it on site)
keep the result in the input tensor a and return nothing
a(i) = a(i)*b(i) + \alpha * a(i)
where i is the index of the element
*/
void _MultiplyMe(XTensor * a, const XTensor * b, DTYPE alpha = 0, int leadingDim = 0);
/* make a tensor of the element-wise product for two input tensors:
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the element */
/*
element-wise product of two tensors (return a XTensor structure)
make a new tensor to keep the result and return it
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the element
*/
XTensor Multiply(const XTensor &a, const XTensor &b, DTYPE alpha = 0, int leadingDim = 0);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -29,12 +29,12 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
set every entry to its minus value
>> a - the tensor we are processing
*/
void Negate(XTensor * a)
void _Negate(XTensor * a)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
CudaNegate(a);
_CudaNegate(a);
return;
}
#endif
......
......@@ -66,7 +66,7 @@ set each entry to its negtive value
>> a - the tensor
*/
extern "C"
void CudaNegate(XTensor * a)
void _CudaNegate(XTensor * a)
{
CheckNTErrors((a->isSparse == false), "TODO!");
......
......@@ -19,6 +19,9 @@
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __NEGATE_CUH__
#define __NEGATE_CUH__
#include "Negate.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -35,8 +38,10 @@ void KernelNegate(__half * d, int size);
/* set each entry to its negtive value */
extern "C"
void CudaNegate(XTensor * a);
void _CudaNegate(XTensor * a);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
} // namespace nts(NiuTrans.Tensor)
#endif // __NEGATE_CUH__
\ No newline at end of file
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its minus value */
extern "C"
void Negate(XTensor * a);
void _Negate(XTensor * a);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -29,12 +29,12 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
set every entry to its sign value
>> a - the tensor we are processing
*/
void Sign(XTensor * a)
void _Sign(XTensor * a)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
CudaSign(a);
_CudaSign(a);
return;
}
#endif
......
......@@ -64,7 +64,7 @@ set each entry to its with float16 data type value
>> a - the tensor
*/
extern "C"
void CudaSign(XTensor * a)
void _CudaSign(XTensor * a)
{
CheckNTErrors((a->isSparse == false), "TODO!");
......
......@@ -19,6 +19,9 @@
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#ifndef __SIGN_CUH__
#define __SIGN_CUH__
#include "Sign.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -35,8 +38,10 @@ void KernelSign(__half * d, int size);
/* set each entry to its sign value */
extern "C"
void CudaSign(XTensor * a);
void _CudaSign(XTensor * a);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
} // namespace nts(NiuTrans.Tensor)
#endif // __SIGN_H__
\ No newline at end of file
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its sign value */
extern "C"
void Sign(XTensor * a);
void _Sign(XTensor * a);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -29,7 +29,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor summation c = a + b * \beta
return a pointer
>> a - a tensor
>> b - another tensor
>> c - where we put a+b*\beta. we save it in a if c is NULL
......@@ -112,8 +112,9 @@ void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
}
/*
tensor summation a = a + b * \beta
do it on site
tensor summation a = a + b * \beta (do it on site)
keep the result in the tensor a and return nothing
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
......@@ -124,18 +125,20 @@ void _SumMe(XTensor * a, const XTensor * b, DTYPE beta)
}
/*
tensor summation a = a + b * \beta
return a XTensor structure
tensor summation c = a + b * \beta (return a XTensor structure)
make a new tensor c to keep the result and return it
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
<< return - the result of tensor summation
*/
XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
/* computation */
/* call _Sum function */
_Sum(&a, &b, &c, beta);
/* tensor connections */
......
......@@ -29,10 +29,16 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* tensor summation c = a + b * \beta */
void _Sum(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
/* tensor summation a = a + b * \beta (return a pointer) */
/*
tensor summation a = a + b * \beta
keep the result in the input tensor a and return nothing
*/
void _SumMe(XTensor * a, const XTensor * b, DTYPE beta = (DTYPE)1.0);
/* tensor summation c = a + b * \beta (return a structure) */
/*
tensor summation c = a + b * \beta
make a new tensor c to keep the result and return it
*/
XTensor Sum(const XTensor &a, const XTensor &b, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -37,11 +37,8 @@ where b is a vector.
>> c - where we put a+b. we save it in a if c is NULL
>> beta - the scaling factor
*/
void SumByColumnTV(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
void _SumByColumnTV(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
if (c == NULL)
c = a;
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((XTensor::IsIdentical(a, c)), "Unmatched tensors in addition!");
CheckNTErrors((b->order == 2 && b->dimSizeRDI[0] == 1 && b->dimSizeRDI[1] == a->dimSizeRDI[1]),
......@@ -56,7 +53,7 @@ void SumByColumnTV(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
CudaSumByColumnTV(a, b, c, beta);
_CudaSumByColumnTV(a, b, c, beta);
#endif
}
else {
......
......@@ -64,11 +64,8 @@ where b is a vector.
>> c - where we put a+b. we save it in a if c is NULL
>> beta - the scaling factor
*/
void CudaSumByColumnTV(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
void _CudaSumByColumnTV(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
if (c == NULL)
c = a;
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((XTensor::IsIdentical(a, c)), "Unmatched tensors in addition!");
CheckNTErrors((b->order == 2 && b->dimSizeRDI[0] == 1 && b->dimSizeRDI[1] == a->dimSizeRDI[1]),
......
......@@ -30,7 +30,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* summation of a tensor and a vector (column vector) */
extern "C"
void CudaSumByColumnTV(XTensor * a, XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
void _CudaSumByColumnTV(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
#endif // USE_CUDA
......
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* sum of a tensor and a (column) vector */
extern "C"
void SumByColumnTV(XTensor * a, XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
void _SumByColumnTV(const XTensor * a, const XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -37,11 +37,8 @@ where c and a are vectors, and b_col is a column in b.
>> c - where we put a+b. we save it in a if c is NULL
>> beta - the scaling factor
*/
void SumByColumnVT(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
void _SumByColumnVT(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
if (c == NULL)
c = a;
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((XTensor::IsIdentical(a, c)), "Unmatched tensors in addition!");
CheckNTErrors((a->order == 2 && a->dimSizeRDI[0] == 1 && b->dimSizeRDI[1] == a->dimSizeRDI[1]),
......@@ -49,7 +46,7 @@ void SumByColumnVT(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
CudaSumByColumnVT(a, b, c, beta);
_CudaSumByColumnVT(a, b, c, beta);
#endif
}
else {
......
......@@ -80,11 +80,8 @@ where c and a are vectors, and b_col is a column in b.
>> c - where we put a+b. we save it in a if c is NULL
>> beta - the scaling factor
*/
void CudaSumByColumnVT(XTensor * a, XTensor * b, XTensor * c, DTYPE beta)
void _CudaSumByColumnVT(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
if (c == NULL)
c = a;
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((XTensor::IsIdentical(a, c)), "Unmatched tensors in addition!");
CheckNTErrors((a->order == 2 && a->dimSizeRDI[0] == 1 && b->dimSizeRDI[1] == a->dimSizeRDI[1]),
......
......@@ -30,7 +30,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* summation of a vector (column vector) and a tensor */
extern "C"
void CudaSumByColumnVT(XTensor * a, XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
void _CudaSumByColumnVT(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
#endif // USE_CUDA
......
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* sum of a (column) vector and a tensor */
extern "C"
void SumByColumnVT(XTensor * a, XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
void _SumByColumnVT(const XTensor * a, const XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -36,9 +36,9 @@ c = trans(a) * trans(b) * \alpha + c * \beta
>> beta - scalar
>> c - output matrix (2d tensor)
*/
void MatrixMULCPU(XTensor * a, MATRIX_TRANS_TYPE transposedA,
XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta)
void _MatrixMULCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((a->order == 2 && b->order == 2 && c->order == 2),
......
......@@ -31,9 +31,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
matrix multiplication via cuda version BLAS
*/
void CudaBLASMatrixMUL(cublasHandle_t * handle,
void * a, MATRIX_TRANS_TYPE transposedA, TENSOR_DATA_TYPE dataTypeA,
void * b, MATRIX_TRANS_TYPE transposedB, TENSOR_DATA_TYPE dataTypeB,
void _CudaBLASMatrixMUL(cublasHandle_t * handle,
const void * a, MATRIX_TRANS_TYPE transposedA, TENSOR_DATA_TYPE dataTypeA,
const void * b, MATRIX_TRANS_TYPE transposedB, TENSOR_DATA_TYPE dataTypeB,
void * c, TENSOR_DATA_TYPE dataTypeC,
int na, int ma, int nb, int mb, int nc, int mc,
DTYPE alpha, DTYPE beta)
......@@ -88,7 +88,7 @@ void CudaBLASMatrixMUL(cublasHandle_t * handle,
/*
matrix multiplication via cuda version BLAS
*/
void CudaBLASMatrixMULBatched(cublasHandle_t * handle,
void _CudaBLASMatrixMULBatched(cublasHandle_t * handle,
const void ** a, MATRIX_TRANS_TYPE transposedA, TENSOR_DATA_TYPE dataTypeA,
const void ** b, MATRIX_TRANS_TYPE transposedB, TENSOR_DATA_TYPE dataTypeB,
void ** c, TENSOR_DATA_TYPE dataTypeC,
......@@ -144,7 +144,7 @@ void CudaBLASMatrixMULBatched(cublasHandle_t * handle,
/* matrix multiplication in batch and strided mode via cuda version BLAS */
extern "C"
void CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
const void * a, MATRIX_TRANS_TYPE transposedA, TENSOR_DATA_TYPE dataTypeA, long long int strideA,
const void * b, MATRIX_TRANS_TYPE transposedB, TENSOR_DATA_TYPE dataTypeB, long long int strideB,
void * c, TENSOR_DATA_TYPE dataTypeC, long long int strideC,
......@@ -201,9 +201,9 @@ void CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
/*
matrix multiplication via cuda version BLAS
*/
void CudaBLASMatrixMULList(cublasHandle_t * handle,
XList * a, MATRIX_TRANS_TYPE transposedA,
XList * b, MATRIX_TRANS_TYPE transposedB,
void _CudaBLASMatrixMULList(cublasHandle_t * handle,
const XList * a, MATRIX_TRANS_TYPE transposedA,
const XList * b, MATRIX_TRANS_TYPE transposedB,
XList * c,
int count, DTYPE alpha, DTYPE beta)
{
......@@ -255,7 +255,7 @@ void CudaBLASMatrixMULList(cublasHandle_t * handle,
if (isUniform) {
XMem * mem = a0->mem;
if (isStrided && a->count > 1) {
CudaBLASMatrixMULBatchedStrided(handle,
_CudaBLASMatrixMULBatchedStrided(handle,
a0->data, transposedA, a0->dataType, strideA / a0->unitSize,
b0->data, transposedB, b0->dataType, strideB / b0->unitSize,
c0->data, c0->dataType, strideC / c0->unitSize, a->count,
......@@ -297,7 +297,7 @@ void CudaBLASMatrixMULList(cublasHandle_t * handle,
cudaMemcpy(bpGPU, bp, sizeof(DTYPE*) * b->count, cudaMemcpyHostToDevice);
cudaMemcpy(cpGPU, cp, sizeof(DTYPE*) * c->count, cudaMemcpyHostToDevice);
CudaBLASMatrixMULBatched(handle,
_CudaBLASMatrixMULBatched(handle,
(const void**)apGPU, transposedA, a0->dataType,
(const void**)bpGPU, transposedB, b0->dataType,
(void**)cpGPU, c0->dataType, a->count,
......@@ -324,7 +324,7 @@ void CudaBLASMatrixMULList(cublasHandle_t * handle,
XTensor * bi = (XTensor*)b->GetItem(i);
XTensor * ci = (XTensor*)c->GetItem(i);
CudaBLASMatrixMUL(handle,
_CudaBLASMatrixMUL(handle,
ai->data, transposedA, ai->dataType,
bi->data, transposedB, bi->dataType,
ci->data, ci->dataType,
......
......@@ -28,21 +28,21 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* matrix multiplication (BLAS) */
extern "C"
void MatrixMULCPU(XTensor * a, MATRIX_TRANS_TYPE transposedA, XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0);
void _MatrixMULCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0);
#ifdef USE_CUDA
/* matrix multiplication via cuda version BLAS */
extern "C"
void CudaBLASMatrixMUL(cublasHandle_t * handle,
void * a, MATRIX_TRANS_TYPE transposedA, TENSOR_DATA_TYPE dataTypeA,
void * b, MATRIX_TRANS_TYPE transposedB, TENSOR_DATA_TYPE dataTypeB,
void _CudaBLASMatrixMUL(cublasHandle_t * handle,
const void * a, MATRIX_TRANS_TYPE transposedA, TENSOR_DATA_TYPE dataTypeA,
const void * b, MATRIX_TRANS_TYPE transposedB, TENSOR_DATA_TYPE dataTypeB,
void * c, TENSOR_DATA_TYPE dataTypeC,
int na, int ma, int nb, int mb, int nc, int mc, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 1.0);
/* matrix multiplication in batch mode via cuda version BLAS */
extern "C"
void CudaBLASMatrixMULBatched(cublasHandle_t * handle,
void _CudaBLASMatrixMULBatched(cublasHandle_t * handle,
const void ** a, MATRIX_TRANS_TYPE transposedA, TENSOR_DATA_TYPE dataTypeA,
const void ** b, MATRIX_TRANS_TYPE transposedB, TENSOR_DATA_TYPE dataTypeB,
void ** c, TENSOR_DATA_TYPE dataTypeC,
......@@ -50,7 +50,7 @@ void CudaBLASMatrixMULBatched(cublasHandle_t * handle,
/* matrix multiplication in batch and strided mode via cuda version BLAS */
extern "C"
void CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
void _CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
const void * a, MATRIX_TRANS_TYPE transposedA, TENSOR_DATA_TYPE dataTypeA, long long int strideA,
const void * b, MATRIX_TRANS_TYPE transposedB, TENSOR_DATA_TYPE dataTypeB, long long int strideB,
void * c, TENSOR_DATA_TYPE dataTypeC, long long int strideC,
......@@ -58,7 +58,7 @@ void CudaBLASMatrixMULBatchedStrided(cublasHandle_t * handle,
/* matrix multiplication in batch mode via cuda version BLAS */
extern "C"
void CudaBLASMatrixMULList(cublasHandle_t * handle, XList * a, MATRIX_TRANS_TYPE transposedA, XList * b, MATRIX_TRANS_TYPE transposedB, XList * c,
void _CudaBLASMatrixMULList(cublasHandle_t * handle, const XList * a, MATRIX_TRANS_TYPE transposedA, const XList * b, MATRIX_TRANS_TYPE transposedB, XList * c,
int count, DTYPE alpha = (DTYPE)1.0, DTYPE beta = 1.0);
#endif
......
......@@ -30,15 +30,15 @@ convert data type
>> input - input tensor
>> output - output tensor
*/
void ConvertTensorDataType(XTensor * input, XTensor * output)
void _ConvertDataType(const XTensor * input, XTensor * output)
{
CheckNTErrors(XTensor::IsIdentical(input, output), "Input and Output are different in type or size!");
CheckNTErrors((input->unitSize == output->unitSize), "Input and Output must be same in size!");
if (input->dataType == output->dataType)
return;
#ifdef USE_CUDA
/* run it on GPUs */
if (input->devID >= 0) {
CudaConvertDataType(input, output);
_CudaConvertDataType(input, output);
return;
}
#endif
......
......@@ -78,7 +78,7 @@ data conversion (cuda code)
>> typeT - target data type
>> size - number of the items in s (and t)
*/
void CudaConvertDataType(int devID, void * s, TENSOR_DATA_TYPE typeS, void * t, TENSOR_DATA_TYPE typeT, int size)
void _CudaConvertDataType(int devID, void * s, TENSOR_DATA_TYPE typeS, void * t, TENSOR_DATA_TYPE typeT, int size)
{
CheckNTErrors((devID >= 0), "This code must be run on GPUs!");
......@@ -112,9 +112,9 @@ convert data type (cuda code)
>> input - input tensor
>> output - output tensor
*/
void CudaConvertDataType(XTensor * input, XTensor * output)
void _CudaConvertDataType(const XTensor * input, XTensor * output)
{
CheckNTErrors(XTensor::IsIdentical(input, output), "Input and Output are different in type or size!");
CheckNTErrors((input->unitSize == output->unitSize), "Input and Output must be same in size!");
if (input->dataType == output->dataType)
return;
......
......@@ -19,6 +19,9 @@
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#ifndef __CONVERTDATATYPE_CUH__
#define __CONVERTDATATYPE_CUH__
#include "ConvertDataType.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -42,8 +45,10 @@ __global__
void KernelIntToFloat(int * inputData, float * outputData, int size);
/* convert data type */
void CudaConvertDataType(XTensor * input, XTensor * output);
void _CudaConvertDataType(const XTensor * input, XTensor * output);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
} // namespace nts(NiuTrans.Tensor)
#endif // __CONVERTDATATYPE_H__
\ No newline at end of file
......@@ -27,7 +27,7 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* convert data type */
void ConvertDataType(XTensor * input, XTensor * output);
void _ConvertDataType(const XTensor * input, XTensor * output);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -26,8 +26,10 @@
namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
generate a tensor with seleccted data in range[low,high] along the given dimension
generate a tensor with selected data in range[low,high] along the given dimension
c = select(a)
>> a - input tensor
>> c - result tensor
>> dim - the dimension along with which we do the job
......@@ -35,7 +37,7 @@ c = select(a)
>> high - higher bound.
Note that range [1,3] means that we select 1 and 2.
*/
void SelectRange(XTensor * a, XTensor * c, int dim, int low, int high)
void _SelectRange(const XTensor * a, XTensor * c, int dim, int low, int high)
{
CheckNTErrors(a != NULL && c != NULL, "empty tensors!");
CheckNTErrors(a->order == c->order, "The input and output tensors must in the same order!");
......@@ -76,4 +78,55 @@ void SelectRange(XTensor * a, XTensor * c, int dim, int low, int high)
}
}
/*
generate a tensor with selected data in range[low,high] along the given dimension (return a XTensor structure)
make a new tensor to keep the result and return it
c = select(a)
>> a - input tensor
>> dim - the dimension along with which we do the job
>> low - lower bound
>> high - higher bound.
Note that range [1,3] means that we select 1 and 2.
<< return - the result of the generated tensor with selected data
*/
XTensor SelectRange(const XTensor &a, int dim, int low, int high)
{
int order = a.order;
int * dimSize = new int[order];
CheckNTErrors(&a != NULL, "Empty input tensors!");
CheckNTErrors(dim >= 0 && dim < a.order, "The input dimension is out of bounds!");
CheckNTErrors(low < high, "Illegal range specified!");
for(int i = 0; i < a.order; i++){
if(i == dim){
CheckNTErrors(low > 0 && low < a.dimSize[dim], "Illegal range specified!");
CheckNTErrors(high > 0 && high <= a.dimSize[dim], "Illegal range specified!");
dimSize[i] = high - low;
}
else
dimSize[i] = a.dimSize[i];
}
XTensor c = NewTensor(order, dimSize, a.dataType, a.denseRatio, a.devID, a.mem);
c.SetZeroAll();
c.SetTMP();
/* call _SelectRange function */
_SelectRange(&a, &c, dim, low, high);
/* tensor connection */
XLink::MakeLink(&a, NULL, &c, GETANDSET_SELECT);
XLink::AddParamToHead(&c, low);
XLink::AddParamToHead(&c, high);
/* destroy variables */
delete dimSize;
return c;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-07-04
*/
#ifndef __SELECT_CUH__
#define __SELECT_CUH__
#include "Select.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
/* generate a tensor with selected data c = select(a) */
extern "C"
void _CudaSelect(const XTensor * a, XTensor * c, XTensor * indexCPU);
/*
generate a tensor with selected data in range[low,high] along the given dimension
c = select(a)
*/
extern "C"
void _CudaSelectRange(const XTensor * a, XTensor * c, int dim, int low, int high);
} // namespace nts(NiuTrans.Tensor)
#endif // __SELECT_CUH__
\ No newline at end of file
......@@ -26,14 +26,29 @@
namespace nts{ // namespace nts(NiuTrans.Tensor)
/* generate a tensor with seleccted data c = select(a) */
extern "C"
void Select(XTensor * a, XTensor * c, XTensor * indexCPU);
/* generate a tensor with selected data c = select(a) */
extern "C"
void _Select(const XTensor * a, XTensor * c, XTensor * indexCPU);
/*
generate a tensor with selected data c = select(a) (returna a XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Select(const XTensor &a, XTensor &indexCPU);
/* generate a tensor with seleccted data in range[low,high] along the given dimension
c = select(a) */
/*
generate a tensor with selected data in range[low,high] along the given dimension
c = select(a)
*/
extern "C"
void SelectRange(XTensor * a, XTensor * c, int dim, int low, int high);
void _SelectRange(const XTensor * a, XTensor * c, int dim, int low, int high);
/*
generate a tensor with selected data in range[low,high] along the given dimension (return a XTensor structure)
make a new tensor to keep the result and return it
c = select(a)
*/
XTensor SelectRange(const XTensor &a, int dim, int low, int high);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -77,7 +77,7 @@ void SetDataRand(XTensor * tensor, DTYPE low, DTYPE high)
else{
XTensor * t2 = NewTensor(tensor->order, tensor->dimSize, tensor->dataType, tensor->denseRatio, -1);
SetDataRand(t2, low, high);
CopyValues(t2, tensor);
_CopyValues(t2, tensor);
delete t2;
}
}
......
......@@ -22,6 +22,7 @@
#include "../../XTensor.h"
#include "Log.h"
#include "Log.cuh"
#include <math.h>
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -29,12 +30,12 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
set every entry to its log value
>> a - the tensor we are processing
*/
void Log(XTensor * a)
void _Log(XTensor * a)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
CudaLog(a);
_CudaLog(a);
return;
}
#endif
......
......@@ -58,7 +58,7 @@ set each entry to its log value
>> a - the tensor
*/
extern "C"
void CudaLog(XTensor * a)
void _CudaLog(XTensor * a)
{
CheckNTErrors((a->isSparse == false), "TODO!");
......
......@@ -19,6 +19,9 @@
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#ifndef __LOG_CUH__
#define __LOG_CUH__
#include "Log.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -35,8 +38,10 @@ void KernelLog(__half * d, int size);
/* set each entry to its log value */
extern "C"
void CudaLog(XTensor * a);
void _CudaLog(XTensor * a);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
} // namespace nts(NiuTrans.Tensor)
#endif // __LOG_CUH__
\ No newline at end of file
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its log value */
extern "C"
void Log(XTensor * a);
void _Log(XTensor * a);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -21,15 +21,18 @@
#include <math.h>
#include "../../XTensor.h"
#include "../../XName.h"
#include "Normalize.h"
#include "Normalize.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
normalized the data with normal distribution. For an input x,
y = a * (x-mean)/sqrt(variance+\epsilon) + b
normalized the data with normal distribution
For an input x, y = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter.
>> input - the input tensor
>> output - the output tensor
>> dim - dimension alone which we generate the mean and variance
......@@ -39,7 +42,7 @@ where a and b are the scalar and bias respectively, and \epsilon is the adjustme
>> b - the bias
>> epsilon - a parameter
*/
void Normalize(XTensor * input, XTensor * output, int dim, XTensor * mean, XTensor * var, XTensor * a, XTensor * b, DTYPE epsilon)
void _Normalize(const XTensor * input, XTensor * output, int dim, const XTensor * mean, const XTensor * var, const XTensor * a, const XTensor * b, DTYPE epsilon)
{
int dimRDI = input->order - dim - 1;
CheckNTErrors((XTensor::IsIdentical(input, output)), "Unmatched input tensors!");
......@@ -68,7 +71,7 @@ void Normalize(XTensor * input, XTensor * output, int dim, XTensor * mean, XTens
if (input->devID >= 0 || output->devID >= 0) {
#ifdef USE_CUDA
CudaNormalize(input, output, dim, mean, var, a, b, epsilon);
_CudaNormalize(input, output, dim, mean, var, a, b, epsilon);
#else
ShowNTErrors("Please specify USE_CUDA and recompile the code!");
#endif
......@@ -91,4 +94,61 @@ void Normalize(XTensor * input, XTensor * output, int dim, XTensor * mean, XTens
}
}
}
/*
normalized the data with normal distribution (do it on site)
keep the result in the input tensor and return nothing
For an input x, x = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter.
>> input - the input tensor
>> dim - dimension alone which we generate the mean and variance
>> mean - the mean of the input
>> var - the variance of the input
>> a - the scalar
>> b - the bias
>> epsilon - a parameter
*/
void _NormalizeMe(XTensor * input, int dim, const XTensor * mean, const XTensor * var, const XTensor * a, const XTensor * b, DTYPE epsilon)
{
_Normalize(input, input, dim, mean, var, a, b, epsilon);
}
/*
normalized the data with normal distribution (return a XTensor structure)
make a new tensor to keep the result and return it
For an input x, y = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter.
>> input - the input tensor
>> dim - dimension alone which we generate the mean and variance
>> mean - the mean of the input
>> var - the variance of the input
>> a - the scalar
>> b - the bias
>> epsilon - a parameter
<< return - the result of normalized the data with normal distribution
*/
XTensor Normalize(const XTensor &input, int dim, const XTensor &mean, const XTensor &var, const XTensor &a, const XTensor &b, DTYPE epsilon)
{
XTensor output(&input);
output.SetTMP();
/* call _Normalize function */
_Normalize(&input, &output, dim, &mean, &var, &a, &b, epsilon);
/* tensor connections */
XList list(5);
list.Add(&input);
list.Add(&mean);
list.Add(&var);
list.Add(&a);
list.Add(&b);
XLink::MakeLink(&list, &output, MATH_NORMALIZE);
XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHead(&output, epsilon);
return output;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -89,9 +89,9 @@ where a and b are the scalar and bias respectively, and \epsilon is the adjustme
>> epsilon - a parameter
*/
extern "C"
void CudaNormalize(XTensor * input, XTensor * output, int dim,
XTensor * mean, XTensor * var,
XTensor * a, XTensor * b,
void _CudaNormalize(const XTensor * input, XTensor * output, int dim,
const XTensor * mean, const XTensor * var,
const XTensor * a, const XTensor * b,
DTYPE epsilon)
{
CheckNTErrors((input->dataType == DEFAULT_DTYPE), "TODO!");
......
......@@ -44,9 +44,9 @@ y = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter
*/
extern "C"
void CudaNormalize(XTensor * input, XTensor * output, int dim,
XTensor * mean, XTensor * var,
XTensor * a, XTensor * b, DTYPE epsilon);
void _CudaNormalize(const XTensor * input, XTensor * output, int dim,
const XTensor * mean, const XTensor * var,
const XTensor * a, const XTensor * b, DTYPE epsilon);
#endif // USE_CUDA
......
......@@ -27,12 +27,29 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
normalized the data with normal distribution. For an input x,
y = a * (x-mean)/sqrt(variance+\epsilon) + b
normalized the data with normal distribution.
For an input x, y = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter.
*/
extern "C"
void Normalize(XTensor * input, XTensor * output, int dim, XTensor * mean, XTensor * var, XTensor * a, XTensor * b, DTYPE epsilon);
void _Normalize(const XTensor * input, XTensor * output, int dim, const XTensor * mean, const XTensor * var, const XTensor * a, const XTensor * b, DTYPE epsilon);
/*
normalized the data with normal distribution (do it on site)
keep the result in the input tenosr and return nothing
For an input x, x = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter.
*/
extern "C"
void _NormalizeMe(XTensor * input, int dim, const XTensor * mean, const XTensor * var, const XTensor * a, const XTensor * b, DTYPE epsilon);
/*
normalized the data with normal distribution (return a XTensor structure)
make a new tensor to keep the result and return it
For an input x, y = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter.
*/
XTensor Normalize(const XTensor &input, int dim, const XTensor &mean, const XTensor &var, const XTensor &a, const XTensor &b, DTYPE epsilon);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -31,12 +31,12 @@ get the power(a, p)
>> a - the tensor
>> p - as it is
*/
void Power(XTensor * a, DTYPE p)
void _Power(XTensor * a, DTYPE p)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
CudaPower(a, p);
_CudaPower(a, p);
return;
}
#endif
......
......@@ -96,7 +96,7 @@ void KernelPower(__half * d, __half p, int size)
/* get the power of the entries */
extern "C"
void CudaPower(XTensor * a, DTYPE p)
void _CudaPower(XTensor * a, DTYPE p)
{
int gridSize[3];
int blockSize[3];
......
......@@ -38,7 +38,7 @@ void KernelSqrtV2(__half * d, int size);
/* get the power of the entries */
extern "C"
void CudaPower(XTensor * a, DTYPE p);
void _CudaPower(XTensor * a, DTYPE p);
#endif // USE_CUDA
......
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* get the power(x, y) */
extern "C"
void Power(XTensor * a, DTYPE p);
void _Power(XTensor * a, DTYPE p);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -28,8 +28,10 @@
namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
scale and shift all tensor entires b = a * scale + shift
scale and shift all tensor entires
b = a * scale + shift
>> a - the input tensor
>> b - the output tensor
>> scale - the scaler factor
......@@ -76,8 +78,11 @@ void _ScaleAndShift(const XTensor * a, XTensor * b, DTYPE scale, DTYPE shift)
}
/*
scale and shift all tensor entires on site b = a * scale + shift
b = a * scale + shift
scale and shift all tensor entires (do it on site)
keep the result in the input tensor a and return nothing
a = a * scale + shift
>> a - the input/output tensor
>> scale - the scaler factor
>> shift - the shift factor
......@@ -88,19 +93,22 @@ void _ScaleAndShiftMe(XTensor * a, DTYPE scale, DTYPE shift)
}
/*
scale and shift all tensor entires b = a * scale + shift
scale and shift all tensor entires (return a XTensor structure)
make a new tensor to keep the result and return it
b = a * scale + shift
>> a - the input tensor
>> b - the output tensor
>> scale - the scaler factor
>> shift - the shift factor
<< return - the result of scaling and shifting all tensor entires
*/
XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift)
{
XTensor b(&a);
b.SetTMP();
/* computation */
/* call _ScaleAndShift function */
_ScaleAndShift(&a, &b, scale, shift);
/* tensor connections */
......
......@@ -30,13 +30,24 @@ namespace nts{ // namespace nts(NiuTrans.Tensor)
#define _LinearMe _ScaleAndShiftMe
#define Linear ScaleAndShift
/* scale and shift all tensor entires b = a * scale + shift */
/*
scale and shift all tensor entires
b = a * scale + shift
*/
void _ScaleAndShift(const XTensor * a, XTensor * b, DTYPE scale, DTYPE shift = 0);
/* scale and shift all tensor entires on site a = a * scale + shift */
/*
scale and shift all tensor entires
keep the result in the input tensor a and return nothing
a = a * scale + shift
*/
void _ScaleAndShiftMe(XTensor * a, DTYPE scale, DTYPE shift = 0);
/* scale and shift all tensor entires b = a * scale + shift, and return the result tensor b */
/*
scale and shift all tensor entires
make a new tensor to keep the result and return it
b = a * scale + shift
*/
XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift = 0);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -36,7 +36,7 @@ copy a number of blocks to target positions
>> targetBlocks - target positions of the copy
>> myMem - the memory pool
*/
void CopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem)
void _CopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem)
{
if (myMem != NULL && myMem->devID >= 0) {
#ifdef USE_CUDA
......@@ -44,7 +44,7 @@ void CopyBlocks(void * source, int blockSize, int blockNum, void * target, int *
int * targetBlocksTMP = (int*)myMem->AllocBuf(myMem->devID, blockNum * sizeof(int));
XMemCopy(targetBlocksTMP, myMem->devID, targetBlocks, -1, blockNum * sizeof(int));
CopyBlocksOnSite(source, blockSize, blockNum, target, targetBlocksTMP, myMem);
_CopyBlocksOnSite(source, blockSize, blockNum, target, targetBlocksTMP, myMem);
myMem->ReleaseBuf(myMem->devID, blockNum * sizeof(int));
#else
......@@ -52,7 +52,7 @@ void CopyBlocks(void * source, int blockSize, int blockNum, void * target, int *
#endif
}
else {
CopyBlocksOnSite(source, blockSize, blockNum, target, targetBlocks, myMem);
_CopyBlocksOnSite(source, blockSize, blockNum, target, targetBlocks, myMem);
}
}
......@@ -66,14 +66,14 @@ copy a number of blocks source source positions to target positions
>> targetBlocks - target positions of the copy
>> myMem - the memory pool
*/
void CopyBlocks(void * source, int blockSize, int * sourceBlocks, int blockNum, void * target, int * targetBlocks, XMem * myMem, int devID)
void _CopyBlocks(void * source, int blockSize, int * sourceBlocks, int blockNum, void * target, int * targetBlocks, XMem * myMem, int devID)
{
if (myMem != NULL)
CheckNTErrors((myMem->devID == devID), "DevIDs are different between memory pool and input devID!");
if (devID >= 0) {
#ifdef USE_CUDA
CudaCopyBlocksSelected(source, blockSize, sourceBlocks, blockNum, target, targetBlocks, myMem, devID);
_CudaCopyBlocksSelected(source, blockSize, sourceBlocks, blockNum, target, targetBlocks, myMem, devID);
#else
ShowNTErrors("Plesae specify USE_CUDA and recompile the code!");
#endif
......
......@@ -27,10 +27,10 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* copy a number of blocks to target positions */
void CopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem);
void _CopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem);
/* copy a number of blocks from source positions to target positions */
void CopyBlocks(void * source, int blockSize, int * sourceBlocks, int blockNum, void * target, int * targetBlocks, XMem * myMem, int devID);
void _CopyBlocks(void * source, int blockSize, int * sourceBlocks, int blockNum, void * target, int * targetBlocks, XMem * myMem, int devID);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -38,7 +38,7 @@ Note that a grid may have a number of blocks
>> myMem - the memory pool
>> isIndexOnDev - indicates whether the index is on the device already
*/
void CopyBlocksInGrid(void * source, int blockSize, int blockNum, int gridNum, void * target,
void _CopyBlocksInGrid(void * source, int blockSize, int blockNum, int gridNum, void * target,
int * index, int unitSize, bool isIndexOnDev, XMem * myMem)
{
CheckNTErrors((unitSize == sizeof(int)), "TODO!");
......@@ -51,7 +51,7 @@ void CopyBlocksInGrid(void * source, int blockSize, int blockNum, int gridNum, v
XMemCopy(indexGPU, myMem->devID, index, -1, blockNum * gridNum * sizeof(int));
}
CudaCopyBlocksInGrid(source, blockSize, blockNum, gridNum, target, indexGPU, unitSize, myMem);
_CudaCopyBlocksInGrid(source, blockSize, blockNum, gridNum, target, indexGPU, unitSize, myMem);
if (!isIndexOnDev)
myMem->ReleaseBuf(myMem->devID, blockNum * gridNum * sizeof(int));
......
......@@ -216,7 +216,7 @@ Note that a grid may have a number of blocks
>> itemSize - size of each data item
>> myMem - the memory pool
*/
void CudaCopyBlocksInGrid(void * source, int blockSize, int blockNum, int gridNum, void * target, int * index, int itemSize, XMem * myMem)
void _CudaCopyBlocksInGrid(void * source, int blockSize, int blockNum, int gridNum, void * target, int * index, int itemSize, XMem * myMem)
{
CheckNTErrors((myMem != NULL && myMem->devID >= 0), "This code must be run on GPUs!");
CheckNTErrors((itemSize == sizeof(int)), "TODO!");
......
......@@ -30,7 +30,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* copy data by index */
extern "C"
void CudaCopyBlocksInGrid(void * source, int blockSize, int blockNum, int gridNum, void * target, int * index, int unitSize, XMem * myMem);
void _CudaCopyBlocksInGrid(void * source, int blockSize, int blockNum, int gridNum, void * target, int * index, int unitSize, XMem * myMem);
#endif // USE_CUDA
......
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* copy a number of blocks in grid */
extern "C"
void CopyBlocksInGrid(void * source, int blockSize, int blockNum, int gridNum, void * target, int * index, int unitSize, bool isIndexOnDev, XMem * myMem);
void _CopyBlocksInGrid(void * source, int blockSize, int blockNum, int gridNum, void * target, int * index, int unitSize, bool isIndexOnDev, XMem * myMem);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -36,11 +36,11 @@ all the data has been on the device (CPU/GPU) already.
>> targetBlocks - target positions of the copy
>> myMem - the memory pool
*/
void CopyBlocksOnSite(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem)
void _CopyBlocksOnSite(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem)
{
if (myMem != NULL && myMem->devID >= 0) {
#ifdef USE_CUDA
CudaCopyBlocks(source, blockSize, blockNum, target, targetBlocks, myMem);
_CudaCopyBlocks(source, blockSize, blockNum, target, targetBlocks, myMem);
#else
ShowNTErrors("Plesae specify USE_CUDA and recompile the code!");
#endif
......
......@@ -80,7 +80,7 @@ copy a number of blocks to target positions (cuda version)
>> targetBlocks - target positions of the copy (on the device)
>> myMem - memory pool
*/
void CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem)
void _CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem)
{
CheckNTErrors((myMem != NULL), "No memory pool!");
CheckNTErrors((myMem->devID >= 0), "Wrong device to run!");
......
......@@ -34,7 +34,7 @@ void KernelCopyBlocks(DTYPE * source, int blockSize, int blockNum, DTYPE * targe
/* copy a number of blocks to target positions (cuda version) */
extern "C"
void CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem);
void _CudaCopyBlocks(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem);
#endif // USE_CUDA
......
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* copy a number of blocks to target positions (on site) */
extern "C"
void CopyBlocksOnSite(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem);
void _CopyBlocksOnSite(void * source, int blockSize, int blockNum, void * target, int * targetBlocks, XMem * myMem);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -70,7 +70,7 @@ copy a number of blocks from source positions to target positions (cuda version)
>> targetBlocks - target positions of the copy
>> myMem - memory pool
*/
void CudaCopyBlocksSelected(void * source, int blockSize, int * sourceBlocks, int blockNum, void * target, int * targetBlocks, XMem * myMem, int devID)
void _CudaCopyBlocksSelected(void * source, int blockSize, int * sourceBlocks, int blockNum, void * target, int * targetBlocks, XMem * myMem, int devID)
{
CheckNTErrors((devID >= 0), "Wrong device to run!");
CheckNTErrors((blockSize % sizeof(DTYPE) == 0), "Unsupported block size!");
......
......@@ -34,7 +34,7 @@ void KernelCopyBlocksSelected(DTYPE * source, int blockSize, int * sourceBlocks,
/* copy a number of blocks form source positions to target positions (cuda version) */
extern "C"
void CudaCopyBlocksSelected(void * source, int blockSize, int * sourceBlocks, int blockNum, void * target, int * targetBlocks, XMem * myMem, int devID);
void _CudaCopyBlocksSelected(void * source, int blockSize, int * sourceBlocks, int blockNum, void * target, int * targetBlocks, XMem * myMem, int devID);
#endif // USE_CUDA
......
......@@ -36,7 +36,7 @@ copy data blocks by 2d layout
>> n - height of each block
>> myMem - the memory pool
*/
void CopyData2D(void ** s, int sPitch, void ** t, int tPitch, int blockNum, int mSize, int n, XMem * myMem)
void _CopyData2D(void ** s, int sPitch, void ** t, int tPitch, int blockNum, int mSize, int n, XMem * myMem)
{
int devID = myMem != NULL ? myMem->devID : -1;
......
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* copy data blocks by 2d layout */
extern "C"
void CopyData2D(void ** s, int sPitch, void ** t, int tPitch, int count, int mSize, int n, XMem * myMem);
void _CopyData2D(void ** s, int sPitch, void ** t, int tPitch, int count, int mSize, int n, XMem * myMem);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -36,7 +36,7 @@ in the k-th grid
>> blockNumInGrid - number of blocks in each grid
>> isIndexOnDev - indicates whether the index is on the device already
*/
void CopyInGrid(XTensor * s, XTensor * t, int * index, int blockDim, int blockNumInGrid, bool isIndexOnDev)
void _CopyInGrid(const XTensor * s, XTensor * t, int * index, int blockDim, int blockNumInGrid, bool isIndexOnDev)
{
CheckNTErrors((XTensor::IsIdentical(s, t)), "Unmatched tensors!");
......@@ -50,7 +50,7 @@ void CopyInGrid(XTensor * s, XTensor * t, int * index, int blockDim, int blockNu
CheckNTErrors((s->unitNum % (blockSize * blockNum) == 0), "Illegal block number!");
gridNum = s->unitNum / (blockSize * blockNum);
CopyBlocksInGrid(s->data, blockSize, blockNum, gridNum, t->data, index, s->unitSize, isIndexOnDev, s->mem);
_CopyBlocksInGrid(s->data, blockSize, blockNum, gridNum, t->data, index, s->unitSize, isIndexOnDev, s->mem);
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -28,7 +28,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* copy a number of blocks in grid. i.e., reorder the data blocks in the same memory piece*/
extern "C"
void CopyInGrid(XTensor * s, XTensor * t, int * index, int blockDim, int blockNumInGrid, bool isIndexOnDev = false);
void _CopyInGrid(const XTensor * s, XTensor * t, int * index, int blockDim, int blockNumInGrid, bool isIndexOnDev = false);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -21,11 +21,13 @@
#include "CopyIndexed.h"
#include "CopyBlocks.h"
#include "../../XName.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
copy indexed sub-tensors
>> s - the source tensor
>> t - the target tensor
>> dim - the leading dimension to define "sub-tensors"
......@@ -34,11 +36,11 @@ copy indexed sub-tensors
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
>> tgtIndex - index of the target sub-tensors
>> copyNum - number of the sub-tensors we copy for each source index, e.g.,
for srcIndex = [1,4] and copyNum = 2, we actually copy the source sub-tensors 1, 2, 4, 5
<< return - whether copy indexed operation was successful
>> copyNum - number of the sub-tensors we copy for each source index,
e.g., for srcIndex = [1,4] and copyNum = 2,
we actually copy the source sub-tensors 1, 2, 4, 5
*/
bool CopyIndexed(XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize, int * tgtIndex, int copyNum)
void _CopyIndexed(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize, int * tgtIndex, int copyNum)
{
CheckNTErrors((s && t), "Invalid tensors!");
CheckNTErrors((s->devID == t->devID || (s->devID < 0 && t->devID < 0)),
......@@ -84,12 +86,62 @@ bool CopyIndexed(XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSiz
CheckNTErrors((tgtIndex[i] < blockNumTgt), "Index is out of range!");
}
CopyBlocks(s->data, blockSizeSrc * s->unitSize, realSrcIndex, realIndexSize, t->data, realTgtIndex, s->mem, s->devID);
_CopyBlocks(s->data, blockSizeSrc * s->unitSize, realSrcIndex, realIndexSize, t->data, realTgtIndex, s->mem, s->devID);
delete[] realSrcIndex;
delete[] realTgtIndex;
}
/*
copy indexed sub-tensors (return a XTensor structure)
make a new tensor to keep the result and return it
>> s - the source tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3,2)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
>> tgtIndex - index of the target sub-tensors
>> copyNum - number of the sub-tensors we copy for each source index,
e.g., for srcIndex = [1,4] and copyNum = 2,
we actually copy the source sub-tensors 1, 2, 4, 5
<< return - the result of copying indexed sub-tensors
*/
XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, int * tgtIndex, int copyNum)
{
CheckNTErrors(&s, "Empty input tensor!");
CheckNTErrors((dim >= 0 && dim < s.order), "A too larget dimension specified!");
int order = s.order;
int * dimSize = new int[order];
for (int i = 0; i < s.order; i++) {
if (i == dim)
dimSize[i] = indexSize * copyNum;
else
dimSize[i] = s.dimSize[i];
}
XTensor t = NewTensor(order, dimSize, s.dataType, s.denseRatio, s.devID, s.mem);
t.SetZeroAll();
t.SetTMP();
/* call _CopyIndexed function */
_CopyIndexed(&s, &t, dim, srcIndex, indexSize, tgtIndex, copyNum);
/* destroy variables */
delete dimSize;
/* tensor connection */
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYINDEXED);
XLink::AddParamToHead(&t, dim);
XLink::AddParamToHeadPointer(&t, srcIndex);
XLink::AddParamToHead(&t, indexSize);
XLink::AddParamToHeadPointer(&t, tgtIndex);
XLink::AddParamToHead(&t, copyNum);
return true;
return t;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -28,7 +28,13 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* copy selected sub-tensors */
extern "C"
bool CopyIndexed(XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize, int * tgtIndex, int copyNum);
void _CopyIndexed(const XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize, int * tgtIndex, int copyNum);
/*
copy selected sub-tensors (return a XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, int * tgtIndex, int copyNum);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -27,18 +27,15 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
copy s to t
>> s - source
>> t - target
>> stream - the stream for creating the job pipeline
<< return - succeeded or not
*/
bool CopyValues(const XTensor * s, XTensor * t, XStream * stream)
void _CopyValues(const XTensor * s, XTensor * t, XStream * stream)
{
if (s == NULL || t == NULL)
return false;
if (s->data == NULL || t->data == NULL)
return false;
CheckNTErrors((s != NULL && t != NULL), "The input tensor and output tensor must be nonempty!");
CheckNTErrors((s->data != NULL), "Cannot copy from an empty data array!");
CheckNTErrors((t->data != NULL), "Cannot copy to an empty data array!");
CheckNTErrors((s->unitNum == t->unitNum), "Unmatched data item number!");
......@@ -48,12 +45,13 @@ bool CopyValues(const XTensor * s, XTensor * t, XStream * stream)
"The code must be run on the same device!");
CheckNTErrors((s->isSparse || t->isSparse), "TODO!");
ConvertDataType(s->devID, s->data, s->dataType, t->data, t->dataType, s->unitNum);
return true;
}
#ifdef USE_CUDA
if (s->devID >= 0 || t->devID >= 0)
return CudaCopyValues(s, t, stream);
if (s->devID >= 0 || t->devID >= 0) {
_CudaCopyValues(s, t, stream);
return;
}
#endif
if (!s->isSparse && !t->isSparse) {
......@@ -68,8 +66,28 @@ bool CopyValues(const XTensor * s, XTensor * t, XStream * stream)
else {
ShowNTErrors("TODO!");
}
}
/*
copy s to t (return a XTensor structure)
make a new tensor to keep the result and return it
>> s - source
>> stream - the stream for creating the job pipeline
<< return - the copyed tensor t
*/
XTensor CopyValues(const XTensor &s, XStream * stream)
{
XTensor t(&s);
t.SetTMP();
/* call _CopyValues function */
_CopyValues(&s, &t, stream);
/* tensor connection */
XLink::MakeLink(&s, NULL, &t, MOVEMENT_COPYVALUES);
return true;
return t;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -35,11 +35,9 @@ copy a range of elements from a source vector to a target vector
>> stream - the stream for creating the job pipeline
<< return - succeed or not
*/
bool CudaCopyValues(const XTensor * s, XTensor * t, XStream * stream)
void _CudaCopyValues(const XTensor * s, XTensor * t, XStream * stream)
{
if (s == NULL || t == NULL)
return false;
CheckNTErrors((s != NULL && t != NULL), "The input tensor and output tensor must be nonempty!");
CheckNTErrors(s->dataType == t->dataType, "Unmatched data type!");
CheckNTErrors((s->unitSize == t->unitSize), "Incompatible vectors in value copy.");
CheckNTErrors((s->denseRatio <= s->denseRatio), "Incompatible vectors in value copy.");
......@@ -83,8 +81,6 @@ bool CudaCopyValues(const XTensor * s, XTensor * t, XStream * stream)
else {
ShowNTErrors("TODO!");
}
return true;
}
......
......@@ -30,7 +30,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* copy all elements from a source matrix to a target matrix */
extern "C"
bool CudaCopyValues(const XTensor * s, XTensor * t, XStream * stream = NULL);
void _CudaCopyValues(const XTensor * s, XTensor * t, XStream * stream = NULL);
#endif // USE_CUDA
......
......@@ -28,7 +28,13 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* copy s to t */
extern "C"
bool CopyValues(const XTensor * s, XTensor * t, XStream * stream = NULL);
void _CopyValues(const XTensor * s, XTensor * t, XStream * stream = NULL);
/*
copy s to t (return a XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor CopyValues(const XTensor &s, XStream * stream = NULL);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -27,12 +27,13 @@
namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
get the max value of the items along a dimension of the tensor.
get the max value of the items along a dimension of the tensor
>> input - the input tensor
>> output - the output tensor
>> dim - the dimension where the reduction is performed on
*/
void ReduceMax(XTensor * input, XTensor * output, int dim)
void _ReduceMax(const XTensor * input, XTensor * output, int dim)
{
CheckNTErrors((input->devID == output->devID || (input->devID < 0 && output->devID < 0)),
"This code must be run on the same device!");
......@@ -55,7 +56,7 @@ void ReduceMax(XTensor * input, XTensor * output, int dim)
if(input->devID >= 0){
#ifdef USE_CUDA
CudaReduceMax(input, output, dim);
_CudaReduceMax(input, output, dim);
#endif
}
else{
......@@ -90,4 +91,43 @@ void ReduceMax(XTensor * input, XTensor * output, int dim)
}
}
/*
get the max value of the items along a dimension of the tensor (return a XTensor structure).
make a new tensor to keep the result and return it
>> input - the input tensor
>> dim - the dimension where the reduction is performed on
<< return - the max value of the items along a dimension of the tensor
*/
XTensor ReduceMax(const XTensor &input, int dim)
{
CheckNTErrors(&input, "Empty input or output tensors!");
CheckNTErrors((dim >= 0 && dim < input.order), "Illegal dimension to reduce!");
int order = input.order - 1;
int * dimSize = new int[order];
for(int i = 0; i < input.order; i++){
if(i < dim)
dimSize[i] = input.dimSize[i];
else if(i > dim)
dimSize[i] = input.dimSize[i + 1];
}
XTensor output = NewTensor(order, dimSize, input.dataType, input.denseRatio, input.devID, input.mem);
output.SetZeroAll();
output.SetTMP();
/* call _ReduceMax function */
_ReduceMax(&input, &output, dim);
/* destroy variables */
delete dimSize;
/* tensor connection */
XLink::MakeLink(&input, NULL, &output, REDUCE_REDUCEMAX);
XLink::AddParamToHead(&output, dim);
return output;
}
} // namespace nts(NiuTrans.Tensor)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论