ScaleAndShift.cpp 3.27 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2017, Natural Language Processing Lab, Northestern University. 
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

#include "../../XTensor.h"
#include "../../XName.h"
#include "../../XUtility.h"
#include "ScaleAndShift.h"
#include "ScaleAndShift.cuh"

namespace nts{ // namespace nts(NiuTrans.Tensor)

/* 
31 32
scale and shift all tensor entires

33
b = a * scale + shift
34

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
>> a - the input tensor
>> b - the output tensor
>> scale - the scaler factor
>> shift - the shift factor
*/
void _ScaleAndShift(const XTensor * a, XTensor * b, DTYPE scale, DTYPE shift)
{
#ifdef USE_CUDA
    /* run it on GPUs */
    if(a->devID >= 0){
        _CudaScaleAndShift(a, b, scale, shift);
        return;
    }
#endif

50
    CheckNTErrors((a->dataType == DEFAULT_DTYPE), "The tensor is not in the default data type!");
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

    /* sparse tensor */
    if(a->isSparse){
        int num = a->unitNumNonZero;
        char * d = (char*)a->data + sizeof(int);
        char * f = d + (sizeof(int) + sizeof(DTYPE)) * 0 + sizeof(int);
        char * db = (char*)b->data + sizeof(int);
        char * fb = db + (sizeof(int) + sizeof(DTYPE)) * 0 + sizeof(int);
        for(int i = 0; i < num; i++){
            DTYPE * v = (DTYPE*)f;
            DTYPE * vb = (DTYPE*)fb;
            *vb = *v * scale + shift;
            f += sizeof(int) + sizeof(DTYPE);
            fb += sizeof(int) + sizeof(DTYPE);
        }
    }
    /* dense tensor */
    else{
        DTYPE * va = (DTYPE*)a->data;
        DTYPE * vb = (DTYPE*)b->data;
        for(int i = 0; i < b->unitNum; i++){
            *vb = *va * scale + shift;
            va++;
            vb++;
        }
    }
}

/* 
80 81 82 83 84
scale and shift all tensor entires (do it on site)
keep the result in the input tensor a and return nothing

a = a * scale + shift

85 86 87 88 89 90 91 92 93 94
>> a - the input/output tensor
>> scale - the scaler factor
>> shift - the shift factor
*/
void _ScaleAndShiftMe(XTensor * a, DTYPE scale, DTYPE shift)
{
    _ScaleAndShift(a, a, scale, shift);
}

/* 
95 96 97
scale and shift all tensor entires (return a XTensor structure)
make a new tensor to keep the result and return it

98
b = a * scale + shift
99

100 101 102
>> a - the input tensor
>> scale - the scaler factor
>> shift - the shift factor
103
<< return - the result of scaling and shifting all tensor entires
104 105 106 107
*/
XTensor ScaleAndShift(const XTensor &a, DTYPE scale, DTYPE shift)
{
    XTensor b(&a);
xiaotong committed
108
    b.SetTMPFlag();
109
    
110
    /* call _ScaleAndShift function */
111 112 113 114 115 116 117 118 119 120 121
    _ScaleAndShift(&a, &b, scale, shift);
    
    /* tensor connections */
    XLink::MakeLink(&a, NULL, &b, MATH_SCALEANDSHIFT);
    XLink::AddParamToHead(&b, scale);
    XLink::AddParamToHead(&b, shift);
    
    return b;
}

} // namespace nts(NiuTrans.Tensor)