Commit 13d3feea by xiaotong

use MatrixMul to replace MatrixMul2D

parent 7a7dc4c6
...@@ -468,7 +468,7 @@ void Update(FNNModel &model, FNNModel &grad, float epsilon) ...@@ -468,7 +468,7 @@ void Update(FNNModel &model, FNNModel &grad, float epsilon)
XTensor * paraGrad = (XTensor*)gradList.GetItem(i); XTensor * paraGrad = (XTensor*)gradList.GetItem(i);
/* the delta rule */ /* the delta rule */
Sum(para, paraGrad, para, -epsilon); _Sum(para, paraGrad, para, -epsilon);
} }
} }
...@@ -707,7 +707,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net) ...@@ -707,7 +707,7 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
s = s + b s = s + b
NOTE: the trick here is to extend b to a 2d tensor NOTE: the trick here is to extend b to a 2d tensor
to fit into the 2d representation in tensor summation */ to fit into the 2d representation in tensor summation */
Sum(&s, &b2D, &s); _Sum(&s, &b2D, &s);
/* pass the state through the hard tanh function: /* pass the state through the hard tanh function:
h = tanh(s) */ h = tanh(s) */
...@@ -735,11 +735,13 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net) ...@@ -735,11 +735,13 @@ void Forward(XTensor inputs[], XTensor &output, FNNModel &model, FNNNet &net)
InitTensor(&b2D, &s); InitTensor(&b2D, &s);
Unsqueeze(&b, &b2D, 0, batchSize); Unsqueeze(&b, &b2D, 0, batchSize);
Sum(&s, &b2D, &s); _Sum(&s, &b2D, &s);
/* y = softmax(s) */ /* y = softmax(s) */
LogSoftmax(&s, &y, 1); LogSoftmax(&s, &y, 1);
} }
} }
/* /*
...@@ -855,7 +857,7 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA ...@@ -855,7 +857,7 @@ void Backward(XTensor inputs[], XTensor &output, XTensor &gold, LOSS_FUNCTION_NA
/* gradient of the embedding weight: dE/dw += x^T * dE/dy /* gradient of the embedding weight: dE/dw += x^T * dE/dy
NOTE that we accumulate dE/dw here because the matrix w NOTE that we accumulate dE/dw here because the matrix w
is shared by several layers (or words) */ is shared by several layers (or words) */
MatrixMul2D(&x, X_TRANS, dedy, X_NOTRANS, &dedw, 1.0F, 1.0F); MatrixMul(&x, X_TRANS, dedy, X_NOTRANS, &dedw, 1.0F, 1.0F);
delete dedy; delete dedy;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论