Commit 430f0dfc by xiaotong

add the decoder for transformer

parent 7250ec45
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-10-09
*/
#include <math.h>
#include "T2TDecoder.h"
#include "../../tensor/core/CHeader.h"
namespace transformer
{
/* constructor */
AttDecoder::AttDecoder()
{
attentionsEnde = NULL;
attEndeLayerNorms = NULL;
}
/* de-constructor */
AttDecoder::~AttDecoder()
{
delete[] attentionsEnde;
delete[] attEndeLayerNorms;
}
/*
initialize the model
>> argc - number of arguments
>> argv - list of pointers to the arguments
>> myIsMasked - indicates whether the masked attention is employed
>> myIgnored - number of positions ignored in attention (from the start)
>> myDevID - device id
>> myMem - the memory pool
*/
void AttDecoder::InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID, XMem * myMem)
{
AttEncoder::InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
attentionsEnde = new T2TAttention[nlayer];
attEndeLayerNorms = new T2TLN[nlayer];
/* initialize the stacked layers */
for(int i = 0; i < nlayer; i++){
attentionsEnde[i].InitModel(argc, argv, myIsMasked, myIgnored, myDevID, myMem);
attEndeLayerNorms[i].InitModel(argc, argv, myDevID, myMem);
}
}
/*
make the decoding network
>> input - the input tensor of the decoder
>> encoderOutput - the output tensor of the encoder
>> mask - the mask that indicate each position is valid
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttDecoder::Make(XTensor &input, XTensor &encoderOutput, XTensor &mask, bool isTraining)
{
XTensor x;
x = embedder.Make(input);
/* dropout */
if(isTraining && dropoutP > 0)
x = Dropout(x, dropoutP);
for(int i = 0; i < nlayer; i++){
XTensor att;
XTensor ende;
XTensor ln;
XTensor fnn;
XTensor res;
/******************/
/* self attention */
att = attentions[i].Make(x, x, x, mask, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
att = Dropout(att, dropoutP);
/* residual connection */
res = Sum(att, x);
/* layer normalization */
x = attLayerNorms[i].Make(res);
/*****************************/
/* encoder-decoder attention */
ende = attentionsEnde[i].Make(encoderOutput, x, encoderOutput, mask, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
ende = Dropout(ende, dropoutP);
/* residual connection */
res = Sum(ende, x);
/* layer normalization */
x = attEndeLayerNorms[i].Make(res);
/*******/
/* fnn */
fnn = fnns[i].Make(x, isTraining);
/* dropout */
if(isTraining && dropoutP > 0)
fnn = Dropout(fnn, dropoutP);
/* residual connection */
res = Sum(fnn, x);
/* layer normalization */
x = fnnLayerNorms[i].Make(res);
}
return x;
}
}
\ No newline at end of file
......@@ -22,19 +22,33 @@
#ifndef __T2TDECODER_H__
#define __T2TDECODER_H__
#include "T2TEncoder.h"
namespace transformer
{
class T2TDecoder
class AttDecoder : public AttEncoder
{
public:
/* encoder-decoder attention model of each layer */
T2TAttention * attentionsEnde;
};
class AttDecoder : T2TDecoder
{
/* layer normalization for encoder-decoder attention */
T2TLN * attEndeLayerNorms;
public:
/* constructor */
AttDecoder();
/* deconstructor */
~AttDecoder();
/* initialize the model */
void InitModel(int argc, char ** argv);
void InitModel(int argc, char ** argv,
bool myIsMasked, int myIgnored,
int myDevID = -1, XMem * myMem = NULL);
/* make the decoding network */
XTensor Make(XTensor &input, XTensor &encoderOutput, XTensor &mask, bool isTraining);
};
}
......
......@@ -31,6 +31,10 @@ namespace transformer
/* constructor */
AttEncoder::AttEncoder()
{
attentions = NULL;
fnns = NULL;
attLayerNorms = NULL;
fnnLayerNorms = NULL;
}
/* de-constructor */
......
......@@ -71,6 +71,9 @@ void T2TModel::InitModel(int argc, char ** argv)
encoder.InitModel(argc, argv, isLM, 0, devID, mem);
outputLayer.InitModel(argc, argv, devID, mem);
if(isMT)
decoder.InitModel(argc, argv, true, 0, devID, mem);
XList params(10);
GetParams(params);
......@@ -93,17 +96,16 @@ XTensor T2TModel::MakeEncoding(XTensor &input, XTensor &mask, bool isTraining)
}
/*
make the entire network (with the output softmax layer)
make the entire network for language modeling (with the output softmax layer)
>> input - input tensor
>> output - output tensor (distribution)
>> padding - padding of the sequences
>> isTraining - indicates whether the model is for training
*/
void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding, bool isTraining)
void T2TModel::MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining)
{
XTensor encoding;
if(isLM){
/* generate mask to see "previous" words only */
int len = input.GetDim(input.order - 2);
int * dims = new int[input.order + 1];
......@@ -141,7 +143,7 @@ void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding, bool isTr
_ScaleAndShiftMe(padding3, 1e9F, -1e9F);
_Sum(&mask, padding3, &mask);
//_Sum(&mask, padding3, &mask);
encoding = MakeEncoding(input, mask, isTraining);
outputLayer.Make(encoding, output);
......@@ -151,10 +153,6 @@ void T2TModel::Make(XTensor &input, XTensor &output, XTensor &padding, bool isTr
DelTensorBuf(padding2);
DelTensorBuf(padding3);
}
else{
ShowNTErrors("TODO!");
}
}
/*
......@@ -181,6 +179,29 @@ void T2TModel::GetParams(XList &list)
list.Add(&encoder.attLayerNorms[i].b);
}
if(isMT){
for(int i = 0; i < decoder.nlayer; i++){
list.Add(&decoder.fnns[i].w1);
list.Add(&decoder.fnns[i].b1);
list.Add(&decoder.fnns[i].w2);
list.Add(&decoder.fnns[i].b2);
list.Add(&decoder.attentionsEnde[i].wk);
list.Add(&decoder.attentionsEnde[i].wq);
list.Add(&decoder.attentionsEnde[i].wv);
list.Add(&decoder.attentionsEnde[i].wa);
list.Add(&decoder.attEndeLayerNorms[i].w);
list.Add(&decoder.attEndeLayerNorms[i].b);
list.Add(&decoder.attentions[i].wk);
list.Add(&decoder.attentions[i].wq);
list.Add(&decoder.attentions[i].wv);
list.Add(&decoder.attentions[i].wa);
list.Add(&decoder.fnnLayerNorms[i].w);
list.Add(&decoder.fnnLayerNorms[i].b);
list.Add(&decoder.attLayerNorms[i].w);
list.Add(&decoder.attLayerNorms[i].b);
}
}
list.Add(&encoder.embedder.w);
}
......
......@@ -71,8 +71,8 @@ public:
/* make the encoding network */
XTensor MakeEncoding(XTensor &input, XTensor &mask, bool isTraining);
/* make the entire network (with the output softmax layer) */
void Make(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
/* make the entire network for langauge modeling (with the output softmax layer) */
void MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining);
/* get parameter matrics */
void GetParams(XList &list);
......
......@@ -197,7 +197,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
XTensor output;
/* make the network */
model->Make(batch, output, padding, true);
model->MakeLM(batch, output, padding, true);
/* back-propagation for obtaining gradients */
if (labelSmoothingP > 0)
......@@ -343,7 +343,7 @@ void T2TTrainer::Test(const char * fn, const char * ofn, T2TModel * model)
XTensor output;
/* make the network */
model->Make(batch, output, padding, false);
model->MakeLM(batch, output, padding, false);
int bSize = batch.GetDim(0);
int length = batch.GetDim(1);
......@@ -886,17 +886,12 @@ void T2TTrainer::RescaleOutput(XTensor * output, XTensor * gold, XTensor * paddi
CheckNTErrors(output->order == 3, "Wrong dimension number!");
CheckNTErrors(gold->order == 3, "Wrong dimension number!");
int num = padding->GetDim(0);
XTensor * factor = NewTensorBuf(1, &num, padding->dataType, 1.0F, padding->devID, padding->mem);
_ReduceSum(padding, factor, padding->order - 1);
DTYPE count = _ReduceSumAll(padding);
_ExpMe(output);
_DivDim(output, factor, output, 0);
_ScaleAndShiftMe(output, 1/count);
_LogMe(output);
_DivDim(gold, factor, gold, 0);
DelTensorBuf(factor);
_ScaleAndShiftMe(gold, 1/count);
}
/*
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论