Add bfloat16 based dot and conversion with single/double

1. Added bfloat16 based dot as new API: shdot
2. Implemented generic kernel and cooperlake-specific (AVX512-BF16) kernel for shdot
3. Added 4 conversion APIs for bfloat16 data type <=> single/double: shstobf16 shdtobf16 sbf16tos dbf16tod
     shstobf16 -- convert single float array to bfloat16 array
     shdtobf16 -- convert double float array to bfloat16 array
     sbf16tos  -- convert bfloat16 array to single float array
     dbf16tod  -- convert bfloat16 array to double float array
4. Implemented generic kernels for all 4 conversion APIs, and cooperlake-specific kernel for shstobf16 and shdtobf16
5. Update level1 thread facilitate functions and macros to support multi-threading for these new APIs
6. Fix Cooperlake platform detection/specify issue when under dynamic-arch building
7. Change the typedef of bfloat16 from unsigned short to more strict uint16_t

Signed-off-by: Chen, Guobing <guobing.chen@intel.com>
This commit is contained in:
Chen, Guobing
2020-08-27 06:42:28 +08:00
parent c7ef7174e4
commit deaeb6c5b8
31 changed files with 1389 additions and 79 deletions

View File

@@ -49,9 +49,36 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha
blas_arg_t args [MAX_CPU_NUMBER];
BLASLONG i, width, astride, bstride;
int num_cpu, calc_type;
int num_cpu, calc_type_a, calc_type_b;
calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2;
switch (mode & BLAS_PREC) {
case BLAS_INT8 :
case BLAS_BFLOAT16:
case BLAS_SINGLE :
case BLAS_DOUBLE :
case BLAS_XDOUBLE :
calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0);
break;
case BLAS_STOBF16 :
calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0);
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
break;
case BLAS_DTOBF16 :
calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0);
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
break;
case BLAS_BF16TOS :
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0);
break;
case BLAS_BF16TOD :
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0);
break;
default:
calc_type_a = calc_type_b = 0;
break;
}
mode |= BLAS_LEGACY;
@@ -77,8 +104,8 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha
bstride = width;
}
astride <<= calc_type;
bstride <<= calc_type;
astride <<= calc_type_a;
bstride <<= calc_type_b;
args[num_cpu].m = width;
args[num_cpu].n = n;
@@ -120,9 +147,36 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL
blas_arg_t args [MAX_CPU_NUMBER];
BLASLONG i, width, astride, bstride;
int num_cpu, calc_type;
int num_cpu, calc_type_a, calc_type_b;
calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2;
switch (mode & BLAS_PREC) {
case BLAS_INT8 :
case BLAS_BFLOAT16:
case BLAS_SINGLE :
case BLAS_DOUBLE :
case BLAS_XDOUBLE :
calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0);
break;
case BLAS_STOBF16 :
calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0);
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
break;
case BLAS_DTOBF16 :
calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0);
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
break;
case BLAS_BF16TOS :
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0);
break;
case BLAS_BF16TOD :
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0);
break;
default:
calc_type_a = calc_type_b = 0;
break;
}
mode |= BLAS_LEGACY;
@@ -148,8 +202,8 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL
bstride = width;
}
astride <<= calc_type;
bstride <<= calc_type;
astride <<= calc_type_a;
bstride <<= calc_type_b;
args[num_cpu].m = width;
args[num_cpu].n = n;

View File

@@ -192,7 +192,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
if (!(mode & BLAS_COMPLEX)){
#ifdef EXPRECISION
if (mode & BLAS_XDOUBLE){
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
/* REAL / Extended Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble,
xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -205,7 +205,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> c, args -> ldc, sb);
} else
#endif
if (mode & BLAS_DOUBLE){
if ((mode & BLAS_PREC) == BLAS_DOUBLE){
/* REAL / Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
double *, BLASLONG, double *, BLASLONG,
@@ -216,21 +216,58 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
} else {
/* REAL / Single */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
float *, BLASLONG, float *, BLASLONG,
float *, BLASLONG, void *) = func;
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){
/* REAL / Single */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
float *, BLASLONG, float *, BLASLONG,
float *, BLASLONG, void *) = func;
afunc(args -> m, args -> n, args -> k,
((float *)args -> alpha)[0],
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
afunc(args -> m, args -> n, args -> k,
((float *)args -> alpha)[0],
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
#ifdef BUILD_HALF
} else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){
/* REAL / BFLOAT16 */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16,
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG,
bfloat16 *, BLASLONG, void *) = func;
afunc(args -> m, args -> n, args -> k,
((bfloat16 *)args -> alpha)[0],
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
} else if ((mode & BLAS_PREC) == BLAS_STOBF16){
/* REAL / BLAS_STOBF16 */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
float *, BLASLONG, bfloat16 *, BLASLONG,
float *, BLASLONG, void *) = func;
afunc(args -> m, args -> n, args -> k,
((float *)args -> alpha)[0],
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
} else if ((mode & BLAS_PREC) == BLAS_DTOBF16){
/* REAL / BLAS_DTOBF16 */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
double *, BLASLONG, bfloat16 *, BLASLONG,
double *, BLASLONG, void *) = func;
afunc(args -> m, args -> n, args -> k,
((double *)args -> alpha)[0],
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
#endif
} else {
/* REAL / Other types in future */
}
} else {
#ifdef EXPRECISION
if (mode & BLAS_XDOUBLE){
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
/* COMPLEX / Extended Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble,
xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -244,7 +281,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> c, args -> ldc, sb);
} else
#endif
if (mode & BLAS_DOUBLE){
if ((mode & BLAS_PREC) == BLAS_DOUBLE) {
/* COMPLEX / Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double,
double *, BLASLONG, double *, BLASLONG,
@@ -256,7 +293,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
} else {
} else if ((mode & BLAS_PREC) == BLAS_SINGLE) {
/* COMPLEX / Single */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float,
float *, BLASLONG, float *, BLASLONG,
@@ -268,7 +305,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
}
} else {
/* COMPLEX / Other types in future */
}
}
}
@@ -414,33 +453,37 @@ blas_queue_t *tscq;
if (sb == NULL) {
if (!(queue -> mode & BLAS_COMPLEX)){
#ifdef EXPRECISION
if (queue -> mode & BLAS_XDOUBLE){
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else
#endif
if (queue -> mode & BLAS_DOUBLE){
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE) {
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else {
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
}
} else {
/* Other types in future */
}
} else {
#ifdef EXPRECISION
if (queue -> mode & BLAS_XDOUBLE){
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else
#endif
if (queue -> mode & BLAS_DOUBLE){
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else {
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
}
} else {
/* Other types in future */
}
}
queue->sb=sb;
}

View File

@@ -142,7 +142,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
if (!(mode & BLAS_COMPLEX)){
#ifdef EXPRECISION
if (mode & BLAS_XDOUBLE){
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
/* REAL / Extended Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble,
xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -155,7 +155,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> c, args -> ldc, sb);
} else
#endif
if (mode & BLAS_DOUBLE){
if ((mode & BLAS_PREC) == BLAS_DOUBLE){
/* REAL / Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
double *, BLASLONG, double *, BLASLONG,
@@ -166,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
} else {
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){
/* REAL / Single */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
float *, BLASLONG, float *, BLASLONG,
@@ -177,10 +177,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
#ifdef BUILD_HALF
} else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){
/* REAL / BFLOAT16 */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16,
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG,
bfloat16 *, BLASLONG, void *) = func;
afunc(args -> m, args -> n, args -> k,
((bfloat16 *)args -> alpha)[0],
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
} else if ((mode & BLAS_PREC) == BLAS_STOBF16){
/* REAL / BLAS_STOBF16 */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
float *, BLASLONG, bfloat16 *, BLASLONG,
float *, BLASLONG, void *) = func;
afunc(args -> m, args -> n, args -> k,
((float *)args -> alpha)[0],
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
} else if ((mode & BLAS_PREC) == BLAS_DTOBF16){
/* REAL / BLAS_DTOBF16 */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
double *, BLASLONG, bfloat16 *, BLASLONG,
double *, BLASLONG, void *) = func;
afunc(args -> m, args -> n, args -> k,
((double *)args -> alpha)[0],
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
#endif
} else {
/* REAL / Other types in future */
}
} else {
#ifdef EXPRECISION
if (mode & BLAS_XDOUBLE){
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
/* COMPLEX / Extended Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble,
xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -194,7 +231,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> c, args -> ldc, sb);
} else
#endif
if (mode & BLAS_DOUBLE){
if ((mode & BLAS_PREC) == BLAS_DOUBLE){
/* COMPLEX / Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double,
double *, BLASLONG, double *, BLASLONG,
@@ -206,7 +243,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
} else {
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){
/* COMPLEX / Single */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float,
float *, BLASLONG, float *, BLASLONG,
@@ -218,8 +255,10 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
}
}
} else {
/* COMPLEX / Other types in future */
}
}
}
static void exec_threads(blas_queue_t *queue, int buf_index){
@@ -255,32 +294,36 @@ static void exec_threads(blas_queue_t *queue, int buf_index){
if (sb == NULL) {
if (!(queue -> mode & BLAS_COMPLEX)){
#ifdef EXPRECISION
if (queue -> mode & BLAS_XDOUBLE){
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else
#endif
if (queue -> mode & BLAS_DOUBLE){
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else {
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE){
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else {
/* Other types in future */
}
} else {
#ifdef EXPRECISION
if (queue -> mode & BLAS_XDOUBLE){
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else
#endif
if (queue -> mode & BLAS_DOUBLE){
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else {
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else {
/* Other types in future */
}
}
queue->sb=sb;

View File

@@ -77,7 +77,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
if (!(mode & BLAS_COMPLEX)){
#ifdef EXPRECISION
if (mode & BLAS_XDOUBLE){
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
/* REAL / Extended Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble,
xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -90,7 +90,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> c, args -> ldc, sb);
} else
#endif
if (mode & BLAS_DOUBLE){
if ((mode & BLAS_PREC) == BLAS_DOUBLE){
/* REAL / Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
double *, BLASLONG, double *, BLASLONG,
@@ -101,7 +101,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
} else {
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){
/* REAL / Single */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
float *, BLASLONG, float *, BLASLONG,
@@ -112,10 +112,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
#ifdef BUILD_HALF
} else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){
/* REAL / BFLOAT16 */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16,
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG,
bfloat16 *, BLASLONG, void *) = func;
afunc(args -> m, args -> n, args -> k,
((bfloat16 *)args -> alpha)[0],
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
} else if ((mode & BLAS_PREC) == BLAS_STOBF16){
/* REAL / BLAS_STOBF16 */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
float *, BLASLONG, bfloat16 *, BLASLONG,
float *, BLASLONG, void *) = func;
afunc(args -> m, args -> n, args -> k,
((float *)args -> alpha)[0],
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
} else if ((mode & BLAS_PREC) == BLAS_DTOBF16){
/* REAL / BLAS_DTOBF16 */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
double *, BLASLONG, bfloat16 *, BLASLONG,
double *, BLASLONG, void *) = func;
afunc(args -> m, args -> n, args -> k,
((double *)args -> alpha)[0],
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
#endif
} else {
/* REAL / Other types in future */
}
} else {
#ifdef EXPRECISION
if (mode & BLAS_XDOUBLE){
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
/* COMPLEX / Extended Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble,
xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -129,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> c, args -> ldc, sb);
} else
#endif
if (mode & BLAS_DOUBLE){
if ((mode & BLAS_PREC) == BLAS_DOUBLE){
/* COMPLEX / Double */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double,
double *, BLASLONG, double *, BLASLONG,
@@ -141,7 +178,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
} else {
} else if ((mode & BLAS_PREC) == BLAS_SINGLE) {
/* COMPLEX / Single */
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float,
float *, BLASLONG, float *, BLASLONG,
@@ -153,7 +190,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
args -> a, args -> lda,
args -> b, args -> ldb,
args -> c, args -> ldc, sb);
}
} else {
/* COMPLEX / Other types in future */
}
}
}
@@ -233,32 +272,36 @@ static DWORD WINAPI blas_thread_server(void *arg){
if (sb == NULL) {
if (!(queue -> mode & BLAS_COMPLEX)){
#ifdef EXPRECISION
if (queue -> mode & BLAS_XDOUBLE){
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * sizeof(xdouble)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else
#endif
if (queue -> mode & BLAS_DOUBLE){
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else {
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else {
/* Other types in future */
}
} else {
#ifdef EXPRECISION
if (queue -> mode & BLAS_XDOUBLE){
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else
#endif
if (queue -> mode & BLAS_DOUBLE){
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else {
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
} else {
/* Other types in future */
}
}
queue->sb=sb;

View File

@@ -207,6 +207,19 @@ extern gotoblas_t gotoblas_SKYLAKEX;
#else
#define gotoblas_SKYLAKEX gotoblas_PRESCOTT
#endif
#ifdef DYN_COOPERLAKE
extern gotoblas_t gotoblas_COOPERLAKE;
#elif defined(DYN_SKYLAKEX)
#define gotoblas_COOPERLAKE gotoblas_SKYLAKEX
#elif defined(DYN_HASWELL)
#define gotoblas_COOPERLAKE gotoblas_HASWELL
#elif defined(DYN_SANDYBRIDGE)
#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE
#elif defined(DYN_NEHALEM)
#define gotoblas_COOPERLAKE gotoblas_NEHALEM
#else
#define gotoblas_COOPERLAKE gotoblas_PRESCOTT
#endif
#else // not DYNAMIC_LIST
@@ -247,14 +260,17 @@ extern gotoblas_t gotoblas_EXCAVATOR;
#ifdef NO_AVX2
#define gotoblas_HASWELL gotoblas_SANDYBRIDGE
#define gotoblas_SKYLAKEX gotoblas_SANDYBRIDGE
#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE
#define gotoblas_ZEN gotoblas_SANDYBRIDGE
#else
extern gotoblas_t gotoblas_HASWELL;
extern gotoblas_t gotoblas_ZEN;
#ifndef NO_AVX512
extern gotoblas_t gotoblas_SKYLAKEX;
extern gotoblas_t gotoblas_COOPERLAKE;
#else
#define gotoblas_SKYLAKEX gotoblas_HASWELL
#define gotoblas_COOPERLAKE gotoblas_HASWELL
#endif
#endif
#else
@@ -262,6 +278,7 @@ extern gotoblas_t gotoblas_SKYLAKEX;
#define gotoblas_SANDYBRIDGE gotoblas_NEHALEM
#define gotoblas_HASWELL gotoblas_NEHALEM
#define gotoblas_SKYLAKEX gotoblas_NEHALEM
#define gotoblas_COOPERLAKE gotoblas_NEHALEM
#define gotoblas_BULLDOZER gotoblas_BARCELONA
#define gotoblas_PILEDRIVER gotoblas_BARCELONA
#define gotoblas_STEAMROLLER gotoblas_BARCELONA
@@ -343,6 +360,23 @@ int support_avx512(){
#endif
}
int support_avx512_bf16(){
#if !defined(NO_AVX) && !defined(NO_AVX512)
int eax, ebx, ecx, edx;
int ret=0;
if (!support_avx512())
return 0;
cpuid_count(7, 1, &eax, &ebx, &ecx, &edx);
if((eax & 32) == 32){
ret=1; // CPUID.7.1:EAX[bit 5] indicates whether avx512_bf16 supported or not
}
return ret;
#else
return 0;
#endif
}
extern void openblas_warning(int verbose, const char * msg);
#define FALLBACK_VERBOSE 1
#define NEHALEM_FALLBACK "OpenBLAS : Your OS does not support AVX instructions. OpenBLAS is using Nehalem kernels as a fallback, which may give poorer performance.\n"
@@ -524,7 +558,10 @@ static gotoblas_t *get_coretype(void){
return &gotoblas_NEHALEM; //OS doesn't support AVX. Use old kernels.
}
}
if (model == 5) {
if (model == 5) {
// Intel Cooperlake
if(support_avx512_bf16())
return &gotoblas_COOPERLAKE;
// Intel Skylake X
if (support_avx512())
return &gotoblas_SKYLAKEX;
@@ -774,7 +811,8 @@ static char *corename[] = {
"Steamroller",
"Excavator",
"Zen",
"SkylakeX"
"SkylakeX",
"Cooperlake"
};
char *gotoblas_corename(void) {
@@ -838,6 +876,7 @@ char *gotoblas_corename(void) {
if (gotoblas == &gotoblas_EXCAVATOR) return corename[22];
if (gotoblas == &gotoblas_ZEN) return corename[23];
if (gotoblas == &gotoblas_SKYLAKEX) return corename[24];
if (gotoblas == &gotoblas_COOPERLAKE) return corename[25];
return corename[0];
}
@@ -868,6 +907,7 @@ static gotoblas_t *force_coretype(char *coretype){
switch (found)
{
case 25: return (&gotoblas_COOPERLAKE);
case 24: return (&gotoblas_SKYLAKEX);
case 23: return (&gotoblas_ZEN);
case 22: return (&gotoblas_EXCAVATOR);