Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
0
Issues
0
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
杨迪
NiuTrans.Tensor
Commits
da1d7ca8
Commit
da1d7ca8
authored
5 years ago
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
tester
parent
1faabe78
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
235 行增加
和
122 行删除
+235
-122
source/sample/transformer/T2TBatchLoader.cpp
+0
-0
source/sample/transformer/T2TBatchLoader.h
+160
-0
source/sample/transformer/T2TSearch.cpp
+30
-7
source/sample/transformer/T2TSearch.h
+1
-1
source/sample/transformer/T2TTester.cpp
+27
-0
source/sample/transformer/T2TTester.h
+11
-0
source/sample/transformer/T2TTrainer.cpp
+0
-0
source/sample/transformer/T2TTrainer.h
+6
-114
没有找到文件。
source/sample/transformer/T2TBatchLoader.cpp
0 → 100644
查看文件 @
da1d7ca8
差异被折叠。
点击展开。
source/sample/transformer/T2TBatchLoader.h
0 → 100644
查看文件 @
da1d7ca8
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-04-25
* it is cold today but i'll move to a warm place tomorrow :)
*/
#ifndef __T2TBATCHLOADER_H__
#define __T2TBATCHLOADER_H__
#include "../../network/XNet.h"
using
namespace
nts
;
namespace
transformer
{
#define MAX_SEQUENCE_LENGTH 1024 * 4
/* node to keep batch information */
struct
BatchNode
{
/* begining position */
int
beg
;
/* end position */
int
end
;
/* maximum word number on the encoder side */
int
maxEnc
;
/* maximum word number on the decoder side */
int
maxDec
;
/* a key for sorting */
int
key
;
};
class
T2TBatchLoader
{
public
:
/* buffer for loading words */
int
*
buf
;
/* another buffer */
int
*
buf2
;
/* batch buf */
BatchNode
*
bufBatch
;
/* buffer size */
int
bufSize
;
/* size of batch buffer */
int
bufBatchSize
;
/* length of each sequence */
int
*
seqLen
;
/* another array */
int
*
seqLen2
;
/* offset of the first word for each sequence */
int
*
seqOffset
;
/* number of sequences in the buffer */
int
nseqBuf
;
/* offset for next sequence in the buffer */
int
nextSeq
;
/* offset for next batch */
int
nextBatch
;
/* indicates whether we double the </s> symbol for the output of lms */
bool
isDoubledEnd
;
/* indicates whether we use batchsize = max * sc
rather rather than batchsize = word-number, where max is the maximum
length and sc is the sentence number */
bool
isSmallBatch
;
/* counterpart of "isSmallBatch" */
bool
isBigBatch
;
/* randomize batches */
bool
isRandomBatch
;
/* bucket size */
int
bucketSize
;
public
:
/* constructor */
T2TBatchLoader
();
/* de-constructor */
~
T2TBatchLoader
();
/* initialization */
void
Init
(
int
argc
,
char
**
argv
);
/* load data to buffer */
int
LoadBuf
(
FILE
*
file
,
bool
isSorted
,
int
step
);
/* clear data buffer */
void
ClearBuf
();
/* load a batch of sequences */
int
LoadBatch
(
FILE
*
file
,
bool
isLM
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
ws
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
/* load a batch of sequences (for language modeling) */
int
LoadBatchLM
(
FILE
*
file
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vs
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
/* load a batch of sequences (for machine translation) */
int
LoadBatchMT
(
FILE
*
file
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
ws
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
/* shuffle the data file */
void
Shuffle
(
const
char
*
srcFile
,
const
char
*
tgtFile
);
};
}
#endif
\ No newline at end of file
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TSearch.cpp
查看文件 @
da1d7ca8
...
...
@@ -319,17 +319,40 @@ void T2TSearch::Collect(T2TStateBundle * beam)
/*
save the output sequences in a tensor
>> input - input sequences
>> output - output sequences (for return)
*/
void
T2TSearch
::
Dump
(
XTensor
*
input
,
XTensor
*
output
)
void
T2TSearch
::
Dump
(
XTensor
*
output
)
{
int
dims
[
MAX_TENSOR_DIM_NUM
];
for
(
int
i
=
0
;
i
<
input
->
order
-
1
;
i
++
)
dims
[
i
]
=
input
->
GetDim
(
i
);
dims
[
input
->
order
-
1
]
=
maxLength
;
int
dims
[
3
]
=
{
batchSize
,
beamSize
,
maxLength
};
int
*
words
=
new
int
[
maxLength
];
InitTensor
(
output
,
3
,
dims
,
X_INT
);
SetDataFixedInt
(
*
output
,
-
1
);
/* heap for an input sentence in the batch */
for
(
int
h
=
0
;
h
<
batchSize
;
h
++
){
XHeap
<
MIN_HEAP
,
float
>
&
heap
=
fullHypos
[
h
];
/* for each output in the beam */
for
(
int
i
=
0
;
i
<
beamSize
;
i
++
){
T2TState
*
state
=
(
T2TState
*
)
heap
.
Pop
().
index
;
int
count
=
0
;
/* we track the state from the end to the beginning */
while
(
state
!=
NULL
){
words
[
count
++
]
=
state
->
prediction
;
state
=
state
->
last
;
}
/* dump the sentence to the output tensor */
for
(
int
w
=
0
;
w
<
count
;
w
++
)
output
->
Set3DInt
(
words
[
count
-
w
-
1
],
h
,
i
,
w
);
}
}
InitTensor
(
output
,
input
->
order
,
dims
,
X_INT
)
;
delete
[]
words
;
}
/*
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TSearch.h
查看文件 @
da1d7ca8
...
...
@@ -88,7 +88,7 @@ public:
void
Collect
(
T2TStateBundle
*
beam
);
/* save the output sequences in a tensor */
void
Dump
(
XTensor
*
input
,
XTensor
*
output
);
void
Dump
(
XTensor
*
output
);
/* check if the token is an end symbol */
bool
IsEnd
(
int
token
);
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TTester.cpp
查看文件 @
da1d7ca8
...
...
@@ -25,4 +25,30 @@ using namespace nts;
namespace
transformer
{
/* constructor */
T2TTester
::
T2TTester
()
{
}
/* de-constructor */
T2TTester
::~
T2TTester
()
{
}
/* initialize the model */
void
T2TTester
::
InitModel
(
int
argc
,
char
**
argv
)
{
}
/*
test the model
>> fn - test data file
>> ofn - output data file
>> model - model that is trained
*/
void
T2TTester
::
Test
(
const
char
*
fn
,
const
char
*
ofn
,
T2TModel
*
model
)
{
}
}
\ No newline at end of file
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TTester.h
查看文件 @
da1d7ca8
...
...
@@ -32,6 +32,17 @@ namespace transformer
class
T2TTester
{
public
:
/* constructor */
T2TTester
();
/* de-constructor */
~
T2TTester
();
/* initialize the model */
void
InitModel
(
int
argc
,
char
**
argv
);
/* test the model */
void
Test
(
const
char
*
fn
,
const
char
*
ofn
,
T2TModel
*
model
);
};
}
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TTrainer.cpp
查看文件 @
da1d7ca8
差异被折叠。
点击展开。
source/sample/transformer/T2TTrainer.h
查看文件 @
da1d7ca8
...
...
@@ -23,35 +23,14 @@
#define __T2TTRAINER_H__
#include "T2TModel.h"
#include "T2TBatchLoader.h"
#include "../../tensor/function/FHeader.h"
#define MAX_SEQUENCE_LENGTH 1024 * 4
using
namespace
nts
;
namespace
transformer
{
/* node to keep batch information */
struct
BatchNode
{
/* begining position */
int
beg
;
/* end position */
int
end
;
/* maximum word number on the encoder side */
int
maxEnc
;
/* maximum word number on the decoder side */
int
maxDec
;
/* a key for sorting */
int
key
;
};
/* trainer of the T2T model */
class
T2TTrainer
{
...
...
@@ -61,42 +40,6 @@ public:
/* parameter array */
char
**
argArray
;
/* buffer for loading words */
int
*
buf
;
/* another buffer */
int
*
buf2
;
/* batch buf */
BatchNode
*
bufBatch
;
/* buffer size */
int
bufSize
;
/* size of batch buffer */
int
bufBatchSize
;
/* length of each sequence */
int
*
seqLen
;
/* another array */
int
*
seqLen2
;
/* offset of the first word for each sequence */
int
*
seqOffset
;
/* number of sequences in the buffer */
int
nseqBuf
;
/* offset for next sequence in the buffer */
int
nextSeq
;
/* offset for next batch */
int
nextBatch
;
/* indicates whether the sequence is sorted by length */
bool
isLenSorted
;
/* dimension size of each inner layer */
int
d
;
...
...
@@ -158,26 +101,15 @@ public:
/* number of batches on which we do model update */
int
updateStep
;
/* indicates whether we double the </s> symbol for the output of lms */
bool
isDoubledEnd
;
/* indicates whether we use batchsize = max * sc
rather rather than batchsize = word-number, where max is the maximum
length and sc is the sentence number */
bool
isSmallBatch
;
/* counterpart of "isSmallBatch" */
bool
isBigBatch
;
/* randomize batches */
bool
isRandomBatch
;
/* indicates whether we intend to debug the net */
bool
isDebugged
;
/* bucket size */
int
bucketSize
;
/* indicates whether the sequence is sorted by length */
bool
isLenSorted
;
/* for batching */
T2TBatchLoader
batchLoader
;
public
:
/* constructor */
...
...
@@ -197,46 +129,6 @@ public:
/* make a checkpoint */
void
MakeCheckpoint
(
T2TModel
*
model
,
const
char
*
validFN
,
const
char
*
modelFN
,
const
char
*
label
,
int
id
);
/* load data to buffer */
int
LoadBuf
(
FILE
*
file
,
bool
isSorted
,
int
step
);
/* clear data buffer */
void
ClearBuf
();
/* load a batch of sequences */
int
LoadBatch
(
FILE
*
file
,
bool
isLM
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
ws
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
/* load a batch of sequences (for language modeling) */
int
LoadBatchLM
(
FILE
*
file
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vs
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
/* load a batch of sequences (for machine translation) */
int
LoadBatchMT
(
FILE
*
file
,
XTensor
*
batchEnc
,
XTensor
*
paddingEnc
,
XTensor
*
batchDec
,
XTensor
*
paddingDec
,
XTensor
*
gold
,
XTensor
*
label
,
int
*
seqs
,
int
vsEnc
,
int
vsDec
,
int
sBatch
,
int
wBatch
,
bool
isSorted
,
int
&
ws
,
int
&
wCount
,
int
devID
,
XMem
*
mem
,
bool
isTraining
);
/* shuffle the data file */
void
Shuffle
(
const
char
*
srcFile
,
const
char
*
tgtFile
);
/* get word probabilities for a batch of sequences */
float
GetProb
(
XTensor
*
output
,
XTensor
*
gold
,
XTensor
*
wordProbs
);
...
...
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论