diff --git a/driver/others/blas_server_omp.c b/driver/others/blas_server_omp.c index a8b3e9a4b..a576127aa 100644 --- a/driver/others/blas_server_omp.c +++ b/driver/others/blas_server_omp.c @@ -76,10 +76,28 @@ static atomic_bool blas_buffer_inuse[MAX_PARALLEL_NUMBER]; static _Bool blas_buffer_inuse[MAX_PARALLEL_NUMBER]; #endif -void goto_set_num_threads(int num_threads) { +static void adjust_thread_buffers() { int i=0, j=0; + //adjust buffer for each thread + for(i=0; i < MAX_PARALLEL_NUMBER; i++) { + for(j=0; j < blas_cpu_number; j++){ + if(blas_thread_buffer[i][j] == NULL){ + blas_thread_buffer[i][j] = blas_memory_alloc(2); + } + } + for(; j < MAX_CPU_NUMBER; j++){ + if(blas_thread_buffer[i][j] != NULL){ + blas_memory_free(blas_thread_buffer[i][j]); + blas_thread_buffer[i][j] = NULL; + } + } + } +} + +void goto_set_num_threads(int num_threads) { + if (num_threads < 1) num_threads = blas_num_threads; if (num_threads > MAX_CPU_NUMBER) num_threads = MAX_CPU_NUMBER; @@ -92,20 +110,7 @@ void goto_set_num_threads(int num_threads) { omp_set_num_threads(blas_cpu_number); - //adjust buffer for each thread - for(i=0; i