Merge pull request #1536 from WestAlgo/develop

Fix race condition in blas_server_omp.c
This commit is contained in:
Zhang Xianyi 2018-05-11 10:09:14 +08:00 committed by GitHub
commit 50acc40613
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 26 deletions

View File

@ -60,6 +60,13 @@ VERSION = 0.3.0.dev
# automatically detected by the the script. # automatically detected by the the script.
# NUM_THREADS = 24 # NUM_THREADS = 24
# If you have enabled USE_OPENMP and your application would call
# OpenBLAS's calculation API from multi threads, please comment it in.
# This flag defines how many instances of OpenBLAS's calculation API can
# actually run in parallel. If more threads call OpenBLAS's calculation API,
# they need to wait for the preceding API calls to finish or risk data corruption.
# NUM_PARALLEL = 2
# if you don't need to install the static library, please comment it in. # if you don't need to install the static library, please comment it in.
# NO_STATIC = 1 # NO_STATIC = 1

View File

@ -184,6 +184,10 @@ endif
endif endif
ifndef NUM_PARALLEL
NUM_PARALLEL = 1
endif
ifndef NUM_THREADS ifndef NUM_THREADS
NUM_THREADS = $(NUM_CORES) NUM_THREADS = $(NUM_CORES)
endif endif
@ -966,6 +970,8 @@ endif
CCOMMON_OPT += -DMAX_CPU_NUMBER=$(NUM_THREADS) CCOMMON_OPT += -DMAX_CPU_NUMBER=$(NUM_THREADS)
CCOMMON_OPT += -DMAX_PARALLEL_NUMBER=$(NUM_PARALLEL)
ifdef USE_SIMPLE_THREADED_LEVEL3 ifdef USE_SIMPLE_THREADED_LEVEL3
CCOMMON_OPT += -DUSE_SIMPLE_THREADED_LEVEL3 CCOMMON_OPT += -DUSE_SIMPLE_THREADED_LEVEL3
endif endif

View File

@ -96,6 +96,10 @@ if (NOT CMAKE_CROSSCOMPILING)
endif() endif()
if (NOT DEFINED NUM_PARALLEL)
set(NUM_PARALLEL 1)
endif()
if (NOT DEFINED NUM_THREADS) if (NOT DEFINED NUM_THREADS)
if (DEFINED NUM_CORES AND NOT NUM_CORES EQUAL 0) if (DEFINED NUM_CORES AND NOT NUM_CORES EQUAL 0)
# HT? # HT?
@ -224,6 +228,8 @@ endif ()
set(CCOMMON_OPT "${CCOMMON_OPT} -DMAX_CPU_NUMBER=${NUM_THREADS}") set(CCOMMON_OPT "${CCOMMON_OPT} -DMAX_CPU_NUMBER=${NUM_THREADS}")
set(CCOMMON_OPT "${CCOMMON_OPT} -DMAX_PARALLEL_NUMBER=${NUM_PARALLEL}")
if (USE_SIMPLE_THREADED_LEVEL3) if (USE_SIMPLE_THREADED_LEVEL3)
set(CCOMMON_OPT "${CCOMMON_OPT} -DUSE_SIMPLE_THREADED_LEVEL3") set(CCOMMON_OPT "${CCOMMON_OPT} -DUSE_SIMPLE_THREADED_LEVEL3")
endif () endif ()

View File

@ -179,7 +179,7 @@ extern "C" {
#define ALLOCA_ALIGN 63UL #define ALLOCA_ALIGN 63UL
#define NUM_BUFFERS (MAX_CPU_NUMBER * 2) #define NUM_BUFFERS (MAX_CPU_NUMBER * 2 * MAX_PARALLEL_NUMBER)
#ifdef NEEDBUNDERSCORE #ifdef NEEDBUNDERSCORE
#define BLASFUNC(FUNC) FUNC##_ #define BLASFUNC(FUNC) FUNC##_

View File

@ -36,6 +36,13 @@
/* or implied, of The University of Texas at Austin. */ /* or implied, of The University of Texas at Austin. */
/*********************************************************************/ /*********************************************************************/
#if _STDC_VERSION__ >= 201112L
#ifndef _Atomic
#define _Atomic volatile
#endif
#include <stdatomic.h>
#endif
#include <stdbool.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
//#include <sys/mman.h> //#include <sys/mman.h>
@ -49,11 +56,16 @@
int blas_server_avail = 0; int blas_server_avail = 0;
static void * blas_thread_buffer[MAX_CPU_NUMBER]; static void * blas_thread_buffer[MAX_PARALLEL_NUMBER][MAX_CPU_NUMBER];
#if _STDC_VERSION__ >= 201112L
static atomic_bool blas_buffer_inuse[MAX_PARALLEL_NUMBER];
#else
static _Bool blas_buffer_inuse[MAX_PARALLEL_NUMBER];
#endif
void goto_set_num_threads(int num_threads) { void goto_set_num_threads(int num_threads) {
int i=0; int i=0, j=0;
if (num_threads < 1) num_threads = blas_num_threads; if (num_threads < 1) num_threads = blas_num_threads;
@ -68,15 +80,17 @@ void goto_set_num_threads(int num_threads) {
omp_set_num_threads(blas_cpu_number); omp_set_num_threads(blas_cpu_number);
//adjust buffer for each thread //adjust buffer for each thread
for(i=0; i<blas_cpu_number; i++){ for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
if(blas_thread_buffer[i]==NULL){ for(j=0; j<blas_cpu_number; j++){
blas_thread_buffer[i]=blas_memory_alloc(2); if(blas_thread_buffer[i][j]==NULL){
blas_thread_buffer[i][j]=blas_memory_alloc(2);
} }
} }
for(; i<MAX_CPU_NUMBER; i++){ for(; j<MAX_CPU_NUMBER; j++){
if(blas_thread_buffer[i]!=NULL){ if(blas_thread_buffer[i][j]!=NULL){
blas_memory_free(blas_thread_buffer[i]); blas_memory_free(blas_thread_buffer[i][j]);
blas_thread_buffer[i]=NULL; blas_thread_buffer[i][j]=NULL;
}
} }
} }
#if defined(ARCH_MIPS64) #if defined(ARCH_MIPS64)
@ -92,30 +106,34 @@ void openblas_set_num_threads(int num_threads) {
int blas_thread_init(void){ int blas_thread_init(void){
int i=0; int i=0, j=0;
blas_get_cpu_number(); blas_get_cpu_number();
blas_server_avail = 1; blas_server_avail = 1;
for(i=0; i<blas_num_threads; i++){ for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
blas_thread_buffer[i]=blas_memory_alloc(2); for(j=0; j<blas_num_threads; j++){
blas_thread_buffer[i][j]=blas_memory_alloc(2);
}
for(; j<MAX_CPU_NUMBER; j++){
blas_thread_buffer[i][j]=NULL;
} }
for(; i<MAX_CPU_NUMBER; i++){
blas_thread_buffer[i]=NULL;
} }
return 0; return 0;
} }
int BLASFUNC(blas_thread_shutdown)(void){ int BLASFUNC(blas_thread_shutdown)(void){
int i=0; int i=0, j=0;
blas_server_avail = 0; blas_server_avail = 0;
for(i=0; i<MAX_CPU_NUMBER; i++){ for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
if(blas_thread_buffer[i]!=NULL){ for(j=0; j<MAX_CPU_NUMBER; j++){
blas_memory_free(blas_thread_buffer[i]); if(blas_thread_buffer[i][j]!=NULL){
blas_thread_buffer[i]=NULL; blas_memory_free(blas_thread_buffer[i][j]);
blas_thread_buffer[i][j]=NULL;
}
} }
} }
@ -206,7 +224,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
} }
} }
static void exec_threads(blas_queue_t *queue){ static void exec_threads(blas_queue_t *queue, int buf_index){
void *buffer, *sa, *sb; void *buffer, *sa, *sb;
int pos=0, release_flag=0; int pos=0, release_flag=0;
@ -223,7 +241,7 @@ static void exec_threads(blas_queue_t *queue){
if ((sa == NULL) && (sb == NULL) && ((queue -> mode & BLAS_PTHREAD) == 0)) { if ((sa == NULL) && (sb == NULL) && ((queue -> mode & BLAS_PTHREAD) == 0)) {
pos = omp_get_thread_num(); pos = omp_get_thread_num();
buffer = blas_thread_buffer[pos]; buffer = blas_thread_buffer[buf_index][pos];
//fallback //fallback
if(buffer==NULL) { if(buffer==NULL) {
@ -291,7 +309,7 @@ static void exec_threads(blas_queue_t *queue){
int exec_blas(BLASLONG num, blas_queue_t *queue){ int exec_blas(BLASLONG num, blas_queue_t *queue){
BLASLONG i; BLASLONG i, buf_index;
if ((num <= 0) || (queue == NULL)) return 0; if ((num <= 0) || (queue == NULL)) return 0;
@ -302,6 +320,23 @@ int exec_blas(BLASLONG num, blas_queue_t *queue){
} }
#endif #endif
while(true) {
for(i=0; i < MAX_PARALLEL_NUMBER; i++) {
#if _STDC_VERSION__ >= 201112L
_Bool inuse = false;
if(atomic_compare_exchange_weak(&blas_buffer_inuse[i], &inuse, true)) {
#else
if(blas_buffer_inuse[i] == false) {
blas_buffer_inuse[i] = true;
#endif
buf_index = i;
break;
}
}
if(i != MAX_PARALLEL_NUMBER)
break;
}
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (i = 0; i < num; i ++) { for (i = 0; i < num; i ++) {
@ -309,9 +344,15 @@ int exec_blas(BLASLONG num, blas_queue_t *queue){
queue[i].position = i; queue[i].position = i;
#endif #endif
exec_threads(&queue[i]); exec_threads(&queue[i], buf_index);
} }
#if _STDC_VERSION__ >= 201112L
atomic_store(&blas_buffer_inuse[buf_index], false);
#else
blas_buffer_inuse[buf_index] = false;
#endif
return 0; return 0;
} }