Commit 67bbdfd2 by xuchen

merge with the latest branch of xuchen

parents a33c3231 d664c0a0
......@@ -56,6 +56,74 @@ private:
/* gradient for matrix multiply: c = matmul(a, b) */
static
void GradMatrixMul(XTensor * node);
/* gradient for log: c = log(a) */
static
void GradLog(XTensor * node);
/* gradient for power */
static
void GradPower(XTensor * node);
/* gradient for negate */
static
void GradNegate(XTensor * node);
/* gradient for ScaleAndShift */
static
void GradScaleAndShift(XTensor * node);
/* gradient for Minus */
static
void GradSub(XTensor * node);
/* gradient for Divide */
static
void GradDiv(XTensor * node);
/* gradient for reduceMean */
static
void GradReduceMean(XTensor * node);
/* gradient for reduceSum */
static
void GradReduceSum(XTensor * node);
/* gradient for reduceSumSquared */
static
void GradReduceSumSquared(XTensor * node);
/* gradient for reduceVariance */
static
void GradReduceVariance(XTensor * node);
/* gradient for sin */
static
void GradSin(XTensor * node);
/* gradient for cos */
static
void GradCos(XTensor * node);
/* gradient for tan */
static
void GradTan(XTensor * node);
/* gradient for exp */
static
void GradExp(XTensor * node);
/* gradient for normalize */
static
void GradNormalize(XTensor * node);
/* gradient for absolute */
static
void GradAbsolute(XTensor * node);
/* gradient for sign */
static
void GradSign(XTensor * node);
};
}
......
......@@ -47,6 +47,8 @@ void XShapeGrad::MakeGrad(XTensor * node)
GradSplit(node);
else if(operID == SHAPE_SPLIT_LIST)
GradSplitList(node);
else if (operID == SHAPE_TRANSPOSE)
GradTranspose(node);
else{
ShowNTErrors("TODO!");
}
......@@ -370,4 +372,36 @@ void XShapeGrad::GradUnsqueeze(XTensor * node)
node->visitMark = NODE_FINISHED;
}
/*
gradient for transposing a tensor
for
c = Transpose(a)
we have
dE/da = Transpose(dE/dc)
>> node - the node (c) for backward computation
*/
void XShapeGrad::GradTranspose(XTensor * node)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum == 1, "Wrong input tensor number for TRANSPOSE!");
XTensor * output = node;
XTensor * input = income.tails[0];
XTensor * b = NewTensor(input);
XNoder::MakeGrad(input);
int i = income.GetParamInt(0);
int j = income.GetParamInt(1);
CheckNTErrors(input->order > i && i >= 0, "index of dimension is out of scope!");
CheckNTErrors(input->order > j && j >= 0, "index of dimension is out of scope!");
_Transpose(output->grad, b, i, j);
_Sum(input->grad, b, input->grad);
node->visitMark = NODE_FINISHED;
delete b;
}
}
\ No newline at end of file
......@@ -70,6 +70,10 @@ private:
/* gradient computation for unsqueezing a tensor : c = unsqueeze(a) */
static
void GradUnsqueeze(XTensor * node);
/* gradient computation for unsqueezing a tensor : c = unsqueeze(a) */
static
void GradTranspose(XTensor * node);
};
......
......@@ -37,6 +37,7 @@
using namespace nts;
void SetDataTest();
void SmallTest();
void TransposeTest();
......
......@@ -29,22 +29,34 @@ const char * GetOPName(int type)
if ((type & MATH_BASE) != 0){
if (type == MATH_ABSOLUTE)
return "M_ABSOLUTE";
else if (type == MATH_EXP)
return "M_EXP";
else if (type == MATH_LOG)
return "M_LOG";
else if (type == MATH_SIN)
return "M_SIN";
else if (type == MATH_COS)
return "M_COS";
else if (type == MATH_TAN)
return "M_TAN";
else if (type == MATH_MATRIXMUL)
return "M_MATRIXMUL";
else if (type == MATH_MATRIXMULBATCHED)
return "M_MATRIXMULBATCHED";
else if (type == MATH_MULTIPLY)
return "M_MULTIPLY";
else if (type == MATH_DIV)
return "M_DIV";
else if (type == MATH_NEGATE)
return "M_NEGATE";
else if (type == MATH_SIGN)
return "M_SIGN";
else if (type == MATH_SUM)
return "M_SUM";
else if (type == MATH_SUB)
return "M_SUB";
else if (type == MATH_SUMDIM)
return "M_SUMDIM";
else if (type == MATH_LOG)
return "M_LOG";
else if (type == MATH_NORMALIZE)
return "M_NORMALIZE";
else if (type == MATH_POWER)
......
......@@ -31,16 +31,23 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* math operations */
#define MATH_BASE 0x00001000
#define MATH_ABSOLUTE MATH_BASE + 1
#define MATH_MATRIXMUL MATH_ABSOLUTE + 1
#define MATH_EXP MATH_ABSOLUTE + 1
#define MATH_LOG MATH_EXP + 1
#define MATH_SIN MATH_LOG + 1
#define MATH_COS MATH_SIN + 1
#define MATH_TAN MATH_COS + 1
#define MATH_NEGATE MATH_TAN + 1
#define MATH_MATRIXMUL MATH_TAN + 1
#define MATH_MATRIXMULBATCHED MATH_MATRIXMUL + 1
#define MATH_MULTIPLY MATH_MATRIXMULBATCHED + 1
#define MATH_NEGATE MATH_MULTIPLY + 1
#define MATH_SIGN MATH_NEGATE + 1
#define MATH_DIV MATH_MULTIPLY + 1
#define MATH_SIGN MATH_DIV + 1
#define MATH_SUM MATH_SIGN + 1
#define MATH_SUMDIM MATH_SUM + 1
#define MATH_SUB MATH_SUM + 1
#define MATH_SUMDIM MATH_SUB + 1
#define MATH_LOG MATH_SUMDIM + 1
#define MATH_NORMALIZE MATH_LOG + 1
#define MATH_NORMALIZE MATH_SUMDIM + 1
#define MATH_POWER MATH_NORMALIZE + 1
#define MATH_SCALEANDSHIFT MATH_POWER + 1
......
......@@ -26,49 +26,62 @@
#include "../XTensor.h"
#include "shape/Concatenate.h"
#include "shape/ConcatenateSolely.h"
#include "movement/CopyBlocks.h"
#include "movement/CopyBlocksInGrid.h"
#include "movement/CopyBlocksOnSite.h"
#include "movement/CopyData2D.h"
#include "movement/CopyIndexed.h"
#include "movement/CopyInGrid.h"
#include "movement/CopyValues.h"
#include "utilities/FlushToMem.h"
#include "shape/MakeMergeBlockIndex.h"
#include "shape/MakeSplitBlockIndex.h"
#include "arithmetic/Div.h"
#include "arithmetic/MatrixMul.h"
#include "arithmetic/MatrixMul2D.h"
#include "arithmetic/MatrixMul2DMultiTheading.h"
#include "arithmetic/MatrixMul2DParallel.h"
#include "arithmetic/MatrixMulBatched.h"
#include "shape/Merge.h"
#include "shape/MergeBlockLists.h"
#include "arithmetic/Multiply.h"
#include "arithmetic/Negate.h"
#include "arithmetic/Sign.h"
#include "arithmetic/Sub.h"
#include "arithmetic/Sum.h"
#include "arithmetic/SumByColumnTV.h"
#include "arithmetic/SumByColumnVT.h"
#include "arithmetic/SumDim.h"
#include "arithmetic/XTensorBLAS.h"
#include "getandset/ConvertDataType.h"
#include "getandset/Select.h"
#include "getandset/SetData.h"
#include "math/Normalize.h"
#include "shape/Permute.h"
#include "math/Power.h"
#include "math/ScaleAndShift.h"
#include "math/Unary.h"
#include "movement/CopyBlocks.h"
#include "movement/CopyBlocksInGrid.h"
#include "movement/CopyBlocksOnSite.h"
#include "movement/CopyData2D.h"
#include "movement/CopyIndexed.h"
#include "movement/CopyInGrid.h"
#include "movement/CopyValues.h"
#include "reduce/ReduceMax.h"
#include "reduce/ReduceMean.h"
#include "reduce/ReduceStandardVariance.h"
#include "reduce/ReduceSum.h"
#include "reduce/ReduceSumSquared.h"
#include "reduce/ReduceVariance.h"
#include "math/ScaleAndShift.h"
#include "getandset/Select.h"
#include "getandset/SetData.h"
#include "sort/Sort.h"
#include "shape/Concatenate.h"
#include "shape/ConcatenateSolely.h"
#include "shape/MakeMergeBlockIndex.h"
#include "shape/MakeSplitBlockIndex.h"
#include "shape/Merge.h"
#include "shape/MergeBlockLists.h"
#include "shape/Permute.h"
#include "shape/Split.h"
#include "arithmetic/Sum.h"
#include "arithmetic/SumByColumnTV.h"
#include "arithmetic/SumByColumnVT.h"
#include "arithmetic/SumDim.h"
#include "sort/TopK.h"
#include "shape/Transpose.h"
#include "shape/Unsqueeze.h"
#include "sort/Sort.h"
#include "sort/TopK.h"
#include "utilities/XMatrixSegment.h"
#include "arithmetic/XTensorBLAS.h"
#include "utilities/FlushToMem.h"
#endif // __CHEADER_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include <math.h>
#include "../../XTensor.h"
#include "../../XName.h"
#include "Absolute.h"
#include "Absolute.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
set every entry to its absolute value
>> a - input tensor we are processing
>> b - output tensor we are processing
*/
void _Absolute(const XTensor * a, XTensor * b)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
_CudaAbsolute(a, b);
return;
}
#endif
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
DTYPE * d = (DTYPE*)a->data;
DTYPE * db = (DTYPE*)b->data;
for (int i = 0; i < a->unitNum; i++)
db[i] = (DTYPE)fabs(d[i]);
}
/*
set every entry to its absolute value (do it on site)
keep the result in the input tensor a and return nothing
>> a - the tensor we are processing
*/
void _AbsoluteMe(XTensor * a)
{
_Absolute(a, a);
}
/*
set every entry to its absolute value (return a XTensor structure)
make a new tensor to keep the result and return it
>> a - input tensor we are processing
<< return - the absolute value of input tensor
*/
XTensor Absolute(const XTensor & a)
{
XTensor b(&a);
b.SetTMP();
/* call _Absolute function */
_Absolute(&a, &b);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_ABSOLUTE);
return b;
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "Absolute.h"
#include "Absolute.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
set each entry to its absolute value (CUDA Kernel)
>> a - pointer to input data array
>> b - pointer to output data array
>> size - size of the data array
*/
__global__
void KernelAbsolute(DTYPE * a, DTYPE * b, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
b[i] = fabs(a[i]);
}
/*
set each entry to its absolute value (CUDA Kernel)
This is for float16 computation
>> a - pointer to input data array
>> b - pointer to output data array
>> size - size of the data array
*/
__global__
void KernelAbsolute(__half * a, __half * b, int size)
{
return;
}
/*
set each entry to its absolute value
>> a - input tensor
>> b - output tensor
*/
void _CudaAbsolute(const XTensor * a, XTensor * b)
{
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->isSparse == false), "TODO!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
KernelAbsolute << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum);
}
else if (a->dataType == X_FLOAT16) {
KernelAbsolute << <blocks, threads >> >((__half*)a->data, (__half*)b->data, a->unitNum);
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "Div.h"
#include "Div.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
element-wise division of two tensors
c(i) = a(i)/b(i) + \alpha * c(i)
where i is the index of the item
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
*/
void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim)
{
int leadingDimRDI = a->order - leadingDim - 1;
CheckNTErrors((a->unitNum <= c->unitNum && b->unitNum <= c->unitNum),
"Unmatched tensors in multiplication!");
CheckNTErrors((a->order == b->order && a->order == c->order),
"Unmatched tensors!");
#ifdef USE_CUDA
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
_CudaDiv(a, b, c, alpha, leadingDim);
return;
}
#endif
int stride = 1;
int blockSizeA = 1;
int blockSizeB = 1;
int blockSizeC = 1;
int blockNum = 1;
int dimensionSizeA = a->dimSizeRDI[leadingDimRDI];
int dimensionSizeB = b->dimSizeRDI[leadingDimRDI];
int dimensionSizeC = c->dimSizeRDI[leadingDimRDI];
for (int i = 0; i < a->order; i++) {
if (i != leadingDimRDI) {
CheckNTErrors((a->dimSizeRDI[i] == b->dimSizeRDI[i] && a->dimSizeRDI[i] == c->dimSizeRDI[i]),
"Unmatched tensors!");
}
if (i < leadingDimRDI)
stride *= a->dimSizeRDI[i];
}
blockSizeA = stride * dimensionSizeA;
blockSizeB = stride * dimensionSizeB;
blockSizeC = stride * dimensionSizeC;
blockNum = a->unitNum / blockSizeA;
if (!a->isSparse && !b->isSparse) {
if (a->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE) {
if (a->unitNum == c->unitNum && b->unitNum == c->unitNum) {
int size = a->unitNum;
DTYPE * ap = (DTYPE*)a->data;
DTYPE * bp = (DTYPE*)b->data;
DTYPE * cp = (DTYPE*)c->data;
if (alpha == 0) {
for (int i = 0; i < size; i++)
cp[i] = ap[i] / bp[i];
}
else {
for (int i = 0; i < size; i++)
cp[i] = ap[i] / bp[i] + alpha * cp[i];
}
}
else {
for (int k = 0; k < blockNum; k++) {
for (int ci = 0, ai = 0, bi = 0; ci < dimensionSizeC; ci++, ai++, bi++) {
if (ai >= dimensionSizeA)
ai = 0;
if (bi >= dimensionSizeB)
bi = 0;
DTYPE * ap = (DTYPE*)a->data + k * blockSizeA + ai * stride;
DTYPE * bp = (DTYPE*)b->data + k * blockSizeB + bi * stride;
DTYPE * cp = (DTYPE*)c->data + k * blockSizeC + ci * stride;
for (int j = 0; j < stride; j++)
cp[j] = ap[j] / bp[j] + cp[j] * alpha;
}
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
/*
element-wise division of two tensors (do it on site)
keep the result in the input tensor a and return nothing
a(i) = a(i)*b(i) + \alpha * a(i)
where i is the index of the item
>> a - tensor a (where keep the result)
>> b - tensor b
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
*/
void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha, int leadingDim)
{
_Div(a, b, a, alpha, leadingDim);
}
/*
element-wise division of two tensors (return a XTensor structure)
make a new tensor c to keep the result and return it
c(i) = a(i)*b(i)
where i is the index of the item
>> a - tensor a
>> b - tensor b
>> leadingDim - the dimension along which we perform broadcasting
<< return - the product of the tensors
*/
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim)
{
CheckNTErrors(a.dimSize[leadingDim] == b.dimSize[leadingDim], "TODO!");
XTensor c(&a);
c.SetTMP();
/* call _Multiply function */
_Div(&a, &b, &c, 0, leadingDim);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_DIV);
XLink::AddParamToHeadInt(&c, leadingDim);
return c;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "Div.h"
#include "Div.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
division of data arrays in a element-wise manner c(i) = a(i)/b(i)
>> a - data array a
>> b - data array b
>> c - result data array
>> size - size of c
*/
__global__
void KernelDivElementWise(DTYPE * a, DTYPE * b, DTYPE * c, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
c[i] = a[i] / b[i];
}
/*
division of data arrays in a element-wise manner c(i) = a(i)/b(i) + \alpha*c(i)
>> a - data array a
>> b - data array b
>> c - result data array
>> size - size of c
>> alpha - the coefficient
*/
__global__
void KernelDivElementWiseV2(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE alpha)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
c[i] = a[i] / b[i] + alpha * c[i];
}
/*
division of two tensors in a element-wise manner c(i) = a(i)/b(i).
Note that a and b can be of different sizes here, i.e.,
|a_lead| <= |c_lead| and |b_lead| <= |c_lead|
where |a_lead| means the size of the leading dimension of a
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> stride - the number of items we go over when move next along the leading dimension in a block
>> ldSizeA - size of the leading dimension of a
>> ldSizeB - size of the leading dimension of b
>> ldSizeC - size of the leading dimension of c
>> blockNum - number of blocks
*/
template<int nonZeroAlpha> __global__
void KernelDivElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE alpha,
int stride, int ldSizeA, int ldSizeB, int ldSizeC, int blockNum)
{
__shared__ DTYPE* ap[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE* bp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
__shared__ DTYPE* cp[MAX_CUDA_THREAD_NUM_PER_BLOCK];
int i = blockDim.x * blockIdx.x + threadIdx.x;
int j = blockDim.y * blockIdx.y + threadIdx.y;
if (i >= blockNum * stride || j >= ldSizeC)
return;
if (threadIdx.y == 0) {
int block = i / stride;
int size = block * stride;
ap[threadIdx.x] = a + size * ldSizeA;
bp[threadIdx.x] = b + size * ldSizeB;
cp[threadIdx.x] = c + size * ldSizeC;
}
__syncthreads();
int aj = j >= ldSizeA ? j % ldSizeA : j;
int bj = j >= ldSizeB ? j % ldSizeB : j;
int offseti = i % stride;
if (nonZeroAlpha == 0)
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] / bp[threadIdx.x][bj * ldSizeB + offseti];
else
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] / bp[threadIdx.x][bj * ldSizeB + offseti]
+ alpha * cp[threadIdx.x][j * ldSizeC + offseti];
}
/*
element-wise division of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the item index
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> leadingDim - dimension along which we perform broadcasting
*/
void _CudaDiv(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim)
{
int leadingDimRDI = a->order - leadingDim - 1;
CheckNTErrors((a->unitNum <= c->unitNum && b->unitNum <= c->unitNum),
"Unmatched tensors in multiplication!");
CheckNTErrors((a->order == b->order && a->order == c->order), "Unmatched tensors!");
int stride = 1;
int blockSizeA = 1;
int blockNum = 1;
int dimensionSizeA = a->dimSizeRDI[leadingDimRDI];
int dimensionSizeB = b->dimSizeRDI[leadingDimRDI];
int dimensionSizeC = c->dimSizeRDI[leadingDimRDI];
for (int i = 0; i < a->order; i++) {
if (i != leadingDimRDI) {
CheckNTErrors((a->dimSizeRDI[i] == b->dimSizeRDI[i] &&
a->dimSizeRDI[i] == c->dimSizeRDI[i]),
"Unmatched tensors!");
}
if (i < leadingDimRDI)
stride *= a->dimSizeRDI[i];
}
blockSizeA = stride * dimensionSizeA;
blockNum = a->unitNum / blockSizeA;
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
if (!a->isSparse && !b->isSparse) {
if (a->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE) {
int cudaGridSize[3];
int cudaBlockSize[3];
if (a->unitNum == c->unitNum && b->unitNum == c->unitNum) {
GDevs.GetCudaThread(a->devID, c->unitNum, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[0]), threads(cudaBlockSize[0]);
if (alpha == 0)
KernelDivElementWise << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, c->unitNum);
else
KernelDivElementWiseV2 << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, c->unitNum, alpha);
}
else {
GDevs.GetCudaThread2D(c->devID, stride * blockNum, dimensionSizeC, MAX_INT, cudaGridSize, cudaBlockSize);
dim3 blocks(cudaGridSize[0], cudaGridSize[1]), threads(cudaBlockSize[0], cudaBlockSize[1]);
if (alpha == 0) {
KernelDivElementWiseTensorDynamic<0> << <blocks, threads >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, 0,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
}
else {
KernelDivElementWiseTensorDynamic<1> << <blocks, threads >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, alpha,
stride, dimensionSizeA, dimensionSizeB, dimensionSizeC, blockNum);
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#ifndef __DIV_CUH__
#define __DIV_CUH__
#include "Div.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* division of two tensors in a element-wise manner c(i) = a(i)/b(i) */
__global__
void KernelDivElementWise(DTYPE * a, DTYPE * b, DTYPE * c, int size);
/* division of two tensors in a element-wise manner c(i) = a(i)/b(i) + \alpha*c(i) */
__global__
void KernelDivElementWiseV2(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE alpha);
/* division of two tensors in a element-wise manner c(i) = a(i)/b(i)+ \alpha*c(i) */
template<int nonZeroAlpha>__global__
void KernelDivElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE alpha, int stride, int ldSizeA, int ldSizeB, int ldSizeC, int blockNum);
/* element-wise division of two tensors */
void _CudaDiv(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __DIV_CUH__
......@@ -16,31 +16,39 @@
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#ifndef __LOG_H__
#define __LOG_H__
#ifndef __DIV_H__
#define __DIV_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its log value */
void _Log(const XTensor * a, XTensor * b);
/*
element-wise division of two tensors:
c(i) = a(i)/b(i) + \alpha * c(i)
where i is the index of the element
*/
void _Div(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha = 0, int leadingDim = 0);
/*
set every entry to its log value (do it on site)
element-wise division of two tensors (do it on site)
keep the result in the input tensor a and return nothing
a(i) = a(i)/b(i) + \alpha * a(i)
where i is the index of the element
*/
void _LogMe(XTensor * a);
void _DivMe(XTensor * a, const XTensor * b, DTYPE alpha = 0, int leadingDim = 0);
/*
set every entry to its log value (return a XTensor structure)
element-wise division of two tensors (return a XTensor structure)
make a new tensor to keep the result and return it
c(i) = a(i)/b(i)
where i is the index of the element
*/
XTensor Log(const XTensor & a);
XTensor Div(const XTensor &a, const XTensor &b, int leadingDim = 0);
} // namespace nts(NiuTrans.Tensor)
#endif // __LOG_H__
#endif // __DIV_H__
\ No newline at end of file
......@@ -32,9 +32,9 @@ element-wise product of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
where i is the index of the item
>> a - matrix a
>> b - matrix b
>> c - result matrix
>> a - tensor a
>> b - tensor b
>> c - result tensor
>> alpha - the coefficient
>> leadingDim - the dimension along which we perform broadcasting
*/
......
......@@ -104,9 +104,9 @@ void KernelMulElementWiseTensorDynamic(DTYPE * a, DTYPE * b, DTYPE * c, DTYPE al
int offseti = i % stride;
if (nonZeroAlpha == 0)
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj* ldSizeA + offseti] * bp[threadIdx.x][bj* ldSizeB + offseti];
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] * bp[threadIdx.x][bj * ldSizeB + offseti];
else
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj* ldSizeA + offseti] * bp[threadIdx.x][bj* ldSizeB + offseti] +
cp[threadIdx.x][j * ldSizeC + offseti] = ap[threadIdx.x][aj * ldSizeA + offseti] * bp[threadIdx.x][bj * ldSizeB + offseti] +
alpha * cp[threadIdx.x][j * ldSizeC + offseti];
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "Sub.h"
#include "Sub.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor subtraction c = a - b * \beta
>> a - a tensor
>> b - another tensor
>> c - where we put a-b*\beta. we save it in a if c is NULL
>> beta - the scaling factor
*/
void _Sub(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors(a->unitNum == b->unitNum && a->unitNum == c->unitNum,
"Unmatched tensors in addition!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Unmatched tensors in addition!");
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
if (a == c) {
int P2PAccesible = 0;
#ifdef CUDA_UVA
cudaDeviceCanAccessPeer(&P2PAccesible, a->devID, b->devID);
#endif
if ((a->devID < 0 && b->devID >= 0) ||
(a->devID >= 0 && b->devID < 0) ||
(a->devID >= 0 && b->devID >= 0 && a->devID != b->devID && !P2PAccesible))
{
ShowNTErrors("Cannot run this method on multiple devices simultaneously!");
}
else
_CudaSub(a, b, c, beta);
}
else
_CudaSub(a, b, c, beta);
#endif
}
else {
if (!a->isSparse && !b->isSparse) {
CheckNTErrors(!c->isSparse, "Illegal use of sparse tensor in addition!");
if (a->dataType == DEFAULT_DTYPE &&
b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE)
{
DTYPE * ap = (DTYPE*)a->data;
DTYPE * bp = (DTYPE*)b->data;
DTYPE * cp = (DTYPE*)c->data;
/* unrolling */
int num = a->unitNum;
if (num % 4 == 0) {
for (int i = 0; i < num; i += 4) {
cp[i] = ap[i] - bp[i] * beta;
cp[i + 1] = ap[i + 1] - bp[i + 1] * beta;
cp[i + 2] = ap[i + 2] - bp[i + 2] * beta;
cp[i + 3] = ap[i + 3] - bp[i + 3] * beta;
}
}
else if (num % 2 == 0) {
for (int i = 0; i < num; i += 2) {
cp[i] = ap[i] - bp[i] * beta;
cp[i + 1] = ap[i + 1] - bp[i + 1] * beta;
}
}
else {
for (int i = 0; i < num; i++) {
cp[i] = ap[i] - bp[i] * beta;
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
}
/*
tensor subtraction a = a - b * \beta (do it on site)
keep the result in the tensor a and return nothing
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
*/
void _SubMe(XTensor * a, const XTensor * b, DTYPE beta)
{
_Sub(a, b, a, beta);
}
/*
tensor subtraction c = a - b * \beta (return a XTensor structure)
make a new tensor c to keep the result and return it
>> a - a tensor
>> b - another tensor
>> beta - the scaling factor
<< return - the result of tensor subtraction
*/
XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta)
{
XTensor c(&a);
c.SetTMP();
/* call _Sub function */
_Sub(&a, &b, &c, beta);
/* tensor connections */
XLink::MakeLink(&a, &b, &c, MATH_SUB);
XLink::AddParamToHead(&c, beta);
return c;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#include "../../XDevice.h"
#include "../../XUtility.h"
#include "Sub.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
subtraction of data arrays (CUDA Kernel)
c = a - b * \beta
>> a - A matrix
>> b - another matrix
>> c - where we put a-b
>> size - the size of a/b/c
>> beta - the coefficient
*/
__global__
void KernelSUB(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
c[i] = a[i] - b[i] * beta;
}
/*
tensor subtraction c = a - b * \beta (cuda version)
>> a - a tensor
>> b - another tensor
>> c - where we put a-b*\beta.
>> beta - the scaling factor
*/
void _CudaSub(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
CheckNTErrors(a && b && c, "Empty tensor input!");
CheckNTErrors((a->unitNum == b->unitNum && a->unitNum == c->unitNum),
"Unmatched tensors in addition!");
CheckNTErrors((a->dataType == b->dataType && a->dataType == c->dataType),
"Unmatched tensors in addition!");
CheckNTErrors((a->devID == b->devID && a->devID == c->devID),
"The tensors must be on the same!");
int devIDBackup = XDevice::GetGPUDevice();
XDevice::SetGPUDevice(a->devID);
if (!a->isSparse && !b->isSparse) {
CheckNTErrors(!c->isSparse, "Illegal use of sparse matrix in addition!");
if (a->dataType == DEFAULT_DTYPE &&
b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE)
{
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
KernelSUB << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, a->unitNum, beta);
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
XDevice::SetGPUDevice(devIDBackup);
}
/* subtraction over arrays
tensor subtraction c = a - b * \beta (cuda version) with an input handle
>> devID - device ID (MUST >= 0)
>> handle - cuda handle
>> a - an array
>> b - another array
>> c - where we put a-b
>> size - size of the array
>> beta - the coefficient
*/
void _CudaSubWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta)
{
if (size == 0)
return;
if (c == NULL)
c = a;
CheckNTErrors((a && b && c), "Empty arrays in addition!");
int devIDBackup;
ProtectCudaDev(devID, devIDBackup);
if (c == a) {
#ifdef DOUBELPRICSION
cublasDaxpy(*handle, size, &beta, b, 1, a, 1);
#else
cublasSaxpy(*handle, size, &beta, b, 1, a, 1);
#endif
}
else {
int gridSize[3], blockSize[3];
GDevs.GetCudaThread(devID, size, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
KernelSUB<<<blocks, threads>>>((DTYPE*)a, (DTYPE*)b, (DTYPE*)c, size, beta);
}
BacktoCudaDev(devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#ifndef __SUB_CUH__
#define __SUB_CUH__
#include "Sub.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* subtraction of data arrays (CUDA Kernel) */
__global__
void KernelSUB(DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1.0);
/* tensor subtraction c = a - b * \beta (cuda version) */
void _CudaSub(const XTensor * a, const XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
/* tensor subtraction c = a - b * \beta (cuda version) with an input handle */
void _CudaSubWithHandle(int devID, cublasHandle_t * handle, DTYPE * a, DTYPE * b, DTYPE * c, int size, DTYPE beta = (DTYPE)1.0);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __SUB_CUH__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
* Today is the first day of August. It's still very hot.
*/
#ifndef __ABSOLUTE_H__
#define __ABSOLUTE_H__
#ifndef __SUB_H__
#define __SUB_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its absolute value */
void _Absolute(const XTensor * a, XTensor * b);
/* tensor subtraction c = a - b * \beta */
void _Sub(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
/*
set every entry to its absolute value (do it on site)
/*
tensor subtraction a = a - b * \beta
keep the result in the input tensor a and return nothing
*/
void _AbsoluteMe(XTensor * a);
/*
set every entry to its absolute value (return a XTensor structure)
make a new tensor to keep the result and return it
void _SubMe(XTensor * a, const XTensor * b, DTYPE beta = (DTYPE)1.0);
/*
tensor subtraction c = a - b * \beta
make a new tensor c to keep the result and return it
*/
XTensor Absolute(const XTensor & a);
XTensor Sub(const XTensor &a, const XTensor &b, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __ABSOLUTE_H__
#endif // __SUB_H__
......@@ -116,7 +116,8 @@ void _SumDim(const XTensor * a, const XTensor * b, XTensor * c, int n, DTYPE bet
}
/*
tensor summation (on site)
tensor summation (do it on site)
keep the result in the input tensor and return nothing
a = a + b * \beta
where the size of b is equal to the n-th dimension of a,
......@@ -133,7 +134,8 @@ void _SumDim(XTensor * a, const XTensor * b, int n, DTYPE beta)
}
/*
tensor summation (return a structure and make tensor connections)
tensor summation (return a XTensor structure and make tensor connections)
make a new tensor to keep the result and return it
c = a + b * \beta
where the size of b is equal to the n-th dimension of a,
......@@ -141,9 +143,9 @@ i.e., a is summed with b by broadcasting
>> a - a tensor
>> b - another tensor whose size is equal to that of dimension n of a
>> c - where we put a+b*\beta. we save it in a if c is NULL
>> n - the dimension index
>> beta - the scaling factor
<< return - the result tensor by tensor summation
*/
XTensor SumDim(const XTensor &a, const XTensor &b, int n, DTYPE beta)
{
......
......@@ -20,6 +20,7 @@
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-05-08
*/
#include <math.h>
#include "SetData.h"
#include "SetData.cuh"
#include "../../XUtility.h"
......@@ -37,6 +38,43 @@
namespace nts{ // namespace nts(NiuTrans.Tensor)
/*
Fills the input Tensor or Variable with values according to the method described in
"Understanding the difficulty of training deep feedforward neural networks" - Glorot, X. & Bengio, Y. (2010),
using a uniform distribution. The resulting tensor will have values sampled from :math:`U(-a, a)`
where :math:`a = gain \times \sqrt{2 / (fan\_in + fan\_out)} \times \sqrt{3}`. Also known as Glorot initialisation.
>> tensor - the tensor whose data array would be initialized
>> gain - an optional scaling factor
*/
void _SetDataFanInOut(XTensor * tensor, DTYPE gain)
{
CheckNTErrors(tensor->dataType == X_FLOAT, "the tensor must be in X_FLOAT!");
CheckNTErrors(tensor->order >= 2, "the tensor dimension must be no less than 2!");
int fanIn = 1;
int fanOut = 1;
int order = tensor->order;
if (order == 2) {
fanIn = tensor->dimSize[1];
fanOut = tensor->dimSize[0];
}
else {
int numInputFmaps = tensor->dimSize[1];
int numOutputFmaps = tensor->dimSize[0];
int receptiveFieldSize = 0;
for (int i = 2; i < order; i++)
receptiveFieldSize += tensor->dimSize[i];
fanIn = numInputFmaps * receptiveFieldSize;
fanOut = numOutputFmaps * receptiveFieldSize;
}
DTYPE std = gain * sqrt(2.0/(fanIn + fanOut));
DTYPE a = sqrt(3.0) * std;
_SetDataRand(tensor, -a, a);
}
/*
generate data items with a fixed value p
>> tensor - the tensor whose data array would be initialized
......@@ -65,7 +103,7 @@ void _SetDataFixed(XTensor * tensor, void * valuePointer)
}
else{
#ifdef USE_CUDA
CudaSetDataFixedInt(tensor, p);
_CudaSetDataFixedInt(tensor, p);
#endif
}
}
......@@ -88,7 +126,7 @@ void _SetDataFixed(XTensor * tensor, void * valuePointer)
}
else{
#ifdef USE_CUDA
CudaSetDataFixedFloat(tensor, p);
_CudaSetDataFixedFloat(tensor, p);
#endif
}
}
......@@ -111,7 +149,7 @@ void _SetDataFixed(XTensor * tensor, void * valuePointer)
}
else{
#ifdef USE_CUDA
CudaSetDataFixedDouble(tensor, p);
_CudaSetDataFixedDouble(tensor, p);
#endif
}
}
......@@ -137,7 +175,7 @@ generate data items with a fixed value p (in integer)
*/
void _SetDataFixedInt(XTensor * tensor, int p)
{
CheckNTErrors(tensor->dataType == X_INT, "the tensor must be in X_INT");
CheckNTErrors(tensor->dataType == X_INT, "the tensor must be in X_INT!");
if(p == 0)
tensor->SetZeroAll();
......@@ -152,7 +190,7 @@ generate data items with a fixed value p (in float)
*/
void _SetDataFixedFloat(XTensor * tensor, float p)
{
CheckNTErrors(tensor->dataType == X_FLOAT, "the tensor must be in X_INT");
CheckNTErrors(tensor->dataType == X_FLOAT, "the tensor must be in X_FLOAT!");
if(p == 0)
tensor->SetZeroAll();
......@@ -167,7 +205,7 @@ generate data items with a fixed value p (in double)
*/
void _SetDataFixedDouble(XTensor * tensor, double p)
{
CheckNTErrors(tensor->dataType == X_DOUBLE, "the tensor must be in X_INT");
CheckNTErrors(tensor->dataType == X_DOUBLE, "the tensor must be in X_DOUBLE!");
if(p == 0)
tensor->SetZeroAll();
......@@ -183,6 +221,8 @@ generate data items with a uniform distribution in [low,high]
*/
void _SetDataRand(XTensor * tensor, DTYPE low, DTYPE high)
{
CheckNTErrors(high > low, "the high value must be greater than low value!");
if(tensor == NULL)
return;
......@@ -215,10 +255,13 @@ void _SetDataRand(XTensor * tensor, DTYPE low, DTYPE high)
TODO: generate data points on GPUs straightforwardly.
*/
else{
XTensor * t2 = NewTensor(tensor->order, tensor->dimSize, tensor->dataType, tensor->denseRatio, -1);
_SetDataRand(t2, low, high);
_CopyValues(t2, tensor);
delete t2;
#ifdef USE_CUDA
_CudaSetDataRand(tensor, low, high);
#endif
//XTensor * t2 = NewTensor(tensor->order, tensor->dimSize, tensor->dataType, tensor->denseRatio, -1);
//_SetDataRand(t2, low, high);
//_CopyValues(t2, tensor);
//delete t2;
}
}
......
......@@ -21,7 +21,10 @@
* I'm surprised that I did not write this file till today.
*/
#include <curand.h>
#include <time.h>
#include "SetData.cuh"
#include <curand_kernel.h>
#include "../../XDevice.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -46,7 +49,7 @@ generate data items with a fixed value p (in int)
>> tensor - the tensor for initialization
>> p - the initial value
*/
void CudaSetDataFixedInt(XTensor * tensor, int p)
void _CudaSetDataFixedInt(XTensor * tensor, int p)
{
CheckNTErrors(tensor->dataType == X_INT, "the tensor must be in X_INT!");
......@@ -86,7 +89,7 @@ generate data items with a fixed value p (in float)
>> tensor - the tensor for initialization
>> p - the initial value
*/
void CudaSetDataFixedFloat(XTensor * tensor, float p)
void _CudaSetDataFixedFloat(XTensor * tensor, float p)
{
CheckNTErrors(tensor->dataType == X_FLOAT, "the tensor must be in X_FLOAT!");
......@@ -126,7 +129,7 @@ generate data items with a fixed value p (in double)
>> tensor - the tensor for initialization
>> p - the initial value
*/
void CudaSetDataFixedDouble(XTensor * tensor, double p)
void _CudaSetDataFixedDouble(XTensor * tensor, double p)
{
CheckNTErrors(tensor->dataType == X_DOUBLE, "the tensor must be in X_DOUBLE!");
......@@ -146,4 +149,115 @@ void CudaSetDataFixedDouble(XTensor * tensor, double p)
BacktoCudaDev(tensor->devID, devIDBackup);
}
/*
call curand_init function on each kernel with the same random seed
and init the rng states
*/
__global__
void KernelInitializeCurand(curandState * state, unsigned long seed)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
curand_init(seed, i, 0, &state[i]);
}
/* */
__device__
float GenerateFloat(curandState* globalState, int i)
{
//copy state to local mem
curandState localState = globalState[i];
//apply uniform distribution with calculated random
float randNum = curand_uniform(&localState);
//update state
globalState[i] = localState;
//return value
return randNum;
}
/**/
__device__
double GenerateDouble(curandState* globalState, int i)
{
//copy state to local mem
curandState localState = globalState[i];
//apply uniform distribution with calculated random
double randNum = curand_uniform_double(&localState);
//update state
globalState[i] = localState;
//return value
return randNum;
}
/*
set data array with a uniform distribution in [low, high]
>> deviceStates - the state of curand
>> d - float datatype pointer to the data array
>> size - size of the array
>> low - low value of the range
>> high - high value of the range
*/
__global__
void KernelSetDataRandFloat(curandState* deviceStates, float * d, int size, DTYPE low, DTYPE variance)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
float randNum = GenerateFloat(deviceStates, i);
d[i] = randNum * variance + low;
}
}
/*
set data array with a uniform distribution in [low, high]
>> deviceStates - the state of curand
>> d - double datatype pointer to the data array
>> size - size of the array
>> low - low value of the range
>> high - high value of the range
*/
__global__
void KernelSetDataRandDouble(curandState* deviceStates, double * d, int size, DTYPE low, DTYPE variance)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size){
double randNum = GenerateDouble(deviceStates, i);
d[i] = randNum * variance + low;
}
}
/*
generate data items with a uniform distribution in [low,high]
>> tensor - the tensor whose data array would be initialized
>> low - lower value of the range
>> high - higher value of the range
*/
void _CudaSetDataRand(XTensor * tensor, DTYPE low, DTYPE high)
{
CheckNTErrors(high > low, "the high value must be greater than low value!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(tensor->devID, tensor->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(tensor->devID, devIDBackup);
curandState *deviceStates;
cudaMalloc(&deviceStates, sizeof(curandState));
DTYPE variance = high - low;
KernelInitializeCurand<<<blocks, threads>>>(deviceStates, unsigned(time(NULL)));
if (tensor->dataType == X_FLOAT)
KernelSetDataRandFloat <<<blocks, threads >>>(deviceStates, (float*)tensor->data, tensor->unitNum, low, variance);
else if (tensor->dataType == X_DOUBLE)
KernelSetDataRandDouble <<<blocks, threads >>>(deviceStates, (double*)tensor->data, tensor->unitNum, low, variance);
BacktoCudaDev(tensor->devID, devIDBackup);
}
} // namespace nts(NiuTrans.Tensor)
......@@ -29,13 +29,16 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* generate data items with a fixed value p (in int) */
void CudaSetDataFixedInt(XTensor * tensor, int p);
void _CudaSetDataFixedInt(XTensor * tensor, int p);
/* generate data items with a fixed value p (in float) */
void CudaSetDataFixedFloat(XTensor * tensor, float p);
void _CudaSetDataFixedFloat(XTensor * tensor, float p);
/* generate data items with a fixed value p (in double) */
void CudaSetDataFixedDouble(XTensor * tensor, double p);
void _CudaSetDataFixedDouble(XTensor * tensor, double p);
/* generate data items with a uniform distribution in [low,high] */
void _CudaSetDataRand(XTensor * tensor, DTYPE low, DTYPE high);
} // namespace nts(NiuTrans.Tensor)
......
......@@ -27,6 +27,9 @@
namespace nts { // namespace nts(NiuTrans.Tensor)
/* generate data items with a xavier initialization */
void _SetDataFanInOut(XTensor * tensor, DTYPE gain = 1.0F);
/* generate data items with a fixed value p */
void _SetDataFixed(XTensor * tensor, void * valuePointer);
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "Log.h"
#include "Log.cuh"
#include <math.h>
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
set every entry to its log value (do it on site)
>> a - input tensor we are processing
>> b - output tensor we are processing
*/
void _Log(const XTensor * a, XTensor * b)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
_CudaLog(a, b);
return;
}
#endif
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
DTYPE * d = (DTYPE*)a->data;
DTYPE * db = (DTYPE*)b->data;
for (int i = 0; i < a->unitNum; i++)
db[i] = (DTYPE)log(d[i]);
}
/*
set every entry to its log value
keep the result in the input tensor a and return nothing
>> a - the tensor we are processing
*/
void _LogMe(XTensor * a)
{
_Log(a, a);
}
/*
set every entry to its log value (return a XTensor structure)
make a new tensor to keep the result and return it
>> a - input tensor we are processing
<< return - the log value of the input tensor
*/
XTensor Log(const XTensor & a)
{
XTensor b(&a);
b.SetTMP();
/* call _Log function */
_Log(&a, &b);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_LOG);
return b;
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "Log.h"
#include "Log.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
set each entry to its log value (CUDA Kernel)
>> a - pointer to input data array
>> b - pointer to output data array
>> size - size of the data array
*/
__global__
void KernelLog(DTYPE * a, DTYPE * b, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
b[i] = log(a[i]);
}
/*
set each entry to its log value (CUDA Kernel)
This is for float16 computation
>> a - pointer to input data array
>> b - pointer to output data array
>> size - size of the data array
*/
__global__
void KernelLog(__half * a, __half * b, int size)
{
return;
}
/*
set each entry to its log value
>> a - input tensor
>> b - output tensor
*/
void _CudaLog(const XTensor * a, XTensor * b)
{
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->isSparse == false), "TODO!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
KernelLog << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum);
}
else if (a->dataType == X_FLOAT16) {
KernelLog << <blocks, threads >> >((__half*)a->data, (__half*)b->data, a->unitNum);
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#include <math.h>
#include "../../XName.h"
#include "Unary.h"
#include "Unary.cuh"
namespace nts{
#ifdef USE_CUDA
/* define three marco separately, specify the respective function names */
#define _SIMPLE_UNARY_FUNCTION(_funcName, _cudaFuncName, origFunc) \
void _funcName(const XTensor * a, XTensor * b) \
{ \
/* run it on GPUs */ \
if (a->devID >= 0) { \
_cudaFuncName(a, b); \
return; \
} \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!"); \
DTYPE * d = (DTYPE*)a->data; \
DTYPE * db = (DTYPE*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (DTYPE)origFunc(d[i]); \
}
#define _SIMPLE_UNARY_FUNCTION_ME(_funcNameMe, _funcName) \
void _funcNameMe(XTensor * a) \
{ \
_funcName(a, a); \
}
#define SIMPLE_UNARY_FUNCTION(funcName, _funcName, operationId) \
XTensor funcName(const XTensor &a) \
{ \
XTensor b(&a); \
b.SetTMP(); \
_funcName(&a, &b); \
XLink::MakeLink(&a, NULL, &b, operationId); \
return b; \
}
_SIMPLE_UNARY_FUNCTION(_Absolute, _CudaAbsolute, fabs)
_SIMPLE_UNARY_FUNCTION_ME(_AbsoluteMe, _Absolute)
SIMPLE_UNARY_FUNCTION(Absolute, _Absolute, MATH_ABSOLUTE)
_SIMPLE_UNARY_FUNCTION(_Exp, _CudaExp, exp)
_SIMPLE_UNARY_FUNCTION_ME(_ExpMe, _Exp)
SIMPLE_UNARY_FUNCTION(Exp, _Exp, MATH_EXP)
_SIMPLE_UNARY_FUNCTION(_Log, _CudaLog, log)
_SIMPLE_UNARY_FUNCTION_ME(_LogMe, _Log)
SIMPLE_UNARY_FUNCTION(Log, _Log, MATH_LOG)
_SIMPLE_UNARY_FUNCTION(_Sin, _CudaSin, sin)
_SIMPLE_UNARY_FUNCTION_ME(_SinMe, _Sin)
SIMPLE_UNARY_FUNCTION(Sin, _Sin, MATH_SIN)
_SIMPLE_UNARY_FUNCTION(_Cos, _CudaCos, cos)
_SIMPLE_UNARY_FUNCTION_ME(_CosMe, _Cos)
SIMPLE_UNARY_FUNCTION(Cos, _Cos, MATH_COS)
_SIMPLE_UNARY_FUNCTION(_Tan, _CudaTan, tan)
_SIMPLE_UNARY_FUNCTION_ME(_TanMe, _Tan)
SIMPLE_UNARY_FUNCTION(Tan, _Tan, MATH_TAN)
#else
/* define three marco separately, specify the respective function names */
#define _SIMPLE_UNARY_FUNCTION(_funcName, origFunc) \
void _funcName(const XTensor * a, XTensor * b) \
{ \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!"); \
DTYPE * d = (DTYPE*)a->data; \
DTYPE * db = (DTYPE*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (DTYPE)origFunc(d[i]); \
}
#define _SIMPLE_UNARY_FUNCTION_ME(_funcNameMe, _funcName) \
void _funcNameMe(XTensor * a) \
{ \
_funcName(a, a); \
}
#define SIMPLE_UNARY_FUNCTION(funcName, _funcName, operationId) \
XTensor funcName(const XTensor &a) \
{ \
XTensor b(&a); \
b.SetTMP(); \
_funcName(&a, &b); \
XLink::MakeLink(&a, NULL, &b, operationId); \
return b; \
}
_SIMPLE_UNARY_FUNCTION(_Absolute, fabs)
_SIMPLE_UNARY_FUNCTION_ME(_AbsoluteMe, _Absolute)
SIMPLE_UNARY_FUNCTION(Absolute, _Absolute, MATH_ABSOLUTE)
_SIMPLE_UNARY_FUNCTION(_Exp, exp)
_SIMPLE_UNARY_FUNCTION_ME(_ExpMe, _Exp)
SIMPLE_UNARY_FUNCTION(Exp, _Exp, MATH_EXP)
_SIMPLE_UNARY_FUNCTION(_Log, log)
_SIMPLE_UNARY_FUNCTION_ME(_LogMe, _Log)
SIMPLE_UNARY_FUNCTION(Log, _Log, MATH_LOG)
_SIMPLE_UNARY_FUNCTION(_Sin, sin)
_SIMPLE_UNARY_FUNCTION_ME(_SinMe, _Sin)
SIMPLE_UNARY_FUNCTION(Sin, _Sin, MATH_SIN)
_SIMPLE_UNARY_FUNCTION(_Cos, cos)
_SIMPLE_UNARY_FUNCTION_ME(_CosMe, _Cos)
SIMPLE_UNARY_FUNCTION(Cos, _Cos, MATH_COS)
_SIMPLE_UNARY_FUNCTION(_Tan, tan)
_SIMPLE_UNARY_FUNCTION_ME(_TanMe, _Tan)
SIMPLE_UNARY_FUNCTION(Tan, _Tan, MATH_TAN)
#endif
}
\ No newline at end of file
#include <math.h>
#include "../../XDevice.h"
#include "../../XName.h"
#include "Unary.cuh"
namespace nts {
#define SIMPLE_UNARY_FUNCTION_GPU(funcName, origFunc) \
__global__ \
void Kernel##funcName(DTYPE * a, DTYPE * b, int size) \
{ \
int i = blockDim.x * blockIdx.x + threadIdx.x; \
\
if (i < size) \
b[i] = (DTYPE)origFunc(a[i]); \
} \
__global__ \
void Kernel##funcName(__half * a, __half * b, int size) \
{ \
return; \
} \
void _Cuda##funcName(const XTensor * a, XTensor * b) \
{ \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \
CheckNTErrors((a->isSparse == false), "TODO!"); \
\
int gridSize[3]; \
int blockSize[3]; \
\
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize); \
\
dim3 blocks(gridSize[0]); \
dim3 threads(blockSize[0]); \
\
int devIDBackup; \
ProtectCudaDev(a->devID, devIDBackup); \
\
if (a->dataType == DEFAULT_DTYPE) { \
Kernel##funcName << <blocks, threads >> > \
((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum); \
} \
else if (a->dataType == X_FLOAT16) { \
Kernel##funcName << <blocks, threads >> > \
((__half*)a->data, (__half*)b->data, a->unitNum); \
} \
else { \
ShowNTErrors("TODO!"); \
} \
\
BacktoCudaDev(a->devID, devIDBackup); \
} \
SIMPLE_UNARY_FUNCTION_GPU(Absolute, fabs)
SIMPLE_UNARY_FUNCTION_GPU(Exp, exp)
SIMPLE_UNARY_FUNCTION_GPU(Log, log)
SIMPLE_UNARY_FUNCTION_GPU(Sin, sin)
SIMPLE_UNARY_FUNCTION_GPU(Cos, cos)
SIMPLE_UNARY_FUNCTION_GPU(Tan, tan)
}
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#ifndef __UNARY_CUH__
#define __UNARY_CUH__
#include "../../XTensor.h"
#include "Unary.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* set each entry to its absolute value (CUDA Kernel) */
__global__
void KernelAbsolute(DTYPE * a, DTYPE * b, int size);
/* set each entry to its absolute value (CUDA Kernel) with float16 data type*/
__global__
void KernelAbsolute(__half * a, __half * b, int size);
/* set each entry to its absolute value */
void _CudaAbsolute(const XTensor * a, XTensor * b);
/* set each entry to its exponent value (CUDA Kernel) */
__global__
void KernelExp(DTYPE * a, DTYPE * b, int size);
/* set each entry to its exponent value (CUDA Kernel) with float16 data type*/
__global__
void KernelExp(__half * a, __half * b, int size);
/* set each entry to its exponent value */
void _CudaExp(const XTensor * a, XTensor * b);
/* set each entry to its logarithm value (CUDA Kernel) */
__global__
void KernelLog(DTYPE * a, DTYPE * b, int size);
/* set each entry to its logarithm value (CUDA Kernel) with float16 data type*/
__global__
void KernelLog(__half * a, __half * b, int size);
/* set each entry to its logarithm value */
void _CudaLog(const XTensor * a, XTensor * b);
/* set each entry to its sine value (CUDA Kernel) */
__global__
void KernelSin(DTYPE * a, DTYPE * b, int size);
/* set each entry to its sine value (CUDA Kernel) with float16 data type*/
__global__
void KernelSin(__half * a, __half * b, int size);
/* set each entry to its sine value */
void _CudaSin(const XTensor * a, XTensor * b);
/* set each entry to its cosine value (CUDA Kernel) */
__global__
void KernelCos(DTYPE * a, DTYPE * b, int size);
/* set each entry to its cosine value (CUDA Kernel) with float16 data type*/
__global__
void KernelCos(__half * a, __half * b, int size);
/* set each entry to its cosine value */
void _CudaCos(const XTensor * a, XTensor * b);
/* set each entry to its tangent value (CUDA Kernel) */
__global__
void KernelTan(DTYPE * a, DTYPE * b, int size);
/* set each entry to its tangent value (CUDA Kernel) with float16 data type*/
__global__
void KernelTan(__half * a, __half * b, int size);
/* set each entry to its tangent value */
void _CudaTan(const XTensor * a, XTensor * b);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __UNARY_CUH__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#ifndef __UNARY_H__
#define __UNARY_H__
#include "../../XTensor.h"
namespace nts{
/* set every entry to its absolute value */
void _Absolute(const XTensor * a, XTensor * b);
/*
set every entry to its absolute value (do it on site)
keep the result in the input tensor a and return nothing
*/
void _AbsoluteMe(XTensor * a);
/*
set every entry to its absolute value (return a XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Absolute(const XTensor & a);
/* set every entry to its exponent value */
void _Exp(const XTensor * a, XTensor * b);
/*
set every entry to its exponent value (do it on site)
keep the result in the input tensor a and return nothing
*/
void _ExpMe(XTensor * a);
/*
set every entry to its exponent value (return a XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Exp(const XTensor & a);
/* set every entry to its logarithm value */
void _Log(const XTensor * a, XTensor * b);
/*
set every entry to its logarithm value (do it on site)
keep the result in the input tensor a and return nothing
*/
void _LogMe(XTensor * a);
/*
set every entry to its logarithm value (return a XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Log(const XTensor & a);
/* set every entry to its sine value */
void _Sin(const XTensor * a, XTensor * b);
/*
set every entry to its sine value (do it on site)
keep the result in the input tensor a and return nothing
*/
void _SinMe(XTensor * a);
/*
set every entry to its sine value (return a XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Sin(const XTensor & a);
/* set every entry to its cosine value */
void _Cos(const XTensor * a, XTensor * b);
/*
set every entry to its cosine value (do it on site)
keep the result in the input tensor a and return nothing
*/
void _CosMe(XTensor * a);
/*
set every entry to its cosine value (return a XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Cos(const XTensor & a);
/* set every entry to its tangent value */
void _Tan(const XTensor * a, XTensor * b);
/*
set every entry to its tangent value (do it on site)
keep the result in the input tensor a and return nothing
*/
void _TanMe(XTensor * a);
/*
set every entry to its tangent value (return a XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Tan(const XTensor & a);
}
#endif //end __UNARY_H__
\ No newline at end of file
......@@ -24,12 +24,22 @@
#include "Transpose.h"
#include "Merge.h"
#include "../../XUtility.h"
#include "../../XName.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
tensor transposition of dimensions i and j
b = transposed(a)
For a input tensor a, we tranpose the dimensions i and j of it.
E.g., let a be a tensor of size x * y * z, i = 0, j = 2,
then the output will be a tensor of size z * y * x.
>> a - the input tensor
>> b - the output tensor by transpose tensor a with specified dimensions i and j
>> i - the transposed dimension
>> j - the transposed dimension
*/
void _Transpose(const XTensor * a, XTensor * b, const int i, const int j)
{
......@@ -96,4 +106,52 @@ void _Transpose(const XTensor * a, XTensor * b, const int i, const int j)
}
}
/*
tensor transposition of dimensions i and j (return a XTensor structure).
make a new tensor to keep the result and return it.
b = transposed(a)
For a input tensor a, we tranpose the dimensions i and j of it.
E.g., let a be a tensor of size x * y * z, i = 0, j = 2,
then the output will be a tensor of size z * y * x.
>> a - the input tensor
>> i - the transposed dimension
>> j - the transposed dimension
<< return - the output tensor by transpose tensor a with specified dimensions i and j
*/
XTensor Transpose(const XTensor &a, const int i, const int j)
{
CheckNTErrors(a.order > i && i >= 0, "index of dimension is out of scope!");
CheckNTErrors(a.order > j && j >= 0, "index of dimension is out of scope!");
int order = a.order;
int * dimSize = new int[order];
for(int k = 0; k < order; k++){
if(k == i)
dimSize[k] = a.dimSize[j];
else if(k == j)
dimSize[k] = a.dimSize[i];
else
dimSize[k] = a.dimSize[k];
}
float dr = (!a.isSparse) ? 1.0F : a.denseRatio;
XTensor b(order, dimSize, a.dataType, dr, a.devID, a.mem);
b.SetTMP();
/* call _Transpose function */
_Transpose(&a, &b, i, j);
/* tensor connection */
XLink::MakeLink(&a, NULL, &b, SHAPE_TRANSPOSE);
XLink::AddParamToHeadInt(&b, i);
XLink::AddParamToHeadInt(&b, j);
/* destroy variables */
delete[] dimSize;
return b;
}
}
......@@ -34,13 +34,6 @@ b = transposed(a)
void _Transpose(const XTensor * a, XTensor * b, const int i, const int j);
/*
tensor transposition of dimensions i and j (do this on site)
keep the result in the input tensor and return nothing.
a = transposed(a)
*/
void _TransposeMe(XTensor * a, const int i, const int j);
/*
tensor transposition of dimensions i and j (return a XTensor structure).
make a new tensor to keep the result and return it.
b = transposed(a)
......
......@@ -24,7 +24,7 @@
#include "../XDevice.h"
#include "../core/math/Power.h"
#include "../core/math/ScaleAndShift.h"
#include "../core/math/Log.h"
#include "../core/math/Unary.h"
#include "../core/arithmetic/Negate.h"
#include "../core/arithmetic/Sum.h"
#include "../core/arithmetic/Multiply.h"
......
......@@ -19,6 +19,7 @@
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-12
*/
#include "../core/math/Unary.h"
#include "TAbsolute.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -30,14 +31,14 @@ Set every entry to its absolute value.
bool TestAbsolute1()
{
/* a tensor of size (3, 2) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 3;
aDimSize[1] = 2;
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 3;
dimSize[1] = 2;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE aData[3][2] = { {1.0F, -2.0F},
{0.5F, -4.0F},
......@@ -50,14 +51,14 @@ bool TestAbsolute1()
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(aOrder, aDimSize);
XTensor * aMe = NewTensor(aOrder, aDimSize);
XTensor * a = NewTensor(order, dimSize);
XTensor * b = NewTensor(order, dimSize);
XTensor * aMe = NewTensor(order, dimSize);
XTensor bUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
aMe->SetData(aData, aUnitNum);
a->SetData(aData, unitNum);
aMe->SetData(aData, unitNum);
/* call Absolute function */
_Absolute(a, b);
......@@ -65,21 +66,21 @@ bool TestAbsolute1()
bUser = Absolute(*a);
/* check results */
cpuTest = b->CheckData(answer, aUnitNum, 1e-4F) && aMe->CheckData(answer, aUnitNum, 1e-4F) && bUser.CheckData(answer, aUnitNum, 1e-4F);
cpuTest = b->CheckData(answer, unitNum, 1e-4F) && aMe->CheckData(answer, unitNum, 1e-4F) && bUser.CheckData(answer, unitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * aMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * aGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * aMeGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor bUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
aMeGPU->SetData(aData, aUnitNum);
aGPU->SetData(aData, unitNum);
aMeGPU->SetData(aData, unitNum);
/* call Absolute function */
_Absolute(aGPU, bGPU);
......@@ -87,7 +88,7 @@ bool TestAbsolute1()
bUserGPU = Absolute(*aGPU);
/* check results */
gpuTest = bGPU->CheckData(answer, aUnitNum, 1e-4F) && aMeGPU->CheckData(answer, aUnitNum, 1e-4F) && bUserGPU.CheckData(answer, aUnitNum, 1e-4F);
gpuTest = bGPU->CheckData(answer, unitNum, 1e-4F) && aMeGPU->CheckData(answer, unitNum, 1e-4F) && bUserGPU.CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete a;
......@@ -96,7 +97,7 @@ bool TestAbsolute1()
delete aGPU;
delete bGPU;
delete aMeGPU;
delete[] aDimSize;
delete[] dimSize;
return cpuTest && gpuTest;
#else
......@@ -104,7 +105,7 @@ bool TestAbsolute1()
delete a;
delete b;
delete aMe;
delete[] aDimSize;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
......
......@@ -22,7 +22,6 @@
#ifndef __TEST_ABSOLUTE_H__
#define __TEST_ABSOLUTE_H__
#include "../core/arithmetic/Absolute.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#include "../core/math/Unary.h"
#include "TCos.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: test Cos function.
Set every entry to its cosine value.
*/
bool TestCos1()
{
/* a tensor of size (3, 2) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 3;
dimSize[1] = 2;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE aData[3][2] = { {1.0F, 2.0F},
{-1.0F, -2.0F},
{0.0F, 0.5F} };
DTYPE answer[3][2] = { {0.5403F, -0.4161F},
{0.5403F, -0.4161F},
{1.0F, 0.8776F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(order, dimSize);
XTensor * b = NewTensor(order, dimSize);
XTensor * aMe = NewTensor(order, dimSize);
XTensor bUser;
/* initialize variables */
a->SetData(aData, unitNum);
aMe->SetData(aData, unitNum);
/* call Cos function */
_Cos(a, b);
_CosMe(aMe);
bUser = Cos(*a);
/* check results */
cpuTest = b->CheckData(answer, unitNum, 1e-4F) && aMe->CheckData(answer, unitNum, 1e-4F) && bUser.CheckData(answer, unitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * aMeGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor bUserGPU;
/* Initialize variables */
aGPU->SetData(aData, unitNum);
aMeGPU->SetData(aData, unitNum);
/* call Cos function */
_Cos(aGPU, bGPU);
_CosMe(aMeGPU);
bUserGPU = Cos(*aGPU);
/* check results */
gpuTest = bGPU->CheckData(answer, unitNum, 1e-4F) && aMeGPU->CheckData(answer, unitNum, 1e-4F) && bUserGPU.CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete a;
delete b;
delete aMe;
delete aGPU;
delete bGPU;
delete aMeGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete aMe;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for Cos Function */
bool TestCos()
{
XPRINT(0, stdout, "[TEST Cos] set every entry to its cosine value \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestCos1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#ifndef __TEST_SIN_H__
#define __TEST_SIN_H__
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for Sin Function */
bool TestSin();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_SIN_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#ifndef __TEST_COS_H__
#define __TEST_COS_H__
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for Cos Function */
bool TestCos();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_COS_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#include "TDiv.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: element-wise division of two tensors
c(i) = a(i)/b(i) + \alpha * c(i)
In this case, (2, 2) (2, 2) -> (2, 2), leadingDim=0, alpha=0.
*/
bool TestDiv1()
{
/* a source tensor of size (2, 2) */
int sOrder1 = 2;
int * sDimSize1 = new int[sOrder1];
sDimSize1[0] = 2;
sDimSize1[1] = 2;
int sUnitNum1 = 1;
for (int i = 0; i < sOrder1; i++)
sUnitNum1 *= sDimSize1[i];
/* a source tensor of size (2, 2) */
int sOrder2 = 2;
int * sDimSize2 = new int[sOrder2];
sDimSize2[0] = 2;
sDimSize2[1] = 2;
int sUnitNum2 = 1;
for (int i = 0; i < sOrder2; i++)
sUnitNum2 *= sDimSize2[i];
/* a target tensor of size (2, 2) */
int tOrder = 2;
int * tDimSize = new int[tOrder];
tDimSize[0] = 2;
tDimSize[1] = 2;
int tUnitNum = 1;
for (int i = 0; i < tOrder; i++)
tUnitNum *= tDimSize[i];
DTYPE sData1[2][2] = { {0.0F, 1.0F},
{2.0F, 3.0F} };
DTYPE sData2[2][2] = { {1.0F, 1.0F},
{4.0F, 9.0F} };
DTYPE answer[2][2] = { {0.0F, 1.0F},
{0.5F, 0.3333F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s1 = NewTensor(sOrder1, sDimSize1);
XTensor * s2 = NewTensor(sOrder2, sDimSize2);
XTensor * t = NewTensor(tOrder, tDimSize);
XTensor * tMe = NewTensor(tOrder, tDimSize);
XTensor tUser;
/* initialize variables */
s1->SetData(sData1, sUnitNum1);
tMe->SetData(sData1, sUnitNum1);
s2->SetData(sData2, sUnitNum2);
t->SetZeroAll();
/* call Div function */
_Div(s1, s2, t, 0, 0);
_DivMe(tMe, s2, 0, 0);
tUser = Div(*s1, *s2, 0);
/* check results */
cpuTest = t->CheckData(answer, tUnitNum, 1e-4F) &&
tMe->CheckData(answer, tUnitNum, 1e-4F) &&
tUser.CheckData(answer, tUnitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * sGPU1 = NewTensor(sOrder1, sDimSize1, X_FLOAT, 1.0F, 0);
XTensor * sGPU2 = NewTensor(sOrder2, sDimSize2, X_FLOAT, 1.0F, 0);
XTensor * tGPU = NewTensor(tOrder, tDimSize, X_FLOAT, 1.0F, 0);
XTensor * tMeGPU = NewTensor(tOrder, tDimSize, X_FLOAT, 1.0F, 0);
XTensor tUserGPU;
/* Initialize variables */
sGPU1->SetData(sData1, sUnitNum1);
tMeGPU->SetData(sData1, sUnitNum1);
sGPU2->SetData(sData2, sUnitNum2);
tGPU->SetZeroAll();
/* call Div function */
_Div(sGPU1, sGPU2, tGPU, 0, 0);
_DivMe(tMeGPU, sGPU2, 0, 0);
tUserGPU = Div(*sGPU1, *sGPU2, 0);
/* check results */
gpuTest = tGPU->CheckData(answer, tUnitNum, 1e-4F) &&
tMeGPU->CheckData(answer, tUnitNum, 1e-4F) &&
tUserGPU.CheckData(answer, tUnitNum, 1e-4F);
/* destroy variables */
delete s1;
delete s2;
delete t;
delete tMe;
delete sGPU1;
delete sGPU2;
delete tGPU;
delete tMeGPU;
delete[] sDimSize1;
delete[] sDimSize2;
delete[] tDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s1;
delete s2;
delete t;
delete tMe;
delete[] sDimSize1;
delete[] sDimSize2;
delete[] tDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for Div Function */
bool TestDiv()
{
XPRINT(0, stdout, "[TEST Div] element-wise division of two tensors \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestDiv1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#ifndef __TEST_DIV_H__
#define __TEST_DIV_H__
#include "../core/arithmetic/Div.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for Div Function */
extern "C"
bool TestDiv();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_DIV_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#include "../core/math/Unary.h"
#include "TExp.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: test Exp function.
Set every entry to its exponent value.
*/
bool TestExp1()
{
/* a tensor of size (3, 2) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 3;
dimSize[1] = 2;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE aData[3][2] = { {1.0F, 2.0F},
{-1.0F, -2.0F},
{0.0F, 0.5F} };
DTYPE answer[3][2] = { {2.7183F, 7.3891F},
{0.3679F, 0.1353F},
{1.0F, 1.6487F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(order, dimSize);
XTensor * b = NewTensor(order, dimSize);
XTensor * aMe = NewTensor(order, dimSize);
XTensor bUser;
/* initialize variables */
a->SetData(aData, unitNum);
aMe->SetData(aData, unitNum);
/* call Exp function */
_Exp(a, b);
_ExpMe(aMe);
bUser = Exp(*a);
/* check results */
cpuTest = b->CheckData(answer, unitNum, 1e-4F) && aMe->CheckData(answer, unitNum, 1e-4F) && bUser.CheckData(answer, unitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * aMeGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor bUserGPU;
/* Initialize variables */
aGPU->SetData(aData, unitNum);
aMeGPU->SetData(aData, unitNum);
/* call Exp function */
_Exp(aGPU, bGPU);
_ExpMe(aMeGPU);
bUserGPU = Exp(*aGPU);
/* check results */
gpuTest = bGPU->CheckData(answer, unitNum, 1e-4F) && aMeGPU->CheckData(answer, unitNum, 1e-4F) && bUserGPU.CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete a;
delete b;
delete aMe;
delete aGPU;
delete bGPU;
delete aMeGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete aMe;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for Exp Function */
bool TestExp()
{
XPRINT(0, stdout, "[TEST Exp] set every entry to its exponent value \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestExp1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -16,26 +16,16 @@
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#include "Absolute.h"
#ifndef __TEST_EXP_H__
#define __TEST_EXP_H__
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* test for Exp Function */
bool TestExp();
/* set each entry to its absolute value (CUDA Kernel) */
__global__
void KernelAbsolute(DTYPE * a, DTYPE * b, int size);
/* set each entry to its absolute value (CUDA Kernel) with float16 data type*/
__global__
void KernelAbsolute(__half * a, __half * b, int size);
/* set each entry to its absolute value */
void _CudaAbsolute(const XTensor * a, XTensor * b);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_EXP_H__
......@@ -19,6 +19,7 @@
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-12
*/
#include "../core/math/Unary.h"
#include "TLog.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......@@ -30,14 +31,14 @@ Set every entry to its log value.
bool TestLog1()
{
/* a tensor of size (3, 2) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 3;
aDimSize[1] = 2;
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 3;
dimSize[1] = 2;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE aData[3][2] = { {1.0F, 2.0F},
{0.5F, 4.0F},
......@@ -50,14 +51,14 @@ bool TestLog1()
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(aOrder, aDimSize);
XTensor * aMe = NewTensor(aOrder, aDimSize);
XTensor * a = NewTensor(order, dimSize);
XTensor * b = NewTensor(order, dimSize);
XTensor * aMe = NewTensor(order, dimSize);
XTensor bUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
aMe->SetData(aData, aUnitNum);
a->SetData(aData, unitNum);
aMe->SetData(aData, unitNum);
/* call Log function */
_Log(a, b);
......@@ -65,21 +66,21 @@ bool TestLog1()
bUser = Log(*a);
/* check results */
cpuTest = b->CheckData(answer, aUnitNum, 1e-4F) && aMe->CheckData(answer, aUnitNum, 1e-4F) && bUser.CheckData(answer, aUnitNum, 1e-4F);
cpuTest = b->CheckData(answer, unitNum, 1e-4F) && aMe->CheckData(answer, unitNum, 1e-4F) && bUser.CheckData(answer, unitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * aMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * aGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * aMeGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor bUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
aMeGPU->SetData(aData, aUnitNum);
aGPU->SetData(aData, unitNum);
aMeGPU->SetData(aData, unitNum);
/* call Log function */
_Log(aGPU, bGPU);
......@@ -87,7 +88,7 @@ bool TestLog1()
bUserGPU = Log(*aGPU);
/* check results */
gpuTest = bGPU->CheckData(answer, aUnitNum, 1e-4F) && aMeGPU->CheckData(answer, aUnitNum, 1e-4F) && bUserGPU.CheckData(answer, aUnitNum, 1e-4F);
gpuTest = bGPU->CheckData(answer, unitNum, 1e-4F) && aMeGPU->CheckData(answer, unitNum, 1e-4F) && bUserGPU.CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete a;
......@@ -96,7 +97,7 @@ bool TestLog1()
delete aGPU;
delete bGPU;
delete aMeGPU;
delete[] aDimSize;
delete[] dimSize;
return cpuTest && gpuTest;
#else
......@@ -104,7 +105,7 @@ bool TestLog1()
delete a;
delete b;
delete aMe;
delete[] aDimSize;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
......
......@@ -22,8 +22,6 @@
#ifndef __TEST_LOG_H__
#define __TEST_LOG_H__
#include "../core/math/Log.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for Log Function */
......
......@@ -16,8 +16,8 @@
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-02
*/
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-02
*/
#ifndef __TEST_LOGSOFTMAX_H__
#define __TEST_LOGSOFTMAX_H__
......
......@@ -25,133 +25,10 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: element-wise product of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
In this case, (2, 1) (2, 1) -> (2, 1), leadingDim=0, alpha=0.
*/
bool TestMultiply1()
{
/* a source tensor of size (2, 1) */
int sOrder1 = 2;
int * sDimSize1 = new int[sOrder1];
sDimSize1[0] = 2;
sDimSize1[1] = 1;
int sUnitNum1 = 1;
for (int i = 0; i < sOrder1; i++)
sUnitNum1 *= sDimSize1[i];
/* a source tensor of size (2, 1) */
int sOrder2 = 2;
int * sDimSize2 = new int[sOrder2];
sDimSize2[0] = 2;
sDimSize2[1] = 1;
int sUnitNum2 = 1;
for (int i = 0; i < sOrder2; i++)
sUnitNum2 *= sDimSize2[i];
/* a target tensor of size (2, 1) */
int tOrder = 2;
int * tDimSize = new int[tOrder];
tDimSize[0] = 2;
tDimSize[1] = 1;
int tUnitNum = 1;
for (int i = 0; i < tOrder; i++)
tUnitNum *= tDimSize[i];
DTYPE sData1[2][1] = { {0.0F},
{1.0F} };
DTYPE sData2[2][1] = { {2.0F},
{3.0F} };
DTYPE answer[2][1] = { {0.0F},
{3.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s1 = NewTensor(sOrder1, sDimSize1);
XTensor * s2 = NewTensor(sOrder2, sDimSize2);
XTensor * t = NewTensor(tOrder, tDimSize);
XTensor * tMe = NewTensor(tOrder, tDimSize);
XTensor tUser;
/* initialize variables */
s1->SetData(sData1, sUnitNum1);
tMe->SetData(sData1, sUnitNum1);
s2->SetData(sData2, sUnitNum2);
t->SetZeroAll();
/* call Multiply function */
_Multiply(s1, s2, t, 0, 0);
_MultiplyMe(tMe, s2, 0, 0);
tUser = Multiply(*s1, *s2, 0);
/* check results */
cpuTest = t->CheckData(answer, tUnitNum)
&& tMe->CheckData(answer, tUnitNum) && tUser.CheckData(answer, tUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * sGPU1 = NewTensor(sOrder1, sDimSize1, X_FLOAT, 1.0F, 0);
XTensor * sGPU2 = NewTensor(sOrder2, sDimSize2, X_FLOAT, 1.0F, 0);
XTensor * tGPU = NewTensor(tOrder, tDimSize, X_FLOAT, 1.0F, 0);
XTensor * tMeGPU = NewTensor(tOrder, tDimSize, X_FLOAT, 1.0F, 0);
XTensor tUserGPU;
/* Initialize variables */
sGPU1->SetData(sData1, sUnitNum1);
tMeGPU->SetData(sData1, sUnitNum1);
sGPU2->SetData(sData2, sUnitNum2);
tGPU->SetZeroAll();
/* call Multiply function */
_Multiply(sGPU1, sGPU2, tGPU, 0, 0);
_MultiplyMe(tMeGPU, sGPU2, 0, 0);
tUserGPU = Multiply(*sGPU1, *sGPU2, 0);
/* check results */
gpuTest = tGPU->CheckData(answer, tUnitNum)
&& tMeGPU->CheckData(answer, tUnitNum) && tUserGPU.CheckData(answer, tUnitNum);
/* destroy variables */
delete s1;
delete s2;
delete t;
delete tMe;
delete sGPU1;
delete sGPU2;
delete tGPU;
delete tMeGPU;
delete[] sDimSize1;
delete[] sDimSize2;
delete[] tDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s1;
delete s2;
delete t;
delete tMe;
delete[] sDimSize1;
delete[] sDimSize2;
delete[] tDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 2: element-wise product of two tensors
c(i) = a(i)*b(i) + \alpha * c(i)
In this case, (2, 2) (2, 2) -> (2, 2), leadingDim=0, alpha=0.
*/
bool TestMultiply2()
bool TestMultiply1()
{
/* a source tensor of size (2, 2) */
int sOrder1 = 2;
......@@ -212,8 +89,9 @@ bool TestMultiply2()
tUser = Multiply(*s1, *s2, 0);
/* check results */
cpuTest = t->CheckData(answer, tUnitNum)
&& tMe->CheckData(answer, tUnitNum) && tUser.CheckData(answer, tUnitNum);
cpuTest = t->CheckData(answer, tUnitNum) &&
tMe->CheckData(answer, tUnitNum) &&
tUser.CheckData(answer, tUnitNum);
#ifdef USE_CUDA
/* GPU test */
......@@ -270,113 +148,6 @@ bool TestMultiply2()
#endif // USE_CUDA
}
/*
case 3: element-wise product of two tensors, c(i) = a(i)*b(i) + \alpha * c(i)
In this case, (2, 2) (2, 2) -> (2, 2), leadingDim=1, alpha=0.
*/
bool TestMultiply3()
{
/* a source tensor of size (2, 2) */
int sOrder1 = 2;
int * sDimSize1 = new int[sOrder1];
sDimSize1[0] = 2;
sDimSize1[1] = 2;
int sUnitNum1 = 1;
for (int i = 0; i < sOrder1; i++)
sUnitNum1 *= sDimSize1[i];
/* a source tensor of size (2, 2) */
int sOrder2 = 2;
int * sDimSize2 = new int[sOrder2];
sDimSize2[0] = 2;
sDimSize2[1] = 2;
int sUnitNum2 = 1;
for (int i = 0; i < sOrder2; i++)
sUnitNum2 *= sDimSize2[i];
/* a target tensor of size (2, 2) */
int tOrder = 2;
int * tDimSize = new int[tOrder];
tDimSize[0] = 2;
tDimSize[1] = 2;
int tUnitNum = 1;
for (int i = 0; i < tOrder; i++)
tUnitNum *= tDimSize[i];
DTYPE sData1[2][2] = { {0.0F, 1.0F},
{2.0F, 3.0F} };
DTYPE sData2[2][2] = { {0.0F, 1.0F},
{2.0F, 3.0F} };
DTYPE answer[2][2] = { {0.0F, 1.0F},
{4.0F, 9.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * s1 = NewTensor(sOrder1, sDimSize1);
XTensor * s2 = NewTensor(sOrder2, sDimSize2);
XTensor * t = NewTensor(tOrder, tDimSize);
/* initialize variables */
s1->SetData(sData1, sUnitNum1);
s2->SetData(sData2, sUnitNum2);
t->SetZeroAll();
/* call MultiplyElementWise function */
_Multiply(s1, s2, t, 0, 1);
/* check results */
cpuTest = t->CheckData(answer, tUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * sGPU1 = NewTensor(sOrder1, sDimSize1, X_FLOAT, 1.0F, 0);
XTensor * sGPU2 = NewTensor(sOrder2, sDimSize2, X_FLOAT, 1.0F, 0);
XTensor * tGPU = NewTensor(tOrder, tDimSize, X_FLOAT, 1.0F, 0);
/* Initialize variables */
sGPU1->SetData(sData1, sUnitNum1);
sGPU2->SetData(sData2, sUnitNum2);
tGPU->SetZeroAll();
/* call MultiplyElementWise function */
_Multiply(sGPU1, sGPU2, tGPU, 0, 1);
/* check results */
gpuTest = tGPU->CheckData(answer, tUnitNum);
/* destroy variables */
delete s1;
delete s2;
delete t;
delete sGPU1;
delete sGPU2;
delete tGPU;
delete[] sDimSize1;
delete[] sDimSize2;
delete[] tDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete s1;
delete s2;
delete t;
delete[] sDimSize1;
delete[] sDimSize2;
delete[] tDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
......@@ -398,26 +169,6 @@ bool TestMultiply()
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestMultiply2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* case 3 test */
caseFlag = TestMultiply3();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 3 failed!\n");
}
else
XPRINT(0, stdout, ">> case 3 passed!\n");
/* other cases test */
/*
TODO!!
......
......@@ -19,16 +19,17 @@
* $Created by: Lin Ye (email: linye2015@outlook.com) 2018-06-15
*/
#ifndef __TEST_MULTIPLYELEMENTWISE_H__
#define __TEST_MULTIPLYELEMENTWISE_H__
#ifndef __TEST_MULTIPLY_H__
#define __TEST_MULTIPLY_H__
#include "../core/arithmetic/Multiply.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for MultiplyElementWise Function */
/* test for Multiply Function */
extern "C"
bool TestMultiply();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_MULTIPLYELEMENTWISE_H__
#endif // __TEST_MULTIPLY_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#include "../core/math/Unary.h"
#include "TSin.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: test Sin function.
Set every entry to its sine value.
*/
bool TestSin1()
{
/* a tensor of size (3, 2) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 3;
dimSize[1] = 2;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE aData[3][2] = { {1.0F, 2.0F},
{-1.0F, -2.0F},
{0.0F, 0.5F} };
DTYPE answer[3][2] = { {0.8415F, 0.9093F},
{-0.8415F, -0.9093F},
{0.0F, 0.4794F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(order, dimSize);
XTensor * b = NewTensor(order, dimSize);
XTensor * aMe = NewTensor(order, dimSize);
XTensor bUser;
/* initialize variables */
a->SetData(aData, unitNum);
aMe->SetData(aData, unitNum);
/* call Sin function */
_Sin(a, b);
_SinMe(aMe);
bUser = Sin(*a);
/* check results */
cpuTest = b->CheckData(answer, unitNum, 1e-4F) && aMe->CheckData(answer, unitNum, 1e-4F) && bUser.CheckData(answer, unitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * aMeGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor bUserGPU;
/* Initialize variables */
aGPU->SetData(aData, unitNum);
aMeGPU->SetData(aData, unitNum);
/* call Sin function */
_Sin(aGPU, bGPU);
_SinMe(aMeGPU);
bUserGPU = Sin(*aGPU);
/* check results */
gpuTest = bGPU->CheckData(answer, unitNum, 1e-4F) && aMeGPU->CheckData(answer, unitNum, 1e-4F) && bUserGPU.CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete a;
delete b;
delete aMe;
delete aGPU;
delete bGPU;
delete aMeGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete aMe;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for Sin Function */
bool TestSin()
{
XPRINT(0, stdout, "[TEST Sin] set every entry to its sine value \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestSin1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
......@@ -16,31 +16,16 @@
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#ifndef __LOG_CUH__
#define __LOG_CUH__
#include "Log.h"
#ifndef __TEST_SIN_H__
#define __TEST_SIN_H__
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* set each entry to its log value (CUDA Kernel) */
__global__
void KernelLog(DTYPE * a, DTYPE * b, int size);
/* set each entry to its log value (CUDA Kernel) with float16 data type*/
__global__
void KernelLog(__half * a, __half * b, int size);
/* set each entry to its log value */
void _CudaLog(const XTensor * a, XTensor * b);
#endif // USE_CUDA
/* test for Sin Function */
bool TestSin();
} // namespace nts(NiuTrans.Tensor)
#endif // __LOG_CUH__
\ No newline at end of file
#endif // __TEST_SIN_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#include "TSub.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* case 1: tensor subtraction c = a - b * \beta */
bool TestSub1()
{
/* a tensor of size (2, 4) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 2;
dimSize[1] = 4;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2][4] = { {1.0F, -1.0F, -3.0F, -5.0F},
{-7.0F, -9.0F, -11.0F, -13.0F} };
DTYPE answer[2][4] = { {-1.0F, 2.0F, 5.0F, 8.0F},
{11.0F, 14.0F, 17.0F, 20.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(order, dimSize);
XTensor * b = NewTensor(order, dimSize);
XTensor * c = NewTensor(order, dimSize);
XTensor * cMe = NewTensor(order, dimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, unitNum);
cMe->SetData(aData, unitNum);
b->SetData(bData, unitNum);
c->SetZeroAll();
/* call Sub function */
_Sub(a, b, c);
_SubMe(cMe, b);
cUser = Sub(*a, *b);
/* check results */
cpuTest = c->CheckData(answer, unitNum)
&& cMe->CheckData(answer, unitNum) && cUser.CheckData(answer, unitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, unitNum);
cMeGPU->SetData(aData, unitNum);
bGPU->SetData(bData, unitNum);
cGPU->SetZeroAll();
/* call Sub function */
_Sub(aGPU, bGPU, cGPU);
_SubMe(cMeGPU, bGPU);
cUserGPU = Sub(*aGPU, *bGPU);
/* check results */
gpuTest = cGPU->CheckData(answer, unitNum, 1e-4F)
&& cMeGPU->CheckData(answer, unitNum, 1e-4F) && cUserGPU.CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/* case 2: tensor subtraction c = a - b * \beta */
bool TestSub2()
{
/* a tensor of size (2, 4) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 2;
dimSize[1] = 4;
int unitNum = 1;
for (int i = 0; i < order; i++) {
unitNum *= dimSize[i];
}
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2][4] = { {1.0F, -1.0F, -3.0F, -5.0F},
{-7.0F, -9.0F, -11.0F, -13.0F} };
DTYPE answer[2][4] = { {-0.5F, 1.5F, 3.5F, 5.5F},
{7.5F, 9.5F, 11.5F, 13.5F} };
float beta = 0.5F;
/* CPU test */
bool cpuTest = true;
/* create tensor */
XTensor * a = NewTensor(order, dimSize);
XTensor * b = NewTensor(order, dimSize);
XTensor * c = NewTensor(order, dimSize);
XTensor * cMe = NewTensor(order, dimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, unitNum);
cMe->SetData(aData, unitNum);
b->SetData(bData, unitNum);
c->SetZeroAll();
/* call Sub function */
_Sub(a, b, c, beta);
_SubMe(cMe, b, beta);
cUser = Sub(*a, *b, beta);
/* check results */
cpuTest = c->CheckData(answer, unitNum)
&& cMe->CheckData(answer, unitNum) && cUser.CheckData(answer, unitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, unitNum);
cMeGPU->SetData(aData, unitNum);
bGPU->SetData(bData, unitNum);
cGPU->SetZeroAll();
/* call Sub function */
_Sub(aGPU, bGPU, cGPU, beta);
_SubMe(cMeGPU, bGPU, beta);
cUserGPU = Sub(*aGPU, *bGPU, beta);
/* check results */
gpuTest = cGPU->CheckData(answer, unitNum, 1e-4F)
&& cMeGPU->CheckData(answer, unitNum, 1e-4F) && cUserGPU.CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for Sub Function */
bool TestSub()
{
XPRINT(0, stdout, "[TEST SUB] tensor subtraction c = a - b * beta\n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestSub1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestSub2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-08-01
*/
#ifndef __TEST_SUB_H__
#define __TEST_SUB_H__
#include "../core/arithmetic/Sub.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for Sub Function */
bool TestSub();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_SUB_H__
......@@ -16,8 +16,8 @@
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-04-30
*/
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-04-30
*/
#include "TSum.h"
......@@ -59,14 +59,14 @@ bool TestSum1()
b->SetData(bData, unitNum);
c->SetZeroAll();
/* call sum function */
/* call Sum function */
_Sum(a, b, c);
_SumMe(cMe, b);
cUser = Sum(*a, *b);
/* check results */
cpuTest = c->CheckData(answer, unitNum)
&& cMe->CheckData(answer, unitNum) && cUser.CheckData(answer, unitNum);
&& cMe->CheckData(answer, unitNum) && cUser.CheckData(answer, unitNum);
#ifdef USE_CUDA
/* GPU test */
......@@ -85,14 +85,14 @@ bool TestSum1()
bGPU->SetData(bData, unitNum);
cGPU->SetZeroAll();
/* call sum function */
/* call Sum function */
_Sum(aGPU, bGPU, cGPU);
_SumMe(cMeGPU, bGPU);
cUserGPU = Sum(*aGPU, *bGPU);
/* check results */
gpuTest = cGPU->CheckData(answer, unitNum)
&& cMeGPU->CheckData(answer, unitNum) && cUserGPU.CheckData(answer, unitNum);
&& cMeGPU->CheckData(answer, unitNum) && cUserGPU.CheckData(answer, unitNum);
/* destroy variables */
delete a;
......@@ -155,14 +155,14 @@ bool TestSum2()
b->SetData(bData, unitNum);
c->SetZeroAll();
/* call sum function */
/* call Sum function */
_Sum(a, b, c, beta);
_SumMe(cMe, b, beta);
cUser = Sum(*a, *b, beta);
/* check results */
cpuTest = c->CheckData(answer, unitNum)
&& cMe->CheckData(answer, unitNum) && cUser.CheckData(answer, unitNum);
&& cMe->CheckData(answer, unitNum) && cUser.CheckData(answer, unitNum);
#ifdef USE_CUDA
/* GPU test */
......@@ -181,14 +181,14 @@ bool TestSum2()
bGPU->SetData(bData, unitNum);
cGPU->SetZeroAll();
/* call sum function */
/* call Sum function */
_Sum(aGPU, bGPU, cGPU, beta);
_SumMe(cMeGPU, bGPU, beta);
cUserGPU = Sum(*aGPU, *bGPU, beta);
/* check results */
gpuTest = cGPU->CheckData(answer, unitNum)
&& cMeGPU->CheckData(answer, unitNum) && cUserGPU.CheckData(answer, unitNum);
&& cMeGPU->CheckData(answer, unitNum) && cUserGPU.CheckData(answer, unitNum);
/* destroy variables */
delete a;
......
......@@ -16,8 +16,8 @@
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-04-30
*/
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-04-30
*/
#ifndef __TEST_SUM_H__
#define __TEST_SUM_H__
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-30
*/
#include "TSumDim.h"
#include "../core/arithmetic/SumDim.h"
#include "../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: tensor summation c = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
*/
bool TestSumDim1()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2) */
int bOrder = 1;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2] = {1.0F, -1.0F};
DTYPE answer[2][4] = { {1.0F, 2.0F, 3.0F, 4.0F},
{3.0F, 4.0F, 5.0F, 6.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call SumDim function */
_SumDim(a, b, c, 0);
_SumDim(cMe, b, 0);
cUser = SumDim(*a, *b, 0);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum)
&& cMe->CheckData(answer, aUnitNum)
&& cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call sum function */
_SumDim(aGPU, bGPU, cGPU, 0);
_SumDim(cMeGPU, bGPU, 0);
cUserGPU = SumDim(*aGPU, *bGPU, 0);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum)
&& cMeGPU->CheckData(answer, aUnitNum)
&& cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 2: tensor summation c = a + b * \beta
where the size of b is equal to the n-th dimension of a,
i.e., a is summed with b by broadcasting
*/
bool TestSumDim2()
{
/* a tensor of size (2, 4) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 2;
aDimSize[1] = 4;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2, 2) */
int bOrder = 2;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
bDimSize[1] = 2;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[2][4] = { {0.0F, 1.0F, 2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE bData[2][2] = { {1.0F, -1.0F},
{-1.0F, 1.0F} };
DTYPE answer[2][4] = { {1.0F, 0.0F, 1.0F, 4.0F},
{5.0F, 4.0F, 5.0F, 8.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor * c = NewTensor(aOrder, aDimSize);
XTensor * cMe = NewTensor(aOrder, aDimSize);
XTensor cUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
cMe->SetData(aData, aUnitNum);
b->SetData(bData, bUnitNum);
c->SetZeroAll();
/* call SumDim function */
_SumDim(a, b, c, 1);
_SumDim(cMe, b, 1);
cUser = SumDim(*a, *b, 1);
/* check results */
cpuTest = c->CheckData(answer, aUnitNum)
&& cMe->CheckData(answer, aUnitNum)
&& cUser.CheckData(answer, aUnitNum);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor * cGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * cMeGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor cUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
cMeGPU->SetData(aData, aUnitNum);
bGPU->SetData(bData, bUnitNum);
cGPU->SetZeroAll();
/* call sum function */
_SumDim(aGPU, bGPU, cGPU, 1);
_SumDim(cMeGPU, bGPU, 1);
cUserGPU = SumDim(*aGPU, *bGPU, 1);
/* check results */
gpuTest = cGPU->CheckData(answer, aUnitNum)
&& cMeGPU->CheckData(answer, aUnitNum)
&& cUserGPU.CheckData(answer, aUnitNum);
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete aGPU;
delete bGPU;
delete cGPU;
delete cMeGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete c;
delete cMe;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for SumDim Function */
bool TestSumDim()
{
XPRINT(0, stdout, "[TEST SUMDIM] tensor summation c = a + b * beta by broadcasting\n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestSumDim1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestSumDim2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-30
* I finish my summer holidays and go back to study.
*/
#ifndef __TEST_SUMDIM_H__
#define __TEST_SUMDIM_H__
#include "../core/arithmetic/SumDim.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for SumDim Function */
extern "C"
bool TestSumDim();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_SUMDIM_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#include "../core/math/Unary.h"
#include "TTan.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: test Tan function.
Set every entry to its tangent value.
*/
bool TestTan1()
{
/* a tensor of size (3, 2) */
int order = 2;
int * dimSize = new int[order];
dimSize[0] = 3;
dimSize[1] = 2;
int unitNum = 1;
for (int i = 0; i < order; i++)
unitNum *= dimSize[i];
DTYPE aData[3][2] = { {1.0F, 2.0F},
{-1.0F, -2.0F},
{0.0F, 0.5F} };
DTYPE answer[3][2] = { {1.5574F, -2.1850F},
{-1.5574F, 2.1850F},
{0.0F, 0.5463F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(order, dimSize);
XTensor * b = NewTensor(order, dimSize);
XTensor * aMe = NewTensor(order, dimSize);
XTensor bUser;
/* initialize variables */
a->SetData(aData, unitNum);
aMe->SetData(aData, unitNum);
/* call Tan function */
_Tan(a, b);
_TanMe(aMe);
bUser = Tan(*a);
/* check results */
cpuTest = b->CheckData(answer, unitNum, 1e-4F) && aMe->CheckData(answer, unitNum, 1e-4F) && bUser.CheckData(answer, unitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor * aMeGPU = NewTensor(order, dimSize, X_FLOAT, 1.0F, 0);
XTensor bUserGPU;
/* Initialize variables */
aGPU->SetData(aData, unitNum);
aMeGPU->SetData(aData, unitNum);
/* call Tan function */
_Tan(aGPU, bGPU);
_TanMe(aMeGPU);
bUserGPU = Tan(*aGPU);
/* check results */
gpuTest = bGPU->CheckData(answer, unitNum, 1e-4F) && aMeGPU->CheckData(answer, unitNum, 1e-4F) && bUserGPU.CheckData(answer, unitNum, 1e-4F);
/* destroy variables */
delete a;
delete b;
delete aMe;
delete aGPU;
delete bGPU;
delete aMeGPU;
delete[] dimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete aMe;
delete[] dimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for Tan Function */
bool TestTan()
{
XPRINT(0, stdout, "[TEST Tan] set every entry to its tangent value \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestTan1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-31
*/
#ifndef __TEST_TAN_H__
#define __TEST_TAN_H__
namespace nts { // namespace nts(NiuTrans.Tensor)
/* test for Tan Function */
bool TestTan();
} // namespace nts(NiuTrans.Tensor)
#endif // __TEST_TAN_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-07-12
*/
#include "TTranspose.h"
#include "../core/movement/CopyValues.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
case 1: test Transpose function.
tensor transposition of dimensions i and j
*/
bool TestTranspose1()
{
/* a tensor of size (3, 2) */
int aOrder = 2;
int * aDimSize = new int[aOrder];
aDimSize[0] = 3;
aDimSize[1] = 2;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2, 3) */
int bOrder = 2;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
bDimSize[1] = 3;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[3][2] = { {1.0F, 2.0F},
{3.0F, 4.0F},
{5.0F, 6.0F} };
DTYPE answer[2][3] = { {1.0F, 3.0F, 5.0F},
{2.0F, 4.0F, 6.0F} };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor bUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
/* call Transpose function */
_Transpose(a, b, 0, 1);
bUser = Transpose(*a, 0, 1);
/* check results */
cpuTest = b->CheckData(answer, aUnitNum, 1e-4F)
&& bUser.CheckData(answer, aUnitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor bUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
/* call Transpose function */
_Transpose(aGPU, bGPU, 0, 1);
bUserGPU = Transpose(*aGPU, 0, 1);
/* check results */
gpuTest = bGPU->CheckData(answer, aUnitNum, 1e-4F)
&& bUserGPU.CheckData(answer, aUnitNum, 1e-4F);
/* destroy variables */
delete a;
delete b;
delete aGPU;
delete bGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/*
case 2: test Transpose function.
tensor transposition of dimensions i and j
*/
bool TestTranspose2()
{
/* a tensor of size (4, 3, 2) */
int aOrder = 3;
int * aDimSize = new int[aOrder];
aDimSize[0] = 4;
aDimSize[1] = 3;
aDimSize[2] = 2;
int aUnitNum = 1;
for (int i = 0; i < aOrder; i++)
aUnitNum *= aDimSize[i];
/* a tensor of size (2, 3, 4) */
int bOrder = 3;
int * bDimSize = new int[bOrder];
bDimSize[0] = 2;
bDimSize[1] = 3;
bDimSize[2] = 4;
int bUnitNum = 1;
for (int i = 0; i < bOrder; i++)
bUnitNum *= bDimSize[i];
DTYPE aData[4][3][2] = { { {1.0F, 2.0F},
{3.0F, 4.0F},
{5.0F, 6.0F} },
{ {2.0F, 4.0F},
{4.0F, 7.0F},
{6.0F, 8.0F} },
{ {1.0F, 2.0F},
{3.0F, 4.0F},
{5.0F, 6.0F} },
{ {2.0F, 4.0F},
{4.0F, 7.0F},
{6.0F, 8.0F} },};
DTYPE answer[2][3][4] = { { {1.0F, 2.0F, 1.0F, 2.0F},
{2.0F, 4.0F, 2.0F, 4.0F},
{3.0F, 4.0F, 3.0F, 4.0F} },
{ {4.0F, 7.0F, 4.0F, 7.0F},
{5.0F, 6.0F, 5.0F, 6.0F},
{6.0F, 8.0F, 6.0F, 8.0F} } };
/* CPU test */
bool cpuTest = true;
/* create tensors */
XTensor * a = NewTensor(aOrder, aDimSize);
XTensor * b = NewTensor(bOrder, bDimSize);
XTensor bUser;
/* initialize variables */
a->SetData(aData, aUnitNum);
/* call Transpose function */
_Transpose(a, b, 0, 2);
bUser = Transpose(*a, 0, 2);
/* check results */
cpuTest = b->CheckData(answer, aUnitNum, 1e-4F)
&& bUser.CheckData(answer, aUnitNum, 1e-4F);
#ifdef USE_CUDA
/* GPU test */
bool gpuTest = true;
/* create tensor */
XTensor * aGPU = NewTensor(aOrder, aDimSize, X_FLOAT, 1.0F, 0);
XTensor * bGPU = NewTensor(bOrder, bDimSize, X_FLOAT, 1.0F, 0);
XTensor bUserGPU;
/* Initialize variables */
aGPU->SetData(aData, aUnitNum);
/* call Transpose function */
_Transpose(aGPU, bGPU, 0, 2);
bUserGPU = Transpose(*aGPU, 0, 2);
/* check results */
gpuTest = bGPU->CheckData(answer, aUnitNum, 1e-4F)
&& bUserGPU.CheckData(answer, aUnitNum, 1e-4F);
/* destroy variables */
delete a;
delete b;
delete aGPU;
delete bGPU;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest && gpuTest;
#else
/* destroy variables */
delete a;
delete b;
delete[] aDimSize;
delete[] bDimSize;
return cpuTest;
#endif // USE_CUDA
}
/* other cases */
/*
TODO!!
*/
/* test for Transpose Function */
bool TestTranspose()
{
XPRINT(0, stdout, "[TEST TRANSPOSE] tensor transposition with specified dimensions \n");
bool returnFlag = true, caseFlag = true;
/* case 1 test */
caseFlag = TestTranspose1();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 1 failed!\n");
}
else
XPRINT(0, stdout, ">> case 1 passed!\n");
/* case 2 test */
caseFlag = TestTranspose2();
if (!caseFlag) {
returnFlag = false;
XPRINT(0, stdout, ">> case 2 failed!\n");
}
else
XPRINT(0, stdout, ">> case 2 passed!\n");
/* other cases test */
/*
TODO!!
*/
if (returnFlag) {
XPRINT(0, stdout, ">> All Passed!\n");
}
else
XPRINT(0, stdout, ">> Failed!\n");
XPRINT(0, stdout, "\n");
return returnFlag;
}
} // namespace nts(NiuTrans.Tensor)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论