LogSoftmax.cu 22.3 KB
Newer Older
linye committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-26
20
* $Update by: Lin Ye (email: linye2015@outlook.com) 2019-07-01 float16 added
linye committed
21 22 23 24 25 26 27 28 29
*/

#include "LogSoftmax.h"
#include "LogSoftmax.cuh"
#include "Loss.cuh"
#include "../core/arithmetic/MultiplyDim.h"
#include "../core/reduce/ReduceSum.cuh"
#include "../core/reduce/ReduceMax.cuh"
#include "../XDevice.h"
30
#include <device_launch_parameters.h>
linye committed
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

namespace nts { // namespace nts(NiuTrans.Tensor)

#ifdef USE_CUDA

/*
log scale softmax y = log(e^x / \sum_{i} e^{x_i}) (Cuda version)
>> x - input vector
>> y - result
>> leadDim - leading dimension (along which we perform reduction)
*/
void _CudaLogSoftmax(const XTensor * x, XTensor * y, int leadDim)
{
    ShowNTErrors("You should call LogSoftmax instead!");
}

/* 
log softmax forward computation (Cuda kernel)

for each column j, let y_{i,j} and x_{i,j} are the output
and state value for the i-th element of column j. We have

y_{i,j} = log(e^x_{i,j} / \sum_{i} e^{x_{i,j})

>> x - input tensor (in matrix)
>> max - the max value for each column j
>> sum - \sum_{i} e^{x_{i,j}) for each column j
>> y - output tensor (in matrix)
>> rowNum - row number of the matrix
>> colNum - column number of the matrix
*/
62
template <class T ,TENSOR_DATA_TYPE dataType>
linye committed
63
__global__
64
void KernelLogSoftmaxComputeByRow(T * x, T * max, T * sum, T * y, int rowNum, int colNum)
linye committed
65
{
66 67
    __shared__ T inputSum[MAX_CUDA_THREAD_NUM_PER_BLOCK];
    __shared__ T inputMax[MAX_CUDA_THREAD_NUM_PER_BLOCK];
linye committed
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

    int i = blockDim.y * blockIdx.y + threadIdx.y;
    int j = blockDim.x * blockIdx.x + threadIdx.x;

    /* we keep the sum and max number in the shared memory for each column */
    if (threadIdx.y == 0) {
        inputSum[threadIdx.x] = sum[j];
        inputMax[threadIdx.x] = max[j];
    }

    /* synchronize to make sure the values of max and sum are loaded */
    __syncthreads();

    /* y_{i,j} = log(e^(s_{i,j} - max_{j}) / \sum_{k} e^{s_{k,j} - max_{j}}) */
    if (i < rowNum && j < colNum) {
        int key = i * colNum + j;

85 86
        if (dataType == X_FLOAT) {
            DTYPE r = log((DTYPE)exp((DTYPE)(x[key] - inputMax[threadIdx.x])) / (DTYPE)inputSum[threadIdx.x]);
linye committed
87

88 89 90 91 92 93 94 95 96 97 98
            if (isnan(r))
                r = LOGPROB_MIN;
            if (isinf(r))
                r = LOGPROB_MIN;

            y[key] = MAX(r, LOGPROB_MIN);
        }
        else if (dataType == X_FLOAT16) {
            half r = hlog((half)hexp(x[key] - inputMax[threadIdx.y]) / (half)inputSum[threadIdx.y]);
            y[key] = r;
        }
linye committed
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    }
}

/* 
log softmax forward computation (Cuda kernel)

for each row i, let y_{i,j} and x_{i,j} are the output
and state value for the j-th element of row i. We have

y_{i,j} = log(e^x_{i,j} / \sum_{j} e^{x_{i,j})

>> x - input tensor (in matrix)
>> max - the max value for each row i
>> sum - \sum_{j} e^{x_{i,j}) for each row i
>> y - output tensor (in matrix)
>> rowNum - row number of the matrix
>> colNum - column number of the matrix
*/
117
template <class T ,TENSOR_DATA_TYPE dataType>
linye committed
118
__global__
119
void KernelLogSoftmaxComputeByCol(T * x, T * max, T * sum, T * y, int rowNum, int colNum)
linye committed
120
{
121 122
    __shared__ T inputSum[MAX_CUDA_THREAD_NUM_PER_BLOCK];
    __shared__ T inputMax[MAX_CUDA_THREAD_NUM_PER_BLOCK];
linye committed
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138

    int i = blockDim.y * blockIdx.y + threadIdx.y;
    int j = blockDim.x * blockIdx.x + threadIdx.x;

    /* we keep the sum and max number in the shared memory for each row */
    if (threadIdx.x == 0) {
        inputSum[threadIdx.y] = sum[i];
        inputMax[threadIdx.y] = max[i];
    }

    /* synchronize to make sure the values of max and sum are loaded */
    __syncthreads();

    /* y_{i,j} = log(e^(s_{i,j} - max_{i}) / \sum_{k} e^{s_{i,k} - max_{i}}) */
    if (i < rowNum && j < colNum) {
        int key = i * colNum + j;
139 140 141 142 143 144 145 146 147 148 149 150 151 152
        if (dataType == X_FLOAT) {
            DTYPE r = log((DTYPE)exp((DTYPE)(x[key] - inputMax[threadIdx.y])) / (DTYPE)inputSum[threadIdx.y]);

            if (isnan(r))
                r = LOGPROB_MIN;
            if (isinf(r))
                r = LOGPROB_MIN;

            y[key] = MAX(r, LOGPROB_MIN);
        }
        else if (dataType == X_FLOAT16) {
            half r = hlog((half)hexp(x[key] - inputMax[threadIdx.y]) / (half)inputSum[threadIdx.y]);
            y[key] = r;
        }
linye committed
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    }
}

/*
log scale softmax y = log(e^x / \sum_{i} e^{x_i}) (Cuda version)
>> x - input vector
>> y - result
>> leadDim - leading dimension (along which we perform reduction)
>> sum - \sum_{i} e^{x_i}
>> max - \max_{i} e^{x_i}
*/
void _CudaLogSoftmaxSumMax(XTensor * x, XTensor * y, int leadDim, XTensor * sum, XTensor * max)
{
    CheckNTErrors((x->devID >= 0), "Forward computation of log softmax must be run on GPUs.");
    CheckNTErrors((x->devID == y->devID), "Input tensors must be on the same GPU.");
    CheckNTErrors((x->order == y->order), "Input tensors must be of the same size.");
    CheckNTErrors((x->order == 2), "Input tensors must be of order 2.");

    int devIDBackup;
    ProtectCudaDev(x->devID, devIDBackup);

    if (x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE) {
        int gridSize[3], blockSize[3];

        int n = x->dimSize[0];
        int m = x->dimSize[1];

        /* allocate the buffer */
        DTYPE * maxData = (DTYPE*)max->data;
        DTYPE * sumData = (DTYPE*)sum->data;

        if (leadDim == 0) {
            GDevs.GetCudaThread2D(x->devID, n, m, MAX_INT, gridSize, blockSize);

            /* y_{i,j} = log(e^(s_{i,j} - max_{j}) / \sum_{k} e^{s_{k,j} - max_{j}}) */
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
            KernelLogSoftmaxComputeByRow<DTYPE, DEFAULT_DTYPE> <<<dim3(gridSize[1], gridSize[0]), dim3(blockSize[1], blockSize[0])>>>
                                       ((DTYPE*)x->data, maxData, sumData, (DTYPE*)y->data, n, m);
        }
        else {
            GDevs.GetCudaThread2D(x->devID, m, n, MAX_INT, gridSize, blockSize);

            /* y_{i,j} = log(e^(s_{i,j} - max_{i}) / \sum_{k} e^{s_{i,k} - max_{i}}) */
            KernelLogSoftmaxComputeByCol<DTYPE, DEFAULT_DTYPE> <<<dim3(gridSize[0], gridSize[1]), dim3(blockSize[0], blockSize[1])>>>
                                       ((DTYPE*)x->data, maxData, sumData, (DTYPE*)y->data, n, m);
        }
    }
    else if (x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16) {
        int gridSize[3], blockSize[3];

        int n = x->dimSize[0];
        int m = x->dimSize[1];

        /* allocate the buffer */
        __half * maxData = (half*)max->data;
        __half * sumData = (half*)sum->data;

        if (leadDim == 0) {
            GDevs.GetCudaThread2D(x->devID, n, m, MAX_INT, gridSize, blockSize);

            /* y_{i,j} = log(e^(s_{i,j} - max_{j}) / \sum_{k} e^{s_{k,j} - max_{j}}) */
            KernelLogSoftmaxComputeByRow<half, X_FLOAT16> <<<dim3(gridSize[1], gridSize[0]), dim3(blockSize[1], blockSize[0])>>>
                                       ((half*)x->data, maxData, sumData, (half *)y->data, n, m);
linye committed
215 216 217 218 219
        }
        else {
            GDevs.GetCudaThread2D(x->devID, m, n, MAX_INT, gridSize, blockSize);

            /* y_{i,j} = log(e^(s_{i,j} - max_{i}) / \sum_{k} e^{s_{i,k} - max_{i}}) */
220 221
            KernelLogSoftmaxComputeByCol<half, X_FLOAT16> <<<dim3(gridSize[0], gridSize[1]), dim3(blockSize[0], blockSize[1])>>>
                                       ((half*)x->data, maxData, sumData, (half*)y->data, n, m);
linye committed
222
        }
223

linye committed
224
    }
linye committed
225
    else {
linye committed
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
        ShowNTErrors("TODO!");
    }

    BacktoCudaDev(x->devID, devIDBackup);
}

/*
set dE/dx = exp(y)

>> dedy - dE/dy
>> dedx - dE/dx
>> y - output of the function
>> size - size of output
>> lossName - name of the loss function
*/
241
template <class T>
linye committed
242
__global__
243
void KernelExpLoss(T * dedy, T * dedx, T * y, int size, LOSS_FUNCTION_NAME lossName)
linye committed
244 245 246 247 248 249
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size) {
        /* dE/dx_j = exp(y_j) */
        if (lossName == CROSSENTROPY)
250
            dedx[i] = exp(((DTYPE)y[i]));
linye committed
251 252
        /* dE/dx_j = exp(y_j) */
        else if (lossName == SQUAREDERROR)
253
            dedx[i] = exp(((DTYPE)y[i]));
linye committed
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
        else if (lossName == ONEHOTERROR)
            dedx[i] = 0;
        else
            dedx[i] = 0;
    }
}

/*
backward computation for log softmax

dE/dx = dE/dy * dy/dx

>> dedy - dE/dy
>> dedx - dE/dx
>> gold - gold standard to measure error (or loss)
>> y - output of the function
>> x - input of the function
>> size - size of input/output
>> lossName - name of the loss function
*/
274
template <class T, TENSOR_DATA_TYPE dataType>
linye committed
275
__global__
276
void KernelLogSoftmaxBackwardDEDS(T * dedy, T * dedx, T * gold, T * y, T * x,
linye committed
277 278 279 280 281
                                  int size, LOSS_FUNCTION_NAME lossName)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < size) {
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
        if (dataType == X_FLOAT) {
            DTYPE r = 0;
            /* dE/ds_j = exp(y_j) */
            if (lossName == CROSSENTROPY)
                r = -(DTYPE)gold[i] + (DTYPE)exp(((DTYPE)y[i]));
            /* dE/ds_j = exp(y_j) */
            else if (lossName == SQUAREDERROR)
                r = -(DTYPE)gold[i] + (DTYPE)exp(((DTYPE)y[i]));
            else if (lossName == ONEHOTERROR) {
                if ((DTYPE)gold[i] == 1.0)
                    r = -(DTYPE)gold[i] + (DTYPE)exp(((DTYPE)y[i]));
                else
                    r = 0;
            }
            else {
                r = dedy[i];
            }

            if (isnan(r))
                r = 0;
            if (isinf(r))
linye committed
303 304
                r = 0;

305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
            dedx[i] = r;
        }
        else if (dataType == X_FLOAT16) {
            half r = 0;
            /* dE/ds_j = exp(y_j) */
            if (lossName == CROSSENTROPY)
                r = -(half)gold[i] + (half)hexp(y[i]);
            /* dE/ds_j = exp(y_j) */
            else if (lossName == SQUAREDERROR)
                r = -(half)gold[i] + (half)hexp(y[i]);
            else if (lossName == ONEHOTERROR) {
                if ((half)gold[i] == (half)1.0)
                    r = -(half)gold[i] + (half)hexp(y[i]);
                else
                    r = 0;
            }
            else {
                r = dedy[i];
            }
linye committed
324

325 326
            dedx[i] = r;
        }
linye committed
327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
    }
}

/*
backward computation for log softmax (sparse matrices) for each column

dE/dx_j += -gold_j

(for dE/dx = dE/dy * dy/dx)

>> dedy - dE/dy
>> dedx - dE/dx
>> gold - gold standard to measure error (or loss)
>> y - output of the function
>> x - input of the function
>> rowNum - row number of the matrix
>> colNum - column number of the matrix
>> gNonZeroNum - 
>> lossName - name of the loss function
*/
347
template <class T>
linye committed
348
__global__
349
void KernelLogSoftmaxBackwardDEDSSparseByRow(T * dedy, T * dedx, void * gold, T * y, T * x,
linye committed
350 351
                                             int rowNum, int colNum, int gNonZeroNum, LOSS_FUNCTION_NAME lossName)
{
352
    int tupleSize = sizeof(int) + sizeof(T);
linye committed
353 354 355 356 357 358 359
    int k = blockDim.x * blockIdx.x + threadIdx.x;

    if (k < gNonZeroNum) {
        /* load the sub-block of the sparse matrix b */
        int key = *(int*)((char*)gold + tupleSize * k);
        int ni = key / colNum;
        int mi = key % colNum;
360
        int value = *(T*)((char*)gold + tupleSize * k + sizeof(int));
linye committed
361 362 363 364 365 366 367 368

        if (lossName == CROSSENTROPY)
            dedx[colNum * ni + mi] += -value;
        else if (lossName == SQUAREDERROR)
            dedx[colNum * ni + mi] += -value;
        else if (lossName == ONEHOTERROR) {
            int offset = colNum * ni + mi;
            if (value == 1.0F)
369
                dedx[offset] += (-value + exp(((DTYPE)y[offset])));
linye committed
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
            //dedx[offset] += -value * 0.005;
        }
    }
}

/*
backward computation for dense matrics with default data type

dE/dx = dE/dy * dy/dx

log softmax: y_i = log(e^{x_i} / \sum_{k} e^{x_k})

dy_i/dx_j
= d{log(e^{x_i} / \sum_{k} e^{x_k})}/dx_j
= d{log(e^{x_i})}/dx_j - d{log(\sum_{k} e^{x_k})}/dx_j
= \delta(i,j) - e^{x_j}/\sum_{k} e^{x_k})
= \delta(i,j) - exp(y_j)

where \delta(i,j) = 1 if i = j, and \delta(i,j) = 0 otherwise

if loss E is defined as cross entropy, i.e., E = -\sum_{k} (gold_k * y_k), we have

dE/dy_i = -gold_i

(where {gold_k} is the gold standard distribution)

then

dE/dx_j
= \sum_{i} {dE/dy_i * dy_i/dx_j}
= \sum_{i} {-gold_i * (\delta(i,j) - exp(y_j))}
= \sum_{i} {-gold_i * \delta{i,j)} + \sum_{i} {gold_i * exp(y_j)}
= -gold_i * \delta(i,j) + \sum_{i} {gold_i * exp(y_j)}
= -gold_j + exp(y_j)

Note: gold_i is a distribution, i.e., \sum_{i} gold_i = 1
if gold is with a one-hot representation (gold_i = 1 for only one dimension),
we can reformulize it as dE/dx_j = -\delta(i,j) + exp(y_j)

There are two ways to implement this process.
Method 1. we compute dE/dy and dy/dx resepectively, and then reach dE/dx by dE/dx = dE/dy * dy/dx
(or more precisely dE/dx_j = \sum_{i} {dE/dy_i * dy_i/dx_j})
Method 2. we compute dE/dx (or dE/dx_j) in a single step, rather than resorting to the
sub-models dE/dy and dy/dx. We can do this by using dE/dx_j = -gold_j + exp(y_j)

Here we choose Method 2, i.e., we straightforwardly compute dE/dx_j by

dE/dx_j = -gold_j + exp(y_j)

(or dE/dx_j = -\delta(i,j) + exp(y_j) for a Maximum A Posteriori Estimation (MAP))

Method 1 is also fine but is more time consuming due to the summation over dimensions.
Note that this method is not good for the standard version softmax when working with
the cross entropy loss. Because it is numerical unstable. When we use a usual method to
define softmax, we have softmax: y_i = log(e^{x_i} / \sum_{k} e^{x_k}). It is trivial to
know that dy_i/dx_j = y_i * \delta(i,j) - y_i * y_j. As y_i and y_j could be a small number,
y_i * y_i would result in a much smaller one with a risk of lossing precision. This is even
worse we multiply dy_i/dx_j with dE/dy_i. So it is in general to use log softmax instead for
better numerical stability.

>> gold - gold standard to measure error (or loss)
>> y - output of the function
>> x - input of the function
>> dedy - dE/dy
>> deds - dE/dx
>> lossName - type of loss function, e.g., cross entropy
>> leadDim - leading dimension (along which we perform reduction)
*/
void _CudaLogSoftmaxBackward(XTensor * gold, XTensor * y, XTensor * x,
                             XTensor * dedy, XTensor * dedx, 
                             XTensor * padding, int leadDim, 
                             LOSS_FUNCTION_NAME lossName)
{
    leadDim = leadDim < 0 ? y->order - 1 : leadDim;

    CheckNTErrors((x->devID >= 0), "Backward computation of log softmax must be run on GPUs.");
    CheckNTErrors((x->devID == y->devID && gold->devID == y->devID),
                  "Tensors used in log softmax are not on the same GPU.");
    CheckNTErrors((gold != NULL), "No x gold standard is found!");
449 450
    CheckNTErrors((lossName == CROSSENTROPY || lossName == SQUAREDERROR || lossName == NOLOSS),
                  "Unknown loss function.");
linye committed
451 452 453 454 455 456 457 458 459 460 461 462 463 464

    int leadDimRDI = y->order - leadDim - 1;
    int dimensionSize = y->dimSizeRDI[leadDimRDI];
    int stride = 1;
    int blockSize = 1;
    int blockNum = 1;
    for (int i = 0; i < leadDimRDI; i++)
        stride *= y->dimSizeRDI[i];
    blockSize = stride * dimensionSize;
    blockNum = y->unitNum / blockSize;

    int devIDBackup;
    ProtectCudaDev(x->devID, devIDBackup);

465
    if (x->dataType == DEFAULT_DTYPE && y->dataType == DEFAULT_DTYPE) {     
linye committed
466 467 468 469 470 471 472 473 474 475

        int cudaGridSize[3], cudaBlockSize[3];

        if (lossName == CROSSENTROPY || lossName == SQUAREDERROR) {
            if (gold->isSparse) {
                CheckNTErrors((gold->order == 2), "TODO!")
                CheckNTErrors((leadDim == 0), "TODO!");
                GDevs.GetCudaThread(x->devID, x->unitNum, cudaGridSize, cudaBlockSize);

                /* dE/ds_j = exp(y_j) */
476
                KernelExpLoss <DTYPE> <<< dim3(cudaGridSize[0]), dim3(cudaBlockSize[0]) >>>
linye committed
477 478 479 480 481 482 483 484 485
                                 (NULL,
                                 (DTYPE*)dedx->data,
                                 (DTYPE*)y->data,
                                 dimensionSize * stride,
                                 lossName);

                GDevs.GetCudaThread(x->devID, gold->unitNumNonZero, cudaGridSize, cudaBlockSize);

                /* dE/ds_j += -gold_j */
486
                KernelLogSoftmaxBackwardDEDSSparseByRow <DTYPE> <<< dim3(cudaGridSize[0]), dim3(cudaBlockSize[0]) >>>
linye committed
487 488 489 490 491 492 493 494 495 496 497 498 499 500
                                                           (NULL,
                                                           (DTYPE*)dedx->data,
                                                           (char*)gold->data + sizeof(int),
                                                           (DTYPE*)y->data,
                                                           (DTYPE*)x->data,
                                                           dedx->dimSize[0], dedx->dimSize[1], gold->unitNumNonZero, lossName);
            }
            else {
                CheckNTErrors((XTensor::IsSameShaped(gold, y)), "The tensors must be of the same size!");

                for (int k = 0; k < blockNum; k++) {
                    GDevs.GetCudaThread(x->devID, blockSize, cudaGridSize, cudaBlockSize);

                    /* dE/ds_j = -gold_j + exp(y_j) */
501
                    KernelLogSoftmaxBackwardDEDS <DTYPE, X_FLOAT> <<< dim3(cudaGridSize[0]), dim3(cudaBlockSize[0]) >>>
linye committed
502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
                                                    (NULL,
                                                    (DTYPE*)dedx->data + k * blockSize,
                                                    (DTYPE*)gold->data + k * blockSize,
                                                    (DTYPE*)y->data + k * blockSize,
                                                    (DTYPE*)x->data + k * blockSize,
                                                    dimensionSize * stride, lossName);
                }
            }
            if(padding != NULL) {
                int n = leadDim;

                int paddingOrder = padding->order;
                int * paddingDims = new int[paddingOrder];
                memcpy(paddingDims, padding->dimSize, padding->order * sizeof(int));
                padding->Reshape(padding->unitNum);

                int order = dedx->order;
                int * dims = new int[order];
                memcpy(dims, dedx->dimSize, dedx->order * sizeof(int));
                dedx->Reshape(dedx->unitNum/dedx->GetDim(n), dedx->GetDim(n));
                _MultiplyDimMe(dedx, padding, 0);

                padding->Reshape(paddingOrder, paddingDims);
                dedx->Reshape(order, dims);

                delete[] paddingDims;
                delete[] dims;
            }
        }
        else {
            ShowNTErrors("TODO!");
        }
    }
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
    else if (x->dataType == X_FLOAT16 && y->dataType == X_FLOAT16) {

        int cudaGridSize[3], cudaBlockSize[3];

        if (lossName == CROSSENTROPY || lossName == SQUAREDERROR) {
            if (gold->isSparse) {
                CheckNTErrors((gold->order == 2), "TODO!")
                CheckNTErrors((leadDim == 0), "TODO!");
                GDevs.GetCudaThread(x->devID, x->unitNum, cudaGridSize, cudaBlockSize);

                /* dE/ds_j = exp(y_j) */
                KernelExpLoss <__half> <<< dim3(cudaGridSize[0]), dim3(cudaBlockSize[0]) >>>
                                 (NULL,
                                 (__half*)dedx->data,
                                 (__half*)y->data,
                                 dimensionSize * stride,
                                 lossName);

                GDevs.GetCudaThread(x->devID, gold->unitNumNonZero, cudaGridSize, cudaBlockSize);

                /* dE/ds_j += -gold_j */
                KernelLogSoftmaxBackwardDEDSSparseByRow <__half> <<< dim3(cudaGridSize[0]), dim3(cudaBlockSize[0]) >>>
                                                           (NULL,
                                                           (__half*)dedx->data,
                                                           (char*)gold->data + sizeof(int),
                                                           (__half*)y->data,
                                                           (__half*)x->data,
                                                           dedx->dimSize[0], dedx->dimSize[1], gold->unitNumNonZero, lossName);
            }
            else {
                CheckNTErrors((XTensor::IsSameShaped(gold, y)), "The tensors must be of the same size!");

                for (int k = 0; k < blockNum; k++) {
                    GDevs.GetCudaThread(x->devID, blockSize, cudaGridSize, cudaBlockSize);

                    /* dE/ds_j = -gold_j + exp(y_j) */
                    KernelLogSoftmaxBackwardDEDS <__half, X_FLOAT16> <<< dim3(cudaGridSize[0]), dim3(cudaBlockSize[0]) >>>
                                                    (NULL,
                                                    (__half*)dedx->data + k * blockSize,
                                                    (__half*)gold->data + k * blockSize,
                                                    (__half*)y->data + k * blockSize,
                                                    (__half*)x->data + k * blockSize,
                                                    dimensionSize * stride, lossName);
                }
            }
            if (padding != NULL) {
                int n = leadDim;

                int paddingOrder = padding->order;
                int * paddingDims = new int[paddingOrder];
                memcpy(paddingDims, padding->dimSize, padding->order * sizeof(int));
                padding->Reshape(padding->unitNum);

                int order = dedx->order;
                int * dims = new int[order];
                memcpy(dims, dedx->dimSize, dedx->order * sizeof(int));
                dedx->Reshape(dedx->unitNum / dedx->GetDim(n), dedx->GetDim(n));
                _MultiplyDimMe(dedx, padding, 0);

                padding->Reshape(paddingOrder, paddingDims);
                dedx->Reshape(order, dims);

                delete[] paddingDims;
                delete[] dims;
            }
        }
        else {
            ShowNTErrors("TODO!");
        }
    }
linye committed
605 606 607 608 609 610 611 612 613
    else{
        ShowNTErrors("TODO!");
    }

    BacktoCudaDev(x->devID, devIDBackup);
}

#endif

614
} // namespace nts(NiuTrans.Tensor)