Commit 569cb2dd by Tianzhi

finish reduce sum

parent e5f95479
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -24,6 +24,7 @@
#include "ReduceSum.cuh"
#include "../../XName.h"
#include "../../XBLAS.h"
#include "./VectorBuffer.h"
#include "../arithmetic/XTensorBLAS.h"
#include <iostream>
......@@ -73,8 +74,8 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor
else{
CheckNTErrors((input->dataType == DEFAULT_DTYPE), "TODO!");
int stride = 1;
int strideNum = input->dimSizeRDI[dimRDI];
int stride = 1;
int blockSize = 1;
int blockNum = 1;
for (int i = 0; i < input->order; i++) {
......@@ -85,131 +86,188 @@ void _ReduceSum(const XTensor * input, XTensor * output, int dim, const XTensor
}
blockSize = stride * strideNum;
for(int k = 0; k < blockNum; k++){
DTYPE * ip = (DTYPE*)input->data + blockSize * k;
DTYPE * op = (DTYPE*)output->data + stride * k;
DTYPE * sp = shift != NULL ? (DTYPE*)shift->data + stride * k : NULL;
for(int i = 0; i < stride; i++){
DTYPE sum = 0;
DTYPE bias = shift != NULL ? *(sp + i) : 0;
DTYPE * ipe = ip + blockSize;
if(isExp){
if(bias == 0){
if(power == (DTYPE)1.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += (DTYPE)exp(*ipb);
}
else if(power == (DTYPE)2.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)exp(value * value);
}
}
else if(power == (DTYPE)0.5){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)exp(sqrt(value));
}
}
else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)exp(pow(value, power));
}
}
if(input->dimSizeRDI[0] % (4 * 32 / sizeof(DTYPE)) == 0 && input->dimSizeRDI[0] >= 32){
int vecBufLength = 32 / sizeof(DTYPE);
if(dimRDI == 0){
//data is contiguous in dim 0
for(int i = 0; i < blockNum; i++){
// stride = 1
DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE * op = (DTYPE*)output->data + i;
DTYPE * sp = shift != NULL ? (DTYPE*)shift->data + i : NULL;
DTYPE bias[32 / sizeof(DTYPE)] = {0};
if(shift != NULL){
for(int k = 0; k < 32 / sizeof(DTYPE); k++)
bias[k] = *(sp);
}
else{
if(power == (DTYPE)1.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += (DTYPE)exp(*ipb - bias);
VectorBuffer vecBuf[4];
for(int j = 0; j < 4; j++){
vecBuf[j] = VectorBuffer::loadu((DTYPE*)(ip) + j * vecBufLength, isExp, power, bias);
}
for(int j = 1; j < strideNum / 32; j++){
const DTYPE* ptr = (DTYPE*)(ip + j * vecBufLength);
vecBuf[0] = vecBuf[0] + VectorBuffer::loadu(ptr + 0 * vecBufLength, isExp, power, bias);
vecBuf[1] = vecBuf[1] + VectorBuffer::loadu(ptr + 1 * vecBufLength, isExp, power, bias);
vecBuf[2] = vecBuf[2] + VectorBuffer::loadu(ptr + 2 * vecBufLength, isExp, power, bias);
vecBuf[3] = vecBuf[3] + VectorBuffer::loadu(ptr + 3 * vecBufLength, isExp, power, bias);
}
vecBuf[0] = ((vecBuf[0] + vecBuf[1]) + (vecBuf[2] + vecBuf[3]));
DTYPE sum = (DTYPE) 0.0;
for(int k = 0; k < vecBufLength; k++){
sum = sum + vecBuf[0][k];
}
*op = sum;
}
} else{
//data is separated
for(int i = 0; i < blockNum; i++){
for(int j = 0; j < input->dimSizeRDI[0] / 32; j++){
DTYPE * ip = (DTYPE*)input->data + blockSize * i;
DTYPE * op = (DTYPE*)output->data + stride * i;
DTYPE * sp = shift != NULL ? (DTYPE*)shift->data + stride * i : NULL;
DTYPE bias[4 * 32 / sizeof(DTYPE)] = {0};
if(shift != NULL){
for(int k = 0; k < 4 * 32 / sizeof(DTYPE); k++)
bias[k] = *(sp + k);
}
else if(power == (DTYPE)2.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += (DTYPE)exp(value * value);
}
VectorBuffer vecBuf[4];
for(int k = 0; k < 4; k++){
vecBuf[k] = VectorBuffer::loadu((DTYPE*)(ip) + (j * 4 + k) * 32 / sizeof(DTYPE), isExp, power, bias + j * 32 / sizeof(DTYPE));
}
else if(power == (DTYPE)0.5){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += (DTYPE)exp(sqrt(value));
}
for(int k = 1; k < strideNum; k++){
DTYPE * ptr = ip + k * stride + (j * 4) * vecBufLength;
vecBuf[0] = vecBuf[0] + VectorBuffer::loadu(ptr + 0 * vecBufLength, isExp, power, bias);
vecBuf[1] = vecBuf[1] + VectorBuffer::loadu(ptr + 1 * vecBufLength, isExp, power, bias + 1 * vecBufLength);
vecBuf[2] = vecBuf[2] + VectorBuffer::loadu(ptr + 2 * vecBufLength, isExp, power, bias + 2 * vecBufLength);
vecBuf[3] = vecBuf[3] + VectorBuffer::loadu(ptr + 3 * vecBufLength, isExp, power, bias + 3 * vecBufLength);
}
else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += (DTYPE)exp(pow(value, power));
}
for(int k = 0; k < 4; k++){
for(int l = 0; l < vecBufLength; l++)
*(op + j * 32 + 8 * k + l) = vecBuf[k][l];
}
}
}
else{
if(bias == 0){
if(power == (DTYPE)1.0){
//#if defined(USE_BLAS)
// sum = ASUM(strideNum, ip + i, stride);
//#else
}
}//run vector buffer
else{
for(int k = 0; k < blockNum; k++){
DTYPE * ip = (DTYPE*)input->data + blockSize * k;
DTYPE * op = (DTYPE*)output->data + stride * k;
DTYPE * sp = shift != NULL ? (DTYPE*)shift->data + stride * k : NULL;
for(int i = 0; i < stride; i++){
DTYPE sum = 0;
DTYPE bias = shift != NULL ? *(sp + i) : 0;
DTYPE * ipe = ip + blockSize;
if(isExp){
if(bias == 0){
if(power == (DTYPE)1.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += *ipb;
//#endif
}
else if(power == (DTYPE)2.0){
//#if defined(USE_BLAS)
// sum = NRM2(strideNum, ip + i, stride);
// sum = sum * sum;
//#else
sum += (DTYPE)exp(*ipb);
}
else if(power == (DTYPE)2.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += value * value;
sum += (DTYPE)exp(value * value);
}
}
else if(power == (DTYPE)0.5){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)exp(sqrt(value));
}
}
else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)exp(pow(value, power));
}
//#endif
}
else if(power == (DTYPE)0.5){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)sqrt(value);
}
}
else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)pow(value, power);
if(power == (DTYPE)1.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += (DTYPE)exp(*ipb - bias);
}
else if(power == (DTYPE)2.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += (DTYPE)exp(value * value);
}
}
else if(power == (DTYPE)0.5){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += (DTYPE)exp(sqrt(value));
}
}
else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += (DTYPE)exp(pow(value, power));
}
}
}
}
else{
if(power == (DTYPE)1.0){
//#if defined(USE_BLAS)
// sum = ASUM(strideNum, ip + i, stride);
//#else
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += *ipb;
//#endif
sum -= strideNum * bias;
}
else if(power == (DTYPE)2.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += value * value;
if(bias == 0){
if(power == (DTYPE)1.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += *ipb;
}
}
else if(power == (DTYPE)0.5){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += (DTYPE)sqrt(value);
else if(power == (DTYPE)2.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += value * value;
}
}
else if(power == (DTYPE)0.5){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)sqrt(value);
}
}
else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb);
sum += (DTYPE)pow(value, power);
}
}
}
else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += (DTYPE)pow(value, power);
if(power == (DTYPE)1.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride)
sum += *ipb;
sum -= strideNum * bias;
}
else if(power == (DTYPE)2.0){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += value * value;
}
}
else if(power == (DTYPE)0.5){
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += (DTYPE)sqrt(value);
}
}
else{
for(DTYPE * ipb = ip + i; ipb < ipe; ipb += stride){
DTYPE value = (*ipb) - bias;
sum += (DTYPE)pow(value, power);
}
}
}
}
*(op + i) = sum;
}
*(op + i) = sum;
}
}
}
}
......
#include <cstring>
#include <cmath>
#include "../../XGlobal.h"
class VectorBuffer{
private:
DTYPE values[32 / sizeof(DTYPE)] = {0};
public:
static int size() {
return 32 / sizeof(DTYPE);
}
VectorBuffer() {}
VectorBuffer(DTYPE val) {
for (int i = 0; i != size(); i++) {
values[i] = val;
}
}
static VectorBuffer loadu(const DTYPE* ptr, bool isExp = false, DTYPE power = (DTYPE)1.0F, DTYPE* bias = NULL) {
int count = 32 / sizeof(DTYPE);
VectorBuffer vec;
if(isExp){
if(bias == 0){
if(power == (DTYPE)1.0){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp(*(ptr + i));
}
} else if(power == (DTYPE)2.0){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp((*(ptr + i)) * (*(ptr + i)));
}
} else if(power == (DTYPE)0.5){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp(std::sqrt(*(ptr + i)));
}
} else{
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp(std::pow(*(ptr + i), power));
}
}
}//is bias == 0
else{
if(power == (DTYPE)1.0){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp(*(ptr + i) - bias[i]);
}
} else if(power == (DTYPE)2.0){
for (int i = 0; i != count; i++) {
DTYPE value = *(ptr + i) - bias[i];
vec.values[i] = (DTYPE)std::exp(value * value);
}
} else if(power == (DTYPE)0.5){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp(std::sqrt(*(ptr + i) - bias[i]));
}
} else{
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::exp(std::pow(*(ptr + i) - bias[i], power));
}
}
}
}//isExp
else{
if(bias == 0){
if(power == (DTYPE)1.0){
std::memcpy(vec.values, ptr, count * sizeof(DTYPE));
} else if(power == (DTYPE)2.0){
for (int i = 0; i != count; i++) {
vec.values[i] = (*(ptr + i)) * (*(ptr + i));
}
} else if(power == (DTYPE)0.5){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::sqrt(*(ptr + i));
}
} else{
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::pow(*(ptr + i), power);
}
}
}// if bias == 0
else{
if(power == (DTYPE)1.0){
for (int i = 0; i != count; i++) {
vec.values[i] = *(ptr + i) - bias[i];
}
} else if(power == (DTYPE)2.0){
for (int i = 0; i != count; i++) {
DTYPE value = *(ptr + i) - bias[i];
vec.values[i] = value * value;
}
} else if(power == (DTYPE)0.5){
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::sqrt(*(ptr + i) - bias[i]);
}
} else{
for (int i = 0; i != count; i++) {
vec.values[i] = (DTYPE)std::pow(*(ptr + i) - bias[i], power);
}
}
}
}
return vec;
}
const DTYPE& operator[](int idx) const {
return values[idx];
}
VectorBuffer operator+(const VectorBuffer &a) {
for (int i = 0; i != a.size(); i++) {
this->values[i] = a[i] + this->values[i];
}
return *this;
}
};
\ No newline at end of file
......@@ -37,7 +37,7 @@ bool TestReduceSum1()
int sOrder = 2;
int * sDimSize = new int[sOrder];
sDimSize[0] = 2;
sDimSize[1] = 4;
sDimSize[1] = 32;
int sUnitNum = 1;
for (int i = 0; i < sOrder; i++)
......@@ -46,7 +46,7 @@ bool TestReduceSum1()
/* a tensor of size (4) */
int tOrder1 = 1;
int * tDimSize1 = new int[tOrder1];
tDimSize1[0] = 4;
tDimSize1[0] = 32;
int tUnitNum1 = 1;
for (int i = 0; i < tOrder1; i++)
......@@ -61,10 +61,10 @@ bool TestReduceSum1()
for (int i = 0; i < tOrder2; i++)
tUnitNum2 *= tDimSize2[i];
DTYPE sData[2][4] = { {0.0F, 1.0F, -2.0F, 3.0F},
{4.0F, 5.0F, 6.0F, 7.0F} };
DTYPE answer1[4] = {4.0F, 6.0F, 4.0F, 10.0F};
DTYPE answer2[2] = {2.0F, 22.0F};
DTYPE sData[2][32] = { {0.0F, 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F, 7.0F,0.0F, 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F, 7.0F,0.0F, 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F, 7.0F,0.0F, 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F, 7.0F},
{4.0F, 5.0F, 6.0F, 7.0F, 8.0F, 9.0F, 10.0F, 11.0F,4.0F, 5.0F, 6.0F, 7.0F, 8.0F, 9.0F, 10.0F, 11.0F,4.0F, 5.0F, 6.0F, 7.0F, 8.0F, 9.0F, 10.0F, 11.0F,4.0F, 5.0F, 6.0F, 7.0F, 8.0F, 9.0F, 10.0F, 11.0F} };
DTYPE answer1[32] = {4.0F, 6.0F, 8.0F, 10.0F, 12.0F, 14.0F, 16.0F, 18.0F, 4.0F, 6.0F, 8.0F, 10.0F, 12.0F, 14.0F, 16.0F, 18.0F, 4.0F, 6.0F, 8.0F, 10.0F, 12.0F, 14.0F, 16.0F, 18.0F, 4.0F, 6.0F, 8.0F, 10.0F, 12.0F, 14.0F, 16.0F, 18.0F};
DTYPE answer2[2] = {112.0F, 240.0F};
/* CPU test */
bool cpuTest = true;
......
......@@ -78,7 +78,7 @@ bool Test()
wrong = !TestTopK() || wrong;
wrong = !TestUnsqueeze() || wrong;
wrong = !TestXMem() || wrong;
wrong = !TestCrossEntropy() || wrong;
wrong = !TestDropout() || wrong;
wrong = !TestHardTanH() || wrong;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论