T2TBatchLoader.h 4.21 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2018, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-04-25
 * it is cold today but i'll move to a warm place tomorrow :)
 */

#ifndef __T2TBATCHLOADER_H__
#define __T2TBATCHLOADER_H__

#include "../../network/XNet.h"

using namespace nts;

namespace transformer
{

#define MAX_SEQUENCE_LENGTH 1024 * 4

/* node to keep batch information */
struct BatchNode
{
    /* begining position */
    int beg;

    /* end position */
    int end;

    /* maximum word number on the encoder side */
    int maxEnc;

    /* maximum word number on the decoder side */
    int maxDec;

    /* a key for sorting */
    int key;
};

class T2TBatchLoader
{
public:
    /* buffer for loading words */
    int * buf;

    /* another buffer */
    int * buf2;

    /* batch buf */
    BatchNode * bufBatch;

    /* buffer size */
    int bufSize;

    /* size of batch buffer */
    int bufBatchSize;

    /* length of each sequence */
    int * seqLen;

    /* another array */
    int * seqLen2;

    /* offset of the first word for each sequence */
    int * seqOffset;

    /* number of sequences in the buffer */
    int nseqBuf;

    /* offset for next sequence in the buffer */
    int nextSeq;

    /* offset for next batch */
    int nextBatch;

    /* indicates whether we double the </s> symbol for the output of lms */
    bool isDoubledEnd;
    
    /* indicates whether we use batchsize = max * sc
       rather rather than batchsize = word-number, where max is the maximum
       length and sc is the sentence number */
    bool isSmallBatch;

    /* counterpart of "isSmallBatch" */
    bool isBigBatch;

    /* randomize batches */
    bool isRandomBatch;

    /* bucket size */
    int bucketSize;

public:
    /* constructor */
    T2TBatchLoader();

    /* de-constructor */
    ~T2TBatchLoader();

    /* initialization */
    void Init(int argc, char ** argv);

    /* load data to buffer */
    int LoadBuf(FILE * file, bool isSorted, int step);

    /* clear data buffer */
    void ClearBuf();

    /* set the random batch flag */
    void SetRandomBatch(bool flag = true);

    /* load a batch of sequences */
    int LoadBatch(FILE * file, bool isLM,
                  XTensor * batchEnc, XTensor * paddingEnc, 
                  XTensor * batchDec, XTensor * paddingDec,
                  XTensor * gold, XTensor * label,
                  int * seqs,
                  int vsEnc, int vsDec, int sBatch, int wBatch, 
                  bool isSorted, int &ws, int &wCount,
                  int devID, XMem * mem, 
				  bool isTraining);

    /* load a batch of sequences (for language modeling) */
    int LoadBatchLM(FILE * file, 
                    XTensor * batchEnc, XTensor * paddingEnc,
                    XTensor * batchDec, XTensor * paddingDec,
                    XTensor * gold, XTensor * label,
                    int * seqs, int vs, int sBatch, int wBatch, 
                    bool isSorted, int &wCount,
                    int devID, XMem * mem, 
					bool isTraining);

    /* load a batch of sequences (for machine translation) */
    int LoadBatchMT(FILE * file, 
                    XTensor * batchEnc, XTensor * paddingEnc, 
                    XTensor * batchDec, XTensor * paddingDec,
                    XTensor * gold, XTensor * label,
                    int * seqs, int vsEnc, int vsDec, int sBatch, int wBatch, 
                    bool isSorted, int &ws, int &wCount,
                    int devID, XMem * mem, 
					bool isTraining);

    /* shuffle the data file */
    void Shuffle(const char * srcFile, const char * tgtFile);
};
}

#endif