summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMiao Wang <miaowang@google.com>2016-09-08 11:53:32 -0700
committerMiccia <bono.michele94@gmail.com>2017-02-07 10:57:26 +0100
commit1372cad4b4d406b81ef939d5d84855aafad43396 (patch)
tree5602b83933d89a2d925a4e017f5c7bceded04682
parent5871479caf83c0593d2c9341f8d599766f45c2ca (diff)
Implement multi-thread CPU GEMM for BLAS Intrinsics
- Multi-thread GEMM utilizes existing RS thread pool on top of Eigen. - Large matrix-matrix multiplication is decomposed into multiple tiled matrix-matrix multiplications. Each thread iterates on the unfinished works. - The tiling applies to ONLY ONE dimension of each input matrix, and whether to tile X or Y depends on the transpose of the matrix. - The performance increase is proportional to the number of available CPU cores, for sufficiently large matrices. Test: CTS test (rsblas) pass on Angler, Fugu and new devices. Performance test with RsBlasBenchmark and RsNeuralNet demo on Anger, Ryu, Seed, Shamu, Volantis, Fugu and new devices, showing roughly 70%(Volantix 2 core) ~ 400+%(Angler 8 core) perf gain. Change-Id: If96f4119fd34d5d9d98a2542801495e7ffe577ae (cherry picked from commit 41ab8faaf0d90238d42d8e2bbb7177467c10b4f6)
-rw-r--r--cpu_ref/rsCpuBLASDispatch.h20
-rw-r--r--cpu_ref/rsCpuIntrinsic.h15
-rw-r--r--cpu_ref/rsCpuIntrinsicBLAS.cpp195
3 files changed, 209 insertions, 21 deletions
diff --git a/cpu_ref/rsCpuBLASDispatch.h b/cpu_ref/rsCpuBLASDispatch.h
index 46021358..aa245a75 100644
--- a/cpu_ref/rsCpuBLASDispatch.h
+++ b/cpu_ref/rsCpuBLASDispatch.h
@@ -3,7 +3,16 @@
#else
#include <dlfcn.h>
/*
- * The following enum and function pointers are based on cblas.h
+ * The following enums are based on cblas.h
+ */
+enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
+enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
+enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
+enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
+enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
+#endif
+
+/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
@@ -12,12 +21,6 @@
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
-enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
-enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
-enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
-enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
-enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
-
typedef void (*FnPtr_cblas_sgemv)(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const float alpha, const float *A, const int lda,
@@ -441,6 +444,8 @@ typedef void (*FnPtr_cblas_zher2k)(const enum CBLAS_ORDER Order, const enum CBLA
const void *B, const int ldb, const double beta,
void *C, const int ldc);
+
+#ifdef RS_COMPATIBILITY_LIB
// Macros to help declare our function pointers for the dispatch table.
#define RS_APPLY_MACRO_TO(x) \
FnPtr_##x x;
@@ -464,5 +469,4 @@ bool loadBLASLib() {
#include "rsCpuBLAS.inc"
return true;
}
-
#endif
diff --git a/cpu_ref/rsCpuIntrinsic.h b/cpu_ref/rsCpuIntrinsic.h
index 9c7e1726..3c0dca0a 100644
--- a/cpu_ref/rsCpuIntrinsic.h
+++ b/cpu_ref/rsCpuIntrinsic.h
@@ -24,6 +24,21 @@ namespace android {
namespace renderscript {
+struct MTLaunchStructForEachBlas : public MTLaunchStructCommon {
+ // Driver info structure
+ RsExpandKernelDriverInfo fep;
+
+ // Tile size info for M, and N dimensions.
+ uint32_t tileSizeM;
+ uint32_t numTileM;
+ uint32_t tileSizeN;
+ uint32_t numTileN;
+
+ const Allocation *ains[RS_KERNEL_INPUT_LIMIT];
+ const RsBlasCall *sc;
+};
+
+
class RsdCpuScriptIntrinsic : public RsdCpuScriptImpl {
public:
void populateScript(Script *) override = 0;
diff --git a/cpu_ref/rsCpuIntrinsicBLAS.cpp b/cpu_ref/rsCpuIntrinsicBLAS.cpp
index 9b0a8ed9..aa3f6d1f 100644
--- a/cpu_ref/rsCpuIntrinsicBLAS.cpp
+++ b/cpu_ref/rsCpuIntrinsicBLAS.cpp
@@ -36,7 +36,6 @@ public:
const void * usr,
uint32_t usrLen,
const RsScriptCall *sc) override;
-
void populateScript(Script *) override;
~RsdCpuScriptIntrinsicBLAS() override;
RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s);
@@ -88,10 +87,158 @@ static void initABC(const Allocation ** ain,
*C = ain[2]->mHal.drvState.lod[0].mallocPtr;
*ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size);
}
+}
+// Routine to setup LaunchStruct for GEMM callback.
+static void setupGEMM(MTLaunchStructForEachBlas *mtls, const Allocation **ain, RsBlasCall* call,
+ RsdCpuReferenceImpl *ctx) {
+ uint32_t mm, nn, kk;
+ mm = call->M;
+ nn = call->N;
+ kk = call->K;
+
+ memset(mtls, 0, sizeof(MTLaunchStructForEachBlas));
+ mtls->rs = ctx;
+ mtls->sc = call;
+ mtls->dimPtr = &mtls->fep.dim;
+ mtls->fep.dim.x = nn;
+ mtls->fep.dim.y = mm;
+ mtls->fep.dim.z = kk;
+ if (ain) {
+ memcpy(mtls->ains, ain, 3 * sizeof(ain[0]));
+ }
+ uint32_t elementBytes = 4;
+ if (ain[0]) {
+ elementBytes = ain[0]->getType()->getElement()->getSizeBytes();
+ }
+ const uint32_t MIN_SIZE_TO_TILE = 64 * 1024 / elementBytes;
+ const uint32_t MAX_WORK_PER_THREAD = 512 / elementBytes;
+ const uint32_t THREAD_COUNT = ctx->getThreadCount();
+ uint32_t tileSizeN = 0;
+ uint32_t tileSizeM = 0;
+
+ // Do not tile the matrix if:
+ // 1. It is too small comparing to the other matrix.
+ // 2. It is too small comparing to MIN_SIZE_TO_TILE .
+ if (nn * kk > MIN_SIZE_TO_TILE && nn * THREAD_COUNT > mm) {
+ tileSizeN = rsMin(nn / THREAD_COUNT, MAX_WORK_PER_THREAD);
+ }
+ if (mm * kk > MIN_SIZE_TO_TILE && mm * THREAD_COUNT > nn) {
+ tileSizeM = rsMin(mm / THREAD_COUNT, MAX_WORK_PER_THREAD);
+ }
+ mtls->numTileM = 1;
+ mtls->numTileN = 1;
+ mtls->tileSizeM = mm;
+ mtls->tileSizeN = nn;
+
+ // If tiling is needed, compute the number of slices for A & B.
+ mtls->isThreadable = (tileSizeM > 0 || tileSizeN > 0);
+ if (tileSizeM) {
+ mtls->numTileM += (mm - 1) / tileSizeM;
+ mtls->tileSizeM = tileSizeM;
+ }
+ if (tileSizeN) {
+ mtls->numTileN += (nn - 1) / tileSizeN;
+ mtls->tileSizeN = tileSizeN;
+ }
+ mtls->mSliceNum = 0;
}
+// Generic GEMM callback routine.
+template <typename T_data, typename T_param, typename Func>
+static void walk_tiled_gemm(Func blasFunc, T_param alpha, T_param beta, int vecSize,
+ RsBlasCall* call, const MTLaunchStructForEachBlas *mtls) {
+ // setup BLAS enum args
+ enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA;
+ enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB;
+
+ void *A = nullptr;
+ void *B = nullptr;
+ void *C = nullptr;
+
+ int lda = 0, ldb = 0, ldc = 0;
+
+ const Allocation *ain[RS_KERNEL_INPUT_LIMIT];
+ ain[0] = mtls->ains[0];
+ ain[1] = mtls->ains[1];
+ ain[2] = mtls->ains[2];
+
+ initABC(ain, sizeof(T_data) * vecSize, &A, &B, &C, &lda, &ldb, &ldc);
+
+ // Determin the stride of the tiled matrices.
+ int mStride = (TransA == CblasNoTrans) ? lda : 1;
+ int nStride = (TransB == CblasNoTrans) ? 1 : ldb;
+ while (1) {
+ uint32_t slice = (uint32_t)__sync_fetch_and_add(&mtls->mSliceNum, 1);
+
+ uint32_t mStart = (slice % mtls->numTileM) * mtls->tileSizeM;
+ uint32_t mEnd = mStart + mtls->tileSizeM;
+ mEnd = rsMin(mEnd, (uint32_t)call->M);
+ if (mEnd <= mStart) {
+ return;
+ }
+
+ uint32_t nStart = (slice / mtls->numTileM) * mtls->tileSizeN;
+ uint32_t nEnd = nStart + mtls->tileSizeN;
+ nEnd = rsMin(nEnd, (uint32_t)call->N);
+ if (nEnd <= nStart) {
+ return;
+ }
+
+ blasFunc(CblasRowMajor, TransA, TransB,
+ mEnd - mStart, nEnd - nStart, call->K, alpha,
+ (T_data *)A + mStart * mStride * vecSize, lda,
+ (T_data *)B + nStart * nStride * vecSize, ldb, beta,
+ (T_data *)C + (mStart * ldc + nStart) * vecSize, ldc);
+ }
+}
+
+// SGEMM callback
+static void walk_2d_sgemm(void *usr, uint32_t idx) {
+ const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr;
+ RsBlasCall* call = (RsBlasCall*) mtls->sc;
+
+ float alpha = call->alpha.f;
+ float beta = call->beta.f;
+
+ walk_tiled_gemm<float, float, FnPtr_cblas_sgemm>(cblas_sgemm, alpha, beta, 1, call, mtls);
+}
+
+// DGEMM callback
+static void walk_2d_dgemm(void *usr, uint32_t idx) {
+ const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr;
+ RsBlasCall* call = (RsBlasCall*) mtls->sc;
+
+ double alpha = call->alpha.d;
+ double beta = call->beta.d;
+
+ walk_tiled_gemm<double, double, FnPtr_cblas_dgemm>(cblas_dgemm, alpha, beta, 1, call, mtls);
+}
+
+// CGEMM callback
+static void walk_2d_cgemm(void *usr, uint32_t idx) {
+ const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr;
+ RsBlasCall* call = (RsBlasCall*) mtls->sc;
+
+ void * alpha = (void *)&call->alpha.c;
+ void * beta = (void *)&call->beta.c;
+
+ walk_tiled_gemm<float, void *, FnPtr_cblas_cgemm>(cblas_cgemm, alpha, beta, 2, call, mtls);
+}
+
+// ZGEMM callback
+static void walk_2d_zgemm(void *usr, uint32_t idx) {
+ const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr;
+ RsBlasCall* call = (RsBlasCall*) mtls->sc;
+
+ void * alpha = (void *)&call->alpha.z;
+ void * beta = (void *)&call->beta.z;
+
+ walk_tiled_gemm<double, void *, FnPtr_cblas_zgemm>(cblas_zgemm, alpha, beta, 2, call, mtls);
+}
+
+
void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
const Allocation ** ain,
uint32_t inLen,
@@ -115,6 +262,8 @@ void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
int lda = 0, ldb = 0, ldc = 0;
+ MTLaunchStructForEachBlas mtls;
+
#ifdef RS_COMPATIBILITY_LIB
// Allow BNNM even without libblas
if (call->func != RsBlas_bnnm && !isBlasLibInitialized) {
@@ -492,9 +641,14 @@ void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
// Level 3 BLAS
case (RsBlas_sgemm):
- initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
- cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f,
- (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
+ setupGEMM(&mtls, ain, call, mCtx);
+ if (mtls.isThreadable) {
+ mCtx->launchThreads(walk_2d_sgemm, &mtls);
+ } else {
+ initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
+ cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f,
+ (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
+ }
break;
case (RsBlas_ssymm):
initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
@@ -524,9 +678,14 @@ void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
case (RsBlas_dgemm):
- initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
- cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d,
- (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
+ setupGEMM(&mtls, ain, call, mCtx);
+ if (mtls.isThreadable) {
+ mCtx->launchThreads(walk_2d_dgemm, &mtls);
+ } else {
+ initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
+ cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d,
+ (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
+ }
break;
case (RsBlas_dsymm):
initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
@@ -555,9 +714,14 @@ void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
break;
case (RsBlas_cgemm):
- initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
- cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c,
- A, lda, B, ldb, (void*)&call->beta.c, C, ldc);
+ setupGEMM(&mtls, ain, call, mCtx);
+ if (mtls.isThreadable) {
+ mCtx->launchThreads(walk_2d_cgemm, &mtls);
+ } else {
+ initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
+ cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c,
+ A, lda, B, ldb, (void*)&call->beta.c, C, ldc);
+ }
break;
case (RsBlas_csymm):
initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
@@ -586,9 +750,14 @@ void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
break;
case (RsBlas_zgemm):
- initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
- cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z,
- A, lda, B, ldb, (void*)&call->beta.z, C, ldc);
+ setupGEMM(&mtls, ain, call, mCtx);
+ if (mtls.isThreadable) {
+ mCtx->launchThreads(walk_2d_zgemm, &mtls);
+ } else {
+ initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
+ cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z,
+ A, lda, B, ldb, (void*)&call->beta.z, C, ldc);
+ }
break;
case (RsBlas_zsymm):
initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);