Add bfloat16 based dot and conversion with single/double
1. Added bfloat16 based dot as new API: shdot
2. Implemented generic kernel and cooperlake-specific (AVX512-BF16) kernel for shdot
3. Added 4 conversion APIs for bfloat16 data type <=> single/double: shstobf16 shdtobf16 sbf16tos dbf16tod
shstobf16 -- convert single float array to bfloat16 array
shdtobf16 -- convert double float array to bfloat16 array
sbf16tos -- convert bfloat16 array to single float array
dbf16tod -- convert bfloat16 array to double float array
4. Implemented generic kernels for all 4 conversion APIs, and cooperlake-specific kernel for shstobf16 and shdtobf16
5. Update level1 thread facilitate functions and macros to support multi-threading for these new APIs
6. Fix Cooperlake platform detection/specify issue when under dynamic-arch building
7. Change the typedef of bfloat16 from unsigned short to more strict uint16_t
Signed-off-by: Chen, Guobing <guobing.chen@intel.com>
This commit is contained in:
@@ -262,6 +262,20 @@ ifndef XDOTKERNEL
|
||||
XDOTKERNEL = zdot.S
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
ifndef SHDOTKERNEL
|
||||
SHDOTKERNEL = ../x86_64/shdot.c
|
||||
endif
|
||||
|
||||
ifndef TOBF16KERNEL
|
||||
TOBF16KERNEL = ../x86_64/tobf16.c
|
||||
endif
|
||||
|
||||
ifndef BF16TOKERNEL
|
||||
BF16TOKERNEL = ../x86_64/bf16to.c
|
||||
endif
|
||||
endif
|
||||
|
||||
### NRM2 ###
|
||||
|
||||
ifndef SNRM2KERNEL
|
||||
@@ -516,6 +530,15 @@ XBLASOBJS += \
|
||||
xdotc_k$(TSUFFIX).$(SUFFIX) xdotu_k$(TSUFFIX).$(SUFFIX) xnrm2_k$(TSUFFIX).$(SUFFIX) xqrot_k$(TSUFFIX).$(SUFFIX) \
|
||||
xscal_k$(TSUFFIX).$(SUFFIX) xswap_k$(TSUFFIX).$(SUFFIX) xsum_k$(TSUFFIX).$(SUFFIX)
|
||||
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
SHBLASOBJS += \
|
||||
shdot_k$(TSUFFIX).$(SUFFIX)
|
||||
SHEXTOBJS += \
|
||||
shstobf16_k$(TSUFFIX).$(SUFFIX) shdtobf16_k$(TSUFFIX).$(SUFFIX)
|
||||
SHEXTOBJS += \
|
||||
sbf16tos_k$(TSUFFIX).$(SUFFIX) dbf16tod_k$(TSUFFIX).$(SUFFIX)
|
||||
endif
|
||||
|
||||
### AMAX ###
|
||||
|
||||
|
||||
@@ -734,6 +757,19 @@ $(KDIR)ddot_k$(TSUFFIX).$(SUFFIX) $(KDIR)ddot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNEL
|
||||
$(KDIR)qdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)qdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(QDOTKERNEL)
|
||||
$(CC) -c $(CFLAGS) -UCOMPLEX -DXDOUBLE $< -o $@
|
||||
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
$(KDIR)shdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)shdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHDOTKERNEL)
|
||||
$(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@
|
||||
$(KDIR)shstobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL)
|
||||
$(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@
|
||||
$(KDIR)shdtobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL)
|
||||
$(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@
|
||||
$(KDIR)sbf16tos_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL)
|
||||
$(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@
|
||||
$(KDIR)dbf16tod_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL)
|
||||
$(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@
|
||||
endif
|
||||
|
||||
$(KDIR)sdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)sdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SDOTKERNEL)
|
||||
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE $< -o $@
|
||||
|
||||
|
||||
@@ -62,9 +62,11 @@ gotoblas_t TABLE_NAME = {
|
||||
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
|
||||
#endif
|
||||
|
||||
shstobf16_kTS, shdtobf16_kTS, sbf16tos_kTS, dbf16tod_kTS,
|
||||
|
||||
samax_kTS, samin_kTS, smax_kTS, smin_kTS,
|
||||
isamax_kTS, isamin_kTS, ismax_kTS, ismin_kTS,
|
||||
snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sdot_kTS,
|
||||
snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, shdot_kTS,
|
||||
dsdot_kTS,
|
||||
srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS,
|
||||
sgemv_nTS, sgemv_tTS, sger_kTS,
|
||||
|
||||
@@ -146,6 +146,18 @@ ifndef XDOTKERNEL
|
||||
XDOTKERNEL = zdot.S
|
||||
endif
|
||||
|
||||
ifndef SHDOTKERNEL
|
||||
SHDOTKERNEL = shdot.c
|
||||
endif
|
||||
|
||||
ifndef TOBF16KERNEL
|
||||
TOBF16KERNEL = tobf16.c
|
||||
endif
|
||||
|
||||
ifndef BF16TOKERNEL
|
||||
BF16TOKERNEL = bf16to.c
|
||||
endif
|
||||
|
||||
ifndef ISAMAXKERNEL
|
||||
ISAMAXKERNEL = iamax_sse.S
|
||||
endif
|
||||
|
||||
114
kernel/x86_64/bf16to.c
Normal file
114
kernel/x86_64/bf16to.c
Normal file
@@ -0,0 +1,114 @@
|
||||
/***************************************************************************
|
||||
Copyright (c) 2014, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include <stddef.h>
|
||||
#include "common.h"
|
||||
|
||||
#if defined(DOUBLE)
|
||||
#define FLOAT_TYPE double
|
||||
#elif defined(SINGLE)
|
||||
#define FLOAT_TYPE float
|
||||
#else
|
||||
#endif
|
||||
|
||||
/* Notes for algorithm:
|
||||
* - Input denormal treated as zero
|
||||
* - Force to be QNAN
|
||||
*/
|
||||
static void bf16to_kernel_1(BLASLONG n, const bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out)
|
||||
{
|
||||
BLASLONG register index_in = 0;
|
||||
BLASLONG register index_out = 0;
|
||||
BLASLONG register index = 0;
|
||||
uint16_t * tmp = NULL;
|
||||
#if defined(DOUBLE)
|
||||
float float_out = 0.0;
|
||||
#endif
|
||||
|
||||
while(index<n) {
|
||||
#if defined(DOUBLE)
|
||||
float_out = 0.0;
|
||||
tmp = (uint16_t*)(&float_out);
|
||||
#else
|
||||
*(out+index_out) = 0;
|
||||
tmp = (uint16_t*)(out+index_out);
|
||||
#endif
|
||||
|
||||
switch((*(in+index_in)) & 0xff80u) {
|
||||
case (0x0000u): /* Type 1: Positive denormal */
|
||||
tmp[1] = 0x0000u;
|
||||
tmp[0] = 0x0000u;
|
||||
break;
|
||||
case (0x8000u): /* Type 2: Negative denormal */
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
tmp[1] = 0x8000u;
|
||||
tmp[0] = 0x0000u;
|
||||
#else
|
||||
tmp[1] = 0x0000u;
|
||||
tmp[0] = 0x8000u;
|
||||
#endif
|
||||
break;
|
||||
case (0x7f80u): /* Type 3: Positive infinity or NAN */
|
||||
case (0xff80u): /* Type 4: Negative infinity or NAN */
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
tmp[1] = *(in+index_in);
|
||||
#else
|
||||
tmp[0] = *(in+index_in);
|
||||
#endif
|
||||
/* Specific for NAN */
|
||||
if (((*(in+index_in)) & 0x007fu) != 0) {
|
||||
/* Force to be QNAN */
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
tmp[1] |= 0x0040u;
|
||||
#else
|
||||
tmp[0] |= 0x0040u;
|
||||
#endif
|
||||
}
|
||||
break;
|
||||
default: /* Type 5: Normal case */
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
tmp[1] = *(in+index_in);
|
||||
#else
|
||||
tmp[0] = *(in+index_in);
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
#if defined(DOUBLE)
|
||||
*(out+index_out) = (double)float_out;
|
||||
#endif
|
||||
index_in += inc_in;
|
||||
index_out += inc_out;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
void CNAME(BLASLONG n, bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out)
|
||||
{
|
||||
if (n <= 0) return;
|
||||
|
||||
bf16to_kernel_1(n, in, inc_in, out, inc_out);
|
||||
}
|
||||
104
kernel/x86_64/dtobf16_microk_cooperlake.c
Normal file
104
kernel/x86_64/dtobf16_microk_cooperlake.c
Normal file
@@ -0,0 +1,104 @@
|
||||
/***************************************************************************
|
||||
Copyright (c) 2014, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
/* need a new enough GCC for avx512 support */
|
||||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
|
||||
|
||||
#define HAVE_TOBF16_ACCL_KERNEL 1
|
||||
#include "common.h"
|
||||
#include <immintrin.h>
|
||||
|
||||
static void tobf16_accl_kernel(BLASLONG n, const double * in, bfloat16 * out)
|
||||
{
|
||||
/* Get the 64-bytes unaligned header number targeting for avx512
|
||||
* processing (Assume input float array is natural aligned) */
|
||||
int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 3) & 0x7;
|
||||
|
||||
if (n < align_header) {align_header = n;}
|
||||
|
||||
if (align_header != 0) {
|
||||
unsigned char align_mask8 = (((unsigned char)0xff) >> (8-align_header));
|
||||
__m512d a = _mm512_maskz_loadu_pd(*((__mmask8*) &align_mask8), &in[0]);
|
||||
_mm_mask_storeu_epi16(&out[0], *((__mmask8*) &align_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(a)));
|
||||
}
|
||||
|
||||
if (n == align_header) {
|
||||
return;
|
||||
} else {
|
||||
n -= align_header;
|
||||
in += align_header;
|
||||
out += align_header;
|
||||
}
|
||||
|
||||
int tail_index_8 = n&(~7);
|
||||
int tail_index_32 = n&(~31);
|
||||
int tail_index_128 = n&(~127);
|
||||
unsigned char tail_mask8 = (((unsigned char) 0xff) >> (8 -(n&7)));
|
||||
|
||||
/* Processing the main chunk with 128-elements per round */
|
||||
for (int i = 0; i < tail_index_128; i += 128) {
|
||||
// Fold 1
|
||||
__m512 data1_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 8])), 1);
|
||||
__m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+24])), 1);
|
||||
_mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low));
|
||||
|
||||
// Fold 2
|
||||
__m512 data2_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+32]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+40])), 1);
|
||||
__m512 data2_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+48]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+56])), 1);
|
||||
_mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(data2_512_high, data2_512_low));
|
||||
|
||||
// Fold 3
|
||||
__m512 data3_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+64]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+72])), 1);
|
||||
__m512 data3_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+80]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+88])), 1);
|
||||
_mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(data3_512_high, data3_512_low));
|
||||
|
||||
// Fold 4
|
||||
__m512 data4_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+96]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+104])), 1);
|
||||
__m512 data4_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+112]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+120])), 1);
|
||||
_mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(data4_512_high, data4_512_low));
|
||||
}
|
||||
|
||||
/* Processing the remaining <128 chunk with 32-elements per round */
|
||||
for (int j = tail_index_128; j < tail_index_32; j += 32) {
|
||||
__m512 data1_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 8])), 1);
|
||||
__m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+24])), 1);
|
||||
_mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low));
|
||||
}
|
||||
|
||||
/* Processing the remaining <32 chunk with 8-elements per round */
|
||||
for (int j = tail_index_32; j < tail_index_8; j += 8) {
|
||||
_mm_storeu_si128((__m128i *)&out[j], (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(_mm512_load_pd(&in[j]))));
|
||||
}
|
||||
|
||||
/* Processing the remaining <8 chunk with masked processing */
|
||||
if ((n&7) > 0) {
|
||||
__m512d data_512 = _mm512_maskz_load_pd(*((__mmask8*) &tail_mask8), &in[tail_index_8]);
|
||||
_mm_mask_storeu_epi16(&out[tail_index_8], *((__mmask8*) &tail_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(data_512)));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
115
kernel/x86_64/shdot.c
Normal file
115
kernel/x86_64/shdot.c
Normal file
@@ -0,0 +1,115 @@
|
||||
/***************************************************************************
|
||||
Copyright (c) 2014, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#if defined(COOPERLAKE)
|
||||
#include "shdot_microk_cooperlake.c"
|
||||
#endif
|
||||
|
||||
static float shdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y)
|
||||
{
|
||||
float d = 0.0;
|
||||
|
||||
#ifdef HAVE_SHDOT_ACCL_KERNEL
|
||||
if ((inc_x == 1) && (inc_y == 1)) {
|
||||
return shdot_accl_kernel(n, x, y);
|
||||
}
|
||||
#endif
|
||||
|
||||
float * x_fp32 = malloc(sizeof(float)*n);
|
||||
float * y_fp32 = malloc(sizeof(float)*n);
|
||||
|
||||
SBF16TOS_K(n, x, inc_x, x_fp32, 1);
|
||||
SBF16TOS_K(n, y, inc_y, y_fp32, 1);
|
||||
|
||||
d = SDOTU_K(n, x_fp32, 1, y_fp32, 1);
|
||||
|
||||
free(x_fp32);
|
||||
free(y_fp32);
|
||||
|
||||
return d;
|
||||
}
|
||||
|
||||
#if defined(SMP)
|
||||
static int shdot_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, bfloat16 dummy2,
|
||||
bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y,
|
||||
float *result, BLASLONG dummy3)
|
||||
{
|
||||
*(float *)result = shdot_compute(n, x, inc_x, y, inc_y);
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha,
|
||||
void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc,
|
||||
int (*function)(), int nthreads);
|
||||
#endif
|
||||
|
||||
float CNAME(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y)
|
||||
{
|
||||
float dot_result = 0.0;
|
||||
|
||||
if (n <= 0) return 0.0;
|
||||
|
||||
#if defined(SMP)
|
||||
int nthreads;
|
||||
int thread_thres = 40960;
|
||||
bfloat16 dummy_alpha;
|
||||
#endif
|
||||
|
||||
#if defined(SMP)
|
||||
if (inc_x == 0 || inc_y == 0 || n <= thread_thres)
|
||||
nthreads = 1;
|
||||
else
|
||||
nthreads = num_cpu_avail(1);
|
||||
|
||||
int best_threads = (int) (n/(float)thread_thres + 0.5);
|
||||
|
||||
if (best_threads < nthreads) {
|
||||
nthreads = best_threads;
|
||||
}
|
||||
|
||||
if (nthreads <= 1) {
|
||||
dot_result = shdot_compute(n, x, inc_x, y, inc_y);
|
||||
} else {
|
||||
char thread_result[MAX_CPU_NUMBER * sizeof(double) * 2];
|
||||
int mode = BLAS_BFLOAT16 | BLAS_REAL;
|
||||
blas_level1_thread_with_return_value(mode, n, 0, 0, &dummy_alpha,
|
||||
x, inc_x, y, inc_y, thread_result, 0,
|
||||
(void *)shdot_thread_func, nthreads);
|
||||
float * ptr = (float *)thread_result;
|
||||
for (int i = 0; i < nthreads; i++) {
|
||||
dot_result += (*ptr);
|
||||
ptr = (float *)(((char *)ptr) + sizeof(double) * 2);
|
||||
}
|
||||
}
|
||||
#else
|
||||
dot_result = shdot_compute(n, x, inc_x, y, inc_y);
|
||||
#endif
|
||||
|
||||
return dot_result;
|
||||
}
|
||||
159
kernel/x86_64/shdot_microk_cooperlake.c
Normal file
159
kernel/x86_64/shdot_microk_cooperlake.c
Normal file
@@ -0,0 +1,159 @@
|
||||
/***************************************************************************
|
||||
Copyright (c) 2014, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
/* need a new enough GCC for avx512 support */
|
||||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
|
||||
|
||||
#define HAVE_SHDOT_ACCL_KERNEL 1
|
||||
#include "common.h"
|
||||
#include <immintrin.h>
|
||||
|
||||
static float shdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
|
||||
{
|
||||
__m128 accum128 = _mm_setzero_ps();
|
||||
if (n> 127) { /* n range from 128 to inf. */
|
||||
long tail_index_32 = n&(~31);
|
||||
long tail_index_128 = n&(~127);
|
||||
unsigned int tail_mask_uint = (((unsigned int)0xffffffff) >> (32-(n&31)));
|
||||
__mmask32 tail_mask = *((__mmask32*) &tail_mask_uint);
|
||||
|
||||
__m512 accum512_0 = _mm512_setzero_ps();
|
||||
__m512 accum512_1 = _mm512_setzero_ps();
|
||||
__m512 accum512_2 = _mm512_setzero_ps();
|
||||
__m512 accum512_3 = _mm512_setzero_ps();
|
||||
|
||||
/* Processing the main chunk with 128-elements per round */
|
||||
for (long i = 0; i < tail_index_128; i += 128) {
|
||||
accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[i+ 0]), (__m512bh) _mm512_loadu_si512(&y[i+ 0]));
|
||||
accum512_1 = _mm512_dpbf16_ps(accum512_1, (__m512bh) _mm512_loadu_si512(&x[i+32]), (__m512bh) _mm512_loadu_si512(&y[i+32]));
|
||||
accum512_2 = _mm512_dpbf16_ps(accum512_2, (__m512bh) _mm512_loadu_si512(&x[i+64]), (__m512bh) _mm512_loadu_si512(&y[i+64]));
|
||||
accum512_3 = _mm512_dpbf16_ps(accum512_3, (__m512bh) _mm512_loadu_si512(&x[i+96]), (__m512bh) _mm512_loadu_si512(&y[i+96]));
|
||||
}
|
||||
|
||||
/* Processing the remaining <128 chunk with 32-elements per round */
|
||||
for (long j = tail_index_128; j < tail_index_32; j += 32) {
|
||||
accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[j]), (__m512bh) _mm512_loadu_si512(&y[j]));
|
||||
}
|
||||
|
||||
/* Processing the remaining <32 chunk with masked 32-elements processing */
|
||||
if ((n&31) != 0) {
|
||||
accum512_2 = _mm512_dpbf16_ps(accum512_2,
|
||||
(__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &x[tail_index_32]),
|
||||
(__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &y[tail_index_32]));
|
||||
}
|
||||
|
||||
/* Accumulate the 4 registers into 1 register */
|
||||
accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
|
||||
accum512_2 = _mm512_add_ps(accum512_2, accum512_3);
|
||||
accum512_0 = _mm512_add_ps(accum512_0, accum512_2);
|
||||
|
||||
__m256 accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
|
||||
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
|
||||
} else if (n > 31) { /* n range from 32 to 127 */
|
||||
/* Processing <128 chunk with 32-elements per round */
|
||||
__m256 accum256 = _mm256_setzero_ps();
|
||||
__m256 accum256_1 = _mm256_setzero_ps();
|
||||
int tail_index_32 = n&(~31);
|
||||
for (int j = 0; j < tail_index_32; j += 32) {
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[j+ 0]), (__m256bh) _mm256_loadu_si256(&y[j+ 0]));
|
||||
accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256(&x[j+16]), (__m256bh) _mm256_loadu_si256(&y[j+16]));
|
||||
}
|
||||
accum256 = _mm256_add_ps(accum256, accum256_1);
|
||||
|
||||
/* Processing the remaining <32 chunk with 16-elements processing */
|
||||
if ((n&16) != 0) {
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[tail_index_32]), (__m256bh) _mm256_loadu_si256(&y[tail_index_32]));
|
||||
}
|
||||
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
|
||||
|
||||
/* Processing the remaining <16 chunk with 8-elements processing */
|
||||
if ((n&8) != 0) {
|
||||
int tail_index_16 = n&(~15);
|
||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16]));
|
||||
}
|
||||
|
||||
/* Processing the remaining <8 chunk with masked 8-elements processing */
|
||||
if ((n&7) != 0) {
|
||||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
|
||||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
|
||||
int tail_index_8 = n&(~7);
|
||||
accum128 = _mm_dpbf16_ps(accum128,
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]),
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8]));
|
||||
}
|
||||
} else if (n > 15) { /* n range from 16 to 31 */
|
||||
/* Processing <32 chunk with 16-elements processing */
|
||||
__m256 accum256 = _mm256_setzero_ps();
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[0]), (__m256bh) _mm256_loadu_si256(&y[0]));
|
||||
accum128 += _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
|
||||
|
||||
/* Processing the remaining <16 chunk with 8-elements processing */
|
||||
if ((n&8) != 0) {
|
||||
int tail_index_16 = n&(~15);
|
||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16]));
|
||||
}
|
||||
|
||||
/* Processing the remaining <8 chunk with masked 8-elements processing */
|
||||
if ((n&7) != 0) {
|
||||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
|
||||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
|
||||
int tail_index_8 = n&(~7);
|
||||
accum128 = _mm_dpbf16_ps(accum128,
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]),
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8]));
|
||||
}
|
||||
} else if (n > 7) { /* n range from 8 to 15 */
|
||||
/* Processing <16 chunk with 8-elements processing */
|
||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[0]), (__m128bh) _mm_loadu_si128(&y[0]));
|
||||
|
||||
/* Processing the remaining <8 chunk with masked 8-elements processing */
|
||||
if ((n&7) != 0) {
|
||||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
|
||||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
|
||||
int tail_index_8 = n&(~7);
|
||||
accum128 = _mm_dpbf16_ps(accum128,
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]),
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8]));
|
||||
}
|
||||
} else { /* n range from 1 to 7 */
|
||||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
|
||||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
|
||||
accum128 = _mm_dpbf16_ps(accum128,
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[0]),
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[0]));
|
||||
}
|
||||
|
||||
/* Add up the 4 elements into lowest entry */
|
||||
__m128 accum128_1 = _mm_shuffle_ps(accum128, accum128, 14);
|
||||
accum128 = _mm_add_ps(accum128, accum128_1);
|
||||
accum128_1 = _mm_shuffle_ps(accum128, accum128, 1);
|
||||
accum128 = _mm_add_ps(accum128, accum128_1);
|
||||
|
||||
return accum128[0];
|
||||
}
|
||||
|
||||
#endif
|
||||
86
kernel/x86_64/stobf16_microk_cooperlake.c
Normal file
86
kernel/x86_64/stobf16_microk_cooperlake.c
Normal file
@@ -0,0 +1,86 @@
|
||||
/***************************************************************************
|
||||
Copyright (c) 2014, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
/* need a new enough GCC for avx512 support */
|
||||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
|
||||
|
||||
#define HAVE_TOBF16_ACCL_KERNEL 1
|
||||
#include "common.h"
|
||||
#include <immintrin.h>
|
||||
|
||||
static void tobf16_accl_kernel(BLASLONG n, const float * in, bfloat16 * out)
|
||||
{
|
||||
/* Get the 64-bytes unaligned header number targeting for avx512
|
||||
* processing (Assume input float array is natural aligned) */
|
||||
int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 2) & 0xf;
|
||||
|
||||
if (n < align_header) {align_header = n;}
|
||||
|
||||
if (align_header != 0) {
|
||||
uint16_t align_mask16 = (((uint16_t)0xffff) >> (16-align_header));
|
||||
__m512 a = _mm512_maskz_loadu_ps(*((__mmask16*) &align_mask16), &in[0]);
|
||||
_mm256_mask_storeu_epi16(&out[0], *((__mmask16*) &align_mask16), (__m256i) _mm512_cvtneps_pbh(a));
|
||||
}
|
||||
|
||||
if (n == align_header) {
|
||||
return;
|
||||
} else {
|
||||
n -= align_header;
|
||||
in += align_header;
|
||||
out += align_header;
|
||||
}
|
||||
|
||||
int tail_index_32 = n&(~31);
|
||||
int tail_index_128 = n&(~127);
|
||||
uint32_t tail_mask32 = (((uint32_t) 0xffffffff) >> (32-(n&31)));
|
||||
uint16_t tail_mask16 = (((uint16_t) 0xffff) >> (16-(n&15)));
|
||||
|
||||
/* Processing the main chunk with 128-elements per round */
|
||||
for (int i = 0; i < tail_index_128; i += 128) {
|
||||
_mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 16]), _mm512_load_ps(&in[i+ 0])));
|
||||
_mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 48]), _mm512_load_ps(&in[i+32])));
|
||||
_mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 80]), _mm512_load_ps(&in[i+64])));
|
||||
_mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+112]), _mm512_load_ps(&in[i+96])));
|
||||
}
|
||||
|
||||
/* Processing the remaining <128 chunk with 32-elements per round */
|
||||
for (int j = tail_index_128; j < tail_index_32; j += 32) {
|
||||
_mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[j+ 16]), _mm512_load_ps(&in[j])));
|
||||
}
|
||||
|
||||
/* Processing the remaining <32 chunk with masked processing */
|
||||
if ((n&31) > 15) {
|
||||
__m512 b = _mm512_load_ps(&in[tail_index_32]);
|
||||
__m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32+16]);
|
||||
_mm512_mask_storeu_epi16(&out[tail_index_32], *((__mmask32*) &tail_mask32), (__m512i) _mm512_cvtne2ps_pbh(a, b));
|
||||
} else if ((n&31) > 0) {
|
||||
__m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32]);
|
||||
_mm256_mask_storeu_epi16(&out[tail_index_32], *((__mmask16*) &tail_mask16), (__m256i) _mm512_cvtneps_pbh(a));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
170
kernel/x86_64/tobf16.c
Normal file
170
kernel/x86_64/tobf16.c
Normal file
@@ -0,0 +1,170 @@
|
||||
/***************************************************************************
|
||||
Copyright (c) 2014, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include <stddef.h>
|
||||
#include "common.h"
|
||||
|
||||
#if defined(DOUBLE)
|
||||
#define FLOAT_TYPE double
|
||||
#elif defined(SINGLE)
|
||||
#define FLOAT_TYPE float
|
||||
#else
|
||||
#endif
|
||||
|
||||
#if defined(COOPERLAKE)
|
||||
#if defined(DOUBLE)
|
||||
#include "dtobf16_microk_cooperlake.c"
|
||||
#elif defined(SINGLE)
|
||||
#include "stobf16_microk_cooperlake.c"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/* Notes for algorithm:
|
||||
* - Round to Nearest Even used generally
|
||||
* - QNAN for NAN case
|
||||
* - Input denormals are treated as zero
|
||||
*/
|
||||
static void tobf16_generic_kernel(BLASLONG n, const FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out)
|
||||
{
|
||||
BLASLONG register index_in = 0;
|
||||
BLASLONG register index_out = 0;
|
||||
BLASLONG register index = 0;
|
||||
float float_in = 0.0;
|
||||
uint32_t * uint32_in = (uint32_t *)(&float_in);
|
||||
uint16_t * uint16_in = (uint16_t *)(&float_in);
|
||||
|
||||
while(index<n) {
|
||||
#if defined(DOUBLE)
|
||||
float_in = (float)(*(in+index_in));
|
||||
#else
|
||||
float_in = *(in+index_in);
|
||||
#endif
|
||||
|
||||
switch((*uint32_in) & 0xff800000u) {
|
||||
case (0x00000000u): /* Type 1: Positive denormal */
|
||||
*(out+index_out) = 0x0000u;
|
||||
break;
|
||||
case (0x80000000u): /* Type 2: Negative denormal */
|
||||
*(out+index_out) = 0x8000u;
|
||||
break;
|
||||
case (0x7f800000u): /* Type 3: Positive infinity or NAN */
|
||||
case (0xff800000u): /* Type 4: Negative infinity or NAN */
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
*(out+index_out) = uint16_in[1];
|
||||
#else
|
||||
*(out+index_out) = uint16_in[0];
|
||||
#endif
|
||||
/* Specific for NAN */
|
||||
if (((*uint32_in) & 0x007fffffu) != 0) {
|
||||
/* Force to be QNAN */
|
||||
*(out+index_out) |= 0x0040u;
|
||||
}
|
||||
break;
|
||||
default: /* Type 5: Normal case */
|
||||
(*uint32_in) += ((((*uint32_in) >> 16) & 0x1u) + 0x7fffu);
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
*(out+index_out) = uint16_in[1];
|
||||
#else
|
||||
*(out+index_out) = uint16_in[0];
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
|
||||
index_in += inc_in;
|
||||
index_out += inc_out;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef HAVE_TOBF16_ACCL_KERNEL
|
||||
static void tobf16_accl_kernel(BLASLONG n, const FLOAT_TYPE * in, bfloat16 * out)
|
||||
{
|
||||
tobf16_generic_kernel(n, in, 1, out, 1);
|
||||
}
|
||||
#endif
|
||||
|
||||
static void tobf16_compute(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out)
|
||||
{
|
||||
if ((inc_in == 1) && (inc_out == 1)) {
|
||||
tobf16_accl_kernel(n, in, out);
|
||||
} else {
|
||||
tobf16_generic_kernel(n, in, inc_in, out, inc_out);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(SMP)
|
||||
static int tobf16_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT_TYPE dummy2,
|
||||
FLOAT_TYPE *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y,
|
||||
FLOAT_TYPE *dummy3, BLASLONG dummy4)
|
||||
{
|
||||
tobf16_compute(n, x, inc_x, y, inc_y);
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha,
|
||||
void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc,
|
||||
int (*function)(), int nthreads);
|
||||
#endif
|
||||
|
||||
void CNAME(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out)
|
||||
{
|
||||
if (n <= 0) return;
|
||||
|
||||
#if defined(SMP)
|
||||
int nthreads;
|
||||
FLOAT_TYPE dummy_alpha;
|
||||
FLOAT_TYPE dummy_c;
|
||||
#endif
|
||||
|
||||
#if defined(SMP)
|
||||
if (inc_in == 0 || inc_out == 0 || n <= 100000) {
|
||||
nthreads = 1;
|
||||
} else {
|
||||
if (n/100000 < 100) {
|
||||
nthreads = 4;
|
||||
} else {
|
||||
nthreads = 16;
|
||||
}
|
||||
}
|
||||
|
||||
if (nthreads == 1) {
|
||||
tobf16_compute(n, in, inc_in, out, inc_out);
|
||||
} else {
|
||||
#if defined(DOUBLE)
|
||||
int mode = BLAS_REAL | BLAS_DTOBF16;
|
||||
#elif defined(SINGLE)
|
||||
int mode = BLAS_REAL | BLAS_STOBF16;
|
||||
#endif
|
||||
blas_level1_thread(mode, n, 0, 0, &dummy_alpha,
|
||||
in, inc_in, out, inc_out, &dummy_c, 0,
|
||||
(void *)tobf16_thread_func, nthreads);
|
||||
}
|
||||
#else
|
||||
tobf16_compute(n, in, inc_in, out, inc_out);
|
||||
#endif
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user