Commit b199a6ee by xiaotong

add a new class XLearningRate

parent 2b35ba7a
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-16
* I wore my coat again after the rain yesterday.
*/
#include "XLearningRate.h"
#include <math.h>
namespace nts { // namespace nts(NiuTrans.Tensor)
/* constructor */
XLearningRate::XLearningRate()
{
}
/* de-constructor */
XLearningRate::~XLearningRate()
{
}
/* a Transformer-style scheduler. For more details, see
"Attention is all need" by Vaswani at al.
>> lrate - the learning rate
>> nstep - the update step number
>> nwarmup - the warmup step number
*/
float XLearningRate::MakeLRTransformer(const float lrate, const int nstep, const int nwarmup)
{
float lr = 0;
float warmupEndLR = lrate;
float warmupInitLR = 1e-7F;
float lrStep = (warmupEndLR - warmupInitLR) / nwarmup;
float decayFactor = warmupEndLR * (float)pow(float(nwarmup), 0.5F);
/* learning rate, scheduled by inverse square root */
if (nstep < nwarmup)
lr = warmupInitLR + nstep * lrStep;
else
lr = decayFactor * (float)pow((float)nstep, -0.5F);
return lr;
}
}
\ No newline at end of file
/*
* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2016-2021
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This a learning rate generator. E.g., one can adjust learning rate as
* the training process proceeds.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-03-16
* I wore my coat again after the rain yesterday.
*/
#ifndef __XLEARNINGRATE_H__
#define __XLEARNINGRATE_H__
namespace nts { // namespace nts(NiuTrans.Tensor)
/* Learning rate scheduler */
class XLearningRate
{
public:
/* constructor */
XLearningRate();
/* de-constructor */
~XLearningRate();
/* a Transformer-style scheduler */
float MakeLRTransformer(const float lrate, const int nstep, const int nwarmup);
};
}
#endif
\ No newline at end of file
......@@ -96,4 +96,19 @@ void XOptimizer::UpdateParam(XTensor * param, XTensor * grad, int pid)
_Sum(param, grad, param, -lrate);
}
/* get learning rate */
float XOptimizer::GetLearningRate()
{
return lrate;
}
/*
set learning rate
>> myLRate - the learning rate that we want to use
*/
void XOptimizer::SetLearningRate(float myLRate)
{
lrate = myLRate;
}
}
......@@ -78,6 +78,12 @@ public:
/* update a parameter matrix */
virtual
void UpdateParam(XTensor * param, XTensor * grad, int pid);
/* get learning rate */
float GetLearningRate();
/* set learning rate */
void SetLearningRate(float myLRate);
};
}
......
......@@ -26,6 +26,7 @@
*/
#include "XTrainer.h"
#include "XLearningRate.h"
/* the nts (NiuTrans.Tensor) namespace */
namespace nts {
......@@ -99,6 +100,8 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
int stepAll = 0;
int jobNum = 0;
int accumulation = config->GetInt("accumulation", 1);
int nwarmup = config->GetInt("nwarmup", 0);
int lrate = optimizer->GetLearningRate();
CheckNTErrors(accumulation >= 1, "accumulation must be larger than 0!");
......@@ -119,6 +122,9 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
leader.SetServerModel(config, model);
leader.Start();
/* learning rate scheduler */
XLearningRate LRScheduler;
double startT = GetClockSec();
XPRINT(1, stderr, "[INFO] Initializing the model ... [DONE]\n");
......@@ -132,6 +138,10 @@ void XTrainer::Run(XConfig * config, DataDistributeBase * dataDistributor,
while (ok) {
if (++stepAll % accumulation == 0) {
/* learning rate scheduling */
if (nwarmup > 0)
optimizer->SetLearningRate(LRScheduler.MakeLRTransformer(lrate, step + 1, nwarmup));
/* one step of udpate */
ok = leader.Run(config, dataDistributor, model, optimizer);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论