Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
N
NiuTrans.Tensor
概览
Overview
Details
Activity
Cycle Analytics
版本库
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
问题
8
Issues
8
列表
Board
标记
里程碑
合并请求
0
Merge Requests
0
CI / CD
CI / CD
流水线
作业
日程表
图表
维基
Wiki
代码片段
Snippets
成员
Collapse sidebar
Close sidebar
活动
图像
聊天
创建新问题
作业
提交
Issue Boards
Open sidebar
NiuTrans
NiuTrans.Tensor
Commits
b9871b8d
Commit
b9871b8d
authored
6 years ago
by
xiaotong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
class of predictors
parent
8cb65ef5
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
59 行增加
和
7 行删除
+59
-7
source/sample/transformer/T2TDecoder.cpp
+1
-1
source/sample/transformer/T2TModel.h
+3
-0
source/sample/transformer/T2TPredictor.cpp
+6
-2
source/sample/transformer/T2TPredictor.h
+48
-3
source/sample/transformer/Transformer.cpp
+1
-1
没有找到文件。
source/sample/transformer/T2TDecoder.cpp
查看文件 @
b9871b8d
...
...
@@ -68,7 +68,7 @@ void AttDecoder::InitModel(int argc, char ** argv,
LoadParamFloat
(
argc
,
argv
,
"dropout"
,
&
dropoutP
,
0
);
CheckNTErrors
(
nlayer
>=
1
,
"We have one encoding layer at least!"
);
CheckNTErrors
(
vSize
>
1
,
"set vocabulary size by
\"
-vsize
\"
"
);
CheckNTErrors
(
vSize
>
1
,
"set vocabulary size by
\"
-vsize
tgt
\"
"
);
/* embedding model */
embedder
.
InitModel
(
argc
,
argv
,
devID
,
mem
,
false
);
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2TModel.h
查看文件 @
b9871b8d
...
...
@@ -31,6 +31,9 @@
namespace
transformer
{
/* a transformer model that keeps parameters of the encoder,
the decoder and the output layer (softmax). Also, it creates
the network used in transformer. */
class
T2TModel
{
public
:
...
...
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2T
Searche
r.cpp
→
source/sample/transformer/T2T
Predicto
r.cpp
查看文件 @
b9871b8d
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 201
8
, Natural Language Processing Lab, Northestern University.
* Copyright (C) 201
9
, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -19,5 +19,9 @@
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2019-03-13
*/
#include "T2TSearcher.h"
#include "T2TPredictor.h"
namespace
transformer
{
}
This diff is collapsed.
Click to expand it.
source/sample/transformer/T2T
Searche
r.h
→
source/sample/transformer/T2T
Predicto
r.h
查看文件 @
b9871b8d
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 201
8
, Natural Language Processing Lab, Northestern University.
* Copyright (C) 201
9
, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -20,7 +20,52 @@
* This is the first source file I create in 2019 - new start!
*/
#ifndef __T2TSEARCHER_H__
#define __T2TSEARCHER_H__
#ifndef __T2TPREDICTOR_H__
#define __T2TPREDICTOR_H__
#include "T2TModel.h"
namespace
transformer
{
/* state in decoder - it keeps all previously-generated words and their
hidden states */
class
T2TState
{
/* layers on the encoder side. We actually use the encoder output instead
of all hidden layers. */
XList
*
encoderLayers
;
/* layers on the decoder side */
XList
*
decoderLayers
;
/* */
};
/* The predictor reads the current state and then predicts the next.
It is exactly the same procedure of MT inference -
we get the state of previous words and then generate the next word.
Here, a state can be regared as the representation of words (word
indices, hidden states, embeddings and etc.). */
class
T2TPredictor
{
/* pointer to the transformer model */
T2TModel
*
model
;
public
:
/* constructor */
T2TPredictor
();
/* de-constructor */
~
T2TPredictor
();
/* read a state */
void
Read
(
T2TModel
*
model
,
T2TState
*
current
);
/* predict the next state */
void
Predict
(
T2TState
*
next
);
};
}
#endif
This diff is collapsed.
Click to expand it.
source/sample/transformer/Transformer.cpp
查看文件 @
b9871b8d
...
...
@@ -25,7 +25,7 @@
#include "T2TModel.h"
#include "T2TUtility.h"
#include "T2TTrainer.h"
#include "T2T
Searche
r.h"
#include "T2T
Predicto
r.h"
#include "../../tensor/XDevice.h"
#include "../../tensor/XUtility.h"
#include "../../tensor/XGlobal.h"
...
...
This diff is collapsed.
Click to expand it.
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论