Commit 4727e551 by Lee

Update RNN training

parent b8e87296
...@@ -2666,15 +2666,14 @@ $\textrm{``you''} = \argmax_{y} \textrm{P}(y|\textbf{s}_1, \alert{\textbf{C}})$ ...@@ -2666,15 +2666,14 @@ $\textrm{``you''} = \argmax_{y} \textrm{P}(y|\textbf{s}_1, \alert{\textbf{C}})$
\begin{itemize} \begin{itemize}
\item 有了一个NMT模型,我们应该怎么使用梯度下降算法来训练一个``聪明''的翻译模型呢? \item 有了一个NMT模型,我们应该怎么使用梯度下降算法来训练一个``聪明''的翻译模型呢?
\begin{itemize} \begin{itemize}
\item 参数初始化 \item<1|only@1>{参数初始化:模型结构是确定了,但是我们初始化参数还有很多需要注意的地方,否则训练不了一个优秀的模型
\only<2>{:模型结构是确定了,但是我们初始化参数还有很多需要注意的地方,否则训练不了一个优秀的模型
\begin{itemize} \begin{itemize}
\item LSTM遗忘门偏置初始为1,也就是始终选择遗忘记忆$c$,可以有效防止初始时$c$里包含的错误信号传播后面所有时刻 \item LSTM遗忘门偏置初始为1,也就是始终选择遗忘记忆$c$,可以有效防止初始时$c$里包含的错误信号传播后面所有时刻
\item 其他参数一般使用Xavier参数初始化方法,可以有效稳定训练过程,特别是对于比较``深''的网络 \item 其他参数一般使用Xavier参数初始化方法,可以有效稳定训练过程,特别是对于比较``深''的网络$$W \sim \mathcal{U}(-\sqrt{\frac{6}{d_{\mathrm{in}}+d_{\mathrm{out}}}},\sqrt{\frac{6}{d_{\mathrm{in}}+d_{\mathrm{out}}}})$$
\item $W$是参数,$d_{\mathrm{in}}$$d_{\mathrm{out}}$分别是$W$的输入和输出的维度大小
\end{itemize} \end{itemize}
} }
\item 优化器选择 \item<2-4|only@2-4>{优化器选择:训练RNN我们通常会使用Adam或者SGD两种优化器,它们各有优劣
\only<3-5>{:训练RNN我们通常会使用Adam或者SGD两种优化器,它们各有优劣
\begin{center} \begin{center}
\footnotesize \footnotesize
\begin{tabular}{c|c|c} \begin{tabular}{c|c|c}
...@@ -2685,24 +2684,183 @@ $\textrm{``you''} = \argmax_{y} \textrm{P}(y|\textbf{s}_1, \alert{\textbf{C}})$ ...@@ -2685,24 +2684,183 @@ $\textrm{``you''} = \argmax_{y} \textrm{P}(y|\textbf{s}_1, \alert{\textbf{C}})$
\end{tabular} \end{tabular}
\end{center} \end{center}
} }
\item 学习率调度 \item<3-4|only@3-4>{学习率调度
\only<4>{ \only<3>{
\begin{itemize} \begin{itemize}
\item 不同优化器需要的学习率不同,比如Adam一般使用$0.001$$0.0001$,而SGD则在$0.1\sim 1$之间挑选 \item 不同优化器需要的学习率不同,比如Adam一般使用$0.001$$0.0001$,而SGD则在$0.1\sim 1$之间挑选
\item 但是无论使用哪个优化器,为了达到最好效果,我们通常都需要根据当前的更新次数来调整学习率的大小 \item 但是无论使用哪个优化器,为了达到最好效果,我们通常都需要根据当前的更新次数来调整学习率的大小
\end{itemize} \end{itemize}
}
\only<4>{
\begin{itemize}
\item 学习率预热
\item 学习率衰减
\end{itemize}
}
} }
\only<5>{ \item<5-13|only@5-13>{多设备并行
\begin{itemize} \only<5-6>{
\item 学习率预热 \begin{itemize}
\item 学习率衰减 \item 万事俱备,只是为什么训练这么慢?\only<6>{\alert{- RNN需要等前面所有时刻都完成计算以后才能开始计算当前时刻的输出}}
\end{itemize} \item 我有钱,是不是多买几台设备会更快?\only<6>{\alert{- 可以,但是需要技巧,而且也不是无限增长的}}
\end{itemize}
}
\only<7-13>{
\begin{itemize}
\only<7>{\item 数据并行:如果一台设备能完整放下一个RNN模型,那么数据并行可以把一个大batch均匀切分成$n$个小batch,然后分发到$n$个设备上并行计算,最后把结果汇总,相当于把运算时间变为原来的$1/n$}
\only<8-13>{\item 模型并行:做完了数据并行,仍然太慢了,因为RNN模型太大了,算一个样本也很慢,那么可以把RNN模型按层均匀切分成$l$个小模型,然后分发到$l$个设备上并行计算,相当于把运算时间变为原来的$1/l$
\hspace*{-0.5cm}
\begin{tikzpicture}
\setlength{\base}{1.5em}
\tikzstyle{rnnnode} = [rounded corners=1pt,minimum size=1\base,draw,inner sep=0pt,outer sep=0pt,fill=blue!30!white]
\tikzstyle{wordnode} = [font=\footnotesize,align=center]
\begin{scope}
% rnn[layer][step]
\coordinate (rnn00) at (0,0);
\foreach \i [count=\j from 0] in {1,2,3}
\node[wordnode] (rnn\i0) at ([yshift=2\base]rnn\j0) {$0$};
\foreach \i [count=\j from 0] in {1,2,...,4}
\coordinate (rnn0\i) at ([xshift=2\base]rnn0\j);
% step 1
\visible<8->{
\node[rnnnode] (rnn11) at ([xshift=2\base]rnn10) {};
\draw[-latex'] ([yshift=0.5\base]rnn01) to (rnn11);
\draw[-latex'] ([xshift=0.5\base]rnn10) to (rnn11);
}
\visible<8>{
% frontier
\node[rnnnode,fill=purple] () at (rnn11) {};
\node[draw=red,thick,inner sep=7pt,rounded corners=0.3em,rotate fit=-45,label={[font=\footnotesize,align=center]90:正在运算的\\{\color{red} 循环单元}},fit=(rnn11)] () {};
}
% step 2
\visible<9->{
\node[rnnnode] (rnn12) at ([xshift=2\base]rnn11) {};
\node[rnnnode] (rnn21) at ([yshift=2\base]rnn11) {};
\draw[-latex'] ([yshift=0.5\base]rnn02) to (rnn12);
\draw[-latex'] ([xshift=0.5\base]rnn20) to (rnn21);
\draw[-latex'] (rnn11) to (rnn12);
\draw[-latex'] (rnn11) to (rnn21);
}
\visible<9>{
% frontier
\node[rnnnode,fill=purple] () at (rnn12) {};
\node[rnnnode,fill=purple] () at (rnn21) {};
\node[draw=red,thick,inner sep=7pt,rounded corners=0.3em,rotate fit=-45,label={[font=\footnotesize,align=center]90:正在运算的\\{\color{red} 循环单元}},fit=(rnn12) (rnn21)] () {};
}
% step 3
\visible<10->{
\node[rnnnode] (rnn13) at ([xshift=2\base]rnn12) {};
\node[rnnnode] (rnn31) at ([yshift=2\base]rnn21) {};
\node[rnnnode] (rnn22) at ([xshift=2\base]rnn21) {};
\node[wordnode,anchor=south] (o1) at ([yshift=\base]rnn31.north) {};
\draw[-latex'] ([yshift=0.5\base]rnn03) to (rnn13);
\draw[-latex'] ([xshift=0.5\base]rnn30) to (rnn31);
\draw[-latex'] (rnn12) to (rnn13);
\draw[-latex'] (rnn21) to (rnn31);
\draw[-latex'] (rnn12) to (rnn22);
\draw[-latex'] (rnn21) to (rnn22);
\draw[-latex'] (rnn31) to (o1);
}
\visible<10>{
% frontier
\node[rnnnode,fill=purple] () at (rnn13) {};
\node[rnnnode,fill=purple] () at (rnn31) {};
\node[rnnnode,fill=purple] () at (rnn22) {};
\node[draw=red,thick,inner sep=7pt,rounded corners=0.3em,rotate fit=-45,label={[font=\footnotesize,align=center]90:正在运算的\\{\color{red} 循环单元}},fit=(rnn13) (rnn31) (rnn22)] () {};
}
% step 4
\visible<11->{
\node[rnnnode] (rnn14) at ([xshift=2\base]rnn13) {};
\node[rnnnode] (rnn23) at ([xshift=2\base]rnn22) {};
\node[rnnnode] (rnn32) at ([xshift=2\base]rnn31) {};
\node[wordnode,anchor=south] (o2) at ([yshift=\base]rnn32.north) {不错};
\draw[-latex'] ([yshift=0.5\base]rnn04) to (rnn14);
\draw[-latex'] (rnn13) to (rnn14);
\draw[-latex'] (rnn13) to (rnn23);
\draw[-latex'] (rnn22) to (rnn23);
\draw[-latex'] (rnn22) to (rnn32);
\draw[-latex'] (rnn31) to (rnn32);
\draw[-latex'] (rnn32) to (o2);
}
\visible<11>{
% frontier
\node[rnnnode,fill=purple] () at (rnn14) {};
\node[rnnnode,fill=purple] () at (rnn23) {};
\node[rnnnode,fill=purple] () at (rnn32) {};
\node[draw=red,thick,inner sep=7pt,rounded corners=0.3em,rotate fit=-45,label={[font=\footnotesize,align=center]90:正在运算的\\{\color{red} 循环单元}},fit=(rnn14) (rnn23) (rnn32)] () {};
}
% step 5
\visible<12->{
\node[rnnnode] (rnn24) at ([xshift=2\base]rnn23) {};
\node[rnnnode] (rnn33) at ([xshift=2\base]rnn32) {};
\node[wordnode,anchor=south] (o3) at ([yshift=\base]rnn33.north) {};
\draw[-latex'] (rnn14) to (rnn24);
\draw[-latex'] (rnn23) to (rnn24);
\draw[-latex'] (rnn23) to (rnn33);
\draw[-latex'] (rnn32) to (rnn33);
\draw[-latex'] (rnn33) to (o3);
}
\visible<12>{
% frontier
\node[rnnnode,fill=purple] () at (rnn24) {};
\node[rnnnode,fill=purple] () at (rnn33) {};
\node[draw=red,thick,inner sep=7pt,rounded corners=0.3em,rotate fit=-45,label={[font=\footnotesize,align=center]90:正在运算的\\{\color{red} 循环单元}},fit=(rnn24) (rnn33)] () {};
}
% step 6
\visible<13->{
\node[rnnnode] (rnn34) at ([xshift=2\base]rnn33) {};
\node[wordnode,anchor=south] (o4) at ([yshift=\base]rnn34.north) {EOS};
\draw[-latex'] (rnn33) to (rnn34);
\draw[-latex'] (rnn24) to (rnn34);
\draw[-latex'] (rnn34) to (o4);
}
\visible<13>{
% frontier
\node[rnnnode,fill=purple] () at (rnn34) {};
\node[draw=red,thick,inner sep=7pt,rounded corners=0.3em,rotate fit=-45,label={[font=\footnotesize,align=center]90:正在运算的\\{\color{red} 循环单元}},fit=(rnn34)] () {};
}
% labels
\alt<8-11>{
\draw[decorate,decoration={brace}] ([yshift=-\base]rnn10.west) to node[wordnode,align=right,left,text=red] {正在使用的\\设备1} ([yshift=\base]rnn10.west);
}{
\draw[decorate,decoration={brace}] ([yshift=-\base]rnn10.west) to node[wordnode,align=right,left] {空闲的\\设备1} ([yshift=\base]rnn10.west);
}
\alt<9-12>{
\draw[decorate,decoration={brace}] ([yshift=-\base]rnn20.west) to node[wordnode,align=right,left,text=red] {正在使用的\\设备2} ([yshift=\base]rnn20.west);
}{
\draw[decorate,decoration={brace}] ([yshift=-\base]rnn20.west) to node[wordnode,align=right,left] {空闲的\\设备2} ([yshift=\base]rnn20.west);
}
\alt<10-13>{
\draw[decorate,decoration={brace}] ([yshift=-\base]rnn30.west) to node[wordnode,align=right,left,text=red] {正在使用的\\设备3} ([yshift=\base]rnn30.west);
}{
\draw[decorate,decoration={brace}] ([yshift=-\base]rnn30.west) to node[wordnode,align=right,left] {空闲的\\设备3} ([yshift=\base]rnn30.west);
}
\foreach \i in {1,2,3}
\node[wordnode,font=\scriptsize,anchor=south west] () at (rnn\i0.north west) {\i};
\node[wordnode] () at (rnn01) {};
\node[wordnode] () at (rnn02) {};
\node[wordnode] () at (rnn03) {不错};
\node[wordnode] () at (rnn04) {};
\end{scope}
\end{tikzpicture}
}
\end{itemize}
}
} }
\item 多设备并行 \item<14|only@14>{其他
\only<6->{
\begin{itemize} \begin{itemize}
\item 万事俱备,只是为什么训练这么慢?\only<7->{\alert{- RNN需要等前面所有时刻都完成计算以后才能开始计算当前时刻的输出}} \item 训练RNN的时候,我们通常会遇到梯度爆炸的问题,也就是梯度突然变得很大,这种情况下需要使用``梯度裁剪''来防止梯度$\pi$超过阈值$$\pi'=\pi \cdot \frac{\mathrm{threshold}}{\max(\mathrm{threshold},\parallel \pi \parallel_2)}$$
\item 我有钱,是不是多买几台设备会更快?\only<7->{\alert{- 可以,但是需要技巧,而且也不是无限增长的}} \item 其中$\mathrm{threshold}$是手工设定的梯度大小阈值,$\parallel \cdot \parallel_2$是L2范数
\end{itemize} \end{itemize}
} }
\end{itemize} \end{itemize}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论