Merge pull request #2995 from Flamefire/fix_thread_buffer_init

Don't overwrite blas_thread_buffer if already set
This commit is contained in:
Martin Kroeker 2020-11-20 09:42:10 +01:00 committed by GitHub
commit 6dd71af0c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 22 additions and 26 deletions

View File

@ -76,10 +76,28 @@ static atomic_bool blas_buffer_inuse[MAX_PARALLEL_NUMBER];
static _Bool blas_buffer_inuse[MAX_PARALLEL_NUMBER]; static _Bool blas_buffer_inuse[MAX_PARALLEL_NUMBER];
#endif #endif
void goto_set_num_threads(int num_threads) { static void adjust_thread_buffers() {
int i=0, j=0; int i=0, j=0;
//adjust buffer for each thread
for(i=0; i < MAX_PARALLEL_NUMBER; i++) {
for(j=0; j < blas_cpu_number; j++){
if(blas_thread_buffer[i][j] == NULL){
blas_thread_buffer[i][j] = blas_memory_alloc(2);
}
}
for(; j < MAX_CPU_NUMBER; j++){
if(blas_thread_buffer[i][j] != NULL){
blas_memory_free(blas_thread_buffer[i][j]);
blas_thread_buffer[i][j] = NULL;
}
}
}
}
void goto_set_num_threads(int num_threads) {
if (num_threads < 1) num_threads = blas_num_threads; if (num_threads < 1) num_threads = blas_num_threads;
if (num_threads > MAX_CPU_NUMBER) num_threads = MAX_CPU_NUMBER; if (num_threads > MAX_CPU_NUMBER) num_threads = MAX_CPU_NUMBER;
@ -92,20 +110,7 @@ void goto_set_num_threads(int num_threads) {
omp_set_num_threads(blas_cpu_number); omp_set_num_threads(blas_cpu_number);
//adjust buffer for each thread adjust_thread_buffers();
for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
for(j=0; j<blas_cpu_number; j++){
if(blas_thread_buffer[i][j]==NULL){
blas_thread_buffer[i][j]=blas_memory_alloc(2);
}
}
for(; j<MAX_CPU_NUMBER; j++){
if(blas_thread_buffer[i][j]!=NULL){
blas_memory_free(blas_thread_buffer[i][j]);
blas_thread_buffer[i][j]=NULL;
}
}
}
#if defined(ARCH_MIPS64) #if defined(ARCH_MIPS64)
//set parameters for different number of threads. //set parameters for different number of threads.
blas_set_parameter(); blas_set_parameter();
@ -119,20 +124,11 @@ void openblas_set_num_threads(int num_threads) {
int blas_thread_init(void){ int blas_thread_init(void){
int i=0, j=0;
blas_get_cpu_number(); blas_get_cpu_number();
blas_server_avail = 1; adjust_thread_buffers();
for(i=0; i<MAX_PARALLEL_NUMBER; i++) { blas_server_avail = 1;
for(j=0; j<blas_num_threads; j++){
blas_thread_buffer[i][j]=blas_memory_alloc(2);
}
for(; j<MAX_CPU_NUMBER; j++){
blas_thread_buffer[i][j]=NULL;
}
}
return 0; return 0;
} }