Commit 07efb5de by xuchen

add backward function of ConverDataType

parent 5f933fc6
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* backward computation for data operation
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-12-26
*/
#include "XNoder.h"
#include "XBackwardData.h"
#include "../tensor/XName.h"
#include "../tensor/XUtility.h"
#include "../tensor/core/CHeader.h"
#include "../tensor/core/getandset/SetData.h"
namespace nts{
/* compute dE/dx of a node */
void XDataGrad::MakeGrad(XTensor * node, bool isEfficent)
{
CheckNTErrors(node->grad != NULL, "No gradient found!");
XLink &income = node->income;
int operID = income.typeID;
if(operID == GETANDSET_CONVERTDATATYPE)
GradConvertDataType(node, isEfficent);
else if(operID == GETANDSET_INDEXTOONEHOT)
GradIndexToOnehot(node, isEfficent);
else if(operID == GETANDSET_ONEHOTTOINDEX)
GradOnehotToIndex(node, isEfficent);
else{
ShowNTErrors("TODO!");
}
}
/* indicates whether the node is for a data operation */
bool XDataGrad::IsDataOP(XTensor * node)
{
XLink &income = node->income;
return (income.typeID & DATA_BASE) != 0;
}
/*
gradient computation for convert datatype
for
b = converdatatype(a)
we have
dE/da = convertdatatype(b)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XDataGrad::GradConvertDataType(XTensor * node, bool isEfficent)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum > 0, "Wrong input tensor number for ConvertDataType!");
XTensor * input = income.tails[0];
XNoder::MakeGrad(input);
_ConvertDataType(node->grad, input->grad);
}
/*
gradient computation for OnehotToIndex
for
b = OnehotToIndex(a)
we have
dE/da = IndexToOnehot(b)
>> node - the node (c) for backward computation
>> isEfficient - indicates whether the computation is in
an efficient manner
*/
void XDataGrad::GradIndexToOnehot(XTensor * node, bool isEfficent)
{
XLink &income = node->income;
CheckNTErrors(income.tailNum > 0, "Wrong input tensor number for IndexToOnehot!");
XTensor * input = income.tails[0];
XNoder::MakeGrad(input);
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* backward computation for data operation
* $Created by: Xu Chen (email: hello_master1954@163.com) 2018-12-26
*/
#include "../tensor/XTensor.h"
#include "../tensor/function/FHeader.h"
#ifndef __XBACKWARDDATA_H__
#define __XBACKWARDDATA_H__
namespace nts{
/* this class computes the gradient for tensor data operation given a node */
class XDataGrad
{
public:
/* compute dE/dx of a node */
static
void MakeGrad(XTensor * node, bool isEfficent);
/* indicates whether the node is for a shaping operation */
static
bool IsDataOP(XTensor * node);
private:
/* gradient computation for ConverDataType: b = converdatatype(a, datatype) */
static
void GradConvertDataType(XTensor * node, bool isEfficent);
/* gradient computation for IndexToOnehot: b = indextoonehot(a, num) */
static
void GradIndexToOnehot(XTensor * node, bool isEfficent);
/* gradient computation for OnehotToIndex: b = onehottoindex(a, num) */
static
void GradOnehotToIndex(XTensor * node, bool isEfficent);
};
} // namespace nts(NiuTrans.Tensor)
#endif
\ No newline at end of file
......@@ -62,7 +62,7 @@ void XShapeGrad::MakeGrad(XTensor * node, bool isEfficent)
}
}
/* indicates whether the node is for a math operation */
/* indicates whether the node is for a shape operation */
bool XShapeGrad::IsShapeOP(XTensor * node)
{
XLink &income = node->income;
......
......@@ -97,7 +97,13 @@ const char * GetOPName(int type)
return "R_REDUCEVARIANCE";
}
else if ((type & DATA_BASE) != 0){
if (type == GETANDSET_SELECT)
if (type == GETANDSET_CONVERTDATATYPE)
return "G_CONVERTDATATYPE";
else if (type == GETANDSET_INDEXTOONEHOT)
return "G_INDEXTOONEHOT";
else if (type == GETANDSET_ONEHOTTOINDEX)
return "G_ONEHOTTOINDEX";
else if (type == GETANDSET_SELECT)
return "G_SELECT";
else if (type == MOVEMENT_COPYINDEXED)
return "M_COPYINDEXED";
......
......@@ -72,7 +72,10 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* data and shape related operations */
#define DATA_BASE MATH_BASE * 2
#define GETANDSET DATA_BASE + 1
#define GETANDSET_SELECT GETANDSET + 1
#define GETANDSET_CONVERTDATATYPE GETANDSET + 1
#define GETANDSET_INDEXTOONEHOT GETANDSET_CONVERTDATATYPE + 1
#define GETANDSET_ONEHOTTOINDEX GETANDSET_INDEXTOONEHOT + 1
#define GETANDSET_SELECT GETANDSET_ONEHOTTOINDEX + 1
#define MOVEMENT GETANDSET_SELECT + 1
#define MOVEMENT_COPYINDEXED MOVEMENT + 1
......
......@@ -20,13 +20,16 @@
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "ConvertDataType.h"
#include "ConvertDataType.cuh"
#include "../movement/CopyValues.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
convert data type
>> input - input tensor
>> output - output tensor
*/
......@@ -60,28 +63,34 @@ void _ConvertDataType(const XTensor * input, XTensor * output)
}
/*
convert data type (return an XTensor structure)
make a new tensor to keep the result and return it
>> input - input tensor
>> output - output tensor
<< return - output tensor with the specified data type
*/
XTensor ConvertDataType(const XTensor & input, TENSOR_DATA_TYPE dataType)
{
if (input.dataType == dataType) {
XTensor output;
output = CopyValues(input);
return output;
}
int order = input.order;
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
XTensor output(order, input.dimSize, dataType, dr, input.devID, input.mem);
output.SetTMPFlag();
_Gather(&s, &t, &index);
_ConvertDataType(&input, &output);
/* tensor connection */
XLink::MakeLink(&s, &index, &t, MOVEMENT_GATHER);
}
XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE);
return output;
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
......@@ -31,7 +31,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
void _ConvertDataType(const XTensor * input, XTensor * output);
/* convert data type (return an XTensor structure) */
XTensor ConvertDataType(const XTensor * input, TENSOR_DATA_TYPE dataType);
XTensor ConvertDataType(const XTensor & input, TENSOR_DATA_TYPE dataType);
} // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论