Commit 1f71eb10 by 张裕浩

use vector buffer to accelerate reduce operation

parent f98396a9
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "../../XTensor.h" #include "../../XTensor.h"
#include "../../XName.h" #include "../../XName.h"
#include "../../XBLAS.h"
#include "VectorBuffer.h"
#include "ReduceMax.h" #include "ReduceMax.h"
#include "ReduceMax.cuh" #include "ReduceMax.cuh"
...@@ -76,18 +78,80 @@ void _ReduceMax(const XTensor * input, XTensor * output, int dim) ...@@ -76,18 +78,80 @@ void _ReduceMax(const XTensor * input, XTensor * output, int dim)
} }
blockSize = stride * strideNum; blockSize = stride * strideNum;
for(int k = 0; k < blockNum; k++){
DTYPE * ip = (DTYPE*)input->data + blockSize * k; if(input->dimSizeRDI[0] % (4 * 32 / sizeof(DTYPE)) == 0 && input->dimSizeRDI[0] >= 32){
DTYPE * op = (DTYPE*)output->data + stride * k; int vecBufLength = 32 / sizeof(DTYPE);
for(int i = 0; i < stride; i++){
DTYPE max = FLOAT_MIN; if(dimRDI == 0){
DTYPE * ipe = ip + blockSize; //data is contiguous in dim 0
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ for(int i = 0; i < blockNum; i++){
DTYPE v = *ipb; DTYPE * ip = (DTYPE*)input->data + blockSize * i;
if(max < v) DTYPE * op = (DTYPE*)output->data + i;
max = v; VectorBuffer vecBuf[4];
for(int j = 0; j < 4; j++){
vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip) + j * vecBufLength);
}
for(int j = 1; j < strideNum / 32; j++){
const DTYPE* ptr = (DTYPE*)(ip + j * vecBufLength);
vecBuf[0] = vecBuf[0].maxData(VectorBuffer::loadu(ptr + 0 * vecBufLength));
vecBuf[1] = vecBuf[1].maxData(VectorBuffer::loadu(ptr + 1 * vecBufLength));
vecBuf[2] = vecBuf[2].maxData(VectorBuffer::loadu(ptr + 2 * vecBufLength));
vecBuf[3] = vecBuf[3].maxData(VectorBuffer::loadu(ptr + 3 * vecBufLength));
}
vecBuf[0] = vecBuf[0].maxData(vecBuf[1]);
vecBuf[0] = vecBuf[0].maxData(vecBuf[2]);
vecBuf[0] = vecBuf[0].maxData(vecBuf[3]);
DTYPE maxN = DTYPE_MIN;
for(int k = 0; k < vecBufLength; k++){
maxN = MAX(maxN,vecBuf[0][k]);
}
*op = maxN;
}
} else{
//data is separated
for(int i = 0; i < blockNum; i++){
for(int j = 0; j < input->dimSizeRDI[0] / 32; j++){
DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE * op = (DTYPE*)output->data + stride * i;
VectorBuffer vecBuf[4];
for(int k = 0; k < 4; k++){
vecBuf[k] = VectorBuffer::loadu((DTYPE*)(ip) + (j * 4 + k) * 32 / sizeof(DTYPE));
}
for(int k = 1; k < strideNum; k++){
DTYPE * ptr = ip + k * stride + (j * 4) * vecBufLength;
vecBuf[0] = vecBuf[0].maxData(VectorBuffer::loadu(ptr + 0 * vecBufLength));
vecBuf[1] = vecBuf[1].maxData(VectorBuffer::loadu(ptr + 1 * vecBufLength));
vecBuf[2] = vecBuf[2].maxData(VectorBuffer::loadu(ptr + 2 * vecBufLength));
vecBuf[3] = vecBuf[3].maxData(VectorBuffer::loadu(ptr + 3 * vecBufLength));
}
for(int k = 0; k < 4; k++){
for(int l = 0; l < vecBufLength; l++)
*(op + j * 32 + 8 * k + l) = vecBuf[k][l];
}
}
}
}
}//run vector buffer
else{
for(int k = 0; k < blockNum; k++){
DTYPE * ip = (DTYPE*)input->data + blockSize * k;
DTYPE * op = (DTYPE*)output->data + stride * k;
for(int i = 0; i < stride; i++){
//#if defined(USE_BLAS)
// *(op + i) = *(ip + i + (int)(stride * IAMAX(strideNum, ip + i, stride)));
//#else
DTYPE max = DTYPE_MIN;
DTYPE * ipe = ip + blockSize;
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE v = *ipb;
if(max < v)
max = v;
}
*(op + i) = max;
//#endif
} }
*(op + i) = max;
} }
} }
} }
......
...@@ -23,6 +23,9 @@ ...@@ -23,6 +23,9 @@
#include "ReduceSum.h" #include "ReduceSum.h"
#include "ReduceSum.cuh" #include "ReduceSum.cuh"
#include "../../XName.h" #include "../../XName.h"
#include "../../XBLAS.h"
#include "VectorBuffer.h"
#include <iostream>
namespace nts{ // namespace nts(NiuTrans.Tensor) namespace nts{ // namespace nts(NiuTrans.Tensor)
...@@ -82,118 +85,188 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor ...@@ -82,118 +85,188 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor
} }
blockSize = stride * strideNum; blockSize = stride * strideNum;
for(int k = 0; k < blockNum; k++){ if(input->dimSizeRDI[0] % (4 * 32 / sizeof(DTYPE)) == 0 && input->dimSizeRDI[0] >= 32){
DTYPE * ip = (DTYPE*)input->data + blockSize * k; int vecBufLength = 32 / sizeof(DTYPE);
DTYPE * op = (DTYPE*)output->data + stride * k;
DTYPE * sp = shift != NULL ? (DTYPE*)shift->data + stride * k : NULL; if(dimRDI == 0){
for(int i = 0; i < stride; i++){ //data is contiguous in dim 0
DTYPE sum = 0; for(int i = 0; i < blockNum; i++){
DTYPE bias = shift != NULL ? *(sp + i) : 0; // stride = 1
DTYPE * ipe = ip + blockSize; DTYPE * ip = (DTYPE*)input->data + blockSize * i;
if(isExp){ DTYPE * op = (DTYPE*)output->data + i;
if(bias == 0){ DTYPE * sp = shift != NULL ? (DTYPE*)shift->data + i : NULL;
if(power == (DTYPE)1.0){ DTYPE bias[32 / sizeof(DTYPE)] = {0};
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride) if(shift != NULL){
sum += (DTYPE)exp(*ipb); for(int k = 0; k < 32 / sizeof(DTYPE); k++)
bias[k] = *(sp);
}
VectorBuffer vecBuf[4];
for(int j = 0; j < 4; j++){
vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip) + j * vecBufLength, isExp, power, bias);
}
for(int j = 1; j < strideNum / 32; j++){
const DTYPE* ptr = (DTYPE*)(ip + j * vecBufLength);
vecBuf[0] = vecBuf[0] + VectorBuffer::loadu(ptr + 0 * vecBufLength, isExp, power, bias);
vecBuf[1] = vecBuf[1] + VectorBuffer::loadu(ptr + 1 * vecBufLength, isExp, power, bias);
vecBuf[2] = vecBuf[2] + VectorBuffer::loadu(ptr + 2 * vecBufLength, isExp, power, bias);
vecBuf[3] = vecBuf[3] + VectorBuffer::loadu(ptr + 3 * vecBufLength, isExp, power, bias);
}
vecBuf[0] = ((vecBuf[0] + vecBuf[1]) + (vecBuf[2] + vecBuf[3]));
DTYPE sum = (DTYPE) 0.0;
for(int k = 0; k < vecBufLength; k++){
sum = sum + vecBuf[0][k];
}
*op = sum;
}
} else{
//data is separated
for(int i = 0; i < blockNum; i++){
for(int j = 0; j < input->dimSizeRDI[0] / 32; j++){
DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE * op = (DTYPE*)output->data + stride * i;
DTYPE * sp = shift != NULL ? (DTYPE*)shift->data + stride * i : NULL;
DTYPE bias[4 * 32 / sizeof(DTYPE)] = {0};
if(shift != NULL){
for(int k = 0; k < 4 * 32 / sizeof(DTYPE); k++)
bias[k] = *(sp + k);
} }
else if(power == (DTYPE)2.0){ VectorBuffer vecBuf[4];
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ for(int k = 0; k < 4; k++){
DTYPE value = (*ipb); vecBuf[k] = VectorBuffer::loadu((DTYPE*)(ip) + (j * 4 + k) * 32 / sizeof(DTYPE), isExp, power, bias + j * 32 / sizeof(DTYPE));
sum += (DTYPE)exp(value * value);
}
} }
else if(power == (DTYPE)0.5){ for(int k = 1; k < strideNum; k++){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ DTYPE * ptr = ip + k * stride + (j * 4) * vecBufLength;
DTYPE value = (*ipb); vecBuf[0] = vecBuf[0] + VectorBuffer::loadu(ptr + 0 * vecBufLength, isExp, power, bias);
sum += (DTYPE)exp(sqrt(value)); vecBuf[1] = vecBuf[1] + VectorBuffer::loadu(ptr + 1 * vecBufLength, isExp, power, bias + 1 * vecBufLength);
} vecBuf[2] = vecBuf[2] + VectorBuffer::loadu(ptr + 2 * vecBufLength, isExp, power, bias + 2 * vecBufLength);
vecBuf[3] = vecBuf[3] + VectorBuffer::loadu(ptr + 3 * vecBufLength, isExp, power, bias + 3 * vecBufLength);
} }
else{ for(int k = 0; k < 4; k++){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ for(int l = 0; l < vecBufLength; l++)
DTYPE value = (*ipb); *(op + j * 32 + 8 * k + l) = vecBuf[k][l];
sum += (DTYPE)exp(pow(value, power));
}
} }
} }
else{ }
if(power == (DTYPE)1.0){ }
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride) }//run vector buffer
sum += (DTYPE)exp(*ipb - bias); else{
}
else if(power == (DTYPE)2.0){ for(int k = 0; k < blockNum; k++){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ DTYPE * ip = (DTYPE*)input->data + blockSize * k;
DTYPE value = (*ipb) - bias; DTYPE * op = (DTYPE*)output->data + stride * k;
sum += (DTYPE)exp(value * value); DTYPE * sp = shift != NULL ? (DTYPE*)shift->data + stride * k : NULL;
for(int i = 0; i < stride; i++){
DTYPE sum = 0;
DTYPE bias = shift != NULL ? *(sp + i) : 0;
DTYPE * ipe = ip + blockSize;
if(isExp){
if(bias == 0){
if(power == (DTYPE)1.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += (DTYPE)exp(*ipb);
} }
} else if(power == (DTYPE)2.0){
else if(power == (DTYPE)0.5){ for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ DTYPE value = (*ipb);
DTYPE value = (*ipb) - bias; sum += (DTYPE)exp(value * value);
sum += (DTYPE)exp(sqrt(value)); }
}
else if(power == (DTYPE)0.5){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)exp(sqrt(value));
}
}
else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)exp(pow(value, power));
}
} }
} }
else{ else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ if(power == (DTYPE)1.0){
DTYPE value = (*ipb) - bias; for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += (DTYPE)exp(pow(value, power)); sum += (DTYPE)exp(*ipb - bias);
} }
} else if(power == (DTYPE)2.0){
} for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
} DTYPE value = (*ipb) - bias;
else{ sum += (DTYPE)exp(value * value);
if(bias == 0){ }
if(power == (DTYPE)1.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += *ipb;
}
else if(power == (DTYPE)2.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += value * value;
} }
} else if(power == (DTYPE)0.5){
else if(power == (DTYPE)0.5){ for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ DTYPE value = (*ipb) - bias;
DTYPE value = (*ipb); sum += (DTYPE)exp(sqrt(value));
sum += (DTYPE)sqrt(value); }
} }
} else{
else{ for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ DTYPE value = (*ipb) - bias;
DTYPE value = (*ipb); sum += (DTYPE)exp(pow(value, power));
sum += (DTYPE)pow(value, power); }
} }
} }
} }
else{ else{
if(power == (DTYPE)1.0){ if(bias == 0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride) if(power == (DTYPE)1.0){
sum += *ipb; for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum -= strideNum * bias; sum += *ipb;
}
else if(power == (DTYPE)2.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += value * value;
} }
} else if(power == (DTYPE)2.0){
else if(power == (DTYPE)0.5){ for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ DTYPE value = (*ipb);
DTYPE value = (*ipb) - bias; sum += value * value;
sum += (DTYPE)sqrt(value); }
}
else if(power == (DTYPE)0.5){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)sqrt(value);
}
}
else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)pow(value, power);
}
} }
} }
else{ else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){ if(power == (DTYPE)1.0){
DTYPE value = (*ipb) - bias; for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += (DTYPE)pow(value, power); sum += *ipb;
sum -= strideNum * bias;
}
else if(power == (DTYPE)2.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += value * value;
}
}
else if(power == (DTYPE)0.5){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += (DTYPE)sqrt(value);
}
}
else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += (DTYPE)pow(value, power);
}
} }
} }
} }
*(op + i) = sum;
} }
*(op + i) = sum;
} }
} }
} }
} }
......
#include <cstring>
#include <cmath>
#include "../../XGlobal.h"
namespace nts {
class VectorBuffer {
private:
DTYPE values[32 / sizeof(DTYPE)] = { 0 };
public:
static int size() {
return 32 / sizeof(DTYPE);
}
VectorBuffer() {}
VectorBuffer(DTYPE val) {
for (int i = 0; i != size(); i++) {
values[i] = val;
}
}
static VectorBuffer loadu(const DTYPE* ptr, bool isExp = false, DTYPE power = (DTYPE)1.0F, DTYPE* bias = NULL) {
int count = 32 / sizeof(DTYPE);
VectorBuffer vec;
if (isExp) {
if (bias == NULL) {
if (power == (DTYPE)1.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(*(ptr + i));
}
}
else if (power == (DTYPE)2.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp((*(ptr + i)) * (*(ptr + i)));
}
}
else if (power == (DTYPE)0.5) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(sqrt(*(ptr + i)));
}
}
else {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(pow(*(ptr + i), power));
}
}
}/*is bias == NULL*/
else {
if (power == (DTYPE)1.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(*(ptr + i) - bias[i]);
}
}
else if (power == (DTYPE)2.0) {
for (int i = 0; i != count; i++) {
DTYPE value = *(ptr + i) - bias[i];
vec.values[i] = (DTYPE)exp(value * value);
}
}
else if (power == (DTYPE)0.5) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(sqrt(*(ptr + i) - bias[i]));
}
}
else {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)exp(pow(*(ptr + i) - bias[i], power));
}
}
}
}//isExp
else {
if (bias == NULL) {
if (power == (DTYPE)1.0) {
memcpy(vec.values, ptr, count * sizeof(DTYPE));
}
else if (power == (DTYPE)2.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = (*(ptr + i)) * (*(ptr + i));
}
}
else if (power == (DTYPE)0.5) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)sqrt(*(ptr + i));
}
}
else {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)pow(*(ptr + i), power);
}
}
}// if bias == NULL
else {
if (power == (DTYPE)1.0) {
for (int i = 0; i != count; i++) {
vec.values[i] = *(ptr + i) - bias[i];
}
}
else if (power == (DTYPE)2.0) {
for (int i = 0; i != count; i++) {
DTYPE value = *(ptr + i) - bias[i];
vec.values[i] = value * value;
}
}
else if (power == (DTYPE)0.5) {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)sqrt(*(ptr + i) - bias[i]);
}
}
else {
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)pow(*(ptr + i) - bias[i], power);
}
}
}
}
return vec;
}
const DTYPE& operator[](int idx) const {
return values[idx];
}
inline VectorBuffer operator+(const VectorBuffer &a) {
for (int i = 0; i != a.size(); i++) {
this->values[i] = a[i] + this->values[i];
}
return *this;
}
inline VectorBuffer maxData(const VectorBuffer &a) {
for (int i = 0; i != a.size(); i++) {
this->values[i] = MAX(a[i], this->values[i]);
}
return *this;
}
};
}
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论