Commit 876c85c0 by ltb

add a macro GETANDSET_CONVERTDATATYPE which is used in ConvertDataType

parent 953421c3
...@@ -15,164 +15,158 @@ ...@@ -15,164 +15,158 @@
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05
*/ */
#include "XName.h" #include "XName.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* get operator name */ /* get operator name */
const char * GetOPName(int type) const char * GetOPName(int type)
{ {
if ((type & MATH_BASE) != 0){ if ((type & MATH_BASE) != 0) {
if (type == MATH_ABSOLUTE) if (type == MATH_ABSOLUTE)
return "M_ABSOLUTE"; return "M_ABSOLUTE";
else if (type == MATH_CEIL) else if (type == MATH_CEIL)
return "M_CEIL"; return "M_CEIL";
else if (type == MATH_EXP) else if (type == MATH_EXP)
return "M_EXP"; return "M_EXP";
else if (type == MATH_FLOOR) else if (type == MATH_FLOOR)
return "M_FLOOR"; return "M_FLOOR";
else if (type == MATH_ISNONZERO) else if (type == MATH_ISNONZERO)
return "M_ISNONZERO"; return "M_ISNONZERO";
else if (type == MATH_ISZERO) else if (type == MATH_ISZERO)
return "M_ISZERO"; return "M_ISZERO";
else if (type == MATH_LOG) else if (type == MATH_LOG)
return "M_LOG"; return "M_LOG";
else if (type == MATH_SQRT) else if (type == MATH_SQRT)
return "M_SQRT"; return "M_SQRT";
else if (type == MATH_SQUARE) else if (type == MATH_SQUARE)
return "M_SQUARE"; return "M_SQUARE";
else if (type == MATH_SIN) else if (type == MATH_SIN)
return "M_SIN"; return "M_SIN";
else if (type == MATH_COS) else if (type == MATH_COS)
return "M_COS"; return "M_COS";
else if (type == MATH_TAN) else if (type == MATH_TAN)
return "M_TAN"; return "M_TAN";
else if (type == MATH_ROUND) else if (type == MATH_ROUND)
return "M_ROUND"; return "M_ROUND";
else if (type == MATH_CLIP) else if (type == MATH_CLIP)
return "M_CLIP"; return "M_CLIP";
else if (type == MATH_DIV) else if (type == MATH_DIV)
return "M_DIV"; return "M_DIV";
else if (type == MATH_DIVDIM) else if (type == MATH_DIVDIM)
return "M_DIVDIM"; return "M_DIVDIM";
else if (type == MATH_MATRIXMUL) else if (type == MATH_MATRIXMUL)
return "M_MATRIXMUL"; return "M_MATRIXMUL";
else if (type == MATH_MATRIXMULBATCHED) else if (type == MATH_MATRIXMULBATCHED)
return "M_MATRIXMULBATCHED"; return "M_MATRIXMULBATCHED";
else if (type == MATH_MULTIPLY) else if (type == MATH_MULTIPLY)
return "M_MULTIPLY"; return "M_MULTIPLY";
else if (type == MATH_MULTIPLYDIM) else if (type == MATH_MULTIPLYDIM)
return "M_MULTIPLYDIM"; return "M_MULTIPLYDIM";
else if (type == MATH_MULTIPLYBROADCAST) else if (type == MATH_MULTIPLYBROADCAST)
return "M_MULTIPLYBROADCAST"; return "M_MULTIPLYBROADCAST";
else if (type == MATH_NEGATE) else if (type == MATH_NEGATE)
return "M_NEGATE"; return "M_NEGATE";
else if (type == MATH_NORMALIZE) else if (type == MATH_NORMALIZE)
return "M_NORMALIZE"; return "M_NORMALIZE";
else if (type == MATH_POWER) else if (type == MATH_POWER)
return "M_POWER"; return "M_POWER";
else if (type == MATH_SCALEANDSHIFT) else if (type == MATH_SCALEANDSHIFT)
return "M_SCALEANDSHIFT"; return "M_SCALEANDSHIFT";
else if (type == MATH_SCALE) else if (type == MATH_SCALE)
return "M_SCALE"; return "M_SCALE";
else if (type == MATH_DESCALE) else if (type == MATH_DESCALE)
return "M_DESCALE"; return "M_DESCALE";
else if (type == MATH_SHIFT) else if (type == MATH_SHIFT)
return "M_SHIFT"; return "M_SHIFT";
else if (type == MATH_MULANDSHIFT) else if (type == MATH_MULANDSHIFT)
return "M_OPERATION"; return "M_OPERATION";
else if (type == MATH_SIGN) else if (type == MATH_SIGN)
return "M_SIGN"; return "M_SIGN";
else if (type == MATH_SUB) else if (type == MATH_SUB)
return "M_SUB"; return "M_SUB";
else if (type == MATH_SUBDIM) else if (type == MATH_SUBDIM)
return "M_SUBDIM"; return "M_SUBDIM";
else if (type == MATH_SUM) else if (type == MATH_SUM)
return "M_SUM"; return "M_SUM";
else if (type == MATH_SUMDIM) else if (type == MATH_SUMDIM)
return "M_SUMDIM"; return "M_SUMDIM";
else if (type == MATH_SUMBROADCAST) else if (type == MATH_SUMBROADCAST)
return "M_SUMBROADCAST"; return "M_SUMBROADCAST";
else if (type == REDUCE_REDUCEMAX) else if (type == REDUCE_REDUCEMAX)
return "R_REDUCEMAX"; return "R_REDUCEMAX";
else if (type == REDUCE_REDUCEMEAN) else if (type == REDUCE_REDUCEMEAN)
return "R_REDUCEMEAN"; return "R_REDUCEMEAN";
else if (type == REDUCE_REDUCESUM) else if (type == REDUCE_REDUCESUM)
return "R_REDUCESUM"; return "R_REDUCESUM";
else if (type == REDUCE_REDUCESUMSQUARED) else if (type == REDUCE_REDUCESUMSQUARED)
return "R_REDUCESUMSQUARED"; return "R_REDUCESUMSQUARED";
else if (type == REDUCE_REDUCEVARIANCE) else if (type == REDUCE_REDUCEVARIANCE)
return "R_REDUCEVARIANCE"; return "R_REDUCEVARIANCE";
} }
else if ((type & DATA_BASE) != 0) { else if ((type & DATA_BASE) != 0) {
if (type == GETANDSET_CONVERTDATATYPE) if (type == GETANDSET_SELECT)
return "G_CONVERTDATATYPE"; return "G_SELECT";
else if (type == GETANDSET_INDEXTOONEHOT) else if (type == MOVEMENT_COPYINDEXED)
return "G_INDEXTOONEHOT"; return "M_COPYINDEXED";
else if (type == GETANDSET_ONEHOTTOINDEX) else if (type == MOVEMENT_COPYVALUES)
return "G_ONEHOTTOINDEX"; return "M_COPYVALUES";
else if (type == GETANDSET_SELECT) else if (type == MOVEMENT_GATHER)
return "G_SELECT"; return "M_GATHER";
else if (type == MOVEMENT_COPYINDEXED) else if (type == MOVEMENT_DROPOUTWITHINDEX)
return "M_COPYINDEXED"; return "M_DROPOUTWITHINDEX";
else if (type == MOVEMENT_COPYVALUES) else if (type == SHAPE_CONCATENATE)
return "M_COPYVALUES"; return "S_CONCATENATE";
else if (type == MOVEMENT_GATHER) else if (type == SHAPE_MERGE)
return "M_GATHER"; return "S_MERGE";
else if (type == MOVEMENT_DROPOUTWITHINDEX) else if (type == SHAPE_MERGE_LIST)
return "M_DROPOUTWITHINDEX"; return "S_MERGE_LIST";
else if (type == SHAPE_CONCATENATE) else if (type == SHAPE_PERMUTE)
return "S_CONCATENATE"; return "S_PERMUTE";
else if (type == SHAPE_MERGE) else if (type == SHAPE_RESHAPE)
return "S_MERGE"; return "S_RESHAPE";
else if (type == SHAPE_MERGE_LIST) else if (type == SHAPE_SPLIT)
return "S_MERGE_LIST"; return "S_SPLIT";
else if (type == SHAPE_PERMUTE) else if (type == SHAPE_SPLIT_LIST)
return "S_PERMUTE"; return "S_SPLIT_LIST";
else if (type == SHAPE_RESHAPE) else if (type == SHAPE_SQUEEZE)
return "S_RESHAPE"; return "S_SQUEEZE";
else if (type == SHAPE_SPLIT) else if (type == SHAPE_TRANSPOSE)
return "S_SPLIT"; return "S_TRANSPOSE";
else if (type == SHAPE_SPLIT_LIST) else if (type == SHAPE_UNSQUEEZE)
return "S_SPLIT_LIST"; return "S_UNSQUEEZE";
else if (type == SHAPE_SQUEEZE) else if (type == SORT_SORT)
return "S_SQUEEZE"; return "S_SORT";
else if (type == SHAPE_TRANSPOSE) else if (type == SORT_TOPK)
return "S_TRANSPOSE"; return "S_TOPK";
else if (type == SHAPE_UNSQUEEZE) }
return "S_UNSQUEEZE"; else if ((type & FUNCTION_BASE) != 0) {
else if (type == SORT_SORT) if (type == FUNC_DROPOUT)
return "S_SORT"; return "F_DROPOUT";
else if (type == SORT_TOPK) else if (type == FUNC_HARDTANH)
return "S_TOPK"; return "F_HARDTANH";
} else if (type == FUNC_IDENTITY)
else if ((type & FUNCTION_BASE) != 0){ return "F_IDENTITY";
if (type == FUNC_DROPOUT) else if (type == FUNC_LOGSOFTMAX)
return "F_DROPOUT"; return "F_LOGSOFTMAX";
else if (type == FUNC_HARDTANH) else if (type == FUNC_RECTIFY)
return "F_HARDTANH"; return "F_RECTIFY";
else if (type == FUNC_IDENTITY) else if (type == FUNC_SIGMOID)
return "F_IDENTITY"; return "F_SIGMOID";
else if (type == FUNC_LOGSOFTMAX) else if (type == FUNC_SOFTMAX)
return "F_LOGSOFTMAX"; return "F_SOFTMAX";
else if (type == FUNC_RECTIFY) }
return "F_RECTIFY"; else if ((type & LOSS_BASE) != 0) {
else if (type == FUNC_SIGMOID) if (type == LOSS_CROSSENTROPY)
return "F_SIGMOID"; return "L_CROSSENTROPY";
else if (type == FUNC_SOFTMAX) }
return "F_SOFTMAX";
} return "NULL";
else if ((type & LOSS_BASE) != 0) { }
if (type == LOSS_CROSSENTROPY)
return "L_CROSSENTROPY";
}
return "NULL";
}
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University. * Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved. * All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
* limitations under the License. * limitations under the License.
*/ */
/* /*
* *
* We define various names here * We define various names here
* *
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-05
* It was really HOT these days. I can't imagine it is SO hot here in Shenyang! * It was really HOT these days. I can't imagine it is SO hot here in Shenyang!
*/ */
#ifndef __XNAME_H__ #ifndef __XNAME_H__
#define __XNAME_H__ #define __XNAME_H__
...@@ -31,6 +31,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -31,6 +31,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* math operations */ /* math operations */
#define MATH_BASE 0x00001000 #define MATH_BASE 0x00001000
#define GETANDSET_CONVERTDATATYPE MATH_BASE * 8
#define MATH_ABSOLUTE MATH_BASE + 1 #define MATH_ABSOLUTE MATH_BASE + 1
#define MATH_CEIL MATH_ABSOLUTE + 1 #define MATH_CEIL MATH_ABSOLUTE + 1
#define MATH_EXP MATH_CEIL + 1 #define MATH_EXP MATH_CEIL + 1
...@@ -79,13 +80,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -79,13 +80,9 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* data and shape related operations */ /* data and shape related operations */
#define DATA_BASE MATH_BASE * 2 #define DATA_BASE MATH_BASE * 2
#define GETANDSET DATA_BASE + 1 #define GETANDSET DATA_BASE + 1
#define GETANDSET_CONVERTDATATYPE GETANDSET + 1 #define GETANDSET_SELECT GETANDSET + 1
#define GETANDSET_INDEXTOONEHOT GETANDSET_CONVERTDATATYPE + 1
#define GETANDSET_ONEHOTTOINDEX GETANDSET_INDEXTOONEHOT + 1
#define GETANDSET_SELECT GETANDSET_ONEHOTTOINDEX + 1
#define SHAPE_BASE DATA_BASE * 2 #define MOVEMENT GETANDSET_SELECT + 1
#define MOVEMENT SHAPE_BASE + 1
#define MOVEMENT_COPYINDEXED MOVEMENT + 1 #define MOVEMENT_COPYINDEXED MOVEMENT + 1
#define MOVEMENT_COPYVALUES MOVEMENT_COPYINDEXED + 1 #define MOVEMENT_COPYVALUES MOVEMENT_COPYINDEXED + 1
#define MOVEMENT_GATHER MOVEMENT_COPYVALUES + 1 #define MOVEMENT_GATHER MOVEMENT_COPYVALUES + 1
...@@ -108,7 +105,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -108,7 +105,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define SORT_TOPK SORT_SORT + 1 #define SORT_TOPK SORT_SORT + 1
/* activation functions */ /* activation functions */
#define FUNCTION_BASE SHAPE_BASE * 2 #define FUNCTION_BASE DATA_BASE * 2
#define FUNC_DROPOUT FUNCTION_BASE + 1 #define FUNC_DROPOUT FUNCTION_BASE + 1
#define FUNC_HARDTANH FUNC_DROPOUT + 1 #define FUNC_HARDTANH FUNC_DROPOUT + 1
#define FUNC_IDENTITY FUNC_HARDTANH + 1 #define FUNC_IDENTITY FUNC_HARDTANH + 1
...@@ -121,7 +118,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -121,7 +118,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define LOSS_CROSSENTROPY LOSS_BASE + 1 #define LOSS_CROSSENTROPY LOSS_BASE + 1
/* get operator name */ /* get operator name */
const char * GetOPName(int type); const char * GetOPName(int type);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论