T2TModel.cpp 16 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-31
 */


#include "T2TModel.h"
#include "T2TUtility.h"
25
#include "../../tensor/core/CHeader.h"
xiaotong committed
26
#include "../../tensor/XUtility.h"
27 28 29 30 31 32 33 34 35 36

namespace transformer
{

/* constructor */
T2TModel::T2TModel()
{
    devID = -1;
    isLM = false;
    isMT = false;
xuchen committed
37
    nhead = 1;
38 39 40 41

    encoder = new AttEncoder();
    decoder = new AttDecoder();
    outputLayer = new T2TOutput();
42 43 44 45 46
}

/* de-constructor */
T2TModel::~T2TModel()
{
47 48 49
    delete encoder;
    delete decoder;
    delete outputLayer;
50 51 52 53 54 55 56
}

/* 
initialize the model 
>> argc - number of arguments
>> argv - list of pointers to the arguments
*/
57
void T2TModel::InitModel(int argc, char ** argv)
58 59 60
{
    LoadParamInt(argc, argv, "dev", &devID, -1);
    LoadParamBool(argc, argv, "mt", &isMT, false);
61
    LoadParamBool(argc, argv, "lm", &isLM, !isMT);
xuchen committed
62
    LoadParamInt(argc, argv, "nhead", &nhead, 8);
63

64 65
    encoder->InitModel(argc, argv, true, 0, devID);
    outputLayer->InitModel(argc, argv, devID);
66

67
    if(isMT)
68
        decoder->InitModel(argc, argv, true, 0, devID);
69

70
    TensorList params(10);
71 72 73 74 75 76
    GetParams(params);

    for(int i = 0; i < params.count; i++){
        XTensor * param = (XTensor*)params.Get(i);
        param->SetVarFlag();
    }
77 78 79 80 81
}

/* 
make the encoding network
>> input - input tensor
xuchen committed
82
>> mask - the mask for positions that are/not involved in computation
83
>> isTraining - indicates whether we are training the model
84 85
<< return - encoding result
*/
86
XTensor T2TModel::MakeEncoder(XTensor &input, XTensor &mask, bool isTraining)
87
{
88 89 90
    XTensor nothing;

    return encoder->Make(input, mask, nothing, isTraining);
91 92 93
}

/* 
94 95 96 97
make the decoding network
>> inputDec - input tensor of the decoder
>> outputEnc - output tensor of the encoder
>> output - output tensor (distribution)
98 99
>> mask - mask for positions that are/not involved in computation
>> maskEncDec - mask for the encoder-decoder attention
100 101 102
>> isTraining - indicates whether we are training the model
<< return - encoding result
*/
103
XTensor T2TModel::MakeDecoder(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining)
104
{
105
    return decoder->Make(inputDec, outputEnc, mask, maskEncDec, isTraining);
106 107 108 109
}

/* 
make the network for language modeling (with the output softmax layer) 
110 111
>> input - input tensor
>> output - output tensor (distribution)
112
>> padding - padding of the sequences
113
>> isTraining - indicates whether the model is for training
114
*/
115
void T2TModel::MakeLM(XTensor &input, XTensor &output, XTensor &padding, bool isTraining)
116
{
117 118
    XTensor encoding;
    
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
    /* generate mask to see "previous" words only */
    //int len = input.GetDim(input.order - 2);
    //int * dims = new int[input.order + 1];
    //for(int i = 0; i < input.order; i++)
    //    dims[i + 1] = input.GetDim(i);
    //dims[0] = nhead;
    //dims[input.order] = len;
    //XTensor mask(input.order + 1, dims, X_FLOAT, 1.0F, input.devID, input.mem);

    int len = input.GetDim(input.order - 1);
    int * dims = new int[input.order + 2];
    for(int i = 0; i < input.order; i++)
        dims[i + 1] = input.GetDim(i);
    dims[0] = nhead;
    dims[input.order + 1] = len;
134
    XTensor mask;
135
    InitTensor(&mask, input.order + 2, dims, X_FLOAT, padding.devID);
136 137 138 139 140 141

    /* a upper triangular matrix where the cells of the upper triangular are set to -1e-9.
        this matrix can be used to prevent the attention to current or following words in
        a given sequence. */
    _SetDataLowTri(&mask, 1e9F, 0);
    _ScaleAndShiftMe(&mask, 1.0F, -1e9F);
142
        
143 144 145 146 147 148
    int * dimsPadding = new int[padding.order + 2];
    for(int i = 0; i < padding.order - 1; i++)
        dimsPadding[i] = padding.GetDim(i);
    dimsPadding[padding.order - 1] = padding.GetDim(-1);
    dimsPadding[padding.order] = padding.GetDim(-1);

149
    XTensor * padding2 = NewTensorBuf(padding.order + 1, dimsPadding, padding.dataType,
150
                                        padding.devID);
151 152 153 154 155

    for(int i = 0; i < padding2->order; i++)
        dimsPadding[i + 1] = padding2->GetDim(i);
    dimsPadding[0] = nhead;

156
    //XTensor * padding3 = NewTensorBuf(padding.order + 2, dimsPadding, padding.dataType,
157
    //                                    padding.devID);
158 159 160 161 162 163 164 165 166 167
    //    
    ///* mask of the padding */
    //_Unsqueeze(&padding, padding2, padding.order - 1, padding.GetDim(-1));
    //_Unsqueeze(padding2, padding3, 0, nhead);
    //    
    //_ScaleAndShiftMe(padding3, 1e9F, -1e9F);
    //    
    ////_Sum(&mask, padding3, &mask);

    encoding = MakeEncoder(input, mask, isTraining);
168
    outputLayer->Make(encoding, output);
169 170 171

    delete[] dims;
    delete[] dimsPadding;
172
        
173 174 175 176 177 178 179 180 181 182
    //DelTensorBuf(padding3);
    DelTensorBuf(padding2);
}

/* 
make the network for machine translation (with the output softmax layer) 
>> inputEnc - input tensor of the encoder
>> inputDec - input tensor of the decoder
>> output - output tensor (distribution)
>> paddingEnc - padding of the sequences (on the encoder side)
183
>> paddingDec - padding of the sequences (on the decoder side)
184 185
>> isTraining - indicates whether the model is for training
*/
186
void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, XTensor &paddingDec, bool isTraining)
187 188 189 190 191
{
    XTensor encoding;
    XTensor decoding;
    XTensor maskEnc;
    XTensor maskDec;
192
    XTensor maskEncDec;
xuchen committed
193

194 195 196 197 198
    /* encoder mask */
    MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc);
    
    /* decoder mask */
    MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec);
xiaotong committed
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220

    encoding = MakeEncoder(inputEnc, maskEnc, isTraining);

    decoding = MakeDecoder(inputDec, encoding, maskDec, maskEncDec, isTraining);

    outputLayer->Make(decoding, output);
}

/* 
make the mask for training MT models 
>> inputEnc - input of the encoder
>> inputDec - input of the decoder
>> paddingEnc - padding of the encoder input
>> paddingDec - padding of the decoder input
>> maskEnc - mask of the encoder self-attention
>> maksDec - mask of the decoder self-attention
>> maksEncDec - mask of the decoder enc-dec attention
*/
void T2TModel::MakeMTMask(XTensor &inputEnc,   XTensor &inputDec, 
                          XTensor &paddingEnc, XTensor &paddingDec, 
                          XTensor &maskEnc,    XTensor &maskDec,    XTensor &maskEncDec)
{
xuchen committed
221 222
    int len = inputDec.GetDim(inputDec.order - 1);
    int * dims = new int[inputDec.order + 2];
223 224 225
    for(int i = 0; i < inputDec.order; i++)
        dims[i + 1] = inputDec.GetDim(i);
    dims[0] = nhead;
xuchen committed
226
    dims[inputDec.order + 1] = len;
227
    InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
228
        
229
    /* an upper triangular matrix where the cells of the upper triangular are set to -1e-9.
230 231 232 233
       this matrix can be used to prevent the attention to current or following words in
       a given sequence. */
    _SetDataLowTri(&maskDec, 1e9F, 0);
    _ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
xuchen committed
234

235
    /* encoder-decoder mask that prevents the attention to padding dummy words */
236
    dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
237
    InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, paddingEnc.devID);
238

239
    XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType,
240
                                                paddingEnc.devID);
241
    XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
242 243

    _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
244 245
    _ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
    _Unsqueeze(maskEncDecTMPEnc, &maskEncDec, 0, dims[0]);
246 247 248 249

    DelTensorBuf(maskEncDecTMPDec);
    DelTensorBuf(maskEncDecTMPEnc);

250 251 252 253 254 255
    /* padding on the source side */
    int * dimsPadding = new int[paddingEnc.order + 2];
    for (int i = 0; i < paddingEnc.order - 1; i++)
        dimsPadding[i] = paddingEnc.GetDim(i);
    dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
    dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
xuchen committed
256

257
    XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType,
258
                                        paddingEnc.devID);
259 260 261 262 263

    for (int i = 0; i < padding2->order; i++)
        dimsPadding[i + 1] = padding2->GetDim(i);
    dimsPadding[0] = nhead;

264
    XTensor * padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType,
265
                                        paddingEnc.devID);
266 267 268 269 270 271 272

    /* mask of the padding */
    _Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
    _Unsqueeze(padding2, padding3, 0, nhead);

    _ScaleAndShiftMe(padding3, 1e9F, -1e9F);

273
    InitTensor(&maskEnc, padding3);
274 275 276 277 278 279 280 281 282 283
    maskEnc.SetZeroAll();

    /* generate the mask on the source language side (for padding) */
    _Sum(&maskEnc, padding3, &maskEnc);

    delete[] dims;
    delete[] dimsPadding;

    DelTensorBuf(padding3);
    DelTensorBuf(padding2);
284
}
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
    
/*
make the mask of the encoder
>> inputEnc - input of the encoder
>> paddingEnc - padding of the encoder input
>> maskEnc - mask of the encoder self-attention
*/
void T2TModel::MakeMTMaskEnc(XTensor &inputEnc, XTensor &paddingEnc, XTensor &maskEnc)
{
    /* padding on the source side */
    int * dimsPadding = new int[paddingEnc.order + 2];
    for (int i = 0; i < paddingEnc.order - 1; i++)
        dimsPadding[i] = paddingEnc.GetDim(i);
    dimsPadding[paddingEnc.order - 1] = paddingEnc.GetDim(-1);
    dimsPadding[paddingEnc.order] = paddingEnc.GetDim(-1);
    
301
    XTensor * padding2 = NewTensorBuf(paddingEnc.order + 1, dimsPadding, paddingEnc.dataType,
302
                                        paddingEnc.devID);
303 304 305 306 307
    
    for (int i = 0; i < padding2->order; i++)
        dimsPadding[i + 1] = padding2->GetDim(i);
    dimsPadding[0] = nhead;
    
308
    XTensor * padding3 = NewTensorBuf(paddingEnc.order + 2, dimsPadding, paddingEnc.dataType,
309
                                        paddingEnc.devID);
310 311 312 313 314 315 316
    
    /* mask of the padding */
    _Unsqueeze(&paddingEnc, padding2, paddingEnc.order - 1, paddingEnc.GetDim(-1));
    _Unsqueeze(padding2, padding3, 0, nhead);
    
    _ScaleAndShiftMe(padding3, 1e9F, -1e9F);
    
317
    InitTensor(&maskEnc, padding3);
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
    maskEnc.SetZeroAll();
    
    /* generate the mask on the source language side (for padding) */
    _Sum(&maskEnc, padding3, &maskEnc);
    
    DelTensorBuf(padding3);
    DelTensorBuf(padding2);
    delete[] dimsPadding;
}
    
/*
make the mask of the decoder
>> inputEnc - input of the encoder
>> inputDec - input of the decoder
>> paddingEnc - padding of the encoder input
>> paddingDec - padding of the decoder input
>> maksDec - mask of the decoder self-attention
>> maksEncDec - mask of the decoder enc-dec attention
*/
void T2TModel::MakeMTMaskDec(XTensor &inputEnc, XTensor &inputDec,
                             XTensor &paddingEnc, XTensor &paddingDec,
                             XTensor &maskDec, XTensor &maskEncDec)
{
    int len = inputDec.GetDim(inputDec.order - 1);
    int * dims = new int[inputDec.order + 2];
    for(int i = 0; i < inputDec.order; i++)
        dims[i + 1] = inputDec.GetDim(i);
    dims[0] = nhead;
    dims[inputDec.order + 1] = len;
347
    InitTensor(&maskDec, inputDec.order + 2, dims, X_FLOAT, paddingDec.devID);
348
    
xiaotong committed
349 350 351
    /* An upper triangular matrix where the cells of the upper triangular are set to -1e-9.
       This matrix can be used to block the attention to current or following words in
       a given sequence. */
352
    _SetDataLowTri(&maskDec, 1e9F, 0);
xiaotong committed
353

xiaotong committed
354
    //maskDec.Dump(stderr, "mask: ");
xiaotong committed
355

356 357
    _ScaleAndShiftMe(&maskDec, 1.0F, -1e9F);
    
358 359
    //maskDec.Dump(stderr, "mask: ");

360 361
    /* encoder-decoder mask that prevents the attention to padding dummy words */
    dims[inputDec.order + 1] = inputEnc.GetDim(inputEnc.order - 1);
362
    InitTensor(&maskEncDec, inputDec.order + 2, dims, X_FLOAT, paddingEnc.devID);
363
    
364
    XTensor * maskEncDecTMPEnc = NewTensorBuf(paddingEnc.order + 1, dims + 1, paddingEnc.dataType,
365
                                                paddingEnc.devID);
366
    XTensor * maskEncDecTMPDec = NewTensorBuf(maskEncDecTMPEnc, paddingEnc.devID);
367 368
    
    _Unsqueeze(&paddingEnc, maskEncDecTMPEnc, paddingEnc.order - 1, paddingDec.GetDim(-1));
369

xiaotong committed
370 371
    //paddingEnc.Dump(stderr, "paddingenc:");
    //maskEncDecTMPEnc->Dump(stderr, "maskencdectmpenc:");
372

373
    _ScaleAndShiftMe(maskEncDecTMPEnc, 1e9F, -1e9F);
374

xiaotong committed
375
    //maskEncDecTMPEnc->Dump(stderr, "maskencdectmpenc:");
376

377
    _Unsqueeze(maskEncDecTMPEnc, &maskEncDec, 0, dims[0]);
378

xiaotong committed
379 380
    //maskEncDecTMPEnc->Dump(stderr, "maskencdectmpenc:");
    
381 382 383 384
    DelTensorBuf(maskEncDecTMPDec);
    DelTensorBuf(maskEncDecTMPEnc);
    delete[] dims;
}
385 386 387 388
/* 
get parameter matrics
>> list - the list that keeps the parameter matrics
*/
389
void T2TModel::GetParams(TensorList &list)
390 391
{
    list.Clear();
392
    list.Add(&outputLayer->w);
393
    
394 395 396 397 398
    for(int i = 0; i < encoder->nlayer; i++){
        list.Add(&encoder->fnns[i].w1);
        list.Add(&encoder->fnns[i].b1);
        list.Add(&encoder->fnns[i].w2);
        list.Add(&encoder->fnns[i].b2);
399 400 401 402
        //list.Add(&encoder->attentions[i].wk);
        //list.Add(&encoder->attentions[i].wq);
        //list.Add(&encoder->attentions[i].wv);
        list.Add(&encoder->attentions[i].wbig);
403 404 405 406 407
        list.Add(&encoder->attentions[i].wa);
        list.Add(&encoder->fnnLayerNorms[i].w);
        list.Add(&encoder->fnnLayerNorms[i].b);
        list.Add(&encoder->attLayerNorms[i].w);
        list.Add(&encoder->attLayerNorms[i].b);
408
    }
409
    
410
    list.Add(&encoder->embedder.w);
411 412

    if(isMT){
413 414 415 416 417 418 419 420 421 422 423
        for(int i = 0; i < decoder->nlayer; i++){
            list.Add(&decoder->fnns[i].w1);
            list.Add(&decoder->fnns[i].b1);
            list.Add(&decoder->fnns[i].w2);
            list.Add(&decoder->fnns[i].b2);
            list.Add(&decoder->attentionsEnde[i].wk);
            list.Add(&decoder->attentionsEnde[i].wq);
            list.Add(&decoder->attentionsEnde[i].wv);
            list.Add(&decoder->attentionsEnde[i].wa);
            list.Add(&decoder->attEndeLayerNorms[i].w);
            list.Add(&decoder->attEndeLayerNorms[i].b);
424 425 426 427
            //list.Add(&decoder->attentions[i].wk);
            //list.Add(&decoder->attentions[i].wq);
            //list.Add(&decoder->attentions[i].wv);
            list.Add(&decoder->attentions[i].wbig);
428 429 430 431 432
            list.Add(&decoder->attentions[i].wa);
            list.Add(&decoder->fnnLayerNorms[i].w);
            list.Add(&decoder->fnnLayerNorms[i].b);
            list.Add(&decoder->attLayerNorms[i].w);
            list.Add(&decoder->attLayerNorms[i].b);
433 434
        }
        
435
        list.Add(&decoder->embedder.w);
436
    }
437 438 439 440 441 442 443 444 445
}

/*
dump the parameters 
>> fn - where to keep the model
>> model - the model
*/
void T2TModel::Dump(const char * fn)
{
xiaotong committed
446 447
    double startT = GetClockSec();

448 449 450
    FILE * file = fopen(fn, "wb");
    CheckNTErrors(file, "Cannot open the model file");

451
    TensorList params(100);
452 453 454 455 456 457 458 459 460 461

    GetParams(params);

    for(int i = 0; i < params.count; i++){
        XTensor * p = (XTensor*)params.Get(i);
        p->Dump(file, "param:");
    }

    fclose(file);

xiaotong committed
462 463 464
    double elapsed = GetClockSec() - startT;

    XPRINT1(0, stderr, "[INFO] model saved (took %.1fs)\n", elapsed);
465 466 467 468 469
}

/* read the parameters */
void T2TModel::Read(const char * fn)
{
xiaotong committed
470 471
    double startT = GetClockSec();

472 473 474
    FILE * file = fopen(fn, "rb");
    CheckNTErrors(file, "Cannot open the model file");

475
    TensorList params(100);
476 477 478 479 480 481 482 483 484 485

    GetParams(params);

    for(int i = 0; i < params.count; i++){
        XTensor * p = (XTensor*)params.Get(i);
        p->Read(file, "param:");
    }

    fclose(file);

xiaotong committed
486 487 488
    double elapsed = GetClockSec() - startT;

    XPRINT1(0, stderr, "[INFO] model loaded (took %.1fs)\n", elapsed);
489 490
}

491
}