Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
36c80fc7
Commit
36c80fc7
authored
5 years ago
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bug fixes
parent
699ddac6
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
12 行增加
和
5 行删除
+12
-5
source/sample/transformer/T2TModel.cpp
+4
-1
source/sample/transformer/T2TSearch.cpp
+4
-0
source/sample/transformer/T2TTester.cpp
+3
-3
source/sample/transformer/T2TTrainer.cpp
+1
-1
没有找到文件。
source/sample/transformer/T2TModel.cpp
查看文件 @
36c80fc7
...
...
@@ -44,10 +44,13 @@ T2TModel::T2TModel()
/* de-constructor */
T2TModel
::~
T2TModel
()
{
delete
mem
;
delete
encoder
;
delete
decoder
;
delete
outputLayer
;
/* we delete "mem" at the end because other members are using it and we must
remove the memory space before all tensors are destroyed. */
delete
mem
;
}
/*
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TSearch.cpp
查看文件 @
36c80fc7
...
...
@@ -75,6 +75,10 @@ void T2TSearch::Search(T2TModel * model, XTensor * input, XTensor * padding, XTe
encoding
=
model
->
MakeEncoder
(
*
input
,
maskEnc
,
false
);
encoding
.
SetName
(
ENCODING_NAME
);
/* max output-length = 2 * source-length */
maxLength
=
input
->
GetDim
(
-
2
)
*
2
;
CheckNTErrors
(
maxLength
>
0
,
"no max length specified!"
);
T2TStateBundle
*
states
=
new
T2TStateBundle
[
maxLength
];
T2TStateBundle
*
first
=
states
;
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TTester.cpp
查看文件 @
36c80fc7
...
...
@@ -19,6 +19,7 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-27
*/
#include <math.h>
#include "T2TUtility.h"
#include "T2TTester.h"
#include "T2TSearch.h"
...
...
@@ -130,7 +131,7 @@ void T2TTester::Test(const char * fn, const char * ofn, T2TModel * model)
double
elapsed
=
GetClockSec
()
-
startT
;
XPRINT3
(
0
,
stderr
,
"[INFO] test finished (took %.1fs, word=%d, and ppl=%.3f)
\n
"
,
elapsed
,
wordCountTotal
,
exp
(
loss
/
wordCount
));
elapsed
,
wordCountTotal
,
exp
(
loss
/
wordCount
));
}
}
\ No newline at end of file
}
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TTrainer.cpp
查看文件 @
36c80fc7
...
...
@@ -154,7 +154,7 @@ void T2TTrainer::Train(const char * fn, const char * validFN, const char * model
for
(
epoch
=
1
;
epoch
<=
nepoch
;
epoch
++
){
#ifndef WIN32
if
(
isShuffled
)
Shuffle
(
fn
,
trainFN
);
batchLoader
.
Shuffle
(
fn
,
trainFN
);
#endif
FILE
*
file
=
fopen
(
trainFN
,
"rb"
);
...
...
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论