XModel.h 2.68 KB
Newer Older
xiaotong committed
1 2
/*
* NiuTrans.Tensor - an open-source tensor library
3
* Copyright (C) 2016-2021
xiaotong committed
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
* Natural Language Processing Lab, Northeastern University
* and
* NiuTrans Research
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* This class maintains the parameters (and other stuff) for training. It
* could be used to manage the parameter copy and update in training. E.g.,
* one can use this class to keep the parameters on the server side, or 
* treat it as an individual model on the worker side.
*
* $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2021-02-24
* I created more than one file today, hahaha
*/

#ifndef __XMODEL_H__
#define __XMODEL_H__

#include "../network/XNet.h"
#include "../tensor/XQueue.h"
#include "../tensor/XList.h"

namespace nts { // namespace nts(NiuTrans.Tensor)

xiaotong committed
41 42 43 44 45
/* 
parameter state
1) not ready 
2) ready 
3) the parameter has been collected from other models 
xiaotong committed
46
4) the updated parameter
xiaotong committed
47
*/
xiaotong committed
48 49 50 51
enum PARAM_STATE { PARAM_STATE_NOT_READY, 
                   PARAM_STATE_READY, 
                   PARAM_STATE_COLLECTED, 
                   PARAM_STATE_UPDATED };
xiaotong committed
52

xiaotong committed
53 54 55
/* a model template for training */
class XModel
{
56 57 58 59
protected:
    /* mutex of the model */
    MUTEX_HANDLE modelMutex;

xiaotong committed
60 61 62 63
public:
    /* the list of model parameters (pointers to the parameter tensor) */
    TensorList params;

xiaotong committed
64 65 66
    /* flags of the parameters */
    PARAM_STATE * flags;

xiaotong committed
67 68 69 70 71 72 73 74
public:

    /* constructor */
    XModel();

    /* de-constructor */
    ~XModel();

75 76
    /* clear the model (would be overloaded) */
    virtual
xiaotong committed
77 78
    void Clear();

79 80 81 82
    /* clone the model (would be overloaded) */
    virtual
    XModel * Clone(int devID);

xiaotong committed
83
    /* run the neural network */
84
    virtual
85
    bool RunSimple(XList * inputs, XList * outputs, XList * golds, XList * losses);
xiaotong committed
86 87 88

protected:
    /* run the neural network */
89
    bool RunMe(XList * args);
90 91

public:
xiaotong committed
92
    /* add a parameter tensor */
xiaotong committed
93 94 95 96
    void AddParam(XTensor * param);

    /* check if the parameters are well-defined for training */
    bool CheckParam();
97

98
    /* refresh the model */
xiaotong committed
99 100
    void RefreshMe();

101
    /* wrapper of RefreshMe() */
xiaotong committed
102 103
    static
    void Refresh(XList * args);
104

105 106 107
    /* wrapper of Run() */
    static
    bool Run(XList * args);
xiaotong committed
108

xiaotong committed
109 110 111 112 113
};

}

#endif // __XMODEL_H__