Commit 569cb2dd by Tianzhi

finish reduce sum

parent e5f95479
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -24,6 +24,7 @@
#include "ReduceSum.cuh"
#include "../../XName.h"
#include "../../XBLAS.h"
#include "./VectorBuffer.h"
#include "../arithmetic/XTensorBLAS.h"
#include <iostream>
......@@ -73,8 +74,8 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor
else{
CheckNTErrors((input->dataType == DEFAULT_DTYPE), "TODO!");
int stride = 1;
int strideNum = input->dimSizeRDI[dimRDI];
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for (int i = 0; i < input->order; i++) {
......@@ -85,6 +86,74 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor
}
blockSize = stride * strideNum;
if(input->dimSizeRDI[0] % (4 * 32 / sizeof(DTYPE)) == 0 && input->dimSizeRDI[0] >= 32){
int vecBufLength = 32 / sizeof(DTYPE);
if(dimRDI == 0){
//data is contiguous in dim 0
for(int i = 0; i < blockNum; i++){
// stride = 1
DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE * op = (DTYPE*)output->data + i;
DTYPE * sp = shift != NULL ? (DTYPE*)shift->data + i : NULL;
DTYPE bias[32 / sizeof(DTYPE)] = {0};
if(shift != NULL){
for(int k = 0; k < 32 / sizeof(DTYPE); k++)
bias[k] = *(sp);
}
VectorBuffer vecBuf[4];
for(int j = 0; j < 4; j++){
vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip) + j * vecBufLength, isExp, power, bias);
}
for(int j = 1; j < strideNum / 32; j++){
const DTYPE* ptr = (DTYPE*)(ip + j * vecBufLength);
vecBuf[0] = vecBuf[0] + VectorBuffer::loadu(ptr + 0 * vecBufLength, isExp, power, bias);
vecBuf[1] = vecBuf[1] + VectorBuffer::loadu(ptr + 1 * vecBufLength, isExp, power, bias);
vecBuf[2] = vecBuf[2] + VectorBuffer::loadu(ptr + 2 * vecBufLength, isExp, power, bias);
vecBuf[3] = vecBuf[3] + VectorBuffer::loadu(ptr + 3 * vecBufLength, isExp, power, bias);
}
vecBuf[0] = ((vecBuf[0] + vecBuf[1]) + (vecBuf[2] + vecBuf[3]));
DTYPE sum = (DTYPE) 0.0;
for(int k = 0; k < vecBufLength; k++){
sum = sum + vecBuf[0][k];
}
*op = sum;
}
} else{
//data is separated
for(int i = 0; i < blockNum; i++){
for(int j = 0; j < input->dimSizeRDI[0] / 32; j++){
DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE * op = (DTYPE*)output->data + stride * i;
DTYPE * sp = shift != NULL ? (DTYPE*)shift->data + stride * i : NULL;
DTYPE bias[4 * 32 / sizeof(DTYPE)] = {0};
if(shift != NULL){
for(int k = 0; k < 4 * 32 / sizeof(DTYPE); k++)
bias[k] = *(sp + k);
}
VectorBuffer vecBuf[4];
for(int k = 0; k < 4; k++){
vecBuf[k] = VectorBuffer::loadu((DTYPE*)(ip) + (j * 4 + k) * 32 / sizeof(DTYPE), isExp, power, bias + j * 32 / sizeof(DTYPE));
}
for(int k = 1; k < strideNum; k++){
DTYPE * ptr = ip + k * stride + (j * 4) * vecBufLength;
vecBuf[0] = vecBuf[0] + VectorBuffer::loadu(ptr + 0 * vecBufLength, isExp, power, bias);
vecBuf[1] = vecBuf[1] + VectorBuffer::loadu(ptr + 1 * vecBufLength, isExp, power, bias + 1 * vecBufLength);
vecBuf[2] = vecBuf[2] + VectorBuffer::loadu(ptr + 2 * vecBufLength, isExp, power, bias + 2 * vecBufLength);
vecBuf[3] = vecBuf[3] + VectorBuffer::loadu(ptr + 3 * vecBufLength, isExp, power, bias + 3 * vecBufLength);
}
for(int k = 0; k < 4; k++){
for(int l = 0; l < vecBufLength; l++)
*(op + j * 32 + 8 * k + l) = vecBuf[k][l];
}
}
}
}
}//run vector buffer
else{
for(int k = 0; k < blockNum; k++){
DTYPE * ip = (DTYPE*)input->data + blockSize * k;
DTYPE * op = (DTYPE*)output->data + stride * k;
......@@ -146,23 +215,14 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor
else{
if(bias == 0){
if(power == (DTYPE)1.0){
//#if defined(USE_BLAS)
// sum = ASUM(strideNum, ip + i, stride);
//#else
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += *ipb;
//#endif
}
else if(power == (DTYPE)2.0){
//#if defined(USE_BLAS)
// sum = NRM2(strideNum, ip + i, stride);
// sum = sum * sum;
//#else
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += value * value;
}
//#endif
}
else if(power == (DTYPE)0.5){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
......@@ -179,12 +239,8 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor
}
else{
if(power == (DTYPE)1.0){
//#if defined(USE_BLAS)
// sum = ASUM(strideNum, ip + i, stride);
//#else
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += *ipb;
//#endif
sum -= strideNum * bias;
}
else if(power == (DTYPE)2.0){
......@@ -211,6 +267,8 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor
}
}
}
}
}
/*
......
No preview for this file type
No preview for this file type
#include <cstring>
#include <cmath>
#include "../../XGlobal.h"
class VectorBuffer{
private:
DTYPE values[32 / sizeof(DTYPE)] = {0};
public:
static int size() {
return 32 / sizeof(DTYPE);
}
VectorBuffer() {}
VectorBuffer(DTYPE val) {
for (int i = 0; i != size(); i++) {
values[i] = val;
}
}
static VectorBuffer loadu(const DTYPE* ptr, bool isExp = false, DTYPE power = (DTYPE)1.0F, DTYPE* bias = NULL) {
int count = 32 / sizeof(DTYPE);
VectorBuffer vec;
if(isExp){
if(bias == 0){
if(power == (DTYPE)1.0){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp(*(ptr + i));
}
} else if(power == (DTYPE)2.0){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp((*(ptr + i)) * (*(ptr + i)));
}
} else if(power == (DTYPE)0.5){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp(std::sqrt(*(ptr + i)));
}
} else{
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp(std::pow(*(ptr + i), power));
}
}
}//is bias == 0
else{
if(power == (DTYPE)1.0){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp(*(ptr + i) - bias[i]);
}
} else if(power == (DTYPE)2.0){
for (int i = 0; i != count; i++) {
DTYPE value = *(ptr + i) - bias[i];
vec.values[i] = (DTYPE)std::exp(value * value);
}
} else if(power == (DTYPE)0.5){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp(std::sqrt(*(ptr + i) - bias[i]));
}
} else{
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp(std::pow(*(ptr + i) - bias[i], power));
}
}
}
}//isExp
else{
if(bias == 0){
if(power == (DTYPE)1.0){
std::memcpy(vec.values, ptr, count * sizeof(DTYPE));
} else if(power == (DTYPE)2.0){
for (int i = 0; i != count; i++) {
vec.values[i] = (*(ptr + i)) * (*(ptr + i));
}
} else if(power == (DTYPE)0.5){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::sqrt(*(ptr + i));
}
} else{
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::pow(*(ptr + i), power);
}
}
}// if bias == 0
else{
if(power == (DTYPE)1.0){
for (int i = 0; i != count; i++) {
vec.values[i] = *(ptr + i) - bias[i];
}
} else if(power == (DTYPE)2.0){
for (int i = 0; i != count; i++) {
DTYPE value = *(ptr + i) - bias[i];
vec.values[i] = value * value;
}
} else if(power == (DTYPE)0.5){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::sqrt(*(ptr + i) - bias[i]);
}
} else{
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::pow(*(ptr + i) - bias[i], power);
}
}
}
}
return vec;
}
const DTYPE& operator[](int idx) const {
return values[idx];
}
VectorBuffer operator+(const VectorBuffer &a) {
for (int i = 0; i != a.size(); i++) {
this->values[i] = a[i] + this->values[i];
}
return *this;
}
};
\ No newline at end of file
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论