Commit 2b35ba7a by xiaotong

implementation of Adam

parent 0a70df12
......@@ -39,6 +39,7 @@ XOptimizer::XOptimizer()
/* de-constructor */
XOptimizer::~XOptimizer()
{
Clear();
}
/*
......@@ -60,6 +61,11 @@ void XOptimizer::Clear()
lrate = 0;
}
/* reset the optimizer (re-start) */
void XOptimizer::Reset()
{
}
void XOptimizer::ShowSettings()
{
XPRINT(1, stderr, "[INFO] Optimizer Setup:\n");
......@@ -69,14 +75,6 @@ void XOptimizer::ShowSettings()
}
/*
prepare for the update
>> model - the model that we want to update
*/
void XOptimizer::Prepare(XModel * model)
{
}
/*
record the update
>> model - the model that we want to update
*/
......
......@@ -62,15 +62,15 @@ public:
/* clear the optimizer */
virtual
void Clear();
/* reset the optimizer (re-start) */
virtual
void Reset();
/* show settings */
virtual
void ShowSettings();
/* prepare for the update */
virtual
void Prepare(XModel * model);
/* record the update */
virtual
void Note(XModel * model);
......
......@@ -92,8 +92,6 @@ void XWorkerUpdate::UpdateModel(XModel * model, XOptimizer * optimizer, int slee
{
int finished = 0;
optimizer->Prepare(model);
while (1) {
for (int i = 0; i < model->paramNum; i++) {
if (model->params[i].flag == PARAM_STATE_COLLECTED) {
......
......@@ -26,6 +26,127 @@
*/
#include "Adam.h"
#include "../../tensor/core/CHeader.h"
namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */
Adam::Adam() : XOptimizer()
{
Clear();
}
/* de-constructor */
Adam::~Adam()
{
Clear();
}
/*
initialize the optimizer
>> config - the configuration
*/
void Adam::Init(XConfig &config)
{
XOptimizer::Init(config);
adamBeta1 = config.GetFloat("adambeta1", 0.9F);
adamBeta2 = config.GetFloat("adambeta2", 0.98F);
adamDelta = config.GetFloat("adamdelta", 1e-9F);
}
/* clear the optimizer */
void Adam::Clear()
{
XOptimizer::Clear();
for (int i = 0; i < moments.count; i++) {
XTensor * m = moments[i];
delete m;
}
moments.Clear();
for (int i = 0; i < moments2nd.count; i++) {
XTensor * m2nd = moments2nd[i];
delete m2nd;
}
moments2nd.Clear();
adamBeta1T = 1.0F;
adamBeta2T = 1.0F;
}
/* reset the optimizer (re-start) */
void Adam::Reset()
{
for (int i = 0; i < moments.count; i++) {
XTensor * m = moments[i];
m->SetZeroAll();
}
for (int i = 0; i < moments2nd.count; i++) {
XTensor * m2nd = moments2nd[i];
m2nd->SetZeroAll();
}
adamBeta1T = 1.0F;
adamBeta2T = 1.0F;
}
/* show settings */
void Adam::ShowSettings()
{
XPRINT(1, stderr, "[INFO] Optimizer = Adam\n");
XOptimizer::ShowSettings();
XPRINT2(1, stderr, "%25s = %f\n", "adambeta1", adamBeta1);
XPRINT2(1, stderr, "%25s = %f\n", "adambeta2", adamBeta2);
XPRINT2(1, stderr, "%25s = %f\n", "adamdelta", adamDelta);
}
/* record the update */
void Adam::Note(XModel * model)
{
nstep++;
}
/*
update a parameter matrix using Adam
>> param - the parameter to update
>> grad - the gradient of the parameter
>> pid - index of the parameter
*/
void Adam::UpdateParam(XTensor * param, XTensor * grad, int pid)
{
adamBeta1T *= adamBeta1;
adamBeta2T *= adamBeta2;
float e = lrate * (float)sqrt(1 - adamBeta2T) / (1 - adamBeta1T);
float d = adamDelta * (float)sqrt(1 - adamBeta2T);
/* m = beta_1 * m + (1-beta_1) * grad */
XTensor * m = moments[pid];
_ScaleAndShiftMe(m, adamBeta1, 0);
_Sum(m, grad, m, (1.0F - adamBeta1));
/* v = beta_2 * v + (1-beta_2) * grad * grad*/
XTensor * v = moments2nd[pid];
_Multiply(grad, grad, v, adamBeta2 / (1.0F - adamBeta2));
_ScaleAndShiftMe(v, (1.0F - adamBeta2), 0);
/* allocate a piece of buffer memory */
GMems.GetMem(v->devID)->LockBuf();
XTensor* v2 = NewTensorBuf(v, v->devID);
/* v2 = m / (sqrt(v) + delta) */
_Power(v, v2, 0.5F);
_ScaleAndShiftMe(v2, 1.0F, d);
_Div(m, v2, v2);
/* the delta rule */
_Sum(param, v2, param, -e);
/* release a piece of buffer memory */
DelTensorBuf(v2);
GMems.GetMem(v->devID)->UnlockBuf();
}
}
\ No newline at end of file
......@@ -37,6 +37,45 @@ namespace nts { // namespace nts(NiuTrans.Tensor)
/* an implementation of the Adam optimizer */
class Adam : public XOptimizer
{
protected:
/* list of the moment of the parameter matrices */
TensorList moments;
/* list of the 2nd order moment of the parameter matrices */
TensorList moments2nd;
/* hyper parameters of Adam */
float adamBeta1;
float adamBeta2;
float adamDelta;
float adamBeta1T;
float adamBeta2T;
public:
/* constructor */
Adam();
/* de-constructor */
~Adam();
/* initialize the optimizer */
void Init(XConfig &config);
/* clear the optimizer */
void Clear();
/* reset the optimizer (re-start) */
void Reset();
/* show settings */
void ShowSettings();
/* record the update */
void Note(XModel * model);
/* update a parameter matrix */
void UpdateParam(XTensor * param, XTensor * grad, int pid);
};
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论