Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
da1d7ca8
Commit
da1d7ca8
authored
Apr 25, 2019
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
tester
parent
1faabe78
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
235 行增加
和
122 行删除
+235
-122
source/sample/transformer/T2TBatchLoader.cpp
+0
-0
source/sample/transformer/T2TBatchLoader.h
+160
-0
source/sample/transformer/T2TSearch.cpp
+30
-7
source/sample/transformer/T2TSearch.h
+1
-1
source/sample/transformer/T2TTester.cpp
+27
-0
source/sample/transformer/T2TTester.h
+11
-0
source/sample/transformer/T2TTrainer.cpp
+0
-0
source/sample/transformer/T2TTrainer.h
+6
-114
没有找到文件。
source/sample/transformer/T2TBatchLoader.cpp
0 → 100644
查看文件 @
da1d7ca8
差异被折叠。
点击展开。
source/sample/transformer/T2TBatchLoader.h
0 → 100644
查看文件 @
da1d7ca8
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-04-25
* it is cold today but i'll move to a warm place tomorrow :)
*/
#ifndef __T2TBATCHLOADER_H__
#define __T2TBATCHLOADER_H__
#include "../../network/XNet.h"
using
namespace
nts
;
namespace
transformer
{
#define MAX_SEQUENCE_LENGTH 1024 * 4
/* node to keep batch information */
struct
BatchNode
{
/* begining position */
int
beg
;
/* end position */
int
end
;
/* maximum word number on the encoder side */
int
maxEnc
;
/* maximum word number on the decoder side */
int
maxDec
;
/* a key for sorting */
int
key
;
};
class
T2TBatchLoader
{
public
:
/* buffer for loading words */
int
*
buf
;
/* another buffer */
int
*
buf2
;
/* batch buf */
BatchNode
*
bufBatch
;
/* buffer size */
int
bufSize
;
/* size of batch buffer */
int
bufBatchSize
;
/* length of each sequence */
int
*
seqLen
;
/* another array */
int
*
seqLen2
;
/* offset of the first word for each sequence */
int
*
seqOffset
;
/* number of sequences in the buffer */
int
nseqBuf
;
/* offset for next sequence in the buffer */
int
nextSeq
;
/* offset for next batch */
int
nextBatch
;
/* indicates whether we double the </s> symbol for the output of lms */
bool
isDoubledEnd
;
/* indicates whether we use batchsize = max * sc
rather rather than batchsize = word-number, where max is the maximum
length and sc is the sentence number */
bool
isSmallBatch
;
/* counterpart of "isSmallBatch" */
bool
isBigBatch
;
/* randomize batches */
bool
isRandomBatch
;
/* bucket size */
int
bucketSize
;
public
:
/* constructor */
T2TBatchLoader
();
/* de-constructor */
~
T2TBatchLoader
();
/* initialization */
void
Init
(
int
argc
,
char
**
argv
);
/* load data to buffer */
int
LoadBuf
(
FILE
*
file
,
bool
isSorted
,
int
step
);
/* clear data buffer */
void
ClearBuf
();
/* load a batch of sequences */
int
LoadBatch
(
FILE
*
file
,
bool
isLM
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
ws
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
/* load a batch of sequences (for language modeling) */
int
LoadBatchLM
(
FILE
*
file
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vs
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
/* load a batch of sequences (for machine translation) */
int
LoadBatchMT
(
FILE
*
file
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
ws
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
/* shuffle the data file */
void
Shuffle
(
const
char
*
srcFile
,
const
char
*
tgtFile
);
};
}
#endif
\ No newline at end of file
source/sample/transformer/T2TSearch.cpp
查看文件 @
da1d7ca8
...
...
@@ -319,17 +319,40 @@ void T2TSearch::Collect(T2TStateBundle * beam)
/*
save the output sequences in a tensor
>> input - input sequences
>> output - output sequences (for return)
*/
void
T2TSearch
::
Dump
(
XTensor
*
input
,
XTensor
*
output
)
void
T2TSearch
::
Dump
(
XTensor
*
output
)
{
int
dims
[
MAX_TENSOR_DIM_NUM
];
for
(
int
i
=
0
;
i
<
input
->
order
-
1
;
i
++
)
dims
[
i
]
=
input
->
GetDim
(
i
);
dims
[
input
->
order
-
1
]
=
maxLength
;
int
dims
[
3
]
=
{
batchSize
,
beamSize
,
maxLength
};
int
*
words
=
new
int
[
maxLength
];
InitTensor
(
output
,
3
,
dims
,
X_INT
);
SetDataFixedInt
(
*
output
,
-
1
);
/* heap for an input sentence in the batch */
for
(
int
h
=
0
;
h
<
batchSize
;
h
++
){
XHeap
<
MIN_HEAP
,
float
>
&
heap
=
fullHypos
[
h
];
/* for each output in the beam */
for
(
int
i
=
0
;
i
<
beamSize
;
i
++
){
T2TState
*
state
=
(
T2TState
*
)
heap
.
Pop
().
index
;
int
count
=
0
;
/* we track the state from the end to the beginning */
while
(
state
!=
NULL
){
words
[
count
++
]
=
state
->
prediction
;
state
=
state
->
last
;
}
/* dump the sentence to the output tensor */
for
(
int
w
=
0
;
w
<
count
;
w
++
)
output
->
Set3DInt
(
words
[
count
-
w
-
1
],
h
,
i
,
w
);
}
}
InitTensor
(
output
,
input
->
order
,
dims
,
X_INT
)
;
delete
[]
words
;
}
/*
...
...
source/sample/transformer/T2TSearch.h
查看文件 @
da1d7ca8
...
...
@@ -88,7 +88,7 @@ public:
void
Collect
(
T2TStateBundle
*
beam
);
/* save the output sequences in a tensor */
void
Dump
(
XTensor
*
input
,
XTensor
*
output
);
void
Dump
(
XTensor
*
output
);
/* check if the token is an end symbol */
bool
IsEnd
(
int
token
);
...
...
source/sample/transformer/T2TTester.cpp
查看文件 @
da1d7ca8
...
...
@@ -25,4 +25,30 @@ using namespace nts;
namespace
transformer
{
/* constructor */
T2TTester
::
T2TTester
()
{
}
/* de-constructor */
T2TTester
::~
T2TTester
()
{
}
/* initialize the model */
void
T2TTester
::
InitModel
(
int
argc
,
char
**
argv
)
{
}
/*
test the model
>> fn - test data file
>> ofn - output data file
>> model - model that is trained
*/
void
T2TTester
::
Test
(
const
char
*
fn
,
const
char
*
ofn
,
T2TModel
*
model
)
{
}
}
\ No newline at end of file
source/sample/transformer/T2TTester.h
查看文件 @
da1d7ca8
...
...
@@ -32,6 +32,17 @@ namespace transformer
class
T2TTester
{
public
:
/* constructor */
T2TTester
();
/* de-constructor */
~
T2TTester
();
/* initialize the model */
void
InitModel
(
int
argc
,
char
**
argv
);
/* test the model */
void
Test
(
const
char
*
fn
,
const
char
*
ofn
,
T2TModel
*
model
);
};
}
...
...
source/sample/transformer/T2TTrainer.cpp
查看文件 @
da1d7ca8
差异被折叠。
点击展开。
source/sample/transformer/T2TTrainer.h
查看文件 @
da1d7ca8
...
...
@@ -23,35 +23,14 @@
#define __T2TTRAINER_H__
#include "T2TModel.h"
#include "T2TBatchLoader.h"
#include "../../tensor/function/FHeader.h"
#define MAX_SEQUENCE_LENGTH 1024 * 4
using
namespace
nts
;
namespace
transformer
{
/* node to keep batch information */
struct
BatchNode
{
/* begining position */
int
beg
;
/* end position */
int
end
;
/* maximum word number on the encoder side */
int
maxEnc
;
/* maximum word number on the decoder side */
int
maxDec
;
/* a key for sorting */
int
key
;
};
/* trainer of the T2T model */
class
T2TTrainer
{
...
...
@@ -61,42 +40,6 @@ public:
/* parameter array */
char
**
argArray
;
/* buffer for loading words */
int
*
buf
;
/* another buffer */
int
*
buf2
;
/* batch buf */
BatchNode
*
bufBatch
;
/* buffer size */
int
bufSize
;
/* size of batch buffer */
int
bufBatchSize
;
/* length of each sequence */
int
*
seqLen
;
/* another array */
int
*
seqLen2
;
/* offset of the first word for each sequence */
int
*
seqOffset
;
/* number of sequences in the buffer */
int
nseqBuf
;
/* offset for next sequence in the buffer */
int
nextSeq
;
/* offset for next batch */
int
nextBatch
;
/* indicates whether the sequence is sorted by length */
bool
isLenSorted
;
/* dimension size of each inner layer */
int
d
;
...
...
@@ -158,26 +101,15 @@ public:
/* number of batches on which we do model update */
int
updateStep
;
/* indicates whether we double the </s> symbol for the output of lms */
bool
isDoubledEnd
;
/* indicates whether we use batchsize = max * sc
rather rather than batchsize = word-number, where max is the maximum
length and sc is the sentence number */
bool
isSmallBatch
;
/* counterpart of "isSmallBatch" */
bool
isBigBatch
;
/* randomize batches */
bool
isRandomBatch
;
/* indicates whether we intend to debug the net */
bool
isDebugged
;
/* bucket size */
int
bucketSize
;
/* indicates whether the sequence is sorted by length */
bool
isLenSorted
;
/* for batching */
T2TBatchLoader
batchLoader
;
public
:
/* constructor */
...
...
@@ -197,46 +129,6 @@ public:
/* make a checkpoint */
void
MakeCheckpoint
(
T2TModel
*
model
,
const
char
*
validFN
,
const
char
*
modelFN
,
const
char
*
label
,
int
id
);
/* load data to buffer */
int
LoadBuf
(
FILE
*
file
,
bool
isSorted
,
int
step
);
/* clear data buffer */
void
ClearBuf
();
/* load a batch of sequences */
int
LoadBatch
(
FILE
*
file
,
bool
isLM
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
ws
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
/* load a batch of sequences (for language modeling) */
int
LoadBatchLM
(
FILE
*
file
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vs
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
/* load a batch of sequences (for machine translation) */
int
LoadBatchMT
(
FILE
*
file
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
ws
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
/* shuffle the data file */
void
Shuffle
(
const
char
*
srcFile
,
const
char
*
tgtFile
);
/* get word probabilities for a batch of sequences */
float
GetProb
(
XTensor
*
output
,
XTensor
*
gold
,
XTensor
*
wordProbs
);
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论