Commit 135aadf4 by liyinqiao

Merge with Xuchen branch

parent c7559b7d
# the prefix of the generated executable file
PREFIX = NiuTrans
NIUTRANS_EXE := $(PREFIX).Tensor
# code path and generated file path
ROOT = .
SRC = $(ROOT)/source
LIB_DIR = $(ROOT)/lib
EXE_DIR = $(ROOT)/bin
# whether to generate dll
dll = 0
# 0 - use CPU
# 1 - use GPU
USE_CUDA = 1
# modify this path if neccessary
CUDA_ROOT = /usr/local/cuda-9.0
CUDA_LIB_DIR = $(CUDA_ROOT)/lib64
CUDA_INCLUDE = $(CUDA_ROOT)/include
# use MKL
USE_MKL = 0
INTEL_ROOT = /opt/intel
MKL_ROOT = /opt/intel/mkl
MKL_LIB_DIR = $(MKL_ROOT)/lib/intel64/
MKL_INCLUDE = $(MKL_ROOT)/include
# use OpenBLAS
USE_OPENBLAS = 0
OPENBLAS_ROOT = /opt/OpenBLAS
OPENBLAS_LIB_DIR = $(OPENBLAS_ROOT)/lib
OPENBLAS_INCLUDE = $(OPENBLAS_ROOT)/include
SRC_DIR = $(shell find $(SRC) -type d)
# included header files directory
# depended outside library files directory
INC_DIR = $(SRC_DIR)
DEPLIB_DIR =
ifeq ($(USE_CUDA), 1)
INC_DIR += $(CUDA_INCLUDE)
DEPLIB_DIR += $(CUDA_LIB_DIR)
endif
ifeq ($(USE_MKL), 1)
INC_DIR += $(MKL_INCLUDE)
DEPLIB_DIR += $(MKL_LIB_DIR)
endif
ifeq ($(USE_OPENBLAS), 1)
INC_DIR += $(OPENBLAS_INCLUDE)
DEPLIB_DIR += $(OPENBLAS_LIB_DIR)
endif
# macro
MACRO =
ifeq ($(USE_CUDA), 1)
MACRO += -DUSE_CUDA
endif
ifeq ($(USE_MKL), 1)
MACRO += -DUSE_BLAS -DMKL
endif
ifeq ($(USE_OPENBLAS), 1)
MACRO += -DUSE_BLAS -DOPENBLAS
endif
# dependency
STATIC_DEPLIB =
DYNAMIC_DEPLIB = -lpthread
ifeq ($(USE_MKL), 1)
STATIC_DEPLIB += $(MKL_LIB_DIR)/libmkl_intel_lp64.a \
$(MKL_LIB_DIR)/libmkl_core.a \
$(MKL_LIB_DIR)/libmkl_intel_thread.a \
$(INTEL_ROOT)/lib/intel64/libiomp5.a
DYNAMIC_DEPLIB += -liomp5 -lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core
endif
ifeq ($(USE_OPENBLAS), 1)
STATIC_DEPLIB += $(OPENBLAS_LIB_DIR)/libopenblas.a
DYNAMIC_DEPLIB += -lopenblas
endif
ifeq ($(USE_CUDA), 1)
STATIC_DEPLIB += $(CUDA_LIB_DIR)/libcublas_static.a \
$(CUDA_LIB_DIR)/libculibos.a \
$(CUDA_LIB_DIR)/libnpps_static.a \
$(CUDA_LIB_DIR)/libnppc_static.a \
$(CUDA_LIB_DIR)/libcudadevrt.a \
$(CUDA_LIB_DIR)/libcurand_static.a \
/lib64/libdl.so.2
DYNAMIC_DEPLIB += -lcudart -lnvidia-ml
endif
DEPLIBS = -Wl,--start-group $(STATIC_DEPLIB) -Wl,--end-group -lm -ldl $(DYNAMIC_DEPLIB)
# specify the compilers here
CC = gcc
CXX = g++
NVCC = $(CUDA_ROOT)/bin/nvcc
ifeq ($(USE_INTEL_COMPILER), 1)
CC = icc
CXX = icc
endif
# main file
MAIN_FILE = $(SRC)/network/Main.cpp
Tensor_Main := $(SRC)/tensor/Main.cpp
Network_Main := $(SRC)/network/Main.cpp
ifeq ($(USE_CUDA), 1)
NIUTRANS_EXE := $(NIUTRANS_EXE).GPU
else
NIUTRANS_EXE := $(NIUTRANS_EXE).CPU
endif
NIUTRANS_DLL := $(LIB_DIR)/lib$(NIUTRANS_EXE).so
NIUTRANS_EXE := $(EXE_DIR)/$(NIUTRANS_EXE)
# specify the compiling arguments here
CFLAGS = -std=c++11 -msse4.2 -w -march=native -Wno-enum-compare -Wno-sign-compare -Wno-reorder -Wno-format
# gtx 1080 arch=compute_61,code=sm_61
# k80 arch=compute_37,code=sm_37
# if we set wrong, the result can be `-inf`
CUDA_FLAG = -arch=sm_30 \
-gencode=arch=compute_30,code=sm_30 \
-gencode=arch=compute_50,code=sm_50 \
-gencode=arch=compute_52,code=sm_52 \
-gencode=arch=compute_60,code=sm_60 \
-gencode=arch=compute_61,code=sm_61 \
-gencode=arch=compute_62,code=sm_62 \
-gencode=arch=compute_70,code=sm_70 \
-gencode=arch=compute_70,code=compute_70 \
-maxrregcount=0 --machine 64 -DUSE_CUDA --use_fast_math -std=c++11
CFLAGS += -O3 -flto -DNDEBUG -rdynamic -fkeep-inline-functions
# include dir
CFLAGS += -fPIC $(addprefix -I, $(INC_DIR))
# CUDA_FLAG += $(addprefix -I, $(INC_DIR))
CXXFLAGS = $(CFLAGS)
# lib dir
LDFLAGS = $(addprefix -L, $(DEPLIB_DIR))
# decoder source file
ifeq ($(USE_CUDA), 1)
SOURCES := $(foreach dir,$(SRC_DIR),$(wildcard $(dir)/*.c) $(wildcard $(dir)/*.cpp) $(wildcard $(dir)/*.cc) $(wildcard $(dir)/*.cu))
else
SOURCES := $(foreach dir,$(SRC_DIR),$(wildcard $(dir)/*.c) $(wildcard $(dir)/*.cpp) $(wildcard $(dir)/*.cc) )
endif
SOURCES := $(subst $(Tensor_Main), ,$(SOURCES))
SOURCES := $(subst $(Network_Main), ,$(SOURCES))
# object file
OBJS := $(patsubst %.c,%.o,$(SOURCES))
OBJS := $(patsubst %.cpp,%.o,$(OBJS))
ifeq ($(USE_CUDA), 1)
OBJS := $(patsubst %.cu,%.cuo,$(OBJS))
endif
all: start lib exe finish
start:
@echo ""
@echo "Start building ..."
lib: start_lib niutrans_dll finish_lib
start_lib:
@mkdir -p $(LIB_DIR)
@echo ""
@echo "Start building library"
niutrans_dll: $(NIUTRANS_DLL)
$(NIUTRANS_DLL): $(OBJS)
ifeq ($(dll), 1)
@echo "Building dynamic link library: $(NIUTRANS_DLL)"
@$(CXX) -shared -Wall $(CXXFLAGS) $(MACRO) $(LDFLAGS) $(OBJS) $(DEPLIBS) -o $@
else
@echo "Skip building dynamic link library"
endif
finish_lib:
@echo "Finish building library"
@echo ""
exe: start_exe niutrans_exe finish_exe
start_exe:
@mkdir -p $(EXE_DIR)
@echo ""
@echo "Start building executable file"
niutrans_exe: $(NIUTRANS_EXE)
$(NIUTRANS_EXE): $(OBJS) $(MAIN_FILE)
@echo "Building executable file: $(NIUTRANS_EXE)"
@$(CXX) $(MAIN_FILE) $(CXXFLAGS) $(MACRO) $(LDFLAGS) $(OBJS) $(DEPLIBS) -o $@
finish_exe:
@echo "Finish building executable file"
@echo ""
finish:
@echo "Finish building ..."
@echo ""
%.o: %.c
@$(CC) $(CFLAGS) -c $< -o $@
%.o: %.cpp
@$(CXX) $(CXXFLAGS) $(MACRO) -c $< -o $@
%.cuo: %.cu
ifeq ($(dll), 1)
@$(NVCC) --shared --compiler-options '-fPIC' $(CUDA_FLAG) -c $< -o $@
else
@$(NVCC) $(CUDA_FLAG) -c $< -o $@
endif
.PHONY: clean
clean:
@echo "Cleaning object files"
@-rm -f $(OBJS)
\ No newline at end of file
...@@ -45,7 +45,9 @@ int main( int argc, const char ** argv ) ...@@ -45,7 +45,9 @@ int main( int argc, const char ** argv )
//_CrtSetDbgFlag(_CrtSetDbgFlag(_CRTDBG_REPORT_FLAG) | _CRTDBG_LEAK_CHECK_DF); //_CrtSetDbgFlag(_CrtSetDbgFlag(_CRTDBG_REPORT_FLAG) | _CRTDBG_LEAK_CHECK_DF);
//_CrtSetBreakAlloc(2708); //_CrtSetBreakAlloc(2708);
if(argc > 1 && !strcmp(argv[1], "-fnnlm")) if(argc > 1 && !strcmp(argv[1], "-test"))
Test();
else if(argc > 1 && !strcmp(argv[1], "-fnnlm"))
FNNLMMain(argc - 1, argv + 1); FNNLMMain(argc - 1, argv + 1);
else if(argc > 1 && !strcmp(argv[1], "-t2t")) else if(argc > 1 && !strcmp(argv[1], "-t2t"))
TransformerMain(argc - 1, argv + 1); TransformerMain(argc - 1, argv + 1);
...@@ -54,6 +56,7 @@ int main( int argc, const char ** argv ) ...@@ -54,6 +56,7 @@ int main( int argc, const char ** argv )
fprintf(stderr, "neural networks in an easy way. \n\n"); fprintf(stderr, "neural networks in an easy way. \n\n");
fprintf(stderr, "Run this program with \"-test\" for unit test!\n"); fprintf(stderr, "Run this program with \"-test\" for unit test!\n");
fprintf(stderr, "Or run this program with \"-fnnlm\" for sample FNNLM!\n"); fprintf(stderr, "Or run this program with \"-fnnlm\" for sample FNNLM!\n");
fprintf(stderr, "Or run this program with \"-t2t\" for sample Transformer!\n");
} }
//_CrtDumpMemoryLeaks(); //_CrtDumpMemoryLeaks();
......
...@@ -43,18 +43,18 @@ void XFuncGrad::MakeGrad(XTensor * node, bool isEfficient) ...@@ -43,18 +43,18 @@ void XFuncGrad::MakeGrad(XTensor * node, bool isEfficient)
XNoder::MakeGrad(input); XNoder::MakeGrad(input);
if(operID == FUNC_HARDTANH) if(operID == FUNC_HARDTANH)
_HardTanHBackward(NULL, output, input, output->grad, input->grad, NOLOSS); _HardTanHBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_IDENTITY) else if(operID == FUNC_IDENTITY)
_IdentityBackward(NULL, output, input, output->grad, input->grad, NOLOSS); _IdentityBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_LOGSOFTMAX){ else if(operID == FUNC_LOGSOFTMAX){
int leadDim = income.GetParamInt(0); int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in logsoftmax!"); CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in logsoftmax!");
_LogSoftmaxBackward(NULL, output, input, output->grad, input->grad, NULL, leadDim, NOLOSS); _LogSoftmaxBackward(NULL, output, input, output->grad, input->grad, NULL, leadDim, NOLOSS);
} }
else if(operID == FUNC_RECTIFY) else if(operID == FUNC_RECTIFY)
_RectifyBackward(NULL, output, input, output->grad, input->grad, NOLOSS); _RectifyBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_SIGMOID) else if(operID == FUNC_SIGMOID)
_SigmoidBackward(NULL, output, input, output->grad, input->grad, NOLOSS); _SigmoidBackward(output, input, output->grad, input->grad);
else if(operID == FUNC_SOFTMAX){ else if(operID == FUNC_SOFTMAX){
int leadDim = income.GetParamInt(0); int leadDim = income.GetParamInt(0);
CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in softmax!"); CheckNTErrors(leadDim >= 0 && leadDim < input->order, "wrong leading dimension in softmax!");
......
...@@ -69,7 +69,7 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient) ...@@ -69,7 +69,7 @@ void XLossGrad::MakeGrad(XTensor * node, bool isEfficient)
if(operID == LOSS_CROSSENTROPY) { if(operID == LOSS_CROSSENTROPY) {
if (income.tailNum == 3) if (income.tailNum == 3)
padding = income.tails[2]; padding = income.tails[2];
leadingDim = income.GetParamInt(0); leadingDim = income.GetParamInt(0);
CheckNTErrors(leadingDim >= 0 && leadingDim < output->order, "wrong leading dimension in logsoftmax!"); CheckNTErrors(leadingDim >= 0 && leadingDim < output->order, "wrong leading dimension in logsoftmax!");
_CrossEntropyBackward(dedy, output, gold, weight, padding, leadingDim); _CrossEntropyBackward(dedy, output, gold, weight, padding, leadingDim);
} }
...@@ -98,39 +98,39 @@ compute dE/dx for a given function y = f(x) ...@@ -98,39 +98,39 @@ compute dE/dx for a given function y = f(x)
>> params - parameters of the function >> params - parameters of the function
>> lossName - name of the loss, e.g., cross entropy >> lossName - name of the loss, e.g., cross entropy
*/ */
void XLossGrad::Compute(XTensor * gold, XTensor * y, XTensor * x, //void XLossGrad::Compute(XTensor * gold, XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx, XTensor * padding, // XTensor * dedy, XTensor * dedx, XTensor * padding,
int funcID, void * params, // int funcID, void * params,
LOSS_FUNCTION_NAME lossName) // LOSS_FUNCTION_NAME lossName)
{ //{
CheckNTErrors(gold && y && x, "Empty input tensors!"); // CheckNTErrors(gold && y && x, "Empty input tensors!");
CheckNTErrors(dedx, "Empty gradient tensors!"); // CheckNTErrors(dedx, "Empty gradient tensors!");
CheckNTErrors((funcID & FUNCTION_BASE) != 0, "Illegal function id"); // CheckNTErrors((funcID & FUNCTION_BASE) != 0, "Illegal function id");
//
if(funcID == FUNC_HARDTANH){ // if(funcID == FUNC_HARDTANH){
_HardTanHBackward(gold, y, x, dedy, dedx, lossName); // _HardTanHBackward(gold, y, x, dedy, dedx, lossName);
} // }
else if(funcID == FUNC_IDENTITY){ // else if(funcID == FUNC_IDENTITY){
_IdentityBackward(gold, y, x, dedy, dedx, lossName); // _IdentityBackward(gold, y, x, dedy, dedx, lossName);
} // }
else if(funcID == FUNC_LOGSOFTMAX){ // else if(funcID == FUNC_LOGSOFTMAX){
int leadDim = *(int*)params; // int leadDim = *(int*)params;
_LogSoftmaxBackward(gold, y, x, dedy, dedx, padding, leadDim, lossName); // _LogSoftmaxBackward(gold, y, x, dedy, dedx, padding, leadDim, lossName);
} // }
else if(funcID == FUNC_RECTIFY){ // else if(funcID == FUNC_RECTIFY){
_RectifyBackward(gold, y, x, dedy, dedx, lossName); // _RectifyBackward(gold, y, x, dedy, dedx, lossName);
} // }
else if(funcID == FUNC_SIGMOID){ // else if(funcID == FUNC_SIGMOID){
_SigmoidBackward(gold, y, x, dedy, dedx, lossName); // _SigmoidBackward(gold, y, x, dedy, dedx, lossName);
}else if(funcID == FUNC_SOFTMAX){ // }else if(funcID == FUNC_SOFTMAX){
int leadDim = *(int*)params; // int leadDim = *(int*)params;
_SoftmaxBackward(gold, y, x, dedy, dedx, padding, leadDim, lossName); // _SoftmaxBackward(gold, y, x, dedy, dedx, padding, leadDim, lossName);
} // }
else{ // else{
ShowNTErrors("wrong function found when call the backward process!"); // ShowNTErrors("wrong function found when call the backward process!");
} // }
//
} //}
/* /*
compute dE/dy for variable y and error(loss) function E compute dE/dy for variable y and error(loss) function E
...@@ -139,27 +139,27 @@ compute dE/dy for variable y and error(loss) function E ...@@ -139,27 +139,27 @@ compute dE/dy for variable y and error(loss) function E
>> dedy - dE/dy >> dedy - dE/dy
>> lossName - name of the loss, e.g., cross entropy >> lossName - name of the loss, e.g., cross entropy
*/ */
void XLossGrad::Compute(XTensor * gold, XTensor * y, //void XLossGrad::Compute(XTensor * gold, XTensor * y,
XTensor * dedy, XTensor * padding, // XTensor * dedy, XTensor * padding,
LOSS_FUNCTION_NAME lossName) // LOSS_FUNCTION_NAME lossName)
{ //{
if(gold == NULL){ // if(gold == NULL){
if(dedy->dataType == X_FLOAT) // if(dedy->dataType == X_FLOAT)
_SetDataFixedFloat(dedy, 1.0F); // _SetDataFixedFloat(dedy, 1.0F);
else if(dedy->dataType == X_DOUBLE) // else if(dedy->dataType == X_DOUBLE)
_SetDataFixedDouble(dedy, 1.0); // _SetDataFixedDouble(dedy, 1.0);
else if(dedy->dataType == X_INT) // else if(dedy->dataType == X_INT)
_SetDataFixedInt(dedy, 1); // _SetDataFixedInt(dedy, 1);
else{ // else{
ShowNTErrors("TODO"); // ShowNTErrors("TODO");
} // }
return; // return;
} // }
//
//_LossBackward(dedy, gold, y, lossName); // //_LossBackward(dedy, gold, y, lossName);
if(lossName == CROSSENTROPY) // if(lossName == CROSSENTROPY)
_CrossEntropyBackward(dedy, y, gold, NULL, padding); // _CrossEntropyBackward(dedy, y, gold, NULL, padding);
//
} //}
} }
\ No newline at end of file
...@@ -43,11 +43,11 @@ public: ...@@ -43,11 +43,11 @@ public:
static static
bool IsLossOP(XTensor * node); bool IsLossOP(XTensor * node);
/* compute dE/dx for a given function y = f(x) */ ///* compute dE/dx for a given function y = f(x) */
void Compute(XTensor * gold, XTensor * y, XTensor * x, //void Compute(XTensor * gold, XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx, XTensor * padding, // XTensor * dedy, XTensor * dedx, XTensor * padding,
int funcID, void * params, // int funcID, void * params,
LOSS_FUNCTION_NAME lossName); // LOSS_FUNCTION_NAME lossName);
/* compute dE/dy for variable y and error(loss) function E */ /* compute dE/dy for variable y and error(loss) function E */
void Compute(XTensor * gold, XTensor * y, void Compute(XTensor * gold, XTensor * y,
......
...@@ -530,7 +530,7 @@ void XMathGrad::GradMatrixMul(XTensor * node, bool isEfficient) ...@@ -530,7 +530,7 @@ void XMathGrad::GradMatrixMul(XTensor * node, bool isEfficient)
XTensor * dedc = node->grad; XTensor * dedc = node->grad;
XTensor * deda = a->grad; XTensor * deda = a->grad;
XTensor * dedb = b->grad; XTensor * dedb = b->grad;
if(a->order == 2 && b->order == 2) if(a->order == 2 && b->order == 2)
GradMatrixMul(a, deda, transA, b, dedb, transB, dedc, alpha, isEfficient); GradMatrixMul(a, deda, transA, b, dedb, transB, dedc, alpha, isEfficient);
else if(transA == X_NOTRANS && a->order > 2 && b->order == 2){ else if(transA == X_NOTRANS && a->order > 2 && b->order == 2){
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
* This is a simple impelementation of the feed-forward network-baesd language * This is a simple impelementation of the feed-forward network-baesd language
* model (FNNLM). See more details about FNNLM in * model (FNNLM). See more details about FNNLM in
* "A Neural Probabilistic Language Model" by Bengio et al. * "A Neural Probabilistic Language Model" by Bengio et al.
* Journal of Machine Learning Research 3 (2003) 1137C1155 * Journal of Machine Learning Research 3 (2003) 1137?155
* *
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-06-22 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-06-22
*/ */
...@@ -469,6 +469,10 @@ void Train(const char * train, bool isShuffled, FNNModel &model) ...@@ -469,6 +469,10 @@ void Train(const char * train, bool isShuffled, FNNModel &model)
/* update model parameters */ /* update model parameters */
Update(model, grad, learningRate, false); Update(model, grad, learningRate, false);
/* get probabilities */
float prob = GetProb(output, gold);
loss -= prob;
} }
else{ else{
/* gradient = 0 */ /* gradient = 0 */
...@@ -480,23 +484,19 @@ void Train(const char * train, bool isShuffled, FNNModel &model) ...@@ -480,23 +484,19 @@ void Train(const char * train, bool isShuffled, FNNModel &model)
ForwardAutoDiff(ngrams, ngramNum, output, model); ForwardAutoDiff(ngrams, ngramNum, output, model);
/* this is implemented by multiply function */ /* this is implemented by multiply function */
//ForwardAutoDiff(inputs, output, model);
lossTensor = CrossEntropy(output, gold); lossTensor = CrossEntropy(output, gold);
/* automatic differentiation */ /* automatic differentiation */
autoDiffer.Backward(lossTensor); autoDiffer.Backward(lossTensor);
//autoDiffer.Backward(output, gold, CROSSENTROPY);
/* update model parameters */ /* update model parameters */
Update(model, grad, learningRate, true); Update(model, grad, learningRate, true);
/* get probabilities */
float prob = ReduceSumAll(lossTensor);
loss += prob;
} }
/* get probabilities */
float prob = GetProb(output, gold);
prob = ReduceSumAll(lossTensor);
loss += prob;
wordCount += ngramNum; wordCount += ngramNum;
wordCountTotal += ngramNum; wordCountTotal += ngramNum;
...@@ -579,9 +579,6 @@ void Update(FNNModel &model, FNNModel &grad, float epsilon, bool isNodeGrad) ...@@ -579,9 +579,6 @@ void Update(FNNModel &model, FNNModel &grad, float epsilon, bool isNodeGrad)
XTensor * para = (XTensor*)paraList.GetItem(i); XTensor * para = (XTensor*)paraList.GetItem(i);
XTensor * paraGrad = (XTensor*)gradList.GetItem(i); XTensor * paraGrad = (XTensor*)gradList.GetItem(i);
//fprintf(stderr, "%d\n", i);
//paraGrad->Dump(stderr, "grad:", 10);
/* the delta rule */ /* the delta rule */
_Sum(para, paraGrad, para, -epsilon); _Sum(para, paraGrad, para, -epsilon);
} }
...@@ -600,14 +597,14 @@ float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs) ...@@ -600,14 +597,14 @@ float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs)
InitTensorV2(&probs, &output); InitTensorV2(&probs, &output);
/* probs[i,j] = output[i,j] * gold[i,j] */ /* probs[i,j] = output[i,j] * gold[i,j] */
_Multiply(&output, &gold, &probs); Multiply(output, gold, probs);
/* probability of each word */ /* probability of each word */
XTensor wprobs; XTensor wprobs;
InitTensor1DV2(&wprobs, output.GetDim(0), output.dataType, output.devID); InitTensor1DV2(&wprobs, output.GetDim(0), output.dataType, output.devID);
_ReduceSum(&probs, &wprobs, 1); ReduceSum(probs, wprobs, 1);
if(wordProbs != NULL) if(wordProbs != NULL)
_CopyValues(&wprobs, wordProbs); CopyValues(wprobs, *wordProbs);
/* reshape the tensor to fit it into the reduce procedure /* reshape the tensor to fit it into the reduce procedure
TODO: XTensor supports scalars */ TODO: XTensor supports scalars */
...@@ -619,7 +616,7 @@ float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs) ...@@ -619,7 +616,7 @@ float GetProb(XTensor &output, XTensor &gold, XTensor * wordProbs)
/* probability for the batch */ /* probability for the batch */
XTensor result; XTensor result;
InitTensor1DV2(&result, 1, X_FLOAT, output.devID); InitTensor1DV2(&result, 1, X_FLOAT, output.devID);
_ReduceSum(&probs, &result, 1); ReduceSum(probs, result, 1);
return result.Get1D(0); return result.Get1D(0);
} }
...@@ -784,7 +781,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net) ...@@ -784,7 +781,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
/* generate word embedding of position i: /* generate word embedding of position i:
embedding = input * w */ embedding = input * w */
_MatrixMul(&input, X_NOTRANS, &w, X_NOTRANS, &embedding); MatrixMul(input, X_NOTRANS, w, X_NOTRANS, embedding);
eList.Add(&net.embeddings[i]); eList.Add(&net.embeddings[i]);
} }
...@@ -792,7 +789,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net) ...@@ -792,7 +789,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
/* concatenate word embeddings /* concatenate word embeddings
embeddingcat = cat(embedding_0...embedding_{n-1}) */ embeddingcat = cat(embedding_0...embedding_{n-1}) */
InitModelTensor2D(net.embeddingCat, batchSize, (n - 1) * model.eSize, model); InitModelTensor2D(net.embeddingCat, batchSize, (n - 1) * model.eSize, model);
_Concatenate(&eList, &net.embeddingCat, 1); Concatenate(eList, net.embeddingCat, 1);
/* go over each hidden layer */ /* go over each hidden layer */
for(int i = 0; i < depth; i++){ for(int i = 0; i < depth; i++){
...@@ -807,22 +804,22 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net) ...@@ -807,22 +804,22 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
/* generate hidden states of layer i: /* generate hidden states of layer i:
s = h_pre * w */ s = h_pre * w */
_MatrixMul(&h_pre, X_NOTRANS, &w, X_NOTRANS, &s); MatrixMul(h_pre, X_NOTRANS, w, X_NOTRANS, s);
/* make a 2d tensor for the bias term */ /* make a 2d tensor for the bias term */
XTensor b2D; XTensor b2D;
InitTensorV2(&b2D, &s); InitTensorV2(&b2D, &s);
_Unsqueeze(&b, &b2D, 0, batchSize); Unsqueeze(b, b2D, 0, batchSize);
/* introduce bias term: /* introduce bias term:
s = s + b s = s + b
NOTE: the trick here is to extend b to a 2d tensor NOTE: the trick here is to extend b to a 2d tensor
to fit into the 2d representation in tensor summation */ to fit into the 2d representation in tensor summation */
_Sum(&s, &b2D, &s); Sum(s, b2D, s);
/* pass the state through the hard tanh function: /* pass the state through the hard tanh function:
h = tanh(s) */ h = tanh(s) */
_HardTanH(&s, &h); HardTanH(s, h);
} }
/* generate the output Pr(w_{n-1}|w_0...w_{n-2}): /* generate the output Pr(w_{n-1}|w_0...w_{n-2}):
...@@ -840,16 +837,16 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net) ...@@ -840,16 +837,16 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
InitModelTensor2D(y, batchSize, model.vSize, model); InitModelTensor2D(y, batchSize, model.vSize, model);
/* s = h_last * w */ /* s = h_last * w */
_MatrixMul(&h_last, X_NOTRANS, &w, X_NOTRANS, &s); MatrixMul(h_last, X_NOTRANS, w, X_NOTRANS, s);
XTensor b2D; XTensor b2D;
InitTensorV2(&b2D, &s); InitTensorV2(&b2D, &s);
_Unsqueeze(&b, &b2D, 0, batchSize); Unsqueeze(b, b2D, 0, batchSize);
_Sum(&s, &b2D, &s); Sum(s, b2D, s);
/* y = softmax(s) */ /* y = softmax(s) */
_LogSoftmax(&s, &y, 1); LogSoftmax(s, y, 1);
} }
} }
...@@ -891,18 +888,18 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA ...@@ -891,18 +888,18 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
x is the top most hidden layer) x is the top most hidden layer)
so we know so we know
dE/dw = x^T * dE/ds */ dE/dw = x^T * dE/ds */
_MatrixMul(&x, X_TRANS, &deds, X_NOTRANS, &dedw); MatrixMul(x, X_TRANS, deds, X_NOTRANS, dedw);
/* gradient of the bias: dE/db = dE/ds * 1 = dE/ds /* gradient of the bias: dE/db = dE/ds * 1 = dE/ds
specifically dE/db_{j} = \sum_{i} dE/ds_{i,j} */ specifically dE/db_{j} = \sum_{i} dE/ds_{i,j} */
_ReduceSum(&deds, &dedb, 0); ReduceSum(deds, dedb, 0);
/* then, we compute /* then, we compute
dE/dx_{j} = \sum_j' (dE/ds_{j'} * ds_{j'}/dx_j) dE/dx_{j} = \sum_j' (dE/ds_{j'} * ds_{j'}/dx_j)
= \sum_j' (dE/ds_{j'} * w_{j, j'}) = \sum_j' (dE/ds_{j'} * w_{j, j'})
i.e., i.e.,
dE/dx = dE/ds * w^T */ dE/dx = dE/ds * w^T */
_MatrixMul(&deds, X_NOTRANS, &w, X_TRANS, &dedx); MatrixMul(deds, X_NOTRANS, w, X_TRANS, dedx);
XTensor &gradPassed = dedx; XTensor &gradPassed = dedx;
XTensor dedsHidden; XTensor dedsHidden;
...@@ -927,20 +924,20 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA ...@@ -927,20 +924,20 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
/* backpropagation through the activation fucntion: /* backpropagation through the activation fucntion:
dE/ds = dE/dh * dh/ds */ dE/ds = dE/dh * dh/ds */
_HardTanHBackward(NULL, &h, &s, &dedh, &deds, NOLOSS); _HardTanHBackward(&h, &s, &dedh, &deds);
/* gradient of the weight: dE/dw = x^T * dE/ds */ /* gradient of the weight: dE/dw = x^T * dE/ds */
_MatrixMul(&x, X_TRANS, &deds, X_NOTRANS, &dedw); MatrixMul(x, X_TRANS, deds, X_NOTRANS, dedw);
/* gradient of the bias: dE/db = dE/ds * 1 = dE/ds /* gradient of the bias: dE/db = dE/ds * 1 = dE/ds
specifically dE/db_{j} = \sum_{i} dE/ds_{i,j} */ specifically dE/db_{j} = \sum_{i} dE/ds_{i,j} */
_ReduceSum(&deds, &dedb, 0); ReduceSum(deds, dedb, 0);
/* gradient of the input: dE/dx = dE/ds * w^T */ /* gradient of the input: dE/dx = dE/ds * w^T */
_MatrixMul(&deds, X_NOTRANS, &w, X_TRANS, &dedx); MatrixMul(deds, X_NOTRANS, w, X_TRANS, dedx);
if (i > 0) if (i > 0)
_CopyValues(&dedx, &gradPassed); CopyValues(dedx, gradPassed);
} }
TensorList eList(n - 1); TensorList eList(n - 1);
...@@ -955,7 +952,7 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA ...@@ -955,7 +952,7 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
XTensor &dedyCat = depth > 0 ? dedxBottom : dedx; XTensor &dedyCat = depth > 0 ? dedxBottom : dedx;
/* split the concatenation of gradients of the embeddings */ /* split the concatenation of gradients of the embeddings */
_Split(&dedyCat, &eList, 1, n - 1); Split(dedyCat, eList, 1, n - 1);
/* go over for each word */ /* go over for each word */
for (int i = 0; i < n - 1; i++) { for (int i = 0; i < n - 1; i++) {
...@@ -966,7 +963,7 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA ...@@ -966,7 +963,7 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
/* gradient of the embedding weight: dE/dw += x^T * dE/dy /* gradient of the embedding weight: dE/dw += x^T * dE/dy
NOTE that we accumulate dE/dw here because the matrix w NOTE that we accumulate dE/dw here because the matrix w
is shared by several layers (or words) */ is shared by several layers (or words) */
_MatrixMul(&x, X_TRANS, dedy, X_NOTRANS, &dedw, 1.0F, 1.0F); MatrixMul(x, X_TRANS, *dedy, X_NOTRANS, dedw, 1.0F, 1.0F);
delete dedy; delete dedy;
} }
...@@ -1171,9 +1168,10 @@ void Test(const char * test, const char * result, FNNModel &model) ...@@ -1171,9 +1168,10 @@ void Test(const char * test, const char * result, FNNModel &model)
else { else {
/* this is implemented by gather function */ /* this is implemented by gather function */
ForwardAutoDiff(ngrams, ngramNum, output, model); ForwardAutoDiff(ngrams, ngramNum, output, model);
output = Log(output);
/* this is implemented by multiply function */
//ForwardAutoDiff(inputs, output, model); /* this is implemented by multiply function */
//ForwardAutoDiff(inputs, output, model);
} }
/* prediction probabilities */ /* prediction probabilities */
...@@ -1201,6 +1199,7 @@ void Test(const char * test, const char * result, FNNModel &model) ...@@ -1201,6 +1199,7 @@ void Test(const char * test, const char * result, FNNModel &model)
} }
fclose(file); fclose(file);
fclose(ofile);
double elapsed = GetClockSec() - startT; double elapsed = GetClockSec() - startT;
......
...@@ -297,7 +297,7 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -297,7 +297,7 @@ void T2TSearch::Generate(T2TStateBundle * beam)
row means a previous state. The column number is size-of-beam \times vocab-size. We, row means a previous state. The column number is size-of-beam \times vocab-size. We,
therefore, divide entries of the top-k index by vocab-size to compute the id of the therefore, divide entries of the top-k index by vocab-size to compute the id of the
previous state for each hypothesis in the top-k list. */ previous state for each hypothesis in the top-k list. */
Descale(preID, sizeVocab); DescaleMe(preID, sizeVocab);
/* Then, we do something similar to "preID". For the top-k predictions, we need /* Then, we do something similar to "preID". For the top-k predictions, we need
to know their indices in the vocabulary. We compute the offset of each prediction to know their indices in the vocabulary. We compute the offset of each prediction
...@@ -311,13 +311,13 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -311,13 +311,13 @@ void T2TSearch::Generate(T2TStateBundle * beam)
CopyValues(scoreTopK, score); CopyValues(scoreTopK, score);
/* CPU data (TODO: remove GPU->CPU data copy!!!) */ /* CPU data (TODO: remove GPU->CPU data copy!!!) */
XTensor indexCPU; XTensor indexGPU;
InitTensorV2(&indexCPU, index.order, index.dimSize, index.dataType, -1); indexGPU = CopyValues(index);
CopyValues(index, indexCPU); //InitTensor(&indexCPU, index.order, index.dimSize, index.dataType, index.denseRatio, -1);
//CopyValues(index, indexCPU);
for (int i = 0; i < indexGPU.unitNum; i++)
for (int i = 0; i < indexCPU.unitNum; i++) indexGPU.SetInt(i * stride + indexGPU.GetInt(i), i);
indexCPU.SetInt(i * stride + indexCPU.GetInt(i), i);
CheckNTErrors(XTensor::IsSameShaped(&prob, &probPath), "Wrong tensor shape!"); CheckNTErrors(XTensor::IsSameShaped(&prob, &probPath), "Wrong tensor shape!");
...@@ -338,8 +338,8 @@ void T2TSearch::Generate(T2TStateBundle * beam) ...@@ -338,8 +338,8 @@ void T2TSearch::Generate(T2TStateBundle * beam)
prob.Reshape(1, prob.unitNum); prob.Reshape(1, prob.unitNum);
probTopK.Reshape(1, probTopK.unitNum); probTopK.Reshape(1, probTopK.unitNum);
_Gather(&probPath, &probPathTopK, probPathTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum); _CopyIndexed(&probPath, &probPathTopK, probPathTopK.order - 1, &indexGPU);
_Gather(&prob, &probTopK, probTopK.order - 1, (int*)indexCPU.data, indexCPU.unitNum); _CopyIndexed(&prob, &probTopK, probTopK.order - 1, &indexGPU);
probPath.Reshape(order, dims); probPath.Reshape(order, dims);
probPathTopK.Reshape(order, dimsTopK); probPathTopK.Reshape(order, dimsTopK);
......
...@@ -60,7 +60,7 @@ TENSOR_DATA_TYPE GetDataType(const char * typeName) ...@@ -60,7 +60,7 @@ TENSOR_DATA_TYPE GetDataType(const char * typeName)
} }
} }
/**************************************************** /*
Below is for calling CPU BLAS for fast matrix operations Below is for calling CPU BLAS for fast matrix operations
I'm not sure how fast it is. But it seems that other I'm not sure how fast it is. But it seems that other
guys are crazy about this. So I decided to have a try. guys are crazy about this. So I decided to have a try.
...@@ -81,35 +81,4 @@ _XINLINE_ float Float16ToFloat(unsigned short h) ...@@ -81,35 +81,4 @@ _XINLINE_ float Float16ToFloat(unsigned short h)
return f; return f;
} }
/*
data type conversion
>> devID - device id
>> s - source data array
>> typeS - source data type
>> t - target data array
>> typeT - target data type
>> size - number of the items in s (and t)
*/
void ConvertDataType(int devID, void * s, TENSOR_DATA_TYPE typeS, void * t, TENSOR_DATA_TYPE typeT, int size)
{
CheckNTErrors((devID < 0), "This code must be run on CPUs!");
if(typeS == typeT)
return;
if(typeS == X_FLOAT && typeT == X_FLOAT16){
for(int i = 0; i < size; i++){
((unsigned short*)t)[i] = FloatToFloat16(((float*)s)[i]);
}
}
else if(typeS == X_FLOAT16 && typeT == X_FLOAT){
for(int i = 0; i < size; i++){
((float*)t)[i] = Float16ToFloat(((unsigned short*)s)[i]);
}
}
else{
ShowNTErrors("Unsupported data types for conversion!");
}
}
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
...@@ -49,15 +49,6 @@ extern TENSOR_DATA_TYPE GetDataType(const char * typeName); ...@@ -49,15 +49,6 @@ extern TENSOR_DATA_TYPE GetDataType(const char * typeName);
/* data conversion (for lower precision computation) */ /* data conversion (for lower precision computation) */
unsigned short FloatToFloat16(float f); unsigned short FloatToFloat16(float f);
float Float16ToFloat(unsigned short h); float Float16ToFloat(unsigned short h);
void ConvertDataType(int devID,
void * s, TENSOR_DATA_TYPE typeS,
void * t, TENSOR_DATA_TYPE typeT, int size);
#ifdef USE_CUDA
void CudaConvertDataType(int devID,
void * s, TENSOR_DATA_TYPE typeS,
void * t, TENSOR_DATA_TYPE typeT, int size);
#endif
} /* end of the nts (NiuTrans.Tensor) namespace */ } /* end of the nts (NiuTrans.Tensor) namespace */
......
...@@ -51,7 +51,13 @@ bool CONST_TRUE = true; ...@@ -51,7 +51,13 @@ bool CONST_TRUE = true;
int verboseLevel = 0; int verboseLevel = 0;
bool useBLAS = false; bool useBLAS = false;
bool useCUDA = false;
#ifdef USE_CUDA
bool useCUDA = true;
#else
bool useCUDA = false;
#endif
FILE * tmpLog = NULL; FILE * tmpLog = NULL;
double myTime = 0; double myTime = 0;
......
...@@ -59,6 +59,8 @@ const char * GetOPName(int type) ...@@ -59,6 +59,8 @@ const char * GetOPName(int type)
return "M_DIV"; return "M_DIV";
else if (type == MATH_DIVDIM) else if (type == MATH_DIVDIM)
return "M_DIVDIM"; return "M_DIVDIM";
else if (type == MATH_MASK)
return "M_MASK";
else if (type == MATH_MATRIXMUL) else if (type == MATH_MATRIXMUL)
return "M_MATRIXMUL"; return "M_MATRIXMUL";
else if (type == MATH_MATRIXMULBATCHED) else if (type == MATH_MATRIXMULBATCHED)
......
...@@ -48,7 +48,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -48,7 +48,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#define MATH_CLIP MATH_ROUND + 1 #define MATH_CLIP MATH_ROUND + 1
#define MATH_DIV MATH_CLIP + 1 #define MATH_DIV MATH_CLIP + 1
#define MATH_DIVDIM MATH_DIV + 1 #define MATH_DIVDIM MATH_DIV + 1
#define MATH_MATRIXMUL MATH_DIVDIM + 1 #define MATH_MASK MATH_DIVDIM + 1
#define MATH_MATRIXMUL MATH_MASK + 1
#define MATH_MATRIXMULBATCHED MATH_MATRIXMUL + 1 #define MATH_MATRIXMULBATCHED MATH_MATRIXMUL + 1
#define MATH_MULTIPLY MATH_MATRIXMULBATCHED + 1 #define MATH_MULTIPLY MATH_MATRIXMULBATCHED + 1
#define MATH_MULTIPLYDIM MATH_MULTIPLY + 1 #define MATH_MULTIPLYDIM MATH_MULTIPLY + 1
...@@ -79,7 +80,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -79,7 +80,8 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* data and shape related operations */ /* data and shape related operations */
#define DATA_BASE MATH_BASE * 2 #define DATA_BASE MATH_BASE * 2
#define GETANDSET DATA_BASE + 1 #define GETANDSET DATA_BASE + 1
#define GETANDSET_SELECT GETANDSET + 1 #define GETANDSET_CONVERTDATATYPE GETANDSET + 1
#define GETANDSET_SELECT GETANDSET_CONVERTDATATYPE + 1
#define MOVEMENT GETANDSET_SELECT + 1 #define MOVEMENT GETANDSET_SELECT + 1
#define MOVEMENT_COPYINDEXED MOVEMENT + 1 #define MOVEMENT_COPYINDEXED MOVEMENT + 1
......
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include "core/math/ScaleAndShift.h" #include "core/math/ScaleAndShift.h"
#include "core/getandset/SetData.h" #include "core/getandset/SetData.h"
#include "function/Identity.h" #include "function/Identity.h"
#include "core/CHeader.h"
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -485,6 +486,12 @@ XTensor XTensor::operator- (const DTYPE shift) const ...@@ -485,6 +486,12 @@ XTensor XTensor::operator- (const DTYPE shift) const
return ScaleAndShift(*this, 1, -shift); return ScaleAndShift(*this, 1, -shift);
} }
/* overloading of the minus-sign */
XTensor XTensor::operator- () const
{
return Negate(*this);
}
/* overloading of the division-sign */ /* overloading of the division-sign */
XTensor XTensor::operator/ (const XTensor& tensor) const XTensor XTensor::operator/ (const XTensor& tensor) const
{ {
...@@ -837,6 +844,12 @@ void XTensor::SetData(const void * d, int num, int beg) ...@@ -837,6 +844,12 @@ void XTensor::SetData(const void * d, int num, int beg)
XMemCopy((char*)data + beg * unitSize, devID, d, -1, num * unitSize); XMemCopy((char*)data + beg * unitSize, devID, d, -1, num * unitSize);
} }
/* generate data items with a uniform distribution in [0, 1] */
void XTensor::Rand(int rNum, int cNum)
{
_SetDataRand(this, rNum, cNum);
}
/* /*
set the tensor items by a uniform distribution in range [lower, upper] set the tensor items by a uniform distribution in range [lower, upper]
>> lower - lower value of the range >> lower - lower value of the range
...@@ -2425,7 +2438,7 @@ initialize a dense 5d tensor V2 ...@@ -2425,7 +2438,7 @@ initialize a dense 5d tensor V2
*/ */
void InitTensor5DV2(XTensor * tensor, const int d0, const int d1, const int d2, const int d3, const int d4, void InitTensor5DV2(XTensor * tensor, const int d0, const int d1, const int d2, const int d3, const int d4,
const TENSOR_DATA_TYPE myDataType, const int myDevID) const TENSOR_DATA_TYPE myDataType, const int myDevID)
{ {
int dims[5]; int dims[5];
dims[0] = d0; dims[0] = d0;
......
...@@ -238,6 +238,9 @@ public: ...@@ -238,6 +238,9 @@ public:
/* overloading of the minus-sign */ /* overloading of the minus-sign */
XTensor operator- (const DTYPE shift) const; XTensor operator- (const DTYPE shift) const;
/* overloading of the minus-sign */
XTensor operator- () const;
/* overloading of the division-sign */ /* overloading of the division-sign */
XTensor operator/ (const XTensor &tensor) const; XTensor operator/ (const XTensor &tensor) const;
...@@ -301,6 +304,9 @@ public: ...@@ -301,6 +304,9 @@ public:
/* set the tensor with an data array */ /* set the tensor with an data array */
void SetData(const void * d, int num, int beg = 0); void SetData(const void * d, int num, int beg = 0);
/* generate data items with a uniform distribution in [0, 1] */
void Rand(int rNum, int cNum);
/* set tensor items by a uniform distribution */ /* set tensor items by a uniform distribution */
void SetDataRand(DTYPE lower = 0.0F, DTYPE upper = 1.0F); void SetDataRand(DTYPE lower = 0.0F, DTYPE upper = 1.0F);
...@@ -497,7 +503,7 @@ void InitTensor5D(XTensor * tensor, const int d0, const int d1, const int d2, co ...@@ -497,7 +503,7 @@ void InitTensor5D(XTensor * tensor, const int d0, const int d1, const int d2, co
/* initialize a dense 5d tensor V2 */ /* initialize a dense 5d tensor V2 */
void InitTensor5DV2(XTensor * tensor, const int d0, const int d1, const int d2, const int d3, const int d4, void InitTensor5DV2(XTensor * tensor, const int d0, const int d1, const int d2, const int d3, const int d4,
const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1); const TENSOR_DATA_TYPE myDataType = X_FLOAT, const int myDevID = -1);
/* initialize a tensor with a reference tensor */ /* initialize a tensor with a reference tensor */
void InitTensor(XTensor * tensor, const XTensor * reference); void InitTensor(XTensor * tensor, const XTensor * reference);
......
...@@ -36,13 +36,9 @@ ...@@ -36,13 +36,9 @@
#include "arithmetic/MatrixMulBatched.h" #include "arithmetic/MatrixMulBatched.h"
#include "arithmetic/Multiply.h" #include "arithmetic/Multiply.h"
#include "arithmetic/MultiplyDim.h" #include "arithmetic/MultiplyDim.h"
#include "arithmetic/Negate.h"
#include "arithmetic/Sign.h"
#include "arithmetic/Sub.h" #include "arithmetic/Sub.h"
#include "arithmetic/SubDim.h" #include "arithmetic/SubDim.h"
#include "arithmetic/Sum.h" #include "arithmetic/Sum.h"
#include "arithmetic/SumByColumnTV.h"
#include "arithmetic/SumByColumnVT.h"
#include "arithmetic/SumDim.h" #include "arithmetic/SumDim.h"
#include "arithmetic/XTensorBLAS.h" #include "arithmetic/XTensorBLAS.h"
#include "arithmetic/MulAndShift.h" #include "arithmetic/MulAndShift.h"
...@@ -56,7 +52,6 @@ ...@@ -56,7 +52,6 @@
#include "math/Clip.h" #include "math/Clip.h"
#include "math/Compare.h" #include "math/Compare.h"
#include "math/Normalize.h" #include "math/Normalize.h"
#include "math/Power.h"
#include "math/ScaleAndShift.h" #include "math/ScaleAndShift.h"
#include "math/Unary.h" #include "math/Unary.h"
...@@ -97,5 +92,4 @@ ...@@ -97,5 +92,4 @@
#include "utilities/XMatrixSegment.h" #include "utilities/XMatrixSegment.h"
#include "utilities/FlushToMem.h" #include "utilities/FlushToMem.h"
#include "../function/DropoutWithIndex.h"
#endif // __CHEADER_H__ #endif // __CHEADER_H__
...@@ -151,16 +151,35 @@ XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha) ...@@ -151,16 +151,35 @@ XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha)
XTensor c(&a); XTensor c(&a);
c.SetTMPFlag(); c.SetTMPFlag();
/* call _Sum function */ /* call _Mask function */
_Mask(&a, &mask, &c, alpha); _Mask(&a, &mask, &c, alpha);
/* tensor connections */ /* tensor connections */
//XLink::MakeLink(&a, &mask, &c, MATH_SUM); XLink::MakeLink(&a, &mask, &c, MATH_MASK);
//XLink::AddParamToHead(&c, alpha); XLink::AddParamToHead(&c, alpha);
// TODO!!
ShowNTErrors("TODO!");
return c; return c;
} }
/*
mask entries of a given tensor (return an XTensor structure):
a(i) = a(i) if mask(i) is non-zero
a(i) = alpha if mask(i) = 0
where i is the index of the element
*/
void Mask(const XTensor &a, const XTensor &mask, XTensor &c, DTYPE alpha)
{
if (!c.isInit || !XTensor::IsSameShaped(&a, &c)) {
InitTensor(&c, &a);
}
/* call _Mask function */
_Mask(&a, &mask, &c, alpha);
if (c.enableGrad) {
XLink::MakeLink(&a, &mask, &c, MATH_MASK);
XLink::AddParamToHead(&c, alpha);
}
}
} }
\ No newline at end of file
...@@ -34,7 +34,7 @@ c(i) = a(i) if mask(i) is non-zero ...@@ -34,7 +34,7 @@ c(i) = a(i) if mask(i) is non-zero
c(i) = alpha if mask(i) = 0 c(i) = alpha if mask(i) = 0
where i is the index of the element where i is the index of the element
*/ */
void _Mask(const XTensor * a, const XTensor * mask, XTensor * c, DTYPE alpha); void _Mask(const XTensor * a, const XTensor * mask, XTensor * c, DTYPE alpha = 0.0);
/* /*
mask entries of a given tensor (on site): mask entries of a given tensor (on site):
...@@ -42,10 +42,10 @@ a(i) = a(i) if mask(i) is non-zero ...@@ -42,10 +42,10 @@ a(i) = a(i) if mask(i) is non-zero
a(i) = alpha if mask(i) = 0 a(i) = alpha if mask(i) = 0
where i is the index of the element where i is the index of the element
*/ */
void _MaskMe(XTensor * a, const XTensor * mask, DTYPE alpha); void _MaskMe(XTensor * a, const XTensor * mask, DTYPE alpha = 0.0);
void MaskMe(XTensor & a, const XTensor & mask, DTYPE alpha); void MaskMe(XTensor & a, const XTensor & mask, DTYPE alpha = 0.0);
/* /*
mask entries of a given tensor (return an XTensor structure): mask entries of a given tensor (return an XTensor structure):
a(i) = a(i) if mask(i) is non-zero a(i) = a(i) if mask(i) is non-zero
a(i) = alpha if mask(i) = 0 a(i) = alpha if mask(i) = 0
...@@ -53,6 +53,14 @@ where i is the index of the element ...@@ -53,6 +53,14 @@ where i is the index of the element
*/ */
XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha = 0.0); XTensor Mask(const XTensor &a, const XTensor &mask, DTYPE alpha = 0.0);
/*
mask entries of a given tensor (return an XTensor structure):
a(i) = a(i) if mask(i) is non-zero
a(i) = alpha if mask(i) = 0
where i is the index of the element
*/
void Mask(const XTensor &a, const XTensor &mask, XTensor &c, DTYPE alpha = 0.0);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // __MASK_H__ #endif // __MASK_H__
...@@ -202,7 +202,9 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -202,7 +202,9 @@ void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
delete cList; delete cList;
} }
bool CheckMMulShape(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c) bool CheckMMulShape(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c)
{ {
if (!(a && b && c)) if (!(a && b && c))
return false; return false;
...@@ -231,10 +233,13 @@ bool CheckMMulShape(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTen ...@@ -231,10 +233,13 @@ bool CheckMMulShape(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTen
dimSize[sub++] = bm; dimSize[sub++] = bm;
for (int i = 0; i < order; i++) { for (int i = 0; i < order; i++) {
if (dimSize[i] != c->dimSize[i]) if (dimSize[i] != c->dimSize[i]) {
delete[] dimSize;
return false; return false;
}
} }
delete[] dimSize;
return true; return true;
} }
...@@ -303,8 +308,8 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, ...@@ -303,8 +308,8 @@ XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
} }
void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
const XTensor &b, MATRIX_TRANS_TYPE transposedB, XTensor &c, const XTensor &b, MATRIX_TRANS_TYPE transposedB, XTensor &c,
DTYPE alpha, XPRunner * parallelRunner) DTYPE alpha, DTYPE beta, XPRunner * parallelRunner)
{ {
CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!"); CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!");
CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!"); CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!");
...@@ -337,7 +342,7 @@ void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, ...@@ -337,7 +342,7 @@ void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
} }
/* call _MatrixMul function */ /* call _MatrixMul function */
_MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, 0, parallelRunner); _MatrixMul(&a, transposedA, &b, transposedB, &c, alpha, beta, parallelRunner);
if (c.enableGrad) { if (c.enableGrad) {
/* tensor connections */ /* tensor connections */
...@@ -400,7 +405,7 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b, ...@@ -400,7 +405,7 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b,
} }
void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c, void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c,
DTYPE alpha, XPRunner * parallelRunner) DTYPE alpha, XPRunner * parallelRunner)
{ {
CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!"); CheckNTErrors(a.dataType == b.dataType, "Input tensors should have the same data type!");
CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!"); CheckNTErrors(a.order >= 2 && b.order >= 2, "Input tensors must have a order >= 2!");
......
...@@ -40,8 +40,11 @@ bj is the j-th element tensor of B, and c_{i,j} is the (i,j) elementtensor of th ...@@ -40,8 +40,11 @@ bj is the j-th element tensor of B, and c_{i,j} is the (i,j) elementtensor of th
C should be a tensor of z * x * n * m. C should be a tensor of z * x * n * m.
Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x * y. Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x * y.
*/ */
void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA, const XTensor * b, MATRIX_TRANS_TYPE transposedB, XTensor * c, void _MatrixMul(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0, XPRunner * parallelRunner = NULL); const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0,
XPRunner * parallelRunner = NULL);
/* /*
matrix multiplication (return an XTensor structure) c = trans(a) * trans(b) * alpha matrix multiplication (return an XTensor structure) c = trans(a) * trans(b) * alpha
...@@ -56,11 +59,16 @@ bj is the j-th element tensor of B, and c_{i,j} is the (i,j) elementtensor of th ...@@ -56,11 +59,16 @@ bj is the j-th element tensor of B, and c_{i,j} is the (i,j) elementtensor of th
C should be a tensor of z * x * n * m. C should be a tensor of z * x * n * m.
Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x * y. Obviously C = A * B performs normal matrix multiplication if A = y * z and B = x * y.
*/ */
XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB, XTensor MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL); const XTensor &b, MATRIX_TRANS_TYPE transposedB,
DTYPE alpha = (DTYPE)1.0,
XPRunner * parallelRunner = NULL);
void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA, const XTensor &b, MATRIX_TRANS_TYPE transposedB, void MatrixMul(const XTensor &a, MATRIX_TRANS_TYPE transposedA,
XTensor &c, DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL); const XTensor &b, MATRIX_TRANS_TYPE transposedB,
XTensor &c,
DTYPE alpha = (DTYPE)1.0, DTYPE beta = 0,
XPRunner * parallelRunner = NULL);
/* matrix multiplication with no transposition c = a * b * alpha*/ /* matrix multiplication with no transposition c = a * b * alpha*/
XTensor MatrixMul(const XTensor &a, const XTensor &b, XTensor MatrixMul(const XTensor &a, const XTensor &b,
...@@ -69,7 +77,6 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b, ...@@ -69,7 +77,6 @@ XTensor MatrixMul(const XTensor &a, const XTensor &b,
void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c, void MatrixMul(const XTensor &a, const XTensor &b, XTensor &c,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL); DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // __MATRIXMUL_H__ #endif // __MATRIXMUL_H__
\ No newline at end of file
...@@ -154,7 +154,7 @@ void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA, ...@@ -154,7 +154,7 @@ void _MatrixMulBatchedCPU(const XTensor * a, MATRIX_TRANS_TYPE transposedA,
const XTensor * b, MATRIX_TRANS_TYPE transposedB, const XTensor * b, MATRIX_TRANS_TYPE transposedB,
XTensor * c, DTYPE alpha, DTYPE beta) XTensor * c, DTYPE alpha, DTYPE beta)
{ {
CheckNTErrors((a && b && c), "Empty input tensors!"); CheckNTErrors(a && b && c, "Empty input tensors!");
CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType, CheckNTErrors(a->dataType == b->dataType && a->dataType == c->dataType,
"Input tensors should have the same data type!"); "Input tensors should have the same data type!");
CheckNTErrors(a->order >= 2 && b->order >= 2 && c->order >= 2, CheckNTErrors(a->order >= 2 && b->order >= 2 && c->order >= 2,
......
...@@ -66,7 +66,7 @@ operation c = x * w + b MulAndShift ...@@ -66,7 +66,7 @@ operation c = x * w + b MulAndShift
<< return - the result of matrix multiplication << return - the result of matrix multiplication
*/ */
XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DTYPE alpha, XPRunner * parallelRunner) DTYPE alpha, XPRunner * parallelRunner)
{ {
CheckNTErrors(x.dataType == w.dataType, "Input tensors should have the same data type!"); CheckNTErrors(x.dataType == w.dataType, "Input tensors should have the same data type!");
CheckNTErrors(x.order >= 2 && w.order >= 2, "Input tensors must have a order >= 2!"); CheckNTErrors(x.order >= 2 && w.order >= 2, "Input tensors must have a order >= 2!");
...@@ -129,9 +129,6 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, ...@@ -129,9 +129,6 @@ XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DelTensorBuf(tmp); DelTensorBuf(tmp);
return c; return c;
} }
} }
\ No newline at end of file
...@@ -29,7 +29,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -29,7 +29,7 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b, XTensor MulAndShift(const XTensor &x, const XTensor &w, const XTensor &b,
DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL); DTYPE alpha = (DTYPE)1.0, XPRunner * parallelRunner = NULL);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -123,9 +123,9 @@ where i is the item index ...@@ -123,9 +123,9 @@ where i is the item index
void _CudaMultiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim) void _CudaMultiply(const XTensor * a, const XTensor * b, XTensor * c, DTYPE alpha, int leadingDim)
{ {
int leadingDimRDI = a->order - leadingDim - 1; int leadingDimRDI = a->order - leadingDim - 1;
CheckNTErrors((a->unitNum <= c->unitNum && b->unitNum <= c->unitNum), CheckNTErrors(a->unitNum <= c->unitNum && b->unitNum <= c->unitNum,
"Unmatched tensors in multiplication!"); "Unmatched tensors in multiplication!");
CheckNTErrors((a->order == b->order && a->order == c->order), "Unmatched tensors!"); CheckNTErrors(a->order == b->order && a->order == c->order, "Unmatched tensors!");
int stride = 1; int stride = 1;
int blockSizeA = 1; int blockSizeA = 1;
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "Negate.h"
#include "Negate.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
set every entry to its minus value
>> a - input tensor we are processing
>> b - output tensor we are processing
*/
void _Negate(const XTensor * a, XTensor * b)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
_CudaNegate(a, b);
return;
}
#endif
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
DTYPE * d = (DTYPE*)a->data;
DTYPE * db = (DTYPE*)b->data;
for (int i = 0; i < a->unitNum; i++)
db[i] = -d[i];
}
/*
set every entry to its minus value (do it on site)
keep the result in the input tensor a and return nothing
>> a - the tensor we are processing
*/
void _NegateMe(XTensor * a)
{
_Negate(a, a);
}
/*
set every entry to its minus value (do it on site)
keep the result in the input tensor a and return nothing
>> a - the tensor we are processing
*/
void NegateMe(XTensor& a)
{
_Negate(&a, &a);
}
/*
set every entry to its minus value (return an XTensor structure)
make a new tensor to keep the result and return it
>> a - input tensor we are processing
<< return - the minus value of input tensor
*/
XTensor Negate(const XTensor & a)
{
XTensor b(&a);
b.SetTMPFlag();
/* call _Negate function */
_Negate(&a, &b);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_NEGATE);
return b;
}
/*
set every entry to its minus value
>> a - input tensor we are processing
>> b - output tensor we are processing
*/
void Negate(const XTensor & a, XTensor & b)
{
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) {
InitTensor(&b, &a);
}
/* call _Negate function */
_Negate(&a, &b);
if (b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_NEGATE);
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "Negate.h"
#include "Negate.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
set each entry to its negtive value (CUDA Kernel)
>> a - pointer to the input data array
>> b - pointer to the output data array
>> size - size of the data array
*/
__global__
void KernelNegate(DTYPE * a, DTYPE * b, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
b[i] = -a[i];
}
/*
set each entry to its negtive value (CUDA Kernel)
This is for float16 computation
>> a - pointer to the input data array
>> b - pointer to the output data array
>> size - size of the data array
*/
__global__
void KernelNegate(__half * a, __half * b, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
if (i < size)
b[i] = __hsub(__float2half(0), a[i]);
#else
if (i < size)
b[i] = __float2half(-__half2float(a[i]));
#endif
}
/*
set each entry to its negtive value
>> a - input tensor
>> b - output tensor
*/
void _CudaNegate(const XTensor * a, XTensor * b)
{
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->isSparse == false), "TODO!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
KernelNegate << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum);
}
else if (a->dataType == X_FLOAT16) {
KernelNegate << <blocks, threads >> >((__half*)a->data, (__half*)b->data, a->unitNum);
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __NEGATE_CUH__
#define __NEGATE_CUH__
#include "Negate.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* set each entry to its negtive value (CUDA Kernel) */
__global__
void KernelNegate(DTYPE * a, DTYPE * b, int size);
/* set each entry to its negtive value (CUDA Kernel) with float16 data type*/
__global__
void KernelNegate(__half * a, __half * b, int size);
/* set each entry to its negtive value */
void _CudaNegate(const XTensor * a, XTensor * b);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __NEGATE_CUH__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __NEGATE_H__
#define __NEGATE_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its minus value */
void _Negate(const XTensor * a, XTensor * b);
/*
set every entry to its minus value (do it on site)
keep the result in the input tensor a and return nothing
*/
void _NegateMe(XTensor * a);
void NegateMe(XTensor & a);
/*
set every entry to its minus value (return an XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Negate(const XTensor & a);
/* set every entry to its minus value */
void Negate(const XTensor & a, XTensor & b);
} // namespace nts(NiuTrans.Tensor)
#endif // __NEGATE_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include "../../XTensor.h"
#include "../../XName.h"
#include "Sign.h"
#include "Sign.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
set every entry to its sign value
>> a - input tensor we are processing
>> b - output tensor we are processing
*/
void _Sign(const XTensor * a, XTensor * b)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
_CudaSign(a, b);
return;
}
#endif
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
DTYPE * d = (DTYPE*)a->data;
DTYPE * db = (DTYPE*)b->data;
for (int i = 0; i < a->unitNum; i++) {
if (d[i] > 0)
db[i] = 1.0F;
else if (d[i] == 0)
db[i] = 0.0F;
else
db[i] = -1.0F;
}
}
/*
set every entry to its sign value (do it on site)
keep the result in the input tensor a and return nothing
>> a - the tensor we are processing
*/
void _SignMe(XTensor * a)
{
_Sign(a, a);
}
/*
set every entry to its sign value (do it on site)
keep the result in the input tensor a and return nothing
>> a - the tensor we are processing
*/
void SignMe(XTensor& a)
{
_Sign(&a, &a);
}
/*
set every entry to its sign value (return an XTensor structure)
make a new tensor to keep the result and return it
>> a - input tensor we are processing
<< return - the sign value of the input tensor
*/
XTensor Sign(const XTensor & a)
{
XTensor b(&a);
b.SetTMPFlag();
/* call _Sign function */
_Sign(&a, &b);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SIGN);
return b;
}
/*
set every entry to its sign value
>> a - input tensor we are processing
>> b - output tensor we are processing
*/
void Sign(const XTensor & a, XTensor & b)
{
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) {
InitTensor(&b, &a);
}
/* call _Sign function */
_Sign(&a, &b);
if (b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_SIGN);
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "Sign.h"
#include "Sign.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
set each entry to its sign value (CUDA Kernel)
>> a - pointer to input data array
>> b - pointer to output data array
>> size - size of the data array
*/
__global__
void KernelSign(DTYPE * a, DTYPE * b, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
if (a[i] > 0)
b[i] = 1.0F;
else if (a[i] == 0)
b[i] = 0.0F;
else
b[i] = -1.0F;
}
}
/*
set each entry to its sign value with float16 data type value (CUDA Kernel)
This is for float16 computation
>> a - pointer to input data array
>> b - pointer to output data array
>> size - size of the data array
*/
__global__
void KernelSign(__half * a, __half * b, int size)
{
return;
}
/*
set each entry to its sign value
>> a - input tensor we are processing
>> b - output tensor we are processing
*/
void _CudaSign(const XTensor * a, XTensor * b)
{
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
CheckNTErrors((a->isSparse == false), "TODO!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
KernelSign << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum);
}
else if (a->dataType == X_FLOAT16) {
KernelSign << <blocks, threads >> >((__half*)a->data, (__half*)b->data, a->unitNum);
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#ifndef __SIGN_CUH__
#define __SIGN_CUH__
#include "Sign.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* set each entry to its sign value (CUDA Kernel) */
__global__
void KernelSign(DTYPE * a, DTYPE * b, int size);
/* set each entry to its sign value (CUDA Kernel) with float16 data type*/
__global__
void KernelSign(__half * a, __half * b, int size);
/* set each entry to its sign value */
void _CudaSign(const XTensor * a, XTensor * b);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __SIGN_H__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/
#ifndef __SIGN_H__
#define __SIGN_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* set every entry to its sign value */
void _Sign(const XTensor * a, XTensor * b);
/*
set every entry to its sign value (do it on site)
keep the result in the input tensor a and return nothing
*/
void _SignMe(XTensor * a);
/*
set every entry to its sign value (do it on site)
keep the result in the input tensor a and return nothing
*/
void SignMe(XTensor & a);
/*
set every entry to its sign value (return an XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Sign(const XTensor & a);
/* set every entry to its sign value */
void Sign(const XTensor & a, XTensor & b);
} // namespace nts(NiuTrans.Tensor)
#endif // __SIGN_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XTensor.h"
#include "SumByColumnTV.h"
#include "SumByColumnTV.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
sum of a tensor and a vector (column vector) in a column by column manner
for each column a_col (in a block), we have
c_col = a_col + b * \beta
where b is a vector.
>> a - a tensor
>> b - a vector with the same column size with a
>> c - where we put a+b. we save it in a if c is NULL
>> beta - the scaling factor
*/
void _SumByColumnTV(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((XTensor::IsSameShaped(a, c)), "Unmatched tensors in addition!");
CheckNTErrors((b->order == 2 && b->dimSizeRDI[0] == 1 && b->dimSizeRDI[1] == a->dimSizeRDI[1]),
"Illegal input vector size!");
int rowNum = a->dimSize[0];
int colNum = a->dimSize[1];
int blockNum = 1;
for (int i = 2; i < a->order; i++)
blockNum *= a->dimSizeRDI[i];
int blockSize = colNum * rowNum;
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
_CudaSumByColumnTV(a, b, c, beta);
#endif
}
else {
if (!a->isSparse && !b->isSparse) {
CheckNTErrors(!c->isSparse, "TODO!");
if (a->dataType == DEFAULT_DTYPE &&
b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE)
{
for (int k = 0; k < blockNum; k++) {
for (int i = 0; i < rowNum; i++) {
DTYPE * ap = (DTYPE*)a->data + k * blockSize + i * colNum;
DTYPE * bp = (DTYPE*)b->data;
DTYPE * cp = (DTYPE*)c->data + k * blockSize + i * colNum;
DTYPE v = bp[i];
for (int j = 0; j < colNum; j++)
cp[j] = ap[j] + v * beta;
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "SumByColumnTV.h"
#include "SumByColumnTV.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
summation of a tensor and a vector (column vector)
c_col = a_col + b * \beta
>> a - a tensor
>> b - a vector with the same column size with a
>> c - where we put a+b. we save it in a
>> colNum - column number (of a block)
>> blockSize - size of a block
>> size - size of the entire data array
>> beta - the scaling factor
*/
__global__
void KernelADDByColumnTV(DTYPE * a, DTYPE * b, DTYPE * c, int colNum, int blockSize, int size, DTYPE beta)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= size)
return;
int offset = i % blockSize;
int row = offset / colNum;
c[i] = a[i] + b[row] * beta;
}
/*
summation of a tensor and a vector (column vector)
for each column a_col (in a block), we have
c_col = a_col + b * \beta
where b is a vector.
>> a - a tensor
>> b - a vector with the same column size with a
>> c - where we put a+b. we save it in a if c is NULL
>> beta - the scaling factor
*/
void _CudaSumByColumnTV(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((XTensor::IsSameShaped(a, c)), "Unmatched tensors in addition!");
CheckNTErrors((b->order == 2 && b->dimSizeRDI[0] == 1 && b->dimSizeRDI[1] == a->dimSizeRDI[1]),
"Illegal input vector size!");
CheckNTErrors((a->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE), "TODO");
int rowNum = a->dimSize[0];
int colNum = a->dimSize[1];
int blockNum = 1;
for (int i = 2; i < a->order; i++)
blockNum *= a->dimSizeRDI[i];
int cudaGridSize[3];
int cudaBlockSize[3];
GDevs.GetCudaThread(c->devID, a->unitNum, cudaGridSize, cudaBlockSize);
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
KernelADDByColumnTV << <dim3(cudaGridSize[0]), dim3(cudaBlockSize[0]) >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, colNum, rowNum * colNum, a->unitNum, beta);
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __REDUCEMAX_CUH__
#define __REDUCEMAX_CUH__
#include "../reduce/ReduceMax.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* summation of a tensor and a vector (column vector) */
void _CudaSumByColumnTV(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __REDUCEMAX_CUH__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __SUMBYCOLUMNTV_H__
#define __SUMBYCOLUMNTV_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* sum of a tensor and a (column) vector */
void _SumByColumnTV(const XTensor * a, const XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __SUMBYCOLUMNTV_H__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XTensor.h"
#include "SumByColumnVT.h"
#include "SumByColumnVT.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
sum of a vector (column vector) and a tensor in a column by column manner
for each column b_col, we have
c = a + \sum{col} b_col * \beta
where c and a are vectors, and b_col is a column in b.
>> a - a tensor
>> b - a vector with the same column size with a
>> c - where we put a+b. we save it in a if c is NULL
>> beta - the scaling factor
*/
void _SumByColumnVT(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((XTensor::IsSameShaped(a, c)), "Unmatched tensors in addition!");
CheckNTErrors((a->order == 2 && a->dimSizeRDI[0] == 1 && b->dimSizeRDI[1] == a->dimSizeRDI[1]),
"Illegal input vector size!");
if (a->devID >= 0 || b->devID >= 0 || c->devID >= 0) {
#ifdef USE_CUDA
_CudaSumByColumnVT(a, b, c, beta);
#endif
}
else {
int rowNum = b->dimSize[0];
int colNum = b->dimSize[1];
int blockNum = 1;
for (int i = 2; i < b->order; i++)
blockNum *= b->dimSizeRDI[i];
int blockSize = colNum * rowNum;
if (!a->isSparse && !b->isSparse) {
CheckNTErrors(!c->isSparse, "TODO!");
if (a->dataType == DEFAULT_DTYPE &&
b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE)
{
for (int k = 0; k < blockNum; k++) {
for (int i = 0; i < rowNum; i++) {
DTYPE * ap = (DTYPE*)a->data;
DTYPE * bp = (DTYPE*)b->data + k * blockSize + i * colNum;
DTYPE * cp = (DTYPE*)c->data;
DTYPE sum = 0;
for (int j = 0; j < colNum; j++)
sum += bp[j];
cp[i] = ap[i] + sum * beta;
}
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
else {
// TODO!!
ShowNTErrors("TODO!");
}
}
}
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "SumByColumnVT.h"
#include "SumByColumnVT.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
summation of a vector (column vector) and a tensor
c = a + \sum{col} b_col * \beta
>> a - a vector with the same column size with b
>> b - a tensor
>> c - where we put a+b. we save it in a
>> colNum - column number (of a block)
>> blockSize - size of a block
>> size - size of the entire data array
>> beta - the scaling factor
*/
__global__
void KernelADDByColumnVT(DTYPE * a, DTYPE * b, DTYPE * c, int colNum, int rowNum, int blockNum, DTYPE beta)
{
int row = blockDim.x * blockIdx.x + threadIdx.x;
if (row >= rowNum)
return;
DTYPE sum = 0;
for (int k = 0; k < blockNum; k++) {
DTYPE * bp = b + (rowNum * k + row) * colNum;
if (colNum % 4 == 0) {
for (int i = 0; i < colNum; i += 4)
sum += bp[i] + bp[i + 1] + bp[i + 2] + bp[i + 3];
}
else if (colNum % 2 == 0) {
for (int i = 0; i < colNum; i += 2)
sum += bp[i] + bp[i + 1];
}
else {
for (int i = 0; i < colNum; i++)
sum += bp[i];
}
__syncthreads();
}
c[row] = a[row] + beta * sum;
}
/*
summation of a vector (column vector) and a tensor
for each column b_col, we have
c = a + \sum{col} b_col * \beta
where c and a are vectors, and b_col is a column in b.
>> a - a vector with the same column size with b
>> b - a tensor
>> c - where we put a+b. we save it in a if c is NULL
>> beta - the scaling factor
*/
void _CudaSumByColumnVT(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta)
{
CheckNTErrors((a && b && c), "Empty input tensors!");
CheckNTErrors((XTensor::IsSameShaped(a, c)), "Unmatched tensors in addition!");
CheckNTErrors((a->order == 2 && a->dimSizeRDI[0] == 1 && b->dimSizeRDI[1] == a->dimSizeRDI[1]),
"Illegal input vector size!");
CheckNTErrors((a->dataType == DEFAULT_DTYPE && b->dataType == DEFAULT_DTYPE &&
c->dataType == DEFAULT_DTYPE), "TODO");
int rowNum = b->dimSize[0];
int colNum = b->dimSize[1];
int blockNum = 1;
for (int i = 2; i < b->order; i++)
blockNum *= b->dimSizeRDI[i];
int cudaGridSize[3];
int cudaBlockSize[3];
GDevs.GetCudaThread(c->devID, a->dimSizeRDI[1], cudaGridSize, cudaBlockSize);
int devIDBackup = 0;
ProtectCudaDev(a->devID, devIDBackup);
KernelADDByColumnVT << <dim3(cudaGridSize[0]), dim3(cudaBlockSize[0]) >> >
((DTYPE*)a->data, (DTYPE*)b->data, (DTYPE*)c->data, colNum, rowNum, blockNum, beta);
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __SUMBYCOLUMNVT_CUH__
#define __SUMBYCOLUMNVT_CUH__
#include "SumByColumnVT.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* summation of a vector (column vector) and a tensor */
void _CudaSumByColumnVT(const XTensor * a, const XTensor * b, XTensor * c, DTYPE beta = (DTYPE)1.0);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __SUMBYCOLUMNVT_CUH__
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __SUMBYCOLUMNVT_H__
#define __SUMBYCOLUMNVT_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* sum of a (column) vector and a tensor */
void _SumByColumnVT(const XTensor * a, const XTensor * b, XTensor * c = NULL, DTYPE beta = (DTYPE)1.0);
} // namespace nts(NiuTrans.Tensor)
#endif // __SUMBYCOLUMNVT_H__
...@@ -20,20 +20,55 @@ ...@@ -20,20 +20,55 @@
*/ */
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XName.h"
#include "ConvertDataType.h" #include "ConvertDataType.h"
#include "ConvertDataType.cuh" #include "ConvertDataType.cuh"
#include "../movement/CopyValues.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/*
data type conversion
>> devID - device id
>> s - source data array
>> typeS - source data type
>> t - target data array
>> typeT - target data type
>> size - number of the items in s (and t)
*/
void ConvertDataType(int devID,
void * s, TENSOR_DATA_TYPE typeS,
void * t, TENSOR_DATA_TYPE typeT,
int size)
{
CheckNTErrors((devID < 0), "This code must be run on CPUs!");
if(typeS == typeT)
return;
if(typeS == X_FLOAT && typeT == X_FLOAT16){
for(int i = 0; i < size; i++){
((unsigned short*)t)[i] = FloatToFloat16(((float*)s)[i]);
}
}
else if(typeS == X_FLOAT16 && typeT == X_FLOAT){
for(int i = 0; i < size; i++){
((float*)t)[i] = Float16ToFloat(((unsigned short*)s)[i]);
}
}
else{
ShowNTErrors("Unsupported data types for conversion!");
}
}
/* /*
convert data type convert data type
>> input - input tensor
>> output - output tensor >> input - the input tensor
>> output - the output tensor
*/ */
void _ConvertDataType(const XTensor * input, XTensor * output) void _ConvertDataType(const XTensor * input, XTensor * output)
{ {
//CheckNTErrors((input->unitSize == output->unitSize), "Input and Output must be same in size!");
if (input->dataType == output->dataType) if (input->dataType == output->dataType)
return; return;
...@@ -59,6 +94,50 @@ void _ConvertDataType(const XTensor * input, XTensor * output) ...@@ -59,6 +94,50 @@ void _ConvertDataType(const XTensor * input, XTensor * output)
} }
else else
ShowNTErrors("Unsupported data types for conversion!"); ShowNTErrors("Unsupported data types for conversion!");
}
/*
convert data type (return an XTensor structure)
make a new tensor to keep the result and return it
>> input - the input tensor
<< return - the output tensor with the specified data type
*/
XTensor ConvertDataType(const XTensor & input, TENSOR_DATA_TYPE dataType)
{
if (input.dataType == dataType) {
XTensor output;
output = CopyValues(input);
return output;
}
int order = input.order;
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
XTensor output(order, input.dimSize, dataType, dr, input.devID, input.mem);
output.SetTMPFlag();
_ConvertDataType(&input, &output);
/* tensor connection */
XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE);
return output;
}
void ConvertDataType(const XTensor & input, XTensor & output, TENSOR_DATA_TYPE dataType)
{
if (!output.isInit || input.dataType != output.dataType) {
float dr = (!input.isSparse) ? 1.0F : input.denseRatio;
InitTensor(&output, input.order, input.dimSize, dataType, dr, input.devID, input.mem);
}
_ConvertDataType(&input, &output);
/* tensor connection */
if (output.enableGrad)
XLink::MakeLink(&input, NULL, &output, GETANDSET_CONVERTDATATYPE);
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University. * Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved. * All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-06-14 * $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/ */
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XDevice.h" #include "../../XDevice.h"
...@@ -67,44 +67,7 @@ void KernelIntToFloat(int * inputData, float * outputData, int size) ...@@ -67,44 +67,7 @@ void KernelIntToFloat(int * inputData, float * outputData, int size)
if (i < size){ if (i < size){
outputData[i] = (float)(inputData[i]); outputData[i] = (float)(inputData[i]);
}}
/*
data conversion (cuda code)
>> devID - device id
>> s - source data array
>> typeS - source data type
>> t - target data array
>> typeT - target data type
>> size - number of the items in s (and t)
*/
void _CudaConvertDataType(int devID, void * s, TENSOR_DATA_TYPE typeS, void * t, TENSOR_DATA_TYPE typeT, int size)
{
CheckNTErrors((devID >= 0), "This code must be run on GPUs!");
if(typeS == typeT)
return;
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(devID, size, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(devID, devIDBackup);
if(typeS == X_FLOAT && typeT == X_FLOAT16)
KernelFloatToFloat16<<<blocks, threads>>>((float*)s, (__half*)t, size);
else if(typeS == X_FLOAT16 && typeT == X_FLOAT)
KernelFloat16ToFloat<<<blocks, threads>>>((__half*)s, (float*)t, size);
else{
ShowNTErrors("Unsupported data types for conversion!");
} }
ProtectCudaDev(devID, devIDBackup);
} }
/* /*
...@@ -114,8 +77,6 @@ convert data type (cuda code) ...@@ -114,8 +77,6 @@ convert data type (cuda code)
*/ */
void _CudaConvertDataType(const XTensor * input, XTensor * output) void _CudaConvertDataType(const XTensor * input, XTensor * output)
{ {
//CheckNTErrors((input->unitSize == output->unitSize), "Input and Output must be same in size!");
if (input->dataType == output->dataType) if (input->dataType == output->dataType)
return; return;
...@@ -131,13 +92,17 @@ void _CudaConvertDataType(const XTensor * input, XTensor * output) ...@@ -131,13 +92,17 @@ void _CudaConvertDataType(const XTensor * input, XTensor * output)
ProtectCudaDev(input->devID, devIDBackup); ProtectCudaDev(input->devID, devIDBackup);
if(input->dataType == X_FLOAT && output->dataType == X_INT) if(input->dataType == X_FLOAT && output->dataType == X_INT)
KernelFloatToInt<<<blocks, threads>>>((float*)input->data, (int*)output->data, input->unitNum); KernelFloatToInt<<<blocks, threads>>>
((float*)input->data, (int*)output->data, input->unitNum);
else if(input->dataType == X_INT && output->dataType == X_FLOAT) else if(input->dataType == X_INT && output->dataType == X_FLOAT)
KernelIntToFloat<<<blocks, threads>>>((int*)input->data, (float*)output->data, input->unitNum); KernelIntToFloat<<<blocks, threads>>>
((int*)input->data, (float*)output->data, input->unitNum);
else if(input->dataType == X_FLOAT && output->dataType == X_FLOAT16) else if(input->dataType == X_FLOAT && output->dataType == X_FLOAT16)
KernelFloatToFloat16<<<blocks, threads>>>((float*)input->data, (__half*)output->data, input->unitNum); KernelFloatToFloat16<<<blocks, threads>>>
((float*)input->data, (__half*)output->data, input->unitNum);
else if(input->dataType == X_FLOAT16 && output->dataType == X_FLOAT) else if(input->dataType == X_FLOAT16 && output->dataType == X_FLOAT)
KernelFloat16ToFloat<<<blocks, threads>>>((__half*)input->data, (float*)output->data, input->unitNum); KernelFloat16ToFloat<<<blocks, threads>>>
((__half*)input->data, (float*)output->data, input->unitNum);
else{ else{
ShowNTErrors("Unsupported data types for conversion!"); ShowNTErrors("Unsupported data types for conversion!");
} }
......
/* NiuTrans.Tensor - an open-source tensor library /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University. * Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved. * All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11 * $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/ */
#ifndef __CONVERTDATATYPE_CUH__ #ifndef __CONVERTDATATYPE_CUH__
#define __CONVERTDATATYPE_CUH__ #define __CONVERTDATATYPE_CUH__
......
/* NiuTrans.Tensor - an open-source tensor library /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University. * Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved. * All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11 * $Created by: LI Yinqiao (li.yin.qiao.2012@hotmail.com) 2018-7-11
*/ */
#ifndef __CONVERTDATATYPE_H__ #ifndef __CONVERTDATATYPE_H__
#define __CONVERTDATATYPE_H__ #define __CONVERTDATATYPE_H__
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XDataType.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* data conversion (for lower precision computation) */
void ConvertDataType(int devID,
void * s, TENSOR_DATA_TYPE typeS,
void * t, TENSOR_DATA_TYPE typeT, int size);
/* convert data type */ /* convert data type */
void _ConvertDataType(const XTensor * input, XTensor * output); void _ConvertDataType(const XTensor * input, XTensor * output);
/* convert data type (return an XTensor structure) */
XTensor ConvertDataType(const XTensor & input, TENSOR_DATA_TYPE dataType);
/* convert data type */
void ConvertDataType(const XTensor & input, XTensor & output, TENSOR_DATA_TYPE dataType);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // __CONVERTDATATYPE_H__ #endif // __CONVERTDATATYPE_H__
...@@ -466,13 +466,23 @@ void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift) ...@@ -466,13 +466,23 @@ void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift)
} }
} }
/* generate data items with a uniform distribution in [0, 1] */
void _SetDataRand(XTensor * tensor, int rNum, int cNum)
{
if (tensor == NULL || tensor->isInit == false || tensor->order !=2 ) {
InitTensor2D(tensor, rNum, cNum);
}
_SetDataRand(tensor, 0.0F, 1.0F);
}
/* /*
generate data items with a uniform distribution in [lower, upper] generate data items with a uniform distribution in [lower, upper]
>> tensor - the tensor whose data array would be initialized >> tensor - the tensor whose data array would be initialized
>> lower - lower value of the range >> lower - lower value of the range
>> upper - upper value of the range >> upper - upper value of the range
*/ */
void _SetDataRand(const XTensor * tensor, DTYPE lower, DTYPE upper) void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper)
{ {
CheckNTErrors(upper > lower, "the high value must be greater than low value!"); CheckNTErrors(upper > lower, "the high value must be greater than low value!");
...@@ -525,7 +535,7 @@ the item to a pre-defined value if the item >= p, set the item to 0 otherwise ...@@ -525,7 +535,7 @@ the item to a pre-defined value if the item >= p, set the item to 0 otherwise
>> p - the threshold >> p - the threshold
>> value - the value we intend to assign to the item >> value - the value we intend to assign to the item
*/ */
void _SetDataRandP(const XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value) void _SetDataRandP(XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value)
{ {
CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO"); CheckNTErrors(tensor->dataType == DEFAULT_DTYPE, "TODO");
......
...@@ -569,15 +569,17 @@ void _CudaSetDataRand(const XTensor * tensor, DTYPE lower, DTYPE upper) ...@@ -569,15 +569,17 @@ void _CudaSetDataRand(const XTensor * tensor, DTYPE lower, DTYPE upper)
ProtectCudaDev(tensor->devID, devIDBackup); ProtectCudaDev(tensor->devID, devIDBackup);
curandGenerator_t & gen = GDevs.GPUs[tensor->devID].gen; curandGenerator_t & gen = GDevs.GPUs[tensor->devID].gen;
curandGenerateUniform(gen , (float*)tensor->data , tensor->unitNum); curandGenerateUniform(gen, (float*)tensor->data, tensor->unitNum);
DTYPE variance = upper - lower; DTYPE variance = upper - lower;
if(variance != 1.0F || lower != 0){ if(variance != 1.0F || lower != 0){
if (tensor->dataType == X_FLOAT) if (tensor->dataType == X_FLOAT)
KernelSetDataRandFloat <<<blocks, threads >>>((float*) tensor->data, tensor->unitNum, lower, variance); KernelSetDataRandFloat <<<blocks, threads >>>
((float*) tensor->data, tensor->unitNum, lower, variance);
else if (tensor->dataType == X_DOUBLE) else if (tensor->dataType == X_DOUBLE)
KernelSetDataRandDouble <<<blocks, threads >>>((double*)tensor->data, tensor->unitNum, lower, variance); KernelSetDataRandDouble <<<blocks, threads >>>
((double*)tensor->data, tensor->unitNum, lower, variance);
} }
BacktoCudaDev(tensor->devID, devIDBackup); BacktoCudaDev(tensor->devID, devIDBackup);
......
...@@ -63,12 +63,15 @@ void _SetDataIndexed(XTensor * source, XTensor * modify, int dim, int index); ...@@ -63,12 +63,15 @@ void _SetDataIndexed(XTensor * source, XTensor * modify, int dim, int index);
/* generate data as lower triangular matrics for last two dimensions */ /* generate data as lower triangular matrics for last two dimensions */
void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift); void _SetDataLowTri(XTensor * tensor, DTYPE p, int shift);
/* generate data items with a uniform distribution in [0, 1] */
void _SetDataRand(XTensor * tensor, int rNum, int cNum);
/* generate data items with a uniform distribution in [lower, upper] */ /* generate data items with a uniform distribution in [lower, upper] */
void _SetDataRand(const XTensor * tensor, DTYPE lower, DTYPE upper); void _SetDataRand(XTensor * tensor, DTYPE lower, DTYPE upper);
/* generate data items with a uniform distribution in [lower, upper] and set /* generate data items with a uniform distribution in [lower, upper] and set
the item to a pre-defined value if the item >= p, set the item to 0 otherwise */ the item to a pre-defined value if the item >= p, set the item to 0 otherwise */
void _SetDataRandP(const XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value); void _SetDataRandP(XTensor * tensor, DTYPE lower, DTYPE upper, DTYPE p, DTYPE value);
/* generate data items with a normal distribution with specified mean and standard deviation */ /* generate data items with a normal distribution with specified mean and standard deviation */
void _SetDataRandN(XTensor * tensor, DTYPE mean = 0.0F, DTYPE standardDeviation = 1.0F); void _SetDataRandN(XTensor * tensor, DTYPE mean = 0.0F, DTYPE standardDeviation = 1.0F);
......
...@@ -29,38 +29,25 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -29,38 +29,25 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA #ifdef USE_CUDA
/* scale each entry (CUDA Kernel) */
__global__
void KernelScale(int * a, int * b, int size, int scale);
__global__
void KernelScale(int * a, int * b, int size, float scale);
/* scale each entry */
void _CudaScale(const XTensor * a, XTensor * b, int scale);
void _CudaScaleFloat(const XTensor * a, XTensor * b, float scale);
/* descale each entry (CUDA Kernel) */
__global__
void KernelDescale(int * a, int * b, int size, int scale);
__global__
void KernelDescale(int * a, int * b, int size, float scale);
/* descale each entry */ /* descale each entry */
void _CudaDescale(const XTensor * a, XTensor * b, int scale); template<class T>
void _CudaDescaleFloat(const XTensor * a, XTensor * b, float scale); void _CudaDescale(const XTensor * a, XTensor * b, T num);
/* shift each entry (CUDA Kernel) */ /* power each entry */
__global__ template<class T>
void KernelShift(int * a, int * b, int size, int shift); void _CudaPower(const XTensor * a, XTensor * b, T num);
__global__
void KernelShift(int * a, int * b, int size, float shift);
/* shift each entry */
void _CudaShift(const XTensor * a, XTensor * b, int shift);
void _CudaShiftFloat(const XTensor * a, XTensor * b, float shift);
/* mod each entry (CUDA Kernel) */
__global__
void KernelMod(int * a, int * b, int size, int base);
/* mod each entry */ /* mod each entry */
void _CudaMod(const XTensor * a, XTensor * b, int base); template<class T>
void _CudaMod(const XTensor * a, XTensor * b, T base);
/* scale each entry */
template<class T>
void _CudaScale(const XTensor * a, XTensor * b, T num);
/* shift each entry */
template<class T>
void _CudaShift(const XTensor * a, XTensor * b, T num);
#endif // USE_CUDA #endif // USE_CUDA
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
*/ */
/* /*
* $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2019-04-05 * $Created by: JIANG Yufan (email: jiangyufan2018@outlook.com) 2019-04-05
*/ */
#ifndef __BINARY_H__ #ifndef __BINARY_H__
#define __BINARY_H__ #define __BINARY_H__
...@@ -26,132 +26,110 @@ ...@@ -26,132 +26,110 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* /* descale tensor entires
scale up tensor entires b = a / num */
b = a * scale template<class T>
*/ void _Descale(const XTensor * a, XTensor * b, T num);
void _Scale(const XTensor * a, XTensor * b, int scale); /* descale tensor entires (on site)
void _Scale(const XTensor * a, XTensor * b, float scale); b = a / num */
template<class T>
/* void _DescaleMe(XTensor * a, T num);
scale up tensor entires (on site) /* descale tensor entires (on site)
b = a * scale b = a / num */
*/ template<class T>
void _ScaleMe(XTensor * a, int scale); void DescaleMe(XTensor & a, T num);
void _ScaleMe(XTensor * a, float scale); /* descale tensor entires
b = a / num */
/* template<class T>
scale up tensor entires (on site) void Descale(const XTensor & a, XTensor & b, T num);
b = a * scale /* descale tensor entires (return an XTensor structure)
*/ b = a / num */
void ScaleMe(XTensor & a, int scale); template<class T>
void ScaleMe(XTensor & a, float scale); XTensor Descale(const XTensor & a, T num);
/* /* mod tensor entires
scale up tensor entires b = a % base */
b = a * scale template<class T>
*/ void _Mod(const XTensor * a, XTensor * b, T base);
void Scale(const XTensor & a, XTensor &b, int scale); /* mod base entires (on site)
void Scale(const XTensor & a, XTensor &b, float scale); b = a % num */
template<class T>
/* void _ModMe(XTensor * a, T base);
scale up tensor entires (return an XTensor structure) /* mod tensor entires (on site)
b = a * scale b = a % base */
*/ template<class T>
XTensor Scale(const XTensor & a, float scale); void ModMe(XTensor & a, T base);
/* mod tensor entires
/* b = a % base */
descale tensor entires template<class T>
b = a / scale void Mod(const XTensor & a, XTensor & b, T base);
*/ /* mod tensor entires (return an XTensor structure)
void _Descale(const XTensor * a, XTensor * b, int scale); b = a % base */
void _Descale(const XTensor * a, XTensor * b, float scale); template<class T>
XTensor Mod(const XTensor & a, T base);
/*
descale tensor entires (on site) /* get the power(x, y)
b = a / scale b = power(a, num) */
*/ template<class T>
void _DescaleMe(XTensor * a, int scale); void _Power(const XTensor * a, XTensor * b, T scale);
void _DescaleMe(XTensor * a, float scale); /* get the power(x, y) (on site)
b = power(a, num) */
/* template<class T>
descale tensor entires (on site) void _PowerMe(XTensor * a, T scale);
b = a / scale /* get the power(x, y) (on site)
*/ b = power(a, num) */
void DescaleMe(XTensor & a, int scale); template<class T>
void DescaleMe(XTensor & a, float scale); void PowerMe(XTensor & a, T scale);
/* get the power(x, y)
/* b = power(a, num) */
descale tensor entires template<class T>
b = a / scale void Power(const XTensor & a, XTensor & b, T scale);
*/ /* get the power(x, y) (return an XTensor structure)
void Descale(const XTensor & a, XTensor & b, int scale); b = power(a, num) */
void Descale(const XTensor & a, XTensor & b, float scale); template<class T>
XTensor Power(const XTensor & a, T scale);
/*
descale tensor entires (return an XTensor structure) /* scale up tensor entires
b = a / scale b = a * num */
*/ template<class T>
XTensor Descale(const XTensor & a, float scale); void _Scale(const XTensor * a, XTensor * b, T num);
/* scale up tensor entires (on site)
/* b = a * num */
shift tensor entires template<class T>
b = a + shift void _ScaleMe(XTensor * a, T num);
*/ /* scale up tensor entires (on site)
void _Shift(const XTensor * a, XTensor * b, int shift); b = a * num */
void _Shift(const XTensor * a, XTensor * b, float shift); template<class T>
void ScaleMe(XTensor & a, T num);
/* /* scale up tensor entires
shift tensor entires (on site) b = a * num */
b = a + shift template<class T>
*/ void Scale(const XTensor & a, XTensor & b, T num);
void _ShiftMe(XTensor * a, int shift); /* scale up tensor entires (return an XTensor structure)
void _ShiftMe(XTensor * a, float shift); b = a * num */
template<class T>
/* XTensor Scale(const XTensor & a, T num);
shift tensor entires (on site)
b = a + shift /* shift tensor entires
*/ b = a + num */
void ShiftMe(XTensor & a, int shift); template<class T>
void ShiftMe(XTensor & a, float shift); void _Shift(const XTensor * a, XTensor * b, T num);
/* shift tensor entires (on site)
/* b = a + num */
shift tensor entires template<class T>
b = a + shift void _ShiftMe(XTensor * a, T num);
*/ /* shift tensor entires (on site)
void Shift(const XTensor & a, XTensor & b, int shift); b = a + num */
void Shift(const XTensor & a, XTensor & b, float shift); template<class T>
void ShiftMe(XTensor & a, T num);
/* /* shift tensor entires
shift tensor entires (return an XTensor structure) b = a + num */
b = a + shift template<class T>
*/ void Shift(const XTensor & a, XTensor & b, T num);
XTensor Shift(const XTensor & a, float shift); /* shift tensor entires (return an XTensor structure)
b = a + num */
template<class T>
/* XTensor Shift(const XTensor & a, T num);
mod tensor entires
b = a % mod
*/
void _Mod(const XTensor * a, XTensor * b, int base);
/*
mod tensor entires (on site)
b = a % mod
*/
void _ModMe(XTensor * a, int base);
/*
mod tensor entires (on site)
b = a % mod
*/
void ModMe(XTensor & a, int base);
/*
mod tensor entires
b = a % mod
*/
void Mod(const XTensor & a, XTensor & b, int base);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
...@@ -37,88 +37,72 @@ DTYPE myIsNotEqual(DTYPE a, DTYPE b) ...@@ -37,88 +37,72 @@ DTYPE myIsNotEqual(DTYPE a, DTYPE b)
} }
#ifdef USE_CUDA #ifdef USE_CUDA
/* define three marco separately, specify the respective function names (GPU mode) */ /* define three marco separately, specify the respective function names */
#define _SIMPLE_COMPARE_FUNCTION(_funcName, _cudaFuncName, origFunc) \ #define _SIMPLE_COMPARE_FUNCTION(_funcName, _cudaFuncName, origFunc) \
void _funcName(const XTensor * a, XTensor * b, DTYPE number) \ void _funcName(const XTensor * a, XTensor * b, DTYPE number) \
{ \ { \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \ CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \ "Input tensors should have the same type!"); \
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!"); \ CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!"); \
/* run it on GPUs */ \ /* run it on GPUs */ \
if (a->devID >= 0) { \ if (a->devID >= 0) { \
_cudaFuncName(a, b, number); \ if (useCUDA) { \
return; \ _cudaFuncName(a, b, number); \
} \ return; \
DTYPE * d = (DTYPE*)a->data; \ } \
DTYPE * db = (DTYPE*)b->data; \ else \
for (int i = 0; i < a->unitNum; i++) \ ShowNTErrors("No GPU devices support!") \
db[i] = (DTYPE)origFunc(d[i], number); \ } \
DTYPE * d = (DTYPE*)a->data; \
DTYPE * db = (DTYPE*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (DTYPE)origFunc(d[i], number); \
}
#define _SIMPLE_COMPARE_FUNCTION_ME(_funcNameMe, _funcName) \
void _funcNameMe(XTensor * a, DTYPE number) \
{ \
_funcName(a, a, number); \
}
#define SIMPLE_COMPARE_FUNCTION_ME(funcNameMe, _funcName) \
void funcNameMe(XTensor & a, DTYPE number) \
{ \
_funcName(&a, &a, number); \
}
#define SIMPLE_COMPARE_FUNCTION(funcName, _funcName, operationId) \
XTensor funcName(const XTensor &a, DTYPE number) \
{ \
XTensor b(&a); \
b.SetTMPFlag(); \
_funcName(&a, &b, number); \
return b; \
} }
#define _SIMPLE_COMPARE_FUNCTION_ME(_funcNameMe, _funcName) \ #define SIMPLE_COMPARE_FUNCTION_VOID(funcName, _funcName, operationId) \
void _funcNameMe(XTensor * a, DTYPE number) \ void funcName(const XTensor &a, XTensor &b, DTYPE number) \
{ \ { \
_funcName(a, a, number); \ if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) { \
} InitTensor(&b, &a); \
} \
#define SIMPLE_COMPARE_FUNCTION(funcName, _funcName, operationId) \ _funcName(&a, &b, number); \
XTensor funcName(const XTensor &a, DTYPE number) \
{ \
XTensor b(&a); \
b.SetTMPFlag(); \
_funcName(&a, &b, number); \
return b; \
} }
// I think we needn't to make link. // I think we needn't to make link.
// XLink::MakeLink(&a, NULL, &b, operationId); // XLink::MakeLink(&a, NULL, &b, operationId);
_SIMPLE_COMPARE_FUNCTION(_Equal, _CudaEqual, myIsEqual) _SIMPLE_COMPARE_FUNCTION(_Equal, _CudaEqual, myIsEqual)
_SIMPLE_COMPARE_FUNCTION_ME(_EqualMe, _Equal) _SIMPLE_COMPARE_FUNCTION_ME(_EqualMe, _Equal)
SIMPLE_COMPARE_FUNCTION_ME(EqualMe, _Equal)
SIMPLE_COMPARE_FUNCTION(Equal, _Equal, MATH_EQUAL) SIMPLE_COMPARE_FUNCTION(Equal, _Equal, MATH_EQUAL)
SIMPLE_COMPARE_FUNCTION_VOID(Equal, _Equal, MATH_EQUAL)
_SIMPLE_COMPARE_FUNCTION(_NotEqual, _CudaNotEqual, myIsNotEqual) _SIMPLE_COMPARE_FUNCTION(_NotEqual, _CudaNotEqual, myIsNotEqual)
_SIMPLE_COMPARE_FUNCTION_ME(_NotEqualMe, _NotEqual) _SIMPLE_COMPARE_FUNCTION_ME(_NotEqualMe, _NotEqual)
SIMPLE_COMPARE_FUNCTION_ME(NotEqualMe, _NotEqual)
SIMPLE_COMPARE_FUNCTION(NotEqual, _NotEqual, MATH_NOTEQUAL) SIMPLE_COMPARE_FUNCTION(NotEqual, _NotEqual, MATH_NOTEQUAL)
SIMPLE_COMPARE_FUNCTION_VOID(NotEqual, _NotEqual, MATH_NOTEQUAL)
#else
/* define three marco separately, specify the respective function names (CPU mode) */
#define _SIMPLE_COMPARE_FUNCTION(_funcName, origFunc) \
void _funcName(const XTensor * a, XTensor * b, DTYPE number) \
{ \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!"); \
DTYPE * d = (DTYPE*)a->data; \
DTYPE * db = (DTYPE*)b->data; \
for (int i = 0; i < a->unitNum; i++) \
db[i] = (DTYPE)origFunc(d[i], number); \
}
#define _SIMPLE_COMPARE_FUNCTION_ME(_funcNameMe, _funcName) \
void _funcNameMe(XTensor * a, DTYPE number) \
{ \
_funcName(a, a, number); \
}
#define SIMPLE_COMPARE_FUNCTION(funcName, _funcName, operationId) \
XTensor funcName(const XTensor &a, DTYPE number) \
{ \
XTensor b(&a); \
b.SetTMPFlag(); \
_funcName(&a, &b, number); \
return b; \
}
// I think we needn't to make link.
// XLink::MakeLink(&a, NULL, &b, operationId);
_SIMPLE_COMPARE_FUNCTION(_Equal, myIsEqual)
_SIMPLE_COMPARE_FUNCTION_ME(_EqualMe, _Equal)
SIMPLE_COMPARE_FUNCTION(Equal, _Equal, MATH_EQUAL)
_SIMPLE_COMPARE_FUNCTION(_NotEqual, myIsNotEqual)
_SIMPLE_COMPARE_FUNCTION_ME(_NotEqualMe, _NotEqual)
SIMPLE_COMPARE_FUNCTION(NotEqual, _NotEqual, MATH_NOTEQUAL)
#endif #endif
......
...@@ -38,6 +38,9 @@ void EqualMe(XTensor & a, DTYPE value); ...@@ -38,6 +38,9 @@ void EqualMe(XTensor & a, DTYPE value);
/* check whether every entry is equal to the given value (return an XTensor structure) */ /* check whether every entry is equal to the given value (return an XTensor structure) */
XTensor Equal(const XTensor & a, DTYPE value); XTensor Equal(const XTensor & a, DTYPE value);
/* check whether every entry is equal to the given value */
void Equal(const XTensor & a, XTensor & b, DTYPE value);
/* check whether every entry is not equal to the given value */ /* check whether every entry is not equal to the given value */
void _NotEqual(const XTensor * a, XTensor * b, DTYPE value); void _NotEqual(const XTensor * a, XTensor * b, DTYPE value);
...@@ -50,6 +53,9 @@ void NotEqualMe(XTensor & a, DTYPE value); ...@@ -50,6 +53,9 @@ void NotEqualMe(XTensor & a, DTYPE value);
/* check whether every entry is not equal to the given value (return an XTensor structure) */ /* check whether every entry is not equal to the given value (return an XTensor structure) */
XTensor NotEqual(const XTensor & a, DTYPE value); XTensor NotEqual(const XTensor & a, DTYPE value);
/* check whether every entry is not equal to the given value */
void NotEqual(const XTensor & a, XTensor & b, DTYPE value);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
#endif // end __COMPARE_H__ #endif // end __COMPARE_H__
\ No newline at end of file
...@@ -42,7 +42,9 @@ where a and b are the scalar and bias respectively, and \epsilon is the adjustme ...@@ -42,7 +42,9 @@ where a and b are the scalar and bias respectively, and \epsilon is the adjustme
>> b - the bias >> b - the bias
>> epsilon - a parameter >> epsilon - a parameter
*/ */
void _Normalize(const XTensor * input, XTensor * output, int dim, const XTensor * mean, const XTensor * var, const XTensor * a, const XTensor * b, DTYPE epsilon) void _Normalize(const XTensor * input, XTensor * output, int dim,
const XTensor * mean, const XTensor * var,
const XTensor * a, const XTensor * b, DTYPE epsilon)
{ {
int dimRDI = input->order - dim - 1; int dimRDI = input->order - dim - 1;
CheckNTErrors((XTensor::IsSameShaped(input, output)), "Unmatched input tensors!"); CheckNTErrors((XTensor::IsSameShaped(input, output)), "Unmatched input tensors!");
...@@ -109,7 +111,9 @@ where a and b are the scalar and bias respectively, and \epsilon is the adjustme ...@@ -109,7 +111,9 @@ where a and b are the scalar and bias respectively, and \epsilon is the adjustme
>> b - the bias >> b - the bias
>> epsilon - a parameter >> epsilon - a parameter
*/ */
void _NormalizeMe(XTensor * input, int dim, const XTensor * mean, const XTensor * var, const XTensor * a, const XTensor * b, DTYPE epsilon) void _NormalizeMe(XTensor * input, int dim,
const XTensor * mean, const XTensor * var,
const XTensor * a, const XTensor * b, DTYPE epsilon)
{ {
_Normalize(input, input, dim, mean, var, a, b, epsilon); _Normalize(input, input, dim, mean, var, a, b, epsilon);
} }
...@@ -129,7 +133,9 @@ where a and b are the scalar and bias respectively, and \epsilon is the adjustme ...@@ -129,7 +133,9 @@ where a and b are the scalar and bias respectively, and \epsilon is the adjustme
>> b - the bias >> b - the bias
>> epsilon - a parameter >> epsilon - a parameter
*/ */
void NormalizeMe(XTensor& input, int dim, const XTensor& mean, const XTensor& var, const XTensor& a, const XTensor& b, DTYPE epsilon) void NormalizeMe(XTensor& input, int dim,
const XTensor& mean, const XTensor& var,
const XTensor& a, const XTensor& b, DTYPE epsilon)
{ {
_Normalize(&input, &input, dim, &mean, &var, &a, &b, epsilon); _Normalize(&input, &input, dim, &mean, &var, &a, &b, epsilon);
} }
...@@ -150,7 +156,9 @@ where a and b are the scalar and bias respectively, and \epsilon is the adjustme ...@@ -150,7 +156,9 @@ where a and b are the scalar and bias respectively, and \epsilon is the adjustme
>> epsilon - a parameter >> epsilon - a parameter
<< return - the result of normalized the data with normal distribution << return - the result of normalized the data with normal distribution
*/ */
XTensor Normalize(const XTensor &input, int dim, const XTensor &mean, const XTensor &var, const XTensor &a, const XTensor &b, DTYPE epsilon) XTensor Normalize(const XTensor &input, int dim,
const XTensor &mean, const XTensor &var,
const XTensor &a, const XTensor &b, DTYPE epsilon)
{ {
XTensor output(&input); XTensor output(&input);
output.SetTMPFlag(); output.SetTMPFlag();
...@@ -171,4 +179,48 @@ XTensor Normalize(const XTensor &input, int dim, const XTensor &mean, const XTen ...@@ -171,4 +179,48 @@ XTensor Normalize(const XTensor &input, int dim, const XTensor &mean, const XTen
return output; return output;
} }
/*
normalized the data with normal distribution (return an XTensor structure)
make a new tensor to keep the result and return it
For an input x, y = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter.
>> input - the input tensor
>> output - the output tensor
>> dim - dimension alone which we generate the mean and variance
>> mean - the mean of the input
>> var - the variance of the input
>> a - the scalar
>> b - the bias
>> epsilon - a parameter
<< return - the result of normalized the data with normal distribution
*/
void Normalize(const XTensor &input, XTensor &output, int dim,
const XTensor &mean, const XTensor &var,
const XTensor &a, const XTensor &b, DTYPE epsilon)
{
if (!output.isInit || !XTensor::IsSameShaped(&input, &output)) {
InitTensor(&output, &input);
}
/* call _Normalize function */
_Normalize(&input, &output, dim, &mean, &var, &a, &b, epsilon);
if (output.enableGrad == true) {
/* tensor connections */
TensorList list(5);
list.Add((XTensor*)&input);
list.Add((XTensor*)&mean);
list.Add((XTensor*)&var);
list.Add((XTensor*)&a);
list.Add((XTensor*)&b);
XLink::MakeLink(&list, &output, MATH_NORMALIZE);
XLink::AddParamToHeadInt(&output, dim);
XLink::AddParamToHead(&output, epsilon);
}
}
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
...@@ -31,7 +31,9 @@ normalized the data with normal distribution. ...@@ -31,7 +31,9 @@ normalized the data with normal distribution.
For an input x, y = a * (x-mean)/sqrt(variance+\epsilon) + b For an input x, y = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter. where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter.
*/ */
void _Normalize(const XTensor * input, XTensor * output, int dim, const XTensor * mean, const XTensor * var, const XTensor * a, const XTensor * b, DTYPE epsilon); void _Normalize(const XTensor * input, XTensor * output, int dim,
const XTensor * mean, const XTensor * var,
const XTensor * a, const XTensor * b, DTYPE epsilon);
/* /*
normalized the data with normal distribution (do it on site) normalized the data with normal distribution (do it on site)
...@@ -39,7 +41,9 @@ keep the result in the input tenosr and return nothing ...@@ -39,7 +41,9 @@ keep the result in the input tenosr and return nothing
For an input x, x = a * (x-mean)/sqrt(variance+\epsilon) + b For an input x, x = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter. where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter.
*/ */
void _NormalizeMe(XTensor * input, int dim, const XTensor * mean, const XTensor * var, const XTensor * a, const XTensor * b, DTYPE epsilon); void _NormalizeMe(XTensor * input, int dim,
const XTensor * mean, const XTensor * var,
const XTensor * a, const XTensor * b, DTYPE epsilon);
/* /*
normalized the data with normal distribution (do it on site) normalized the data with normal distribution (do it on site)
...@@ -47,7 +51,9 @@ keep the result in the input tenosr and return nothing ...@@ -47,7 +51,9 @@ keep the result in the input tenosr and return nothing
For an input x, x = a * (x-mean)/sqrt(variance+\epsilon) + b For an input x, x = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter. where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter.
*/ */
void NormalizeMe(XTensor & input, int dim, const XTensor & mean, const XTensor & var, const XTensor & a, const XTensor & b, DTYPE epsilon); void NormalizeMe(XTensor & input, int dim,
const XTensor & mean, const XTensor & var,
const XTensor & a, const XTensor & b, DTYPE epsilon);
/* /*
normalized the data with normal distribution (return an XTensor structure) normalized the data with normal distribution (return an XTensor structure)
...@@ -55,7 +61,19 @@ make a new tensor to keep the result and return it ...@@ -55,7 +61,19 @@ make a new tensor to keep the result and return it
For an input x, y = a * (x-mean)/sqrt(variance+\epsilon) + b For an input x, y = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter. where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter.
*/ */
XTensor Normalize(const XTensor &input, int dim, const XTensor &mean, const XTensor &var, const XTensor &a, const XTensor &b, DTYPE epsilon); XTensor Normalize(const XTensor &input, int dim,
const XTensor &mean, const XTensor &var,
const XTensor &a, const XTensor &b, DTYPE epsilon);
/*
normalized the data with normal distribution (return an XTensor structure)
make a new tensor to keep the result and return it
For an input x, y = a * (x-mean)/sqrt(variance+\epsilon) + b
where a and b are the scalar and bias respectively, and \epsilon is the adjustment parameter.
*/
void Normalize(const XTensor &input, XTensor &output, int dim,
const XTensor &mean, const XTensor &var,
const XTensor &a, const XTensor &b, DTYPE epsilon);
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
......
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include <math.h>
#include "../../XTensor.h"
#include "../../XName.h"
#include "Power.h"
#include "Power.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
/*
get the power(a, p)
>> a - input tensor
>> b - output tensor
>> p - parameter
*/
void _Power(const XTensor * a, XTensor * b, DTYPE p)
{
#ifdef USE_CUDA
/* run it on GPUs */
if (a->devID >= 0) {
_CudaPower(a, b, p);
return;
}
#endif
CheckNTErrors((a->dataType == DEFAULT_DTYPE), "TODO!");
DTYPE * aData = (DTYPE*)a->data;
DTYPE * bData = (DTYPE*)b->data;
if (p == 0) {
for (int i = 0; i < a->unitNum; i++)
bData[i] = (DTYPE)1.0;
}
else if (p == (DTYPE)0.5) {
for (int i = 0; i < a->unitNum; i++)
bData[i] = (DTYPE)sqrt(aData[i]);
}
else if (p == (DTYPE)2.0) {
for (int i = 0; i < a->unitNum; i++)
bData[i] = aData[i] * aData[i];
}
else {
for (int i = 0; i < a->unitNum; i++) {
if (p < 0 && aData[i] == 0)
bData[i] = 1e20F;
else
bData[i] = (DTYPE)pow(aData[i], p);
}
}
}
/*
get the power(a, p) (do it on site)
keep the result in the input tensor a and return nothing
>> a - the tensor
>> p - parameter
*/
void _PowerMe(XTensor * a, DTYPE p)
{
_Power(a, a, p);
}
/*
get the power(a, p) (do it on site)
keep the result in the input tensor a and return nothing
>> a - the tensor
>> p - parameter
*/
void PowerMe(XTensor& a, DTYPE p)
{
_Power(&a, &a, p);
}
/*
get the power(a, p) (return an XTensor structure)
make a new tensor to keep the result and return it
>> a - input tensor
>> p - parameter
<< return - the power value of the input tensor
*/
XTensor Power(const XTensor & a, DTYPE p)
{
XTensor b(&a);
b.SetTMPFlag();
/* call _Power function */
_Power(&a, &b, p);
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_POWER);
XLink::AddParamToHead(&b, p);
return b;
}
/*
get the power(a, p)
>> a - input tensor
>> b - output tensor
>> p - parameter
*/
void Power(const XTensor & a, XTensor & b, DTYPE p)
{
if (!b.isInit || !XTensor::IsSameShaped(&a, &b)) {
InitTensor(&b, &a);
}
/* call _Power function */
_Power(&a, &b, p);
if (b.enableGrad) {
/* tensor connections */
XLink::MakeLink(&a, NULL, &b, MATH_POWER);
XLink::AddParamToHead(&b, p);
}
}
} // namespace nts(NiuTrans.Tensor)
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#include "../../XDevice.h"
#include "../../XTensor.h"
#include "../movement/CopyValues.cuh"
#include "Power.h"
#include "Power.cuh"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/*
set all entries to its root (CUDA Kernel)
>> a - input data array
>> b - output data array
>> size - size of the data array
*/
__global__
void KernelSqrtV2(DTYPE * a, DTYPE * b, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size)
b[i] = sqrt(a[i]);
}
/*
set all entries to its root (CUDA Kernel)
>> a - input data array
>> b - output data array
>> size - size of the data array
*/
__global__
void KernelSqrtV2(__half * a, __half * b, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
if (i < size)
b[i] = hsqrt(a[i]);
#else
if (i < size)
b[i] = __float2half(sqrt(__half2float(a[i])));
#endif
}
/*
get power(d[i], p)
>> a - input data array
>> b - output data array
>> p - power
>> size - size of the data array
*/
__global__
void KernelPower(DTYPE * a, DTYPE * b, DTYPE p, int size)
{
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
DTYPE v = a[i];
if (p < 0 && v == 0)
b[i] = 1e20;
else
b[i] = pow(a[i], p);
}
}
/*
get power(d[i], p)
>> a - input data array
>> b - output data array
>> p - power
>> size - size of the data array
*/
__global__
void KernelPower(__half * a, __half * b, __half p, int size)
{
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
#else
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
float v = __half2float(a[i]);
if (__half2float(p) < 0 && v == 0)
b[i] = __float2half(1e20);
else
b[i] = __float2half(pow(__half2float(a[i]), __half2float(p)));
}
#endif
}
/* get the power of the entries */
void _CudaPower(const XTensor * a, XTensor * b, DTYPE p)
{
CheckNTErrors((XTensor::IsSameShaped(a, b)), "Input tensors should have the same type!");
int gridSize[3];
int blockSize[3];
GDevs.GetCudaThread(a->devID, a->unitNum, gridSize, blockSize);
dim3 blocks(gridSize[0]);
dim3 threads(blockSize[0]);
int devIDBackup;
ProtectCudaDev(a->devID, devIDBackup);
if (a->dataType == DEFAULT_DTYPE) {
if (p == (DTYPE)0.5) {
KernelSqrtV2 << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum);
}
else if (p == (DTYPE)1.0) {
_CudaCopyValues(a, b);
}
else if (p != (DTYPE)1.0) {
KernelPower << <blocks, threads >> >((DTYPE*)a->data, (DTYPE*)b->data, p, a->unitNum);
}
}
else if (a->dataType == X_FLOAT16) {
if (p == (DTYPE)0.5) {
KernelSqrtV2 << <blocks, threads >> >((__half*)a->data, (__half*)b->data, a->unitNum);
}
else if (p != (DTYPE)1.0) {
ShowNTErrors("TODO!");
}
}
else {
ShowNTErrors("TODO!");
}
BacktoCudaDev(a->devID, devIDBackup);
}
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __POWER_CUH__
#define __POWER_CUH__
#include "Power.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA
/* set all entries to its root (CUDA Kernel) */
__global__
void KernelSqrtV2(DTYPE * a, DTYPE * b, int size);
/* set all entries to its root (CUDA Kernel) */
__global__
void KernelSqrtV2(__half * a, __half * b, int size);
/* get the power of the entries */
void _CudaPower(const XTensor * a, XTensor * b, DTYPE p);
#endif // USE_CUDA
} // namespace nts(NiuTrans.Tensor)
#endif // __POWER_CUH__
\ No newline at end of file
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/
#ifndef __POWER_H__
#define __POWER_H__
#include "../../XTensor.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* get the power(x, y) */
void _Power(const XTensor * a, XTensor * b, DTYPE p);
/*
get the power(x, y) (do it on site)
keep the result in the input tensor a and return nothing
*/
void _PowerMe(XTensor * a, DTYPE p);
/*
get the power(x, y) (do it on site)
keep the result in the input tensor a and return nothing
*/
void PowerMe(XTensor & a, DTYPE p);
/*
get the power(x, y) (return an XTensor structure)
make a new tensor to keep the result and return it
*/
XTensor Power(const XTensor & a, DTYPE p);
/* get the power(x, y) */
void Power(const XTensor & a, XTensor & b, DTYPE p);
} // namespace nts(NiuTrans.Tensor)
#endif // __POWER_H__
...@@ -24,55 +24,139 @@ ...@@ -24,55 +24,139 @@
#include "../../XName.h" #include "../../XName.h"
#include "Unary.h" #include "Unary.h"
#include "Unary.cuh" #include "Unary.cuh"
#include<cuda_runtime.h>
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA #ifdef USE_CUDA
template<class T>
__device__ __device__
DTYPE cudasquare(DTYPE x) T UnaryCudaCeil(T x)
{
return (T)ceil((float)x);
}
template<class T>
__device__
T UnaryCudaExp(T x)
{
return (T)exp((float)x);
}
template<class T>
__device__
T UnaryCudaFabs(T x)
{
return (T)fabs((float)x);
}
template<class T>
__device__
T UnaryCudaFloor(T x)
{
return (T)floor((float)x);
}
template<class T>
__device__
T UnaryCudaIsNonZero(T r)
{
return (r != (T)0.0) ? (T)1.0 : (T)0.0;
}
template<class T>
__device__
T UnaryCudaIsZero(T r)
{
return (r == (T)0.0) ? (T)1.0 : (T)0.0;
}
template<class T>
__device__
T UnaryCudaLog(T x)
{
return (T)log((float)x);
}
template<class T>
__device__
T UnaryCudaNegate(T x)
{
return -x;
}
template<class T>
__device__
T UnaryCudaSign(T r)
{
if (r > (T)0)
return 1.0;
else if (r == (T)0)
return 0.0;
else
return -1.0;
}
template<class T>
__device__
T UnaryCudaSqrt(T x)
{
return (T)sqrt((float)x);
}
template<class T>
__device__
T UnaryCudaSquare(T x)
{ {
return x * x; return x * x;
} }
template<class T>
__device__ __device__
DTYPE cudaround(DTYPE r) T UnaryCudaRound(T r)
{ {
return (r > 0.0) ? (DTYPE)floor(r + 0.5) : (DTYPE)ceil(r - 0.5); return (r > (T)0.0) ? (T)UnaryCudaFloor(r + (T)0.5) : (T)UnaryCudaCeil(r - (T)0.5);
} }
template<class T>
__device__ __device__
DTYPE cudaisnonzero(DTYPE r) T UnaryCudaSin(T x)
{ {
return (r != 0.0) ? (DTYPE)1.0 : (DTYPE)0.0; return (T)sin((float)x);
} }
template<class T>
__device__ __device__
DTYPE cudaiszero(DTYPE r) T UnaryCudaCos(T x)
{ {
return (r == 0.0) ? (DTYPE)1.0 : (DTYPE)0.0; return (T)cos((float)x);
}
template<class T>
__device__
T UnaryCudaTan(T x)
{
return (T)tan((float)x);
} }
#define SIMPLE_UNARY_FUNCTION_GPU(funcName, origFunc) \ #define SIMPLE_UNARY_FUNCTION_GPU(funcName, origFunc) \
template<class T> \
__global__ \ __global__ \
void Kernel##funcName(DTYPE * a, DTYPE * b, int size) \ void Kernel##funcName(T * a, T * b, int size) \
{ \ { \
int i = blockDim.x * blockIdx.x + threadIdx.x; \ int i = blockDim.x * blockIdx.x + threadIdx.x; \
\ \
if (i < size) \ if (i < size) \
b[i] = (DTYPE)origFunc(a[i]); \ b[i] = (T)origFunc(a[i]); \
} \
__global__ \
void Kernel##funcName(__half * a, __half * b, int size) \
{ \
return; \
} \ } \
void _Cuda##funcName(const XTensor * a, XTensor * b) \ void _Cuda##funcName(const XTensor * a, XTensor * b) \
{ \ { \
CheckNTErrors((XTensor::IsSameShaped(a, b)), \ CheckNTErrors((XTensor::IsSameShaped(a, b)), \
"Input tensors should have the same type!"); \ "Input tensors should have the same type!"); \
CheckNTErrors((a->isSparse == false), "TODO!"); \ CheckNTErrors(a->isSparse == false, "TODO!"); \
\ \
int gridSize[3]; \ int gridSize[3]; \
int blockSize[3]; \ int blockSize[3]; \
...@@ -85,35 +169,43 @@ void _Cuda##funcName(const XTensor * a, XTensor * b) \ ...@@ -85,35 +169,43 @@ void _Cuda##funcName(const XTensor * a, XTensor * b) \
int devIDBackup; \ int devIDBackup; \
ProtectCudaDev(a->devID, devIDBackup); \ ProtectCudaDev(a->devID, devIDBackup); \
\ \
if (a->dataType == DEFAULT_DTYPE) { \ if (a->dataType == X_FLOAT) { \
Kernel##funcName<<<blocks, threads>>> \ Kernel##funcName<<<blocks, threads>>> \
((DTYPE*)a->data, (DTYPE*)b->data, a->unitNum); \ ((float*)a->data, (float*)b->data, a->unitNum); \
} \ } \
else if (a->dataType == X_FLOAT16) { \ else if (a->dataType == X_DOUBLE) { \
Kernel##funcName<<<blocks, threads>>> \ Kernel##funcName<<<blocks, threads>>> \
((__half*)a->data, (__half*)b->data, a->unitNum); \ ((double*)a->data, (double*)b->data, a->unitNum); \
} \
else if (a->dataType == X_INT) { \
Kernel##funcName<<<blocks, threads>>> \
((int*)a->data, (int*)b->data, a->unitNum); \
} \ } \
else { \ else { \
ShowNTErrors("TODO!"); \ ShowNTErrors("TODO!"); \
} \ } \
\ \
BacktoCudaDev(a->devID, devIDBackup); \ BacktoCudaDev(a->devID, devIDBackup); \
} \ }
SIMPLE_UNARY_FUNCTION_GPU(Absolute, UnaryCudaFabs)
SIMPLE_UNARY_FUNCTION_GPU(Ceil, UnaryCudaCeil)
SIMPLE_UNARY_FUNCTION_GPU(Exp, UnaryCudaExp)
SIMPLE_UNARY_FUNCTION_GPU(Floor, UnaryCudaFloor)
SIMPLE_UNARY_FUNCTION_GPU(IsNonZero, UnaryCudaIsNonZero)
SIMPLE_UNARY_FUNCTION_GPU(IsZero, UnaryCudaIsZero)
SIMPLE_UNARY_FUNCTION_GPU(Log, UnaryCudaLog)
SIMPLE_UNARY_FUNCTION_GPU(Negate, UnaryCudaNegate)
SIMPLE_UNARY_FUNCTION_GPU(Round, UnaryCudaRound)
SIMPLE_UNARY_FUNCTION_GPU(Sign, UnaryCudaSign)
SIMPLE_UNARY_FUNCTION_GPU(Sqrt, UnaryCudaSqrt)
SIMPLE_UNARY_FUNCTION_GPU(Square, UnaryCudaSquare)
SIMPLE_UNARY_FUNCTION_GPU(Absolute, fabs) SIMPLE_UNARY_FUNCTION_GPU(Sin, UnaryCudaSin)
SIMPLE_UNARY_FUNCTION_GPU(Ceil, ceil) SIMPLE_UNARY_FUNCTION_GPU(Cos, UnaryCudaCos)
SIMPLE_UNARY_FUNCTION_GPU(Exp, exp) SIMPLE_UNARY_FUNCTION_GPU(Tan, UnaryCudaTan)
SIMPLE_UNARY_FUNCTION_GPU(Floor, floor)
SIMPLE_UNARY_FUNCTION_GPU(IsNonZero, cudaisnonzero)
SIMPLE_UNARY_FUNCTION_GPU(IsZero, cudaiszero)
SIMPLE_UNARY_FUNCTION_GPU(Log, log)
SIMPLE_UNARY_FUNCTION_GPU(Round, cudaround)
SIMPLE_UNARY_FUNCTION_GPU(Sqrt, sqrt)
SIMPLE_UNARY_FUNCTION_GPU(Square, cudasquare)
SIMPLE_UNARY_FUNCTION_GPU(Sin, sin)
SIMPLE_UNARY_FUNCTION_GPU(Cos, cos)
SIMPLE_UNARY_FUNCTION_GPU(Tan, tan)
#endif // USE_CUDA #endif // USE_CUDA
......
...@@ -29,121 +29,49 @@ namespace nts { // namespace nts(NiuTrans.Tensor) ...@@ -29,121 +29,49 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
#ifdef USE_CUDA #ifdef USE_CUDA
/* set each entry to its absolute value (CUDA Kernel) */
__global__
void KernelAbsolute(DTYPE * a, DTYPE * b, int size);
/* set each entry to its absolute value (CUDA Kernel) with float16 data type*/
__global__
void KernelAbsolute(__half * a, __half * b, int size);
/* set each entry to its absolute value */ /* set each entry to its absolute value */
void _CudaAbsolute(const XTensor * a, XTensor * b); void _CudaAbsolute(const XTensor * a, XTensor * b);
/* set each entry to its ceil value (CUDA Kernel) */
__global__
void KernelCeil(DTYPE * a, DTYPE * b, int size);
/* set each entry to its ceil value (CUDA Kernel) with float16 data type*/
__global__
void KernelCeil(__half * a, __half * b, int size);
/* set each entry to its ceil value */ /* set each entry to its ceil value */
void _CudaCeil(const XTensor * a, XTensor * b); void _CudaCeil(const XTensor * a, XTensor * b);
/* set each entry to its exponent value (CUDA Kernel) */
__global__
void KernelExp(DTYPE * a, DTYPE * b, int size);
/* set each entry to its exponent value (CUDA Kernel) with float16 data type*/
__global__
void KernelExp(__half * a, __half * b, int size);
/* set each entry to its exponent value */ /* set each entry to its exponent value */
void _CudaExp(const XTensor * a, XTensor * b); void _CudaExp(const XTensor * a, XTensor * b);
/* set each entry to its floor value (CUDA Kernel) */
__global__
void KernelFloor(DTYPE * a, DTYPE * b, int size);
/* set each entry to its floor value (CUDA Kernel) with float16 data type*/
__global__
void KernelFloor(__half * a, __half * b, int size);
/* set each entry to its floor value */ /* set each entry to its floor value */
void _CudaFloor(const XTensor * a, XTensor * b); void _CudaFloor(const XTensor * a, XTensor * b);
/* if source entry is non-zero, set target entry to be one, otherwise zero (CUDA Kernel) */
__global__
void KernelIsNonZero(DTYPE * a, DTYPE * b, int size);
/* if source entry is non-zero, set target entry to be one, otherwise zero (CUDA Kernel) with float16 data type*/
__global__
void KernelIsNonZero(__half * a, __half * b, int size);
/* if source entry is non-zero, set target entry to be one, otherwise zero */ /* if source entry is non-zero, set target entry to be one, otherwise zero */
void _CudaIsNonZero(const XTensor * a, XTensor * b); void _CudaIsNonZero(const XTensor * a, XTensor * b);
/* if source entry is zero, set target entry to be one, otherwise zero (CUDA Kernel) */
__global__
void KernelIsZero(DTYPE * a, DTYPE * b, int size);
/* if source entry is zero, set target entry to be one, otherwise zero (CUDA Kernel) with float16 data type*/
__global__
void KernelIsZero(__half * a, __half * b, int size);
/* if source entry is zero, set target entry to be one, otherwise zero */ /* if source entry is zero, set target entry to be one, otherwise zero */
void _CudaIsZero(const XTensor * a, XTensor * b); void _CudaIsZero(const XTensor * a, XTensor * b);
/* set each entry to its logarithm value (CUDA Kernel) */
__global__
void KernelLog(DTYPE * a, DTYPE * b, int size);
/* set each entry to its logarithm value (CUDA Kernel) with float16 data type*/
__global__
void KernelLog(__half * a, __half * b, int size);
/* set each entry to its logarithm value */ /* set each entry to its logarithm value */
void _CudaLog(const XTensor * a, XTensor * b); void _CudaLog(const XTensor * a, XTensor * b);
/* set each entry to its round value (CUDA Kernel) */ /* set each entry to its negative value */
__global__ void _CudaNegate(const XTensor * a, XTensor * b);
void KernelRound(DTYPE * a, DTYPE * b, int size);
/* set each entry to its round value (CUDA Kernel) with float16 data type*/
__global__
void KernelRound(__half * a, __half * b, int size);
/* set each entry to its round value */ /* set each entry to its round value */
void _CudaRound(const XTensor * a, XTensor * b); void _CudaRound(const XTensor * a, XTensor * b);
/* set each entry to its sqrt value (CUDA Kernel) */ /* set each entry to its sign value */
__global__ void _CudaSign(const XTensor * a, XTensor * b);
void KernelSqrt(DTYPE * a, DTYPE * b, int size);
/* set each entry to its sqrt value (CUDA Kernel) with float16 data type*/
__global__
void KernelSqrt(__half * a, __half * b, int size);
/* set each entry to its sqrt value */ /* set each entry to its sqrt value */
void _CudaSqrt(const XTensor * a, XTensor * b); void _CudaSqrt(const XTensor * a, XTensor * b);
/* set each entry to its square value (CUDA Kernel) */
__global__
void KernelSquare(DTYPE * a, DTYPE * b, int size);
/* set each entry to its square value (CUDA Kernel) with float16 data type*/
__global__
void KernelSquare(__half * a, __half * b, int size);
/* set each entry to its square value */ /* set each entry to its square value */
void _CudaSquare(const XTensor * a, XTensor * b); void _CudaSquare(const XTensor * a, XTensor * b);
/* set each entry to its sine value (CUDA Kernel) */
__global__
void KernelSin(DTYPE * a, DTYPE * b, int size);
/* set each entry to its sine value (CUDA Kernel) with float16 data type*/
__global__
void KernelSin(__half * a, __half * b, int size);
/* set each entry to its sine value */ /* set each entry to its sine value */
void _CudaSin(const XTensor * a, XTensor * b); void _CudaSin(const XTensor * a, XTensor * b);
/* set each entry to its cosine value (CUDA Kernel) */
__global__
void KernelCos(DTYPE * a, DTYPE * b, int size);
/* set each entry to its cosine value (CUDA Kernel) with float16 data type*/
__global__
void KernelCos(__half * a, __half * b, int size);
/* set each entry to its cosine value */ /* set each entry to its cosine value */
void _CudaCos(const XTensor * a, XTensor * b); void _CudaCos(const XTensor * a, XTensor * b);
/* set each entry to its tangent value (CUDA Kernel) */
__global__
void KernelTan(DTYPE * a, DTYPE * b, int size);
/* set each entry to its tangent value (CUDA Kernel) with float16 data type*/
__global__
void KernelTan(__half * a, __half * b, int size);
/* set each entry to its tangent value */ /* set each entry to its tangent value */
void _CudaTan(const XTensor * a, XTensor * b); void _CudaTan(const XTensor * a, XTensor * b);
......
...@@ -124,6 +124,20 @@ XTensor Log(const XTensor & a); ...@@ -124,6 +124,20 @@ XTensor Log(const XTensor & a);
/* set every entry to its logarithm value */ /* set every entry to its logarithm value */
void Log(const XTensor & a, XTensor & b); void Log(const XTensor & a, XTensor & b);
/* set every entry to its negative value */
void _Negate(const XTensor * a, XTensor * b);
/* set every entry to its negative value (do it on site)
keep the result in the input tensor a and return nothing */
void _NegateMe(XTensor * a);
/* set every entry to its negative value (do it on site)
keep the result in the input tensor a and return nothing */
void NegateMe(XTensor & a);
/* set every entry to its negative value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Negate(const XTensor & a);
/* set every entry to its negative value */
void Negate(const XTensor & a, XTensor & b);
/* set every entry to its round value */ /* set every entry to its round value */
void _Round(const XTensor * a, XTensor * b); void _Round(const XTensor * a, XTensor * b);
/* set every entry to its round value (do it on site) /* set every entry to its round value (do it on site)
...@@ -138,6 +152,20 @@ XTensor Round(const XTensor & a); ...@@ -138,6 +152,20 @@ XTensor Round(const XTensor & a);
/* set every entry to its round value */ /* set every entry to its round value */
void Round(const XTensor & a, XTensor & b); void Round(const XTensor & a, XTensor & b);
/* set every entry to its sign value */
void _Sign(const XTensor * a, XTensor * b);
/* set every entry to its sign value (do it on site)
keep the result in the input tensor a and return nothing */
void _SignMe(XTensor * a);
/* set every entry to its sign value (do it on site)
keep the result in the input tensor a and return nothing */
void SignMe(XTensor & a);
/* set every entry to its sign value (return an XTensor structure)
make a new tensor to keep the result and return it */
XTensor Sign(const XTensor & a);
/* set every entry to its sign value */
void Sign(const XTensor & a, XTensor & b);
/* set every entry to its sqrt value */ /* set every entry to its sqrt value */
void _Sqrt(const XTensor * a, XTensor * b); void _Sqrt(const XTensor * a, XTensor * b);
/* set every entry to its sqrt value (do it on site) /* set every entry to its sqrt value (do it on site)
...@@ -166,7 +194,6 @@ XTensor Square(const XTensor & a); ...@@ -166,7 +194,6 @@ XTensor Square(const XTensor & a);
/* set every entry to its square value */ /* set every entry to its square value */
void Square(const XTensor & a, XTensor & b); void Square(const XTensor & a, XTensor & b);
/* set every entry to its sine value */ /* set every entry to its sine value */
void _Sin(const XTensor * a, XTensor * b); void _Sin(const XTensor * a, XTensor * b);
/* set every entry to its sine value (do it on site) /* set every entry to its sine value (do it on site)
......
...@@ -189,6 +189,29 @@ void _CopyIndexed(const XTensor * s, XTensor * t, int dim, ...@@ -189,6 +189,29 @@ void _CopyIndexed(const XTensor * s, XTensor * t, int dim,
} }
} }
/*
copy selected sub-tensors
>> s - the source tensor
>> t - the target tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 2,
we have 4 sub-tensors of size (3, 2)
>> srcIndex - the tensor to save the index of the source sub-tensors
>> copyNum - number of the sub-tensors we copy for each source index,
e.g., for srcIndex = [1,4] and copyNum = 2,
we actually copy the source sub-tensors 1, 2, 4, 5
*/
void _CopyIndexed(const XTensor * s, XTensor * t, int dim,
const XTensor * srcIndex, int copyNum)
{
XTensor * tgtIndex = NewTensor(srcIndex);
tgtIndex->SetAscendingOrder(0);
_CopyIndexed(s, t, dim, srcIndex, tgtIndex, copyNum);
delete tgtIndex;
}
/* /*
copy selected sub-tensors where indeces are kept in tensors (return an XTensor structure) copy selected sub-tensors where indeces are kept in tensors (return an XTensor structure)
make a new tensor to keep the result and return it make a new tensor to keep the result and return it
......
...@@ -31,16 +31,14 @@ void _CopyIndexed(const XTensor * s, XTensor * t, int dim, ...@@ -31,16 +31,14 @@ void _CopyIndexed(const XTensor * s, XTensor * t, int dim,
int * srcIndex, int indexSize, int * tgtIndex, int * srcIndex, int indexSize, int * tgtIndex,
int copyNum = 1); int copyNum = 1);
/* copy selected sub-tensors where indeces are kept in tensors */ /* copy selected sub-tensors */
void _CopyIndexed(const XTensor * s, XTensor * t, int dim, void _CopyIndexed(const XTensor * s, XTensor * t, int dim,
const XTensor * srcIndex, const XTensor * tgtIndex, const XTensor * srcIndex, const XTensor * tgtIndex,
int copyNum = 1); int copyNum = 1);
/* /* copy selected sub-tensors */
copy selected sub-tensors (return a XTensor structure) void _CopyIndexed(const XTensor * s, XTensor * t, int dim,
make a new tensor to keep the result and return it (remove this???) const XTensor * srcIndex, int copyNum = 1);
*/
//XTensor CopyIndexed(const XTensor &s, int dim, int * srcIndex, int indexSize, int * tgtIndex, int copyNum);
/* /*
copy selected sub-tensors where indeces are kept in tensors (return an XTensor structure) copy selected sub-tensors where indeces are kept in tensors (return an XTensor structure)
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "../../XUtility.h" #include "../../XUtility.h"
#include "CopyValues.h" #include "CopyValues.h"
#include "CopyValues.cuh" #include "CopyValues.cuh"
#include "../getandset/ConvertDataType.h"
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
......
...@@ -52,15 +52,15 @@ void _CudaCopyValues(const XTensor * s, XTensor * t, XStream * stream) ...@@ -52,15 +52,15 @@ void _CudaCopyValues(const XTensor * s, XTensor * t, XStream * stream)
} }
/* dense -> sparse */ /* dense -> sparse */
else if (!s->isSparse && t->isSparse && else if (!s->isSparse && t->isSparse &&
s->dataType == DEFAULT_DTYPE && s->dataType == DEFAULT_DTYPE &&
t->dataType == DEFAULT_DTYPE) t->dataType == DEFAULT_DTYPE)
{ {
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
} }
/* sparse -> dense */ /* sparse -> dense */
else if (s->isSparse && !t->isSparse && else if (s->isSparse && !t->isSparse &&
s->dataType == DEFAULT_DTYPE && s->dataType == DEFAULT_DTYPE &&
t->dataType == DEFAULT_DTYPE) t->dataType == DEFAULT_DTYPE)
{ {
ShowNTErrors("TODO!"); ShowNTErrors("TODO!");
} }
......
...@@ -33,28 +33,6 @@ gather indexed sub-tensors ...@@ -33,28 +33,6 @@ gather indexed sub-tensors
>> s - the source tensor >> s - the source tensor
>> t - the target tensor >> t - the target tensor
>> dim - the leading dimension to define "sub-tensors"
e.g., for a tensor of size (3, 2, 4) and dim = 0,
we have 3 sub-tensors of size (2, 4)
>> srcIndex - index of the source sub-tensors
>> indexSize - length of srcIndex (and tgtIndex)
*/
void _Gather(XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize)
{
int * tgtIndex = new int[indexSize];
for(int i = 0; i < indexSize; i++)
tgtIndex[i] = i;
_CopyIndexed(s, t, dim, srcIndex, indexSize, tgtIndex, 1);
delete[] tgtIndex;
}
/*
gather indexed sub-tensors
>> s - the source tensor
>> t - the target tensor
>> srcIndex - the tensor to save the index of the source tensor >> srcIndex - the tensor to save the index of the source tensor
*/ */
void _Gather(const XTensor * s, XTensor * t, XTensor * srcIndex) void _Gather(const XTensor * s, XTensor * t, XTensor * srcIndex)
...@@ -101,15 +79,10 @@ XTensor Gather(XTensor &s, XTensor &index) ...@@ -101,15 +79,10 @@ XTensor Gather(XTensor &s, XTensor &index)
CheckNTErrors(s.order == 2, "The order of the input tensor must be 2!"); CheckNTErrors(s.order == 2, "The order of the input tensor must be 2!");
int order = s.order; int order = index.order + 1;
int * dimSize = new int[order]; int * dimSize = new int[order];
memcpy(dimSize, index.dimSize, index.order * sizeof(int));
for (int i = 0; i < s.order; i++) { dimSize[index.order] = s.GetDim(-1);
if (i == dim)
dimSize[i] = index.unitNum;
else
dimSize[i] = s.dimSize[i];
}
float dr = (!s.isSparse) ? 1.0F : s.denseRatio; float dr = (!s.isSparse) ? 1.0F : s.denseRatio;
XTensor t(order, dimSize, s.dataType, dr, s.devID, s.mem); XTensor t(order, dimSize, s.dataType, dr, s.devID, s.mem);
...@@ -122,20 +95,7 @@ XTensor Gather(XTensor &s, XTensor &index) ...@@ -122,20 +95,7 @@ XTensor Gather(XTensor &s, XTensor &index)
/* tensor connection */ /* tensor connection */
XLink::MakeLink(&s, &index, &t, MOVEMENT_GATHER); XLink::MakeLink(&s, &index, &t, MOVEMENT_GATHER);
if(index.order > 1) { return t;
int * dims = new int[index.order + 1];
memcpy(dims, index.dimSize, index.order * sizeof(int));
dims[index.order] = t.GetDim(-1);
XTensor tt;
tt = Reshape(t, index.order + 1, dims);
delete[] dims;
return tt;
}
else {
return t;
}
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
...@@ -27,9 +27,6 @@ ...@@ -27,9 +27,6 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* gather selected sub-tensors */ /* gather selected sub-tensors */
void _Gather(XTensor * s, XTensor * t, int dim, int * srcIndex, int indexSize);
/* gather selected sub-tensors */
void _Gather(const XTensor * s, XTensor * t, XTensor * srcIndex); void _Gather(const XTensor * s, XTensor * t, XTensor * srcIndex);
/* gather selected sub-tensors (return an XTensor structure) /* gather selected sub-tensors (return an XTensor structure)
......
...@@ -219,7 +219,6 @@ void _SpreadForCopyIndexed(XTensor * s, XTensor * c, int dim, ...@@ -219,7 +219,6 @@ void _SpreadForCopyIndexed(XTensor * s, XTensor * c, int dim,
} }
} }
} }
/* /*
...@@ -236,15 +235,18 @@ void _SpreadForGather(XTensor * source, XTensor * collection, XTensor * index) ...@@ -236,15 +235,18 @@ void _SpreadForGather(XTensor * source, XTensor * collection, XTensor * index)
int order = source->order; int order = source->order;
CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!"); CheckNTErrors(source->dataType == DEFAULT_DTYPE, "TODO!");
CheckNTErrors(collection->GetDim(-1) == source->GetDim(-1), "Illegal dimension!");
CheckNTErrors(collection->unitNum/collection->GetDim(-1) == index->unitNum,
"Illegal dimension!");
for(int i = 0; i < order; i++){ //for(int i = 0; i < order; i++){
if(i == dim){ // if(i == dim){
CheckNTErrors(collection->GetDim(i) == index->unitNum, "Illegal dimension!"); // CheckNTErrors(collection->GetDim(i) == index->unitNum, "Illegal dimension!");
} // }
else { // else {
CheckNTErrors(collection->GetDim(i) == source->GetDim(i), "Illegal dimension!"); // CheckNTErrors(collection->GetDim(i) == source->GetDim(i), "Illegal dimension!");
} // }
} //}
#ifdef USE_CUDA #ifdef USE_CUDA
if(source->devID >= 0 && collection->devID >= 0) { if(source->devID >= 0 && collection->devID >= 0) {
......
...@@ -137,6 +137,115 @@ XTensor Concatenate(const TensorList &smalls, int dim) ...@@ -137,6 +137,115 @@ XTensor Concatenate(const TensorList &smalls, int dim)
} }
} }
bool CheckConcatenateShape(const TensorList &smalls, int dim, XTensor &big, bool uniform)
{
XTensor * tensor = (XTensor*)smalls.GetItem(0);
int order = tensor->order;
int * dimSize = new int[order];
if (uniform) {
for (int i = 0; i < tensor->order; i++) {
if (i != dim)
dimSize[i] = tensor->dimSize[i];
else
dimSize[i] = tensor->dimSize[dim] * smalls.count;
}
}
else {
for (int i = 0; i < tensor->order; i++)
if (i != dim)
dimSize[i] = tensor->dimSize[i];
int catDimSize = 0;
for (int i = 0; i < smalls.count; i++) {
XTensor * tensor = (XTensor*)smalls.GetItem(i);
catDimSize += tensor->dimSize[dim];
}
dimSize[dim] = catDimSize;
}
for (int i = 0; i < order; i++) {
if (dimSize[i] != big.dimSize[i]) {
delete[] dimSize;
return false;
}
}
delete[] dimSize;
return false;
}
void Concatenate(const TensorList & smalls, XTensor & big, int dim)
{
CheckNTErrors(smalls.count > 0, "Empty list!");
CheckNTErrors(dim >= 0, "Illegal dimension to concatenate!");
bool uniform = true;
for (int i = 1; i < smalls.count; i++) {
XTensor * a = (XTensor*)smalls.GetItem(i - 1);
XTensor * b = (XTensor*)smalls.GetItem(i);
CheckNTErrors((a && b), "Empty input tensors!");
if (!XTensor::IsSameShaped(a, b))
uniform = false;
}
if (!big.isInit || !CheckConcatenateShape(smalls, dim, big, uniform)) {
XTensor * tensor = (XTensor*)smalls.GetItem(0);
int order = tensor->order;
int * dimSize = new int[order];
if (uniform) {
for (int i = 0; i < tensor->order; i++) {
if (i != dim)
dimSize[i] = tensor->dimSize[i];
else
dimSize[i] = tensor->dimSize[dim] * smalls.count;
}
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
InitTensor(&big, order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
}
else {
for (int i = 0; i < tensor->order; i++)
if (i != dim)
dimSize[i] = tensor->dimSize[i];
int catDimSize = 0;
for (int i = 0; i < smalls.count; i++) {
XTensor * tensor = (XTensor*)smalls.GetItem(i);
catDimSize += tensor->dimSize[dim];
}
dimSize[dim] = catDimSize;
float dr = (!tensor->isSparse) ? 1.0F : tensor->denseRatio;
InitTensor(&big, order, dimSize, tensor->dataType, dr, tensor->devID, tensor->mem);
}
/* destroy variables */
delete[] dimSize;
}
if (uniform) {
/* call _Merge function */
_Merge(&smalls, &big, dim);
/* tensor connection */
if (big.enableGrad) {
XLink::MakeLink(&smalls, &big, SHAPE_MERGE);
XLink::AddParamToHeadInt(&big, dim);
}
}
else {
/* call _ConcatenateSolely function */
_ConcatenateSolely(&smalls, &big, dim);
/* tensor connection */
if (big.enableGrad) {
XLink::MakeLink(&smalls, &big, SHAPE_CONCATENATE);
XLink::AddParamToHeadInt(&big, dim);
}
}
}
/* /*
concatenate two tensors along a given dimension concatenate two tensors along a given dimension
......
...@@ -41,6 +41,8 @@ Note that this is actually a wrapper that selects ...@@ -41,6 +41,8 @@ Note that this is actually a wrapper that selects
*/ */
XTensor Concatenate(const TensorList &smalls, int dim); XTensor Concatenate(const TensorList &smalls, int dim);
void Concatenate(const TensorList & smalls, XTensor & big, int dim);
/* concatenate two tensors along a given dimension */ /* concatenate two tensors along a given dimension */
void _Concatenate(const XTensor * smallA, const XTensor * smallB, XTensor * big, int dim); void _Concatenate(const XTensor * smallA, const XTensor * smallB, XTensor * big, int dim);
......
...@@ -273,16 +273,16 @@ void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim) ...@@ -273,16 +273,16 @@ void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim)
merge small tensors into a big tensor merge small tensors into a big tensor
>> smalls - the list of the small tensors >> smalls - the list of the small tensors
>> big - the merged tensor (for return) >> t - the merged tensor (for return)
>> whereToMerge - the merging operation is along with which dimension >> whereToMerge - the merging operation is along with which dimension
*/ */
void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge) void _Merge(const TensorList * smalls, XTensor * t, int whereToMerge)
{ {
whereToMerge = (whereToMerge < 0 ? big->order - 1 : whereToMerge); whereToMerge = (whereToMerge < 0 ? t->order - 1 : whereToMerge);
CheckNTErrors((smalls != NULL), "Invalid list!"); CheckNTErrors((smalls != NULL), "Invalid list!");
CheckNTErrors((smalls->count > 0), "Empty list!"); CheckNTErrors((smalls->count > 0), "Empty list!");
CheckNTErrors((whereToMerge >= 0 && whereToMerge < big->order), "Wrong range of whereToMerge"); CheckNTErrors((whereToMerge >= 0 && whereToMerge < t->order), "Wrong range of whereToMerge");
bool uniform = true; bool uniform = true;
...@@ -292,7 +292,7 @@ void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge) ...@@ -292,7 +292,7 @@ void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge)
for (int i = 0; i < smalls->count; i++) { for (int i = 0; i < smalls->count; i++) {
XTensor* smallsItem = smalls->GetItem(i); XTensor* smallsItem = smalls->GetItem(i);
CheckNTErrors((big->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!"); CheckNTErrors((t->unitNum == smallsItem->unitNum * mergeNum), "Unmatched tensors!");
if (i > 0) { if (i > 0) {
XTensor * preItem = smalls->GetItem(i - 1); XTensor * preItem = smalls->GetItem(i - 1);
...@@ -325,17 +325,17 @@ void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge) ...@@ -325,17 +325,17 @@ void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge)
/* merging with fewer data copy operations */ /* merging with fewer data copy operations */
if (mergedNum * gridNum <= MIN_TENSOR_MERGE_LIST_NUM) { if (mergedNum * gridNum <= MIN_TENSOR_MERGE_LIST_NUM) {
int sPitch = blockSize * s0->unitSize; int sPitch = blockSize * s0->unitSize;
int tPtich = blockSize * mergedNum * big->unitSize; int tPtich = blockSize * mergedNum * t->unitSize;
int mSize = blockSize * big->unitSize; int mSize = blockSize * t->unitSize;
int n = blockNum; int n = blockNum;
int sStep = 0; int sStep = 0;
int tStep = blockSize * big->unitSize; int tStep = blockSize * t->unitSize;
for (int g = 0; g < gridNum; g++) { for (int g = 0; g < gridNum; g++) {
char * tData = (char*)big->data + g * blockSize * blockNum * big->unitSize; char * tData = (char*)t->data + g * blockSize * blockNum * t->unitSize;
for (int k = 0; k < mergedNum; k++) { for (int k = 0; k < mergedNum; k++) {
XTensor * s = smalls->GetItem(k); XTensor * s = smalls->GetItem(k);
char * sData = (char*)s->data + g * blockSize * blockNum * s->unitSize; char * sData = (char*)s->data + g * blockSize * blockNum * s->unitSize;
XMemCopy2D(tData + k * tStep, tPtich, big->devID, XMemCopy2D(tData + k * tStep, tPtich, t->devID,
sData + k * sStep, sPitch, s->devID, sData + k * sStep, sPitch, s->devID,
mSize, n); mSize, n);
} }
...@@ -358,7 +358,7 @@ void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge) ...@@ -358,7 +358,7 @@ void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge)
if (uniform) if (uniform)
dataTMP = smallsItem0->data; dataTMP = smallsItem0->data;
else else
dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(big->devID, size); dataTMP = mem != NULL ? mem->AllocBuf(mem->devID, size) : XMemAlloc(t->devID, size);
tensorTMP->data = dataTMP; tensorTMP->data = dataTMP;
...@@ -370,7 +370,7 @@ void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge) ...@@ -370,7 +370,7 @@ void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge)
} }
} }
_Merge(tensorTMP, big, whereToMerge + 1); _Merge(tensorTMP, t, whereToMerge + 1);
delete[] dimSizeTMP; delete[] dimSizeTMP;
...@@ -380,7 +380,7 @@ void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge) ...@@ -380,7 +380,7 @@ void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge)
if ((!uniform) && (mem != NULL)) if ((!uniform) && (mem != NULL))
mem->ReleaseBuf(mem->devID, size); mem->ReleaseBuf(mem->devID, size);
else else
XMemFree(big->devID, dataTMP); XMemFree(t->devID, dataTMP);
} }
} }
......
...@@ -36,7 +36,7 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim = -1); ...@@ -36,7 +36,7 @@ XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim = -1);
void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim = -1); void Merge(const XTensor &s, XTensor &t, int whereToMerge, int leadingDim = -1);
/* merge small tensors into a big tensor */ /* merge small tensors into a big tensor */
void _Merge(const TensorList * smalls, XTensor * big, int whereToMerge); void _Merge(const TensorList * smalls, XTensor * t, int whereToMerge);
/* merge small tensors into a big tensor (return an XTensor structure) */ /* merge small tensors into a big tensor (return an XTensor structure) */
XTensor Merge(const TensorList &smalls, int whereToMerge); XTensor Merge(const TensorList &smalls, int whereToMerge);
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* /*
transform a tensor by splitting it, e.g., (N, M) -> (N/3, M, 3) transform a tensor by splitting it, e.g., (N, M) -> (3, N/3, M)
>> s - the source tensor >> s - the source tensor
>> t - the target tensor (for return) >> t - the target tensor (for return)
...@@ -61,7 +61,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum) ...@@ -61,7 +61,7 @@ void _Split(const XTensor * s, XTensor * t, int whereToSplit, int splitNum)
} }
/* for the case that we split the last dimension. Actually /* for the case that we split the last dimension. Actually
(N, M) and (N, M/3, 3) have the same memory layout */ (N, M) and (3, N/3, M) have the same memory layout */
if (s->order - 1 == whereToSplitRDI) { if (s->order - 1 == whereToSplitRDI) {
XMemCopy(t->data, t->devID, s->data, s->devID, s->unitNum * s->unitSize); XMemCopy(t->data, t->devID, s->data, s->devID, s->unitNum * s->unitSize);
return; return;
...@@ -184,7 +184,7 @@ bool CheckSplitSize(const XTensor * s, const XTensor * t, int whereToSplit, int ...@@ -184,7 +184,7 @@ bool CheckSplitSize(const XTensor * s, const XTensor * t, int whereToSplit, int
} }
/* /*
transform a tensor by splitting it, e.g., (N, M) -> (N/3, M, 3) (return an XTensor structure) transform a tensor by splitting it, e.g., (N, M) -> (3, N/3, M) (return an XTensor structure)
make a new tensor to keep the result and return it make a new tensor to keep the result and return it
>> s - the source tensor >> s - the source tensor
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "../XTensor.h" #include "../XTensor.h"
#include "Dropout.h" #include "Dropout.h"
#include "DropoutWithIndex.h"
#include "HardTanH.h" #include "HardTanH.h"
#include "Identity.h" #include "Identity.h"
#include "LogSoftmax.h" #include "LogSoftmax.h"
......
...@@ -16,14 +16,13 @@ ...@@ -16,14 +16,13 @@
*/ */
/* /*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-25 * $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-25
*/ */
#include <stdlib.h> #include <stdlib.h>
#include "../XName.h" #include "../XName.h"
#include "HardTanH.h" #include "HardTanH.h"
#include "HardTanH.cuh" #include "HardTanH.cuh"
#include "../loss/LHeader.h"
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
...@@ -37,27 +36,27 @@ y = 1 if x > 1 ...@@ -37,27 +36,27 @@ y = 1 if x > 1
*/ */
void _HardTanH(const XTensor * x, XTensor * y) void _HardTanH(const XTensor * x, XTensor * y)
{ {
CheckNTErrors(XTensor::IsSameShaped(x, y),
"The input tensor and output tensor must have the same shape!")
#ifdef USE_CUDA #ifdef USE_CUDA
if(x->devID >= 0 || y->devID >= 0){ if(x->devID >= 0 || y->devID >= 0){
_CudaHardTanH(x, y); _CudaHardTanH(x, y);
return; return;
} }
#endif #endif
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){
int n = x->GetSize(); int n = x->GetSize();
DTYPE * ip = (DTYPE*)x->data; DTYPE * ip = (DTYPE*)x->data;
DTYPE * op = (DTYPE*)y->data; DTYPE * op = (DTYPE*)y->data;
for(int i = 0; i < n; i++){ for(int i = 0; i < n; i++){
DTYPE p = ip[i]; DTYPE p = ip[i];
if(p > 1.0) if(p > 1.0)
p = 1.0; p = 1.0;
else if(p < -1.0) else if(p < -1.0)
p = -1.0; p = -1.0;
op[i] = p; op[i] = p;
}
} }
else
ShowNTErrors("TODO!");
} }
/* /*
...@@ -111,50 +110,36 @@ hard tanh: y = 1 if x > 1 ...@@ -111,50 +110,36 @@ hard tanh: y = 1 if x > 1
and dy/dx = 1 if -1 <= x <= 1 and dy/dx = 1 if -1 <= x <= 1
0 otherwise 0 otherwise
>> gold - gold standard to measure error (or loss) >> y - output of the hardtanh function
>> y - output of the function >> x - input of the hardtanh function
>> x - input of the function
>> dedy - dE/dy >> dedy - dE/dy
>> dedx - dE/dx >> dedx - dE/dx
>> lossName - type of loss function, e.g., cross entropy
*/ */
void _HardTanHBackward(XTensor * gold, XTensor * y, XTensor * x, void _HardTanHBackward(XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx, XTensor * dedy, XTensor * dedx)
LOSS_FUNCTION_NAME lossName)
{ {
CheckNTErrors((gold == NULL || XTensor::IsSameShaped(gold, y)), CheckNTErrors(x != NULL, "The input tensor x must be not NULL!")
"The tensors must be of the same size!");
#ifdef USE_CUDA #ifdef USE_CUDA
if(x->devID >= 0 || y->devID >= 0){ if(x->devID >= 0){
_CudaHardTanHBackward(gold, y, x, dedy, dedx, lossName); _CudaHardTanHBackward(y, x, dedy, dedx);
return; return;
} }
#endif #endif
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){ DTYPE * dedyp = (DTYPE*)dedy->data;
/* calculate dE/dy */ DTYPE * dedxp = (DTYPE*)dedx->data;
if(lossName == CROSSENTROPY) DTYPE * ip = (DTYPE*)x->data;
_CrossEntropyBackward(dedy, y, gold); int size = x->unitNum;
else if(lossName != NOLOSS)
_LossBackward(dedy, gold, y, lossName); /* dE/dx = dE/dy * dy/dx */
for(int i = 0; i < size; i++){
DTYPE * dedyp = (DTYPE*)dedy->data; DTYPE s =ip[i];
DTYPE * dedxp = (DTYPE*)dedx->data; if(s > 1.0 || s < -1.0)
DTYPE * ip = (DTYPE*)x->data; dedxp[i] = 0;
int size = y->unitNum; else
dedxp[i] = dedyp[i];
/* dE/dx = dE/dy * dy/dx */
for(int i = 0; i < size; i++){
DTYPE s =ip[i];
if(s > 1.0 || s < -1.0)
dedxp[i] = 0;
else
dedxp[i] = dedyp[i];
}
} }
else
ShowNTErrors("TODO!");
} }
} // namespace nts(NiuTrans.Tensor) } // namespace nts(NiuTrans.Tensor)
...@@ -21,8 +21,6 @@ ...@@ -21,8 +21,6 @@
#include "HardTanH.h" #include "HardTanH.h"
#include "HardTanH.cuh" #include "HardTanH.cuh"
#include "Loss.cuh"
#include "../loss/CrossEntropy.cuh"
#include "../XDevice.h" #include "../XDevice.h"
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
...@@ -63,25 +61,19 @@ y = 1 if x > 1 ...@@ -63,25 +61,19 @@ y = 1 if x > 1
*/ */
void _CudaHardTanH(const XTensor * x, XTensor * y) void _CudaHardTanH(const XTensor * x, XTensor * y)
{ {
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){ CheckNTErrors(!x->isSparse && !y->isSparse,
"The hard tanh activation function does not support sparse tensors.");
CheckNTErrors(!x->isSparse && !y->isSparse, "The hard tanh activation function does not support sparse tensors."); int gridSize[3], blockSize[3];
CheckNTErrors(x->unitNum && y->unitNum, "The x vectors must be of the same length.");
int gridSize[3], blockSize[3]; GDevs.GetCudaThread(x->devID, x->unitNum, gridSize, blockSize);
GDevs.GetCudaThread(x->devID, x->unitNum, gridSize, blockSize); int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup);
int devIDBackup; KernelHardtanhCompute<<<dim3(gridSize[0]), dim3(blockSize[0])>>>((DTYPE*)x->data, (DTYPE*)y->data, x->unitNum);
ProtectCudaDev(x->devID, devIDBackup);
KernelHardtanhCompute<<<dim3(gridSize[0]), dim3(blockSize[0])>>>((DTYPE*)x->data, (DTYPE*)y->data, x->unitNum); BacktoCudaDev(x->devID, devIDBackup);
BacktoCudaDev(x->devID, devIDBackup);
}
else{
ShowNTErrors("TODO!");
}
} }
/* /*
...@@ -92,13 +84,12 @@ dy/dx = 1 if -1 <= x <= 1 ...@@ -92,13 +84,12 @@ dy/dx = 1 if -1 <= x <= 1
>> dedy - dE/dy >> dedy - dE/dy
>> dedx - dE/dx >> dedx - dE/dx
>> gold - gold standard
>> y - y of the function >> y - y of the function
>> x - x of the function >> x - x of the function
>> size - size of y/x >> size - size of y/x
*/ */
__global__ __global__
void KernelHardtanhBackward(DTYPE * dedy, DTYPE * dedx, DTYPE * gold, DTYPE * y, DTYPE * x, int size) void KernelHardtanhBackward(DTYPE * dedy, DTYPE * dedx, DTYPE * x, int size)
{ {
int i = blockDim.x * blockIdx.x + threadIdx.x; int i = blockDim.x * blockIdx.x + threadIdx.x;
...@@ -123,44 +114,29 @@ hard tanh: y = 1 if x > 1 ...@@ -123,44 +114,29 @@ hard tanh: y = 1 if x > 1
and dy/dx = 1 if -1 <= x <= 1 and dy/dx = 1 if -1 <= x <= 1
0 otherwise 0 otherwise
>> gold - gold standard to measure error (or loss) >> y - output of the hardtanh function
>> y - output of the function >> x - input of the hardtanh function
>> x - input of the function
>> dedy - dE/dy >> dedy - dE/dy
>> dedx - dE/dx >> dedx - dE/dx
>> lossName - type of loss function, e.g., cross entropy
*/ */
void _CudaHardTanHBackward(XTensor * gold, XTensor * y, XTensor * x, void _CudaHardTanHBackward(XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx, XTensor * dedy, XTensor * dedx)
LOSS_FUNCTION_NAME lossName)
{ {
if(x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE){ int gridSize[3], blockSize[3];
/* calculate dE/dy */ GDevs.GetCudaThread(x->devID, x->unitNum, gridSize, blockSize);
if(lossName == CROSSENTROPY)
_CudaCrossEntropyBackward(dedy, y, gold);
else if(lossName != NOLOSS)
_CudaLossBackward(dedy, gold, y, lossName);
int gridSize[3], blockSize[3]; int devIDBackup;
ProtectCudaDev(x->devID, devIDBackup);
GDevs.GetCudaThread(x->devID, x->unitNum, gridSize, blockSize); /* dE/dx = dE/dy * dy/dx */
KernelHardtanhBackward<<<dim3(gridSize[0]),dim3(blockSize[0])>>>
((DTYPE*)dedy->data,
(DTYPE*)dedx->data,
(DTYPE*)x->data,
x->unitNum);
int devIDBackup; BacktoCudaDev(x->devID, devIDBackup);
ProtectCudaDev(x->devID, devIDBackup);
/* dE/dx = dE/dy * dy/dx */
KernelHardtanhBackward<<<dim3(gridSize[0]),dim3(blockSize[0])>>>
((DTYPE*)dedy->data,
(DTYPE*)dedx->data,
gold == NULL ? NULL : (DTYPE*)gold->data,
(DTYPE*)y->data, (DTYPE*)x->data,
x->unitNum);
BacktoCudaDev(x->devID, devIDBackup);
}
else
ShowNTErrors("TODO!");
} }
#endif #endif
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#define __HARDTANH_CUH__ #define __HARDTANH_CUH__
#include "../XTensor.h" #include "../XTensor.h"
#include "Loss.h"
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
...@@ -38,9 +37,8 @@ y = 1 if x > 1 ...@@ -38,9 +37,8 @@ y = 1 if x > 1
void _CudaHardTanH(const XTensor * input, XTensor * output); void _CudaHardTanH(const XTensor * input, XTensor * output);
/* de/dx (Cuda version) */ /* de/dx (Cuda version) */
void _CudaHardTanHBackward(XTensor * gold, XTensor * y, XTensor * x, void _CudaHardTanHBackward(XTensor * y, XTensor * x,
XTensor * dedy, XTensor * dedx, XTensor * dedy, XTensor * dedx);
LOSS_FUNCTION_NAME lossName);
#endif // USE_CUDA #endif // USE_CUDA
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论