Commit 13d3feea by xiaotong

use MatrixMul to replace MatrixMul2D

parent 7a7dc4c6
......@@ -468,7 +468,7 @@ void Update(FNNModel &model, FNNModel &grad, float epsilon)
XTensor * paraGrad = (XTensor*)gradList.GetItem(i);
/* the delta rule */
Sum(para, paraGrad, para, -epsilon);
_Sum(para, paraGrad, para, -epsilon);
}
}
......@@ -707,7 +707,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
s = s + b
NOTE: the trick here is to extend b to a 2d tensor
to fit into the 2d representation in tensor summation */
Sum(&s, &b2D, &s);
_Sum(&s, &b2D, &s);
/* pass the state through the hard tanh function:
h = tanh(s) */
......@@ -735,11 +735,13 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
InitTensor(&b2D, &s);
Unsqueeze(&b, &b2D, 0, batchSize);
Sum(&s, &b2D, &s);
_Sum(&s, &b2D, &s);
/* y = softmax(s) */
LogSoftmax(&s, &y, 1);
}
}
/*
......@@ -855,7 +857,7 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
/* gradient of the embedding weight: dE/dw += x^T * dE/dy
NOTE that we accumulate dE/dw here because the matrix w
is shared by several layers (or words) */
MatrixMul2D(&x, X_TRANS, dedy, X_NOTRANS, &dedw, 1.0F, 1.0F);
MatrixMul(&x, X_TRANS, dedy, X_NOTRANS, &dedw, 1.0F, 1.0F);
delete dedy;
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论