/* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/

#include "CopyValues.h"
#include "CopyValues.cuh"

namespace nts { // namespace nts(NiuTrans.Tensor)

/*
copy s to t
>> s - source
>> t - target
>> stream - the stream for creating the job pipeline
<< return - succeeded or not
*/
bool CopyValues(XTensor * s, XTensor * t, XStream * stream)
{
    if (s == NULL || t == NULL)
        return false;
    if (s->data == NULL || t->data == NULL)
        return false;

    CheckNTErrors((t->data != NULL), "Cannot copy to an empty data array!");
    CheckNTErrors((s->unitNum == t->unitNum), "Unmatched data item number!");

    if ((s->dataType == X_FLOAT16 && t->dataType == X_FLOAT) ||
        (s->dataType == X_FLOAT && t->dataType == X_FLOAT16)) {
        CheckNTErrors(((s->devID < 0 && t->devID < 0) || s->devID == t->devID),
            "The code must be run on the same device!");
        CheckNTErrors((s->isSparse || t->isSparse), "TODO!");
        ConvertDataType(s->devID, s->data, s->dataType, t->data, t->dataType, s->unitNum);
        return true;
    }

#ifdef USE_CUDA
    if (s->devID >= 0 || t->devID >= 0)
        return CudaCopyValues(s, t, stream);
#endif

    if (!s->isSparse && !t->isSparse) {
        memcpy((char*)t->data, (char*)s->data, s->unitSize * s->unitNum);
    }
    else if (s->isSparse && t->isSparse) {
        int d = s->GetNonzeroSize();
        t->Resize(s);
        t->unitNumNonZero = d;
        memcpy((char*)t->data, (char*)s->data, sizeof(int) + d *(sizeof(int) + sizeof(DTYPE)));
    }
    else {
        ShowNTErrors("TODO!");
    }

    return true;
}

} // namespace nts(NiuTrans.Tensor)
