Identity.cpp 2.36 KB
Newer Older
xiaotong committed
1
/* NiuTrans.Tensor - an open-source tensor library
liyinqiao committed
2
 * Copyright (C) 2017, Natural Language Processing Lab, Northeastern University. 
xiaotong committed
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-27
*/

#include "Identity.h"
liyinqiao committed
23
#include "../XName.h"
xiaotong committed
24
#include "../XUtility.h"
25
#include "../core/movement/CopyValues.h"
liyinqiao committed
26
#include "../core/shape/IsSameShaped.h"
xiaotong committed
27 28 29 30 31 32

namespace nts{ // namespace nts(NiuTrans.Tensor)

/* 
identity function y = x 
>> x - input tensor
liyinqiao committed
33
>> y - output tensor
xiaotong committed
34
*/
35
void _Identity(const XTensor * x, XTensor * y)
xiaotong committed
36
{
liyinqiao committed
37 38
    CheckNTErrors(_IsSameShaped(x, y), 
                 "The input tensor and output tensor must have the same shape!")
39
    _CopyValues(x, y);
xiaotong committed
40 41 42
}

/* 
liyinqiao committed
43
identity function y = x (return an XTensor structure) 
44 45 46
make a new tensor to keep the result and return it

>> x - input tensor
liyinqiao committed
47
<< return - output tensor
48 49 50 51
*/
XTensor Identity(const XTensor &x)
{
    XTensor y(&x);
liyinqiao committed
52
    y.SetTMPFlag();
53 54 55 56 57

    /* call _Identity function */
    _Identity(&x, &y);

    /* tensor connection */
liyinqiao committed
58 59 60
    if (x.enableGrad) {
        XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY);
    }
61 62 63

    return y;
}
liyinqiao committed
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

void Identity(const XTensor &x, XTensor &y)
{
    if (!y.isInit || !IsSameShaped(y, x)) {
        InitTensorV2(&y, &x);
    }

    /* call _Identity function */
    _Identity(&x, &y);

    if (x.enableGrad) {
        /* tensor connection */
        XLink::MakeLink(&x, NULL, &y, FUNC_IDENTITY);
    }
}

80
/* 
xiaotong committed
81 82 83 84
backward computation for identity function y = x 

dE/dx = dE/dy * dy/dx = dE/dy

liyinqiao committed
85 86
>> y - output of the identity function
>> x - input of the identity function
xiaotong committed
87 88 89
>> dedy - dE/dy
>> dedx - dE/dx
*/
liyinqiao committed
90 91
void _IdentityBackward(const XTensor * y, const XTensor * x,
                       const XTensor * dedy, XTensor * dedx)
xiaotong committed
92
{
liyinqiao committed
93 94
    if(dedy->data != dedx->data)
        _CopyValues(dedy, dedx);
xiaotong committed
95 96 97
}

} // namespace nts(NiuTrans.Tensor)