diff options
| author | Miao Wang <miaowang@google.com> | 2016-09-08 11:53:32 -0700 |
|---|---|---|
| committer | Miccia <bono.michele94@gmail.com> | 2017-02-07 10:57:26 +0100 |
| commit | 1372cad4b4d406b81ef939d5d84855aafad43396 (patch) | |
| tree | 5602b83933d89a2d925a4e017f5c7bceded04682 | |
| parent | 5871479caf83c0593d2c9341f8d599766f45c2ca (diff) | |
Implement multi-thread CPU GEMM for BLAS Intrinsics
- Multi-thread GEMM utilizes existing RS thread pool on top of
Eigen.
- Large matrix-matrix multiplication is decomposed into multiple
tiled matrix-matrix multiplications. Each thread iterates on
the unfinished works.
- The tiling applies to ONLY ONE dimension of each input matrix,
and whether to tile X or Y depends on the transpose of the matrix.
- The performance increase is proportional to the number of
available CPU cores, for sufficiently large matrices.
Test: CTS test (rsblas) pass on Angler, Fugu and new devices.
Performance test with RsBlasBenchmark and RsNeuralNet demo
on Anger, Ryu, Seed, Shamu, Volantis, Fugu and new devices,
showing roughly 70%(Volantix 2 core) ~ 400+%(Angler 8 core) perf gain.
Change-Id: If96f4119fd34d5d9d98a2542801495e7ffe577ae
(cherry picked from commit 41ab8faaf0d90238d42d8e2bbb7177467c10b4f6)
| -rw-r--r-- | cpu_ref/rsCpuBLASDispatch.h | 20 | ||||
| -rw-r--r-- | cpu_ref/rsCpuIntrinsic.h | 15 | ||||
| -rw-r--r-- | cpu_ref/rsCpuIntrinsicBLAS.cpp | 195 |
3 files changed, 209 insertions, 21 deletions
diff --git a/cpu_ref/rsCpuBLASDispatch.h b/cpu_ref/rsCpuBLASDispatch.h index 46021358..aa245a75 100644 --- a/cpu_ref/rsCpuBLASDispatch.h +++ b/cpu_ref/rsCpuBLASDispatch.h @@ -3,7 +3,16 @@ #else #include <dlfcn.h> /* - * The following enum and function pointers are based on cblas.h + * The following enums are based on cblas.h + */ +enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102}; +enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113}; +enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; +enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; +enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; +#endif + +/* * =========================================================================== * Prototypes for level 2 BLAS * =========================================================================== @@ -12,12 +21,6 @@ /* * Routines with standard 4 prefixes (S, D, C, Z) */ -enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102}; -enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113}; -enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; -enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; -enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; - typedef void (*FnPtr_cblas_sgemv)(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha, const float *A, const int lda, @@ -441,6 +444,8 @@ typedef void (*FnPtr_cblas_zher2k)(const enum CBLAS_ORDER Order, const enum CBLA const void *B, const int ldb, const double beta, void *C, const int ldc); + +#ifdef RS_COMPATIBILITY_LIB // Macros to help declare our function pointers for the dispatch table. #define RS_APPLY_MACRO_TO(x) \ FnPtr_##x x; @@ -464,5 +469,4 @@ bool loadBLASLib() { #include "rsCpuBLAS.inc" return true; } - #endif diff --git a/cpu_ref/rsCpuIntrinsic.h b/cpu_ref/rsCpuIntrinsic.h index 9c7e1726..3c0dca0a 100644 --- a/cpu_ref/rsCpuIntrinsic.h +++ b/cpu_ref/rsCpuIntrinsic.h @@ -24,6 +24,21 @@ namespace android { namespace renderscript { +struct MTLaunchStructForEachBlas : public MTLaunchStructCommon { + // Driver info structure + RsExpandKernelDriverInfo fep; + + // Tile size info for M, and N dimensions. + uint32_t tileSizeM; + uint32_t numTileM; + uint32_t tileSizeN; + uint32_t numTileN; + + const Allocation *ains[RS_KERNEL_INPUT_LIMIT]; + const RsBlasCall *sc; +}; + + class RsdCpuScriptIntrinsic : public RsdCpuScriptImpl { public: void populateScript(Script *) override = 0; diff --git a/cpu_ref/rsCpuIntrinsicBLAS.cpp b/cpu_ref/rsCpuIntrinsicBLAS.cpp index 9b0a8ed9..aa3f6d1f 100644 --- a/cpu_ref/rsCpuIntrinsicBLAS.cpp +++ b/cpu_ref/rsCpuIntrinsicBLAS.cpp @@ -36,7 +36,6 @@ public: const void * usr, uint32_t usrLen, const RsScriptCall *sc) override; - void populateScript(Script *) override; ~RsdCpuScriptIntrinsicBLAS() override; RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s); @@ -88,10 +87,158 @@ static void initABC(const Allocation ** ain, *C = ain[2]->mHal.drvState.lod[0].mallocPtr; *ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size); } +} +// Routine to setup LaunchStruct for GEMM callback. +static void setupGEMM(MTLaunchStructForEachBlas *mtls, const Allocation **ain, RsBlasCall* call, + RsdCpuReferenceImpl *ctx) { + uint32_t mm, nn, kk; + mm = call->M; + nn = call->N; + kk = call->K; + + memset(mtls, 0, sizeof(MTLaunchStructForEachBlas)); + mtls->rs = ctx; + mtls->sc = call; + mtls->dimPtr = &mtls->fep.dim; + mtls->fep.dim.x = nn; + mtls->fep.dim.y = mm; + mtls->fep.dim.z = kk; + if (ain) { + memcpy(mtls->ains, ain, 3 * sizeof(ain[0])); + } + uint32_t elementBytes = 4; + if (ain[0]) { + elementBytes = ain[0]->getType()->getElement()->getSizeBytes(); + } + const uint32_t MIN_SIZE_TO_TILE = 64 * 1024 / elementBytes; + const uint32_t MAX_WORK_PER_THREAD = 512 / elementBytes; + const uint32_t THREAD_COUNT = ctx->getThreadCount(); + uint32_t tileSizeN = 0; + uint32_t tileSizeM = 0; + + // Do not tile the matrix if: + // 1. It is too small comparing to the other matrix. + // 2. It is too small comparing to MIN_SIZE_TO_TILE . + if (nn * kk > MIN_SIZE_TO_TILE && nn * THREAD_COUNT > mm) { + tileSizeN = rsMin(nn / THREAD_COUNT, MAX_WORK_PER_THREAD); + } + if (mm * kk > MIN_SIZE_TO_TILE && mm * THREAD_COUNT > nn) { + tileSizeM = rsMin(mm / THREAD_COUNT, MAX_WORK_PER_THREAD); + } + mtls->numTileM = 1; + mtls->numTileN = 1; + mtls->tileSizeM = mm; + mtls->tileSizeN = nn; + + // If tiling is needed, compute the number of slices for A & B. + mtls->isThreadable = (tileSizeM > 0 || tileSizeN > 0); + if (tileSizeM) { + mtls->numTileM += (mm - 1) / tileSizeM; + mtls->tileSizeM = tileSizeM; + } + if (tileSizeN) { + mtls->numTileN += (nn - 1) / tileSizeN; + mtls->tileSizeN = tileSizeN; + } + mtls->mSliceNum = 0; } +// Generic GEMM callback routine. +template <typename T_data, typename T_param, typename Func> +static void walk_tiled_gemm(Func blasFunc, T_param alpha, T_param beta, int vecSize, + RsBlasCall* call, const MTLaunchStructForEachBlas *mtls) { + // setup BLAS enum args + enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA; + enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB; + + void *A = nullptr; + void *B = nullptr; + void *C = nullptr; + + int lda = 0, ldb = 0, ldc = 0; + + const Allocation *ain[RS_KERNEL_INPUT_LIMIT]; + ain[0] = mtls->ains[0]; + ain[1] = mtls->ains[1]; + ain[2] = mtls->ains[2]; + + initABC(ain, sizeof(T_data) * vecSize, &A, &B, &C, &lda, &ldb, &ldc); + + // Determin the stride of the tiled matrices. + int mStride = (TransA == CblasNoTrans) ? lda : 1; + int nStride = (TransB == CblasNoTrans) ? 1 : ldb; + while (1) { + uint32_t slice = (uint32_t)__sync_fetch_and_add(&mtls->mSliceNum, 1); + + uint32_t mStart = (slice % mtls->numTileM) * mtls->tileSizeM; + uint32_t mEnd = mStart + mtls->tileSizeM; + mEnd = rsMin(mEnd, (uint32_t)call->M); + if (mEnd <= mStart) { + return; + } + + uint32_t nStart = (slice / mtls->numTileM) * mtls->tileSizeN; + uint32_t nEnd = nStart + mtls->tileSizeN; + nEnd = rsMin(nEnd, (uint32_t)call->N); + if (nEnd <= nStart) { + return; + } + + blasFunc(CblasRowMajor, TransA, TransB, + mEnd - mStart, nEnd - nStart, call->K, alpha, + (T_data *)A + mStart * mStride * vecSize, lda, + (T_data *)B + nStart * nStride * vecSize, ldb, beta, + (T_data *)C + (mStart * ldc + nStart) * vecSize, ldc); + } +} + +// SGEMM callback +static void walk_2d_sgemm(void *usr, uint32_t idx) { + const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr; + RsBlasCall* call = (RsBlasCall*) mtls->sc; + + float alpha = call->alpha.f; + float beta = call->beta.f; + + walk_tiled_gemm<float, float, FnPtr_cblas_sgemm>(cblas_sgemm, alpha, beta, 1, call, mtls); +} + +// DGEMM callback +static void walk_2d_dgemm(void *usr, uint32_t idx) { + const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr; + RsBlasCall* call = (RsBlasCall*) mtls->sc; + + double alpha = call->alpha.d; + double beta = call->beta.d; + + walk_tiled_gemm<double, double, FnPtr_cblas_dgemm>(cblas_dgemm, alpha, beta, 1, call, mtls); +} + +// CGEMM callback +static void walk_2d_cgemm(void *usr, uint32_t idx) { + const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr; + RsBlasCall* call = (RsBlasCall*) mtls->sc; + + void * alpha = (void *)&call->alpha.c; + void * beta = (void *)&call->beta.c; + + walk_tiled_gemm<float, void *, FnPtr_cblas_cgemm>(cblas_cgemm, alpha, beta, 2, call, mtls); +} + +// ZGEMM callback +static void walk_2d_zgemm(void *usr, uint32_t idx) { + const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr; + RsBlasCall* call = (RsBlasCall*) mtls->sc; + + void * alpha = (void *)&call->alpha.z; + void * beta = (void *)&call->beta.z; + + walk_tiled_gemm<double, void *, FnPtr_cblas_zgemm>(cblas_zgemm, alpha, beta, 2, call, mtls); +} + + void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot, const Allocation ** ain, uint32_t inLen, @@ -115,6 +262,8 @@ void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot, int lda = 0, ldb = 0, ldc = 0; + MTLaunchStructForEachBlas mtls; + #ifdef RS_COMPATIBILITY_LIB // Allow BNNM even without libblas if (call->func != RsBlas_bnnm && !isBlasLibInitialized) { @@ -492,9 +641,14 @@ void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot, // Level 3 BLAS case (RsBlas_sgemm): - initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); - cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f, - (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); + setupGEMM(&mtls, ain, call, mCtx); + if (mtls.isThreadable) { + mCtx->launchThreads(walk_2d_sgemm, &mtls); + } else { + initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); + cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f, + (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); + } break; case (RsBlas_ssymm): initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); @@ -524,9 +678,14 @@ void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot, case (RsBlas_dgemm): - initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); - cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d, - (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); + setupGEMM(&mtls, ain, call, mCtx); + if (mtls.isThreadable) { + mCtx->launchThreads(walk_2d_dgemm, &mtls); + } else { + initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); + cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d, + (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); + } break; case (RsBlas_dsymm): initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); @@ -555,9 +714,14 @@ void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot, break; case (RsBlas_cgemm): - initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); - cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c, - A, lda, B, ldb, (void*)&call->beta.c, C, ldc); + setupGEMM(&mtls, ain, call, mCtx); + if (mtls.isThreadable) { + mCtx->launchThreads(walk_2d_cgemm, &mtls); + } else { + initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); + cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c, + A, lda, B, ldb, (void*)&call->beta.c, C, ldc); + } break; case (RsBlas_csymm): initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); @@ -586,9 +750,14 @@ void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot, break; case (RsBlas_zgemm): - initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); - cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z, - A, lda, B, ldb, (void*)&call->beta.z, C, ldc); + setupGEMM(&mtls, ain, call, mCtx); + if (mtls.isThreadable) { + mCtx->launchThreads(walk_2d_zgemm, &mtls); + } else { + initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); + cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z, + A, lda, B, ldb, (void*)&call->beta.z, C, ldc); + } break; case (RsBlas_zsymm): initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); |
