Commit 442d3ca1 by liyinqiao

Merge with XU Chen branch (Don't use this! It's an incomplete version)

1. Fix minor errors.
2. Add some annotations.
parent 2918c894
...@@ -26,10 +26,10 @@ ...@@ -26,10 +26,10 @@
namespace nts { // namespace nts(NiuTrans.Tensor) namespace nts { // namespace nts(NiuTrans.Tensor)
/* transform a tensor by merging it alone with a dimension, e.g., (M, N/3, 3) -> (M, N) */ /* transform a tensor by merging it along with a dimension, e.g., (M, N/3, 3) -> (M, N) */
void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim = -1); void _Merge(const XTensor * s, XTensor * t, int whereToMerge, int leadingDim = -1);
/* transform a tensor by merging it alone with a dimension (return an XTensor structure) /* transform a tensor by merging it along with a dimension (return an XTensor structure)
e.g., (M, N/3, 3) -> (M, N) */ e.g., (M, N/3, 3) -> (M, N) */
XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim = -1); XTensor Merge(const XTensor &s, int whereToMerge, int leadingDim = -1);
......
/* NiuTrans.Tensor - an open-source tensor library /* NiuTrans.Tensor - an open-source tensor library
* Copyright (C) 2017, Natural Language Processing Lab, Northestern University. * Copyright (C) 2017, Natural Language Processing Lab, Northestern University.
* All rights reserved. * All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
/* /*
* $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24 * $Created by: XIAO Tong (email: xiaotong@mail.neu.edu.cn) 2018-04-24
*/ */
#ifndef __SPLIT_H__ #ifndef __SPLIT_H__
#define __SPLIT_H__ #define __SPLIT_H__
......
...@@ -85,7 +85,7 @@ XTensor Stack(const TensorList &smalls, int dim) ...@@ -85,7 +85,7 @@ XTensor Stack(const TensorList &smalls, int dim)
{ {
int count = smalls.count; int count = smalls.count;
CheckNTErrors(count > 0, "Empty list!"); CheckNTErrors(count > 0, "Empty list!");
CheckNTErrors(dim >= 0, "Illegal dimension to concatenate!"); CheckNTErrors(dim >= 0, "Illegal dimension to Stack!");
XTensor * tensor = smalls.GetItem(0); XTensor * tensor = smalls.GetItem(0);
int order = tensor->order + 1; int order = tensor->order + 1;
...@@ -149,7 +149,7 @@ void Stack(const TensorList &smalls, XTensor &t, int dim) ...@@ -149,7 +149,7 @@ void Stack(const TensorList &smalls, XTensor &t, int dim)
{ {
int count = smalls.count; int count = smalls.count;
CheckNTErrors(count > 0, "Empty list!"); CheckNTErrors(count > 0, "Empty list!");
CheckNTErrors(dim >= 0, "Illegal dimension to concatenate!"); CheckNTErrors(dim >= 0, "Illegal dimension to Stack!");
if (!t.isInit || !CheckStackShape(smalls, t, dim)) { if (!t.isInit || !CheckStackShape(smalls, t, dim)) {
XTensor * tensor = smalls.GetItem(0); XTensor * tensor = smalls.GetItem(0);
......
...@@ -167,6 +167,16 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize) ...@@ -167,6 +167,16 @@ XTensor Unsqueeze(const XTensor &a, int dim, int dSize)
return b; return b;
} }
/*
insert a dimension by copying the blocks for x times
(where x is the size of the inerted dimension) (returna a XTensor structure)
make a new tensor to keep the result and return it
>> a - the input tensor
>> b - the output tensor
>> dim - where to insert the dimension
>> dSize - size of the newly-inserted dimension
*/
void Unsqueeze(const XTensor &a, XTensor &b, int dim, int dSize) void Unsqueeze(const XTensor &a, XTensor &b, int dim, int dSize)
{ {
if (!b.isInit || !CheckUnsqueezeSize(&a, &b, dim, dSize)) { if (!b.isInit || !CheckUnsqueezeSize(&a, &b, dim, dSize)) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论