Commit f12ced64 by liyinqiao

1. Bug fixed in Loss function; 2.Bug fixed in set data random; 3. Add Log function

parent 85ee1664
......@@ -439,14 +439,14 @@ void XTensor::SetDataRand(DTYPE lower, DTYPE upper)
if (dataType == X_FLOAT) {
d = new float[unitNum];
for (int i = 0; i < unitNum; i++) {
DTYPE value = lower + upper * (float)rand() / RAND_MAX;
DTYPE value = lower + (upper - lower) * (float)rand() / RAND_MAX;
*((float*)d + i) = value;
}
}
else if (dataType == X_DOUBLE) {
d = new double[unitNum];
for (int i = 0; i < unitNum; i++) {
*((double*)d + i) = rand() / RAND_MAX;
*((double*)d + i) = lower + (upper - lower) * rand() / RAND_MAX;
}
}
else {
......@@ -813,8 +813,10 @@ set the value of a cell
>> index - index of the cell for each dimension
>>
*/
bool XTensor::Set(DTYPE value, int * index, int size)
bool XTensor::Set(DTYPE value, int index[], int size)
{
CheckNTErrors((dataType == DEFAULT_DTYPE), "The tensor is not in default type.");
return SetToDevice(devID, GetCell(index, size), value);
}
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include "../../XTensor.h"
#include "Log.h"
#include "Log.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
set every entry to its log value
>> a - the tensor we are processing
*/
void Log(XTensor * a)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
CudaLog(a);
return;
}
#endif
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
DTYPE * d = (DTYPE*)a->data;
for (int i = 0; i < a->unitNum; i++)
d[i] = (DTYPE)log(d[i]);
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "Log.h"
#include "Log.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
set each entry to its log value (CUDA Kernel)
>> d - pointer to the data array
>> size - size of the data array
*/
__global__
void KernelLog(DTYPE * d, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
d[i] = log(d[i]);
}
/*
set each entry to its log value (CUDA Kernel)
This is for float16 computation
>> d - pointer to the data array
>> size - size of the data array
*/
__global__
void KernelLog(__half * d, int size)
{
return;
}
/*
set each entry to its log value
>> a - the tensor
*/
extern "C"
void CudaLog(XTensor * a)
{
CheckNTErrors((a->isSparse == false), "TODO!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
KernelLog << <blocks, threads >> >((DTYPE*)a->data, a->unitNum);
}
else if (a->dataType == X_FLOAT16) {
KernelLog << <blocks, threads >> >((__half*)a->data, a->unitNum);
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include "Log.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* set each entry to its log value (CUDA Kernel) */
__global__
void KernelLog(DTYPE * d, int size);
/* set each entry to its log value (CUDA Kernel) with float16 data type*/
__global__
void KernelLog(__half * d, int size);
/* set each entry to its log value */
extern "C"
void CudaLog(XTensor * a);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#ifndef __LOG_H__
#define __LOG_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its log value */
extern "C"
void Log(XTensor * a);
} // namespace nts(NiuTrans.Tensor)
#endif // __LOG_H__
......@@ -22,6 +22,14 @@
#include "Loss.h"
#include "Loss.cuh"
#include "../XDevice.h"
#include "../core/math/Power.h"
#include "../core/math/ScaleAndShift.h"
#include "../core/arithmetic/Log.h"
#include "../core/arithmetic/Negate.h"
#include "../core/arithmetic/Sum.h"
#include "../core/arithmetic/Multiply.h"
#include "../core/reduce/ReduceSum.h"
#include "../core/movement/CopyValues.h"
namespace nts{ // namespace nts(NiuTrans.Tensor)
......@@ -46,7 +54,126 @@ compute the loss
DTYPE CudaLossCompute(XTensor * gold, XTensor * y, LOSS_FUNCTION_NAME LFName,
bool isLogOutput, int leadDim, int gBeg, int gLen, int yBeg)
{
return 0;
CheckNTErrors((gLen >= 0 && gLen <= y->unitNum), "Illegal input length!");
CheckNTErrors((XTensor::IsIdentical(gold, y)), "The input tensors must be of the same size!");
CheckNTErrors((gold->dimSizeRDI[0] == 1 && y->dimSizeRDI[0] == 1), "TODO!");
CheckNTErrors((gold->order > leadDim && leadDim >= 0), "Illegal leading dimension!");
CheckNTErrors((gold->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE),
"TODO!");
CheckNTErrors((gold->devID == y->devID), "Tensors must be on the same device!");
CheckNTErrors((gold->devID >= 0), "Tensors must be on GPU device!");
CheckNTErrors((gLen == gold->dimSize[leadDim] && gBeg == 0 && yBeg == 0), "TODO!");
if(isLogOutput)
return LossComputeForLogScale(gold, y, LFName, leadDim, gBeg, gLen, yBeg);
DTYPE error = 0.0F;
/*
squared error
loss = sum_{i} 0.5*(gold_i - output_i)^2
where gold_i is the gold standard and output_i is the model prediction
*/
if(LFName == SQUAREDERROR){
XTensor * diff = NewTensor(gold->order, gold->dimSize, gold->dataType, gold->denseRatio, gold->devID, gold->mem);
_Sum(gold, y, diff, -1.0F);
Power(diff, 2.0F);
ScaleAndShift(diff, 0.5F, 0.0F);
int reduceTimes = diff->order;
for (int i = 0; i < reduceTimes; i++) {
int diffOrder = diff->order - 1;
int * diffDimSize = new int[diffOrder];
memcpy(diffDimSize, diff->dimSize + 1, diffOrder * sizeof(int));
XTensor * diffNew = NewTensor(diffOrder, diffDimSize, X_FLOAT, 1.0F, diff->devID, diff->mem);
int reducePlace = diff->dimSize[0] == 1 ? 1 : 0;
ReduceSum(diff, diffNew, reducePlace);
if (diffNew->order == 1) {
diffNew->order = 2;
diffNew->dimSize[1] = diffNew->dimSize[0];
diffNew->dimSize[0] = 1;
diffNew->dimSizeRDI[1] = 1;
}
delete diff;
diff = diffNew;
delete diffDimSize;
}
error = diff->Get2D(0, 0);
delete diff;
}
/*
cross entropy
loss = sum_{i} (-gold_i * log(output_i))
where gold and output are distributions
*/
if(LFName == CROSSENTROPY){
XTensor * diff = NewTensor(y->order, y->dimSize, y->dataType, y->denseRatio, y->devID, y->mem);
CopyValues(y, diff);
Log(diff);
Multiply(gold, diff, diff);
Negate(diff);
int reduceTimes = diff->order;
for (int i = 0; i < reduceTimes; i++) {
int diffOrder = diff->order - 1;
int * diffDimSize = new int[diffOrder];
memcpy(diffDimSize, diff->dimSize + 1, diffOrder * sizeof(int));
XTensor * diffNew = NewTensor(diffOrder, diffDimSize, X_FLOAT, 1.0F, diff->devID, diff->mem);
int reducePlace = diff->dimSize[0] == 1 ? 1 : 0;
ReduceSum(diff, diffNew, reducePlace);
if (diffNew->order == 1) {
diffNew->order = 2;
diffNew->dimSize[1] = diffNew->dimSize[0];
diffNew->dimSize[0] = 1;
diffNew->dimSizeRDI[1] = 1;
}
delete diff;
diff = diffNew;
delete diffDimSize;
}
error = diff->Get2D(0, 0);
delete diff;
}
/*
one hot error
loss = sum_{i} e_i
where e_i = 0.5*(t_i - y_i)^2 if t_i = 1,
e_i = 0 otherwise
*/
if(LFName == ONEHOTERROR){
XTensor * diff = NewTensor(gold->order, gold->dimSize, gold->dataType, gold->denseRatio, gold->devID, gold->mem);
XTensor * yOnehot = NewTensor(y->order, y->dimSize, y->dataType, y->denseRatio, y->devID, y->mem);
CopyValues(y, yOnehot);
Multiply(gold, y, yOnehot);
_Sum(gold, yOnehot, diff, -1.0F);
Power(diff, 2.0F);
ScaleAndShift(diff, 0.5F, 0.0F);
int reduceTimes = diff->order;
for (int i = 0; i < reduceTimes; i++) {
int diffOrder = diff->order - 1;
int * diffDimSize = new int[diffOrder];
memcpy(diffDimSize, diff->dimSize + 1, diffOrder * sizeof(int));
XTensor * diffNew = NewTensor(diffOrder, diffDimSize, X_FLOAT, 1.0F, diff->devID, diff->mem);
int reducePlace = diff->dimSize[0] == 1 ? 1 : 0;
ReduceSum(diff, diffNew, reducePlace);
if (diffNew->order == 1) {
diffNew->order = 2;
diffNew->dimSize[1] = diffNew->dimSize[0];
diffNew->dimSize[0] = 1;
diffNew->dimSizeRDI[1] = 1;
}
delete diff;
diff = diffNew;
delete diffDimSize;
}
error = diff->Get2D(0, 0);
delete diff;
delete yOnehot;
}
return error;
// TODO: call cuda kernels for computing the errors
}
......@@ -140,13 +267,25 @@ backward compuation for cross entropy (Cuda kernel)
>> size - size of the vector (dedy)
*/
extern "C" __global__
void KernelLossBackwardCrossEntropy(DTYPE * dedy, DTYPE * t, DTYPE * y, int size)
void KernelLossBackwardCrossEntropy(DTYPE * dedy, DTYPE * t, DTYPE * y, int tBeg, int tLen, int yBeg, int blockNum, int stride, int dimensionSize)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i > stride * dimensionSize * blockNum)
return;
if (i < size){
int blockNumIndex = i / (stride * dimensionSize);
int blockNumTail = i % (stride * dimensionSize);
int dimensionSizeIndex = blockNumTail / stride;
int strideIndex = blockNumTail % stride;
if (dimensionSizeIndex >= tLen)
return;
dedy[blockNumIndex * stride * dimensionSize + strideIndex + stride * (yBeg + dimensionSizeIndex)] = -t[blockNumIndex * stride * dimensionSize +
strideIndex + stride * (tBeg + dimensionSizeIndex)] / y[blockNumIndex * stride * dimensionSize + strideIndex + stride * (yBeg + dimensionSizeIndex)];
/*if (i < size){
dedy[i] = -t[i]/y[i];
}
}*/
}
/*
......@@ -193,9 +332,11 @@ void CudaLossBackward(XTensor * dedy, XTensor * t, XTensor * y,
LOSS_FUNCTION_NAME LFName,
int leadDim, int tBeg, int tLen, int yBeg)
{
CheckNTErrors((tLen <= y->unitNum), "Illegal input length!");
CheckNTErrors((XTensor::IsIdentical(t, y)&& XTensor::IsIdentical(dedy, y)),
"The input tensors must be of the same size!");
CheckNTErrors((t->dimSizeRDI[0] == 1 && y->dimSizeRDI[0] == 1 && dedy->dimSizeRDI[1] == 1), "TODO!");
CheckNTErrors(((dedy->devID == t->devID) && (dedy->devID == y->devID)), "Tensor must be on the same device!");
CheckNTErrors((t->order > leadDim), "Illegal leading dimension!");
CheckNTErrors((t->dataType == DEFAULT_DTYPE &&
y->dataType == DEFAULT_DTYPE &&
dedy->dataType == DEFAULT_DTYPE),
......@@ -208,21 +349,25 @@ void CudaLossBackward(XTensor * dedy, XTensor * t, XTensor * y,
"The vectors must be on the same GPU.");
CheckNTErrors((tBeg == yBeg), "TODO!");
int leadDimRDI = y->order - leadDim - 1;
int leadDimRDI = leadDim >= 0 ? y->order - leadDim - 1 : -1;
if(leadDimRDI < 0){
leadDimRDI = y->dimSizeRDI[y->order - 1];
leadDimRDI = y->order - 1;
tBeg = 0;
yBeg = 0;
tLen = y->dimSizeRDI[leadDimRDI];
}
int dimensionSize = y->dimSizeRDI[leadDimRDI];
int stride = 1;
int blockSize = 1;
int blockNum = 1;
int size = 1;
for(int i = 0; i < leadDimRDI; i++)
stride *= y->dimSizeRDI[i];
size = tLen * stride;
blockSize = stride * dimensionSize;
blockNum = y->unitNum / blockSize;
int cudaGridSize[3], cudaBlockSize[3];
......@@ -265,7 +410,7 @@ void CudaLossBackward(XTensor * dedy, XTensor * t, XTensor * y,
ShowNTErrors("TODO!");
}
else if(size == y->unitNum){
KernelLossBackwardCrossEntropy<<<blocks, threads>>>(dedyp, tp, yp, tLen);
KernelLossBackwardCrossEntropy<<<blocks, threads>>>(dedyp, tp, yp, tBeg, tLen, yBeg, blockNum, stride, dimensionSize);
}
else{
KernelLossBackwardCrossEntropyBlock<<<blocks, threads>>>(dedyp, tp, yp, blockSize, tBeg * stride, tLen * stride, y->unitNum);
......
......@@ -97,7 +97,7 @@ void KernelRectifyBackward(DTYPE * dedy, DTYPE * dedx, DTYPE * gold, DTYPE * y,
if (i < size){
DTYPE s = x[i];
if(s >= 0)
dedx[i] = 1;
dedx[i] = dedy[i];
else
dedx[i] = 0;
}
......
......@@ -248,7 +248,8 @@ void CudaSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
"Unknown loss function.");
if(lossName == CROSSENTROPY || lossName == SQUAREDERROR){
ShowNTErrors("TODO!");
//ShowNTErrors("TODO!");
_Sum(y, gold, dedx, -1.0F);
}
else if(lossName == ONEHOTERROR){
ShowNTErrors("TODO!");
......
......@@ -110,9 +110,9 @@ bool TestIdentity2()
for (int i = 0; i < sOrder; i++)
sUnitNum *= sDimSize[i];
DTYPE xData[1][3] = { {0.0F, 1.0F, 2.0F} };
DTYPE xData[1][3] = { {1.0F, 1.0F, 2.0F} };
DTYPE gData[1][3] = { {0.0F, 0.0F, 1.0F} };
DTYPE dedxAnswer[3] = {0.090031F, 0.244728F, -0.334759F};
DTYPE dedxAnswer[3] = {0.0F, 0.0F, -0.5F};
/* CPU test */
bool cpuTest = true;
......
......@@ -106,18 +106,19 @@ log softmax: y_i = log(e^{x_i} / \sum_{k} e^{x_k})
bool TestLogSoftmax2()
{
/* a input tensor of size (3) */
int sOrder = 1;
int sOrder = 2;
int * sDimSize = new int[sOrder];
sDimSize[0] = 3;
sDimSize[0] = 1;
sDimSize[1] = 3;
int sUnitNum = 1;
for (int i = 0; i < sOrder; i++)
sUnitNum *= sDimSize[i];
DTYPE xData[3] = {0.0F, 1.0F, 2.0F};
DTYPE gData[3] = {0.5F, 0.8F, 1.5F};
DTYPE yAnswer[3] = {-2.4076F, -1.4076F, -0.4076F};
DTYPE dedxAnswer[3] = {-0.409969F, -0.555272F, -0.834759F};
DTYPE xData[1][3] = { {0.0F, 1.0F, 2.0F} };
DTYPE gData[1][3] = {0.5F, 0.8F, 1.5F};
DTYPE yAnswer[1][3] = {-2.4076F, -1.4076F, -0.4076F};
DTYPE dedxAnswer[1][3] = {-0.409969F, -0.555272F, -0.834759F};
/* CPU test */
bool cpuTest = true;
......@@ -137,7 +138,7 @@ bool TestLogSoftmax2()
dedy->SetZeroAll();
/* call LogSoftmax function */
LogSoftmax(x, y, 0);
LogSoftmax(x, y, 1);
/* call LogSoftmaxBackward function */
LogSoftmaxBackward(g, y, x, dedy, dedx, 0, CROSSENTROPY);
......@@ -164,7 +165,7 @@ bool TestLogSoftmax2()
dedyGPU->SetZeroAll();
/* call LogSoftmax function */
LogSoftmax(xGPU, yGPU, 0);
LogSoftmax(xGPU, yGPU, 1);
/* call LogSoftmaxBackward function */
LogSoftmaxBackward(gGPU, yGPU, xGPU, dedyGPU, dedxGPU, 0, CROSSENTROPY);
......
......@@ -21,6 +21,7 @@
#include "../core/math/ScaleAndShift.h"
#include "../function/Loss.h"
#include "TLoss.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
......
......@@ -70,7 +70,7 @@ bool TestSigmoid1()
SigmoidBackward(g, y, x, dedy, dedx, NOLOSS);
/* check result */
cpuTest = y->CheckData(yAnswer, sUnitNum) && dedx->CheckData(dedxAnswer, sUnitNum);
cpuTest = y->CheckData(yAnswer, sUnitNum, 0.001F) && dedx->CheckData(dedxAnswer, sUnitNum, 0.001F);
#ifdef USE_CUDA
/* GPU test */
......@@ -97,7 +97,7 @@ bool TestSigmoid1()
SigmoidBackward(gGPU, yGPU, xGPU, dedyGPU, dedxGPU, NOLOSS);
/* check result */
gpuTest = yGPU->CheckData(yAnswer, sUnitNum) && dedxGPU->CheckData(dedxAnswer, sUnitNum);
gpuTest = yGPU->CheckData(yAnswer, sUnitNum, 0.001F) && dedxGPU->CheckData(dedxAnswer, sUnitNum, 0.001F);
/* destroy variables */
delete x;
......
......@@ -79,7 +79,7 @@ bool TestSoftmax1()
Softmax(xGPU, yGPU, 1);
/* check result */
gpuTest = yGPU->CheckData(answer, sUnitNum);
gpuTest = yGPU->CheckData(answer, sUnitNum, 0.001F);
/* destroy variables */
delete x;
......@@ -141,7 +141,7 @@ bool TestSoftmax2()
SoftmaxBackward(g, y, x, dedy, dedx, 1, CROSSENTROPY);
/* check result */
cpuTest = dedx->CheckData(dedxAnswer, sUnitNum);
cpuTest = dedx->CheckData(dedxAnswer, sUnitNum, 0.001F);
#ifdef USE_CUDA
/* GPU test */
......@@ -168,7 +168,7 @@ bool TestSoftmax2()
SoftmaxBackward(gGPU, yGPU, xGPU, dedyGPU, dedxGPU, 1, CROSSENTROPY);
/* check result */
gpuTest = dedxGPU->CheckData(dedxAnswer, sUnitNum);
gpuTest = dedxGPU->CheckData(dedxAnswer, sUnitNum, 0.001F);
/* destroy variables */
delete x;
......
......@@ -61,13 +61,13 @@ bool Test()
wrong = !TestUnsqueeze() || wrong;
wrong = !TestXMem() || wrong;
//wrong = !TestHardTanH() || wrong;
//wrong = !TestIdentity() || wrong;
//wrong = !TestLogSoftmax() || wrong;
//wrong = !TestLoss() || wrong;
//wrong = !TestRectify() || wrong;
//wrong = !TestSigmoid() || wrong;
//wrong = !TestSoftmax() || wrong;
wrong = !TestHardTanH() || wrong;
wrong = !TestIdentity() || wrong;
wrong = !TestLogSoftmax() || wrong;
wrong = !TestLoss() || wrong;
wrong = !TestRectify() || wrong;
wrong = !TestSigmoid() || wrong;
wrong = !TestSoftmax() || wrong;
/* other test */
/*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论