Rename "HALF" and "sh" to "BFLOAT16" and "sb"
This commit is contained in:
parent
68ce719fac
commit
fd94236042
|
@ -28,16 +28,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
#include "common.h"
|
||||
|
||||
#if defined(COOPERLAKE)
|
||||
#include "shdot_microk_cooperlake.c"
|
||||
#include "sbdot_microk_cooperlake.c"
|
||||
#endif
|
||||
|
||||
static float shdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y)
|
||||
static float sbdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y)
|
||||
{
|
||||
float d = 0.0;
|
||||
|
||||
#ifdef HAVE_SHDOT_ACCL_KERNEL
|
||||
#ifdef HAVE_SBDOT_ACCL_KERNEL
|
||||
if ((inc_x == 1) && (inc_y == 1)) {
|
||||
return shdot_accl_kernel(n, x, y);
|
||||
return sbdot_accl_kernel(n, x, y);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -56,11 +56,11 @@ static float shdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y,
|
|||
}
|
||||
|
||||
#if defined(SMP)
|
||||
static int shdot_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, bfloat16 dummy2,
|
||||
static int sbdot_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, bfloat16 dummy2,
|
||||
bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y,
|
||||
float *result, BLASLONG dummy3)
|
||||
{
|
||||
*(float *)result = shdot_compute(n, x, inc_x, y, inc_y);
|
||||
*(float *)result = sbdot_compute(n, x, inc_x, y, inc_y);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -94,13 +94,13 @@ float CNAME(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y
|
|||
}
|
||||
|
||||
if (nthreads <= 1) {
|
||||
dot_result = shdot_compute(n, x, inc_x, y, inc_y);
|
||||
dot_result = sbdot_compute(n, x, inc_x, y, inc_y);
|
||||
} else {
|
||||
char thread_result[MAX_CPU_NUMBER * sizeof(double) * 2];
|
||||
int mode = BLAS_BFLOAT16 | BLAS_REAL;
|
||||
blas_level1_thread_with_return_value(mode, n, 0, 0, &dummy_alpha,
|
||||
x, inc_x, y, inc_y, thread_result, 0,
|
||||
(void *)shdot_thread_func, nthreads);
|
||||
(void *)sbdot_thread_func, nthreads);
|
||||
float * ptr = (float *)thread_result;
|
||||
for (int i = 0; i < nthreads; i++) {
|
||||
dot_result += (*ptr);
|
||||
|
@ -108,7 +108,7 @@ float CNAME(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y
|
|||
}
|
||||
}
|
||||
#else
|
||||
dot_result = shdot_compute(n, x, inc_x, y, inc_y);
|
||||
dot_result = sbdot_compute(n, x, inc_x, y, inc_y);
|
||||
#endif
|
||||
|
||||
return dot_result;
|
||||
|
|
|
@ -28,11 +28,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
/* need a new enough GCC for avx512 support */
|
||||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
|
||||
|
||||
#define HAVE_SHDOT_ACCL_KERNEL 1
|
||||
#define HAVE_SBDOT_ACCL_KERNEL 1
|
||||
#include "common.h"
|
||||
#include <immintrin.h>
|
||||
|
||||
static float shdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
|
||||
static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
|
||||
{
|
||||
__m128 accum128 = _mm_setzero_ps();
|
||||
if (n> 127) { /* n range from 128 to inf. */
|
||||
|
|
Loading…
Reference in New Issue